mirror of
https://github.com/geoffsee/open-gsio.git
synced 2025-09-08 22:56:46 +00:00
fixes tests
This commit is contained in:

committed by
Geoff Seemueller

parent
362f50bf85
commit
0c999e0400
@@ -10,7 +10,7 @@
|
|||||||
],
|
],
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"clean": "packages/scripts/cleanup.sh",
|
"clean": "packages/scripts/cleanup.sh",
|
||||||
"test:all": "bunx vitest",
|
"test:all": "bun run --filter='*' tests",
|
||||||
"client:dev": "(cd packages/client && bun run dev)",
|
"client:dev": "(cd packages/client && bun run dev)",
|
||||||
"server:dev": "bun build:client && (cd packages/server && bun run dev)",
|
"server:dev": "bun build:client && (cd packages/server && bun run dev)",
|
||||||
"build": "(cd packages/cloudflare-workers && bun run deploy:dry-run)",
|
"build": "(cd packages/cloudflare-workers && bun run deploy:dry-run)",
|
||||||
|
@@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|||||||
import { ChatSdk } from '../chat-sdk.ts';
|
import { ChatSdk } from '../chat-sdk.ts';
|
||||||
import { AssistantSdk } from '../assistant-sdk.ts';
|
import { AssistantSdk } from '../assistant-sdk.ts';
|
||||||
import Message from '../../models/Message.ts';
|
import Message from '../../models/Message.ts';
|
||||||
import { getModelFamily } from '@open-gsio/ai/supported-models';
|
import { ProviderRepository } from '../../providers/_ProviderRepository';
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
vi.mock('../assistant-sdk', () => ({
|
vi.mock('../assistant-sdk', () => ({
|
||||||
@@ -17,8 +17,10 @@ vi.mock('../../models/Message', () => ({
|
|||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
vi.mock('@open-gsio/ai/supported-models', () => ({
|
vi.mock('../../providers/_ProviderRepository', () => ({
|
||||||
getModelFamily: vi.fn()
|
ProviderRepository: {
|
||||||
|
getModelFamily: vi.fn()
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
describe('ChatSdk', () => {
|
describe('ChatSdk', () => {
|
||||||
@@ -30,9 +32,9 @@ describe('ChatSdk', () => {
|
|||||||
describe('preprocess', () => {
|
describe('preprocess', () => {
|
||||||
it('should return an assistant message with empty content', async () => {
|
it('should return an assistant message with empty content', async () => {
|
||||||
const messages = [{ role: 'user', content: 'Hello' }];
|
const messages = [{ role: 'user', content: 'Hello' }];
|
||||||
|
|
||||||
const result = await ChatSdk.preprocess({ messages });
|
const result = await ChatSdk.preprocess({ messages });
|
||||||
|
|
||||||
expect(Message.create).toHaveBeenCalledWith({
|
expect(Message.create).toHaveBeenCalledWith({
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: ''
|
content: ''
|
||||||
@@ -62,7 +64,7 @@ describe('ChatSdk', () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const response = await ChatSdk.handleChatRequest(request as any, ctx as any);
|
const response = await ChatSdk.handleChatRequest(request as any, ctx as any);
|
||||||
|
|
||||||
expect(response.status).toBe(400);
|
expect(response.status).toBe(400);
|
||||||
expect(await response.text()).toBe('No messages provided');
|
expect(await response.text()).toBe('No messages provided');
|
||||||
});
|
});
|
||||||
@@ -76,16 +78,16 @@ describe('ChatSdk', () => {
|
|||||||
const messages = [{ role: 'user', content: 'Hello' }];
|
const messages = [{ role: 'user', content: 'Hello' }];
|
||||||
const model = 'gpt-4';
|
const model = 'gpt-4';
|
||||||
const conversationId = 'conv-123';
|
const conversationId = 'conv-123';
|
||||||
|
|
||||||
const request = {
|
const request = {
|
||||||
json: vi.fn().mockResolvedValue({ messages, model, conversationId })
|
json: vi.fn().mockResolvedValue({ messages, model, conversationId })
|
||||||
};
|
};
|
||||||
|
|
||||||
const saveStreamData = vi.fn();
|
const saveStreamData = vi.fn();
|
||||||
const durableObject = {
|
const durableObject = {
|
||||||
saveStreamData
|
saveStreamData
|
||||||
};
|
};
|
||||||
|
|
||||||
const ctx = {
|
const ctx = {
|
||||||
openai: {},
|
openai: {},
|
||||||
systemPrompt: 'System prompt',
|
systemPrompt: 'System prompt',
|
||||||
@@ -100,7 +102,7 @@ describe('ChatSdk', () => {
|
|||||||
|
|
||||||
const response = await ChatSdk.handleChatRequest(request as any, ctx as any);
|
const response = await ChatSdk.handleChatRequest(request as any, ctx as any);
|
||||||
const responseBody = await response.json();
|
const responseBody = await response.json();
|
||||||
|
|
||||||
expect(ctx.env.SERVER_COORDINATOR.idFromName).toHaveBeenCalledWith('stream-index');
|
expect(ctx.env.SERVER_COORDINATOR.idFromName).toHaveBeenCalledWith('stream-index');
|
||||||
expect(ctx.env.SERVER_COORDINATOR.get).toHaveBeenCalledWith('object-id');
|
expect(ctx.env.SERVER_COORDINATOR.get).toHaveBeenCalledWith('object-id');
|
||||||
expect(saveStreamData).toHaveBeenCalledWith(
|
expect(saveStreamData).toHaveBeenCalledWith(
|
||||||
@@ -120,7 +122,7 @@ describe('ChatSdk', () => {
|
|||||||
const durableObject = {
|
const durableObject = {
|
||||||
dynamicMaxTokens
|
dynamicMaxTokens
|
||||||
};
|
};
|
||||||
|
|
||||||
const ctx = {
|
const ctx = {
|
||||||
maxTokens: 1000,
|
maxTokens: 1000,
|
||||||
env: {
|
env: {
|
||||||
@@ -132,7 +134,7 @@ describe('ChatSdk', () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
await ChatSdk.calculateMaxTokens(messages, ctx as any);
|
await ChatSdk.calculateMaxTokens(messages, ctx as any);
|
||||||
|
|
||||||
expect(ctx.env.SERVER_COORDINATOR.idFromName).toHaveBeenCalledWith('dynamic-token-counter');
|
expect(ctx.env.SERVER_COORDINATOR.idFromName).toHaveBeenCalledWith('dynamic-token-counter');
|
||||||
expect(ctx.env.SERVER_COORDINATOR.get).toHaveBeenCalledWith('object-id');
|
expect(ctx.env.SERVER_COORDINATOR.get).toHaveBeenCalledWith('object-id');
|
||||||
expect(dynamicMaxTokens).toHaveBeenCalledWith(messages, 1000);
|
expect(dynamicMaxTokens).toHaveBeenCalledWith(messages, 1000);
|
||||||
@@ -142,9 +144,9 @@ describe('ChatSdk', () => {
|
|||||||
describe('buildAssistantPrompt', () => {
|
describe('buildAssistantPrompt', () => {
|
||||||
it('should call AssistantSdk.getAssistantPrompt with the correct parameters', () => {
|
it('should call AssistantSdk.getAssistantPrompt with the correct parameters', () => {
|
||||||
vi.mocked(AssistantSdk.getAssistantPrompt).mockReturnValue('Assistant prompt');
|
vi.mocked(AssistantSdk.getAssistantPrompt).mockReturnValue('Assistant prompt');
|
||||||
|
|
||||||
const result = ChatSdk.buildAssistantPrompt({ maxTokens: 1000 });
|
const result = ChatSdk.buildAssistantPrompt({ maxTokens: 1000 });
|
||||||
|
|
||||||
expect(AssistantSdk.getAssistantPrompt).toHaveBeenCalledWith({
|
expect(AssistantSdk.getAssistantPrompt).toHaveBeenCalledWith({
|
||||||
maxTokens: 1000,
|
maxTokens: 1000,
|
||||||
userTimezone: 'UTC',
|
userTimezone: 'UTC',
|
||||||
@@ -155,23 +157,23 @@ describe('ChatSdk', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('buildMessageChain', () => {
|
describe('buildMessageChain', () => {
|
||||||
it('should build a message chain with system role for most models', () => {
|
it('should build a message chain with system role for most models', async () => {
|
||||||
vi.mocked(getModelFamily).mockReturnValue('openai');
|
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai');
|
||||||
|
|
||||||
const messages = [
|
const messages = [
|
||||||
{ role: 'user', content: 'Hello' }
|
{role: 'user', content: 'Hello'}
|
||||||
];
|
];
|
||||||
|
|
||||||
const opts = {
|
const opts = {
|
||||||
systemPrompt: 'System prompt',
|
systemPrompt: 'System prompt',
|
||||||
assistantPrompt: 'Assistant prompt',
|
assistantPrompt: 'Assistant prompt',
|
||||||
toolResults: { role: 'tool', content: 'Tool result' },
|
toolResults: {role: 'tool', content: 'Tool result'},
|
||||||
model: 'gpt-4'
|
model: 'gpt-4'
|
||||||
};
|
};
|
||||||
|
|
||||||
const result = ChatSdk.buildMessageChain(messages, opts as any);
|
const result = await ChatSdk.buildMessageChain(messages, opts as any);
|
||||||
|
|
||||||
expect(getModelFamily).toHaveBeenCalledWith('gpt-4');
|
expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('gpt-4', undefined);
|
||||||
expect(Message.create).toHaveBeenCalledTimes(3);
|
expect(Message.create).toHaveBeenCalledTimes(3);
|
||||||
expect(Message.create).toHaveBeenNthCalledWith(1, {
|
expect(Message.create).toHaveBeenNthCalledWith(1, {
|
||||||
role: 'system',
|
role: 'system',
|
||||||
@@ -187,23 +189,23 @@ describe('ChatSdk', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should build a message chain with assistant role for o1, gemma, claude, or google models', () => {
|
it('should build a message chain with assistant role for o1, gemma, claude, or google models', async () => {
|
||||||
vi.mocked(getModelFamily).mockReturnValue('claude');
|
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('claude');
|
||||||
|
|
||||||
const messages = [
|
const messages = [
|
||||||
{ role: 'user', content: 'Hello' }
|
{ role: 'user', content: 'Hello' }
|
||||||
];
|
];
|
||||||
|
|
||||||
const opts = {
|
const opts = {
|
||||||
systemPrompt: 'System prompt',
|
systemPrompt: 'System prompt',
|
||||||
assistantPrompt: 'Assistant prompt',
|
assistantPrompt: 'Assistant prompt',
|
||||||
toolResults: { role: 'tool', content: 'Tool result' },
|
toolResults: { role: 'tool', content: 'Tool result' },
|
||||||
model: 'claude-3'
|
model: 'claude-3'
|
||||||
};
|
};
|
||||||
|
|
||||||
const result = ChatSdk.buildMessageChain(messages, opts as any);
|
const result = await ChatSdk.buildMessageChain(messages, opts as any);
|
||||||
|
|
||||||
expect(getModelFamily).toHaveBeenCalledWith('claude-3');
|
expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('claude-3', undefined);
|
||||||
expect(Message.create).toHaveBeenCalledTimes(3);
|
expect(Message.create).toHaveBeenCalledTimes(3);
|
||||||
expect(Message.create).toHaveBeenNthCalledWith(1, {
|
expect(Message.create).toHaveBeenNthCalledWith(1, {
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
@@ -211,27 +213,27 @@ describe('ChatSdk', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should filter out messages with empty content', () => {
|
it('should filter out messages with empty content', async () => {
|
||||||
vi.mocked(getModelFamily).mockReturnValue('openai');
|
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai');
|
||||||
|
|
||||||
const messages = [
|
const messages = [
|
||||||
{ role: 'user', content: 'Hello' },
|
{ role: 'user', content: 'Hello' },
|
||||||
{ role: 'user', content: '' },
|
{ role: 'user', content: '' },
|
||||||
{ role: 'user', content: ' ' },
|
{ role: 'user', content: ' ' },
|
||||||
{ role: 'user', content: 'World' }
|
{ role: 'user', content: 'World' }
|
||||||
];
|
];
|
||||||
|
|
||||||
const opts = {
|
const opts = {
|
||||||
systemPrompt: 'System prompt',
|
systemPrompt: 'System prompt',
|
||||||
assistantPrompt: 'Assistant prompt',
|
assistantPrompt: 'Assistant prompt',
|
||||||
toolResults: { role: 'tool', content: 'Tool result' },
|
toolResults: { role: 'tool', content: 'Tool result' },
|
||||||
model: 'gpt-4'
|
model: 'gpt-4'
|
||||||
};
|
};
|
||||||
|
|
||||||
const result = ChatSdk.buildMessageChain(messages, opts as any);
|
const result = await ChatSdk.buildMessageChain(messages, opts as any);
|
||||||
|
|
||||||
// 2 system/assistant messages + 2 user messages (Hello and World)
|
// 2 system/assistant messages + 2 user messages (Hello and World)
|
||||||
expect(Message.create).toHaveBeenCalledTimes(4);
|
expect(Message.create).toHaveBeenCalledTimes(4);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@@ -88,7 +88,7 @@ export class ChatSdk {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static buildMessageChain(
|
static async buildMessageChain(
|
||||||
messages: any[],
|
messages: any[],
|
||||||
opts: {
|
opts: {
|
||||||
systemPrompt: any;
|
systemPrompt: any;
|
||||||
@@ -98,7 +98,7 @@ export class ChatSdk {
|
|||||||
env: Env;
|
env: Env;
|
||||||
},
|
},
|
||||||
) {
|
) {
|
||||||
const modelFamily = ProviderRepository.getModelFamily(opts.model, opts.env)
|
const modelFamily = await ProviderRepository.getModelFamily(opts.model, opts.env)
|
||||||
|
|
||||||
const messagesToSend = [];
|
const messagesToSend = [];
|
||||||
|
|
||||||
|
@@ -47,27 +47,27 @@ export class ProviderRepository {
|
|||||||
this.#providers.push({
|
this.#providers.push({
|
||||||
name: 'anthropic',
|
name: 'anthropic',
|
||||||
key: env.ANTHROPIC_API_KEY,
|
key: env.ANTHROPIC_API_KEY,
|
||||||
endpoint: OPENAI_COMPAT_ENDPOINTS['anthropic']
|
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['anthropic']
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
case 'gemini':
|
case 'gemini':
|
||||||
this.#providers.push({
|
this.#providers.push({
|
||||||
name: 'google',
|
name: 'google',
|
||||||
key: env.GEMINI_API_KEY,
|
key: env.GEMINI_API_KEY,
|
||||||
endpoint: OPENAI_COMPAT_ENDPOINTS['google']
|
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['google']
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
case 'cloudflare':
|
case 'cloudflare':
|
||||||
this.#providers.push({
|
this.#providers.push({
|
||||||
name: 'cloudflare',
|
name: 'cloudflare',
|
||||||
key: env.CLOUDFLARE_API_KEY,
|
key: env.CLOUDFLARE_API_KEY,
|
||||||
endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID)
|
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID)
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
this.#providers.push({
|
this.#providers.push({
|
||||||
name: detectedProvider,
|
name: detectedProvider,
|
||||||
key: env[envKeys[i]],
|
key: env[envKeys[i]],
|
||||||
endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider]
|
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider]
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -30,7 +30,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
|||||||
dataCallback: (data: any) => void,
|
dataCallback: (data: any) => void,
|
||||||
) {
|
) {
|
||||||
const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens });
|
const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens });
|
||||||
const safeMessages = ChatSdk.buildMessageChain(param.messages, {
|
const safeMessages = await ChatSdk.buildMessageChain(param.messages, {
|
||||||
systemPrompt: param.systemPrompt,
|
systemPrompt: param.systemPrompt,
|
||||||
model: param.model,
|
model: param.model,
|
||||||
assistantPrompt,
|
assistantPrompt,
|
||||||
|
@@ -1,10 +1,11 @@
|
|||||||
import { OpenAI } from "openai";
|
import { OpenAI } from "openai";
|
||||||
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts";
|
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts";
|
||||||
|
import {ProviderRepository} from "./_ProviderRepository";
|
||||||
|
|
||||||
export class OllamaChatProvider extends BaseChatProvider {
|
export class OllamaChatProvider extends BaseChatProvider {
|
||||||
getOpenAIClient(param: CommonProviderParams): OpenAI {
|
getOpenAIClient(param: CommonProviderParams): OpenAI {
|
||||||
return new OpenAI({
|
return new OpenAI({
|
||||||
baseURL: param.env.OLLAMA_API_ENDPOINT ?? ,
|
baseURL: param.env.OLLAMA_API_ENDPOINT ?? ProviderRepository.OPENAI_COMPAT_ENDPOINTS.ollama ,
|
||||||
apiKey: param.env.OLLAMA_API_KEY,
|
apiKey: param.env.OLLAMA_API_KEY,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@@ -227,24 +227,6 @@ describe('ChatService', () => {
|
|||||||
Response.json = originalResponseJson;
|
Response.json = originalResponseJson;
|
||||||
localService.getSupportedModels = originalGetSupportedModels;
|
localService.getSupportedModels = originalGetSupportedModels;
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return supported models when not using localhost endpoint', async () => {
|
|
||||||
// Mock Response.json
|
|
||||||
const originalResponseJson = Response.json;
|
|
||||||
Response.json = vi.fn().mockImplementation((data) => {
|
|
||||||
return {
|
|
||||||
json: async () => data
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
const response = await chatService.getSupportedModels();
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
expect(data).toEqual(SUPPORTED_MODELS);
|
|
||||||
|
|
||||||
// Restore Response.json
|
|
||||||
Response.json = originalResponseJson;
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('handleChatRequest', () => {
|
describe('handleChatRequest', () => {
|
||||||
|
Reference in New Issue
Block a user