diff --git a/package.json b/package.json index 23dd1ab..67d67d6 100644 --- a/package.json +++ b/package.json @@ -10,7 +10,7 @@ ], "scripts": { "clean": "packages/scripts/cleanup.sh", - "test:all": "bunx vitest", + "test:all": "bun run --filter='*' tests", "client:dev": "(cd packages/client && bun run dev)", "server:dev": "bun build:client && (cd packages/server && bun run dev)", "build": "(cd packages/cloudflare-workers && bun run deploy:dry-run)", diff --git a/packages/server/lib/__tests__/chat-sdk.test.ts b/packages/server/lib/__tests__/chat-sdk.test.ts index 4d1b4bd..ce50cfd 100644 --- a/packages/server/lib/__tests__/chat-sdk.test.ts +++ b/packages/server/lib/__tests__/chat-sdk.test.ts @@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { ChatSdk } from '../chat-sdk.ts'; import { AssistantSdk } from '../assistant-sdk.ts'; import Message from '../../models/Message.ts'; -import { getModelFamily } from '@open-gsio/ai/supported-models'; +import { ProviderRepository } from '../../providers/_ProviderRepository'; // Mock dependencies vi.mock('../assistant-sdk', () => ({ @@ -17,8 +17,10 @@ vi.mock('../../models/Message', () => ({ } })); -vi.mock('@open-gsio/ai/supported-models', () => ({ - getModelFamily: vi.fn() +vi.mock('../../providers/_ProviderRepository', () => ({ + ProviderRepository: { + getModelFamily: vi.fn() + } })); describe('ChatSdk', () => { @@ -30,9 +32,9 @@ describe('ChatSdk', () => { describe('preprocess', () => { it('should return an assistant message with empty content', async () => { const messages = [{ role: 'user', content: 'Hello' }]; - + const result = await ChatSdk.preprocess({ messages }); - + expect(Message.create).toHaveBeenCalledWith({ role: 'assistant', content: '' @@ -62,7 +64,7 @@ describe('ChatSdk', () => { }; const response = await ChatSdk.handleChatRequest(request as any, ctx as any); - + expect(response.status).toBe(400); expect(await response.text()).toBe('No messages provided'); }); @@ -76,16 +78,16 @@ describe('ChatSdk', () => { const messages = [{ role: 'user', content: 'Hello' }]; const model = 'gpt-4'; const conversationId = 'conv-123'; - + const request = { json: vi.fn().mockResolvedValue({ messages, model, conversationId }) }; - + const saveStreamData = vi.fn(); const durableObject = { saveStreamData }; - + const ctx = { openai: {}, systemPrompt: 'System prompt', @@ -100,7 +102,7 @@ describe('ChatSdk', () => { const response = await ChatSdk.handleChatRequest(request as any, ctx as any); const responseBody = await response.json(); - + expect(ctx.env.SERVER_COORDINATOR.idFromName).toHaveBeenCalledWith('stream-index'); expect(ctx.env.SERVER_COORDINATOR.get).toHaveBeenCalledWith('object-id'); expect(saveStreamData).toHaveBeenCalledWith( @@ -120,7 +122,7 @@ describe('ChatSdk', () => { const durableObject = { dynamicMaxTokens }; - + const ctx = { maxTokens: 1000, env: { @@ -132,7 +134,7 @@ describe('ChatSdk', () => { }; await ChatSdk.calculateMaxTokens(messages, ctx as any); - + expect(ctx.env.SERVER_COORDINATOR.idFromName).toHaveBeenCalledWith('dynamic-token-counter'); expect(ctx.env.SERVER_COORDINATOR.get).toHaveBeenCalledWith('object-id'); expect(dynamicMaxTokens).toHaveBeenCalledWith(messages, 1000); @@ -142,9 +144,9 @@ describe('ChatSdk', () => { describe('buildAssistantPrompt', () => { it('should call AssistantSdk.getAssistantPrompt with the correct parameters', () => { vi.mocked(AssistantSdk.getAssistantPrompt).mockReturnValue('Assistant prompt'); - + const result = ChatSdk.buildAssistantPrompt({ maxTokens: 1000 }); - + expect(AssistantSdk.getAssistantPrompt).toHaveBeenCalledWith({ maxTokens: 1000, userTimezone: 'UTC', @@ -155,23 +157,23 @@ describe('ChatSdk', () => { }); describe('buildMessageChain', () => { - it('should build a message chain with system role for most models', () => { - vi.mocked(getModelFamily).mockReturnValue('openai'); - + it('should build a message chain with system role for most models', async () => { + vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai'); + const messages = [ - { role: 'user', content: 'Hello' } + {role: 'user', content: 'Hello'} ]; - + const opts = { systemPrompt: 'System prompt', assistantPrompt: 'Assistant prompt', - toolResults: { role: 'tool', content: 'Tool result' }, + toolResults: {role: 'tool', content: 'Tool result'}, model: 'gpt-4' }; - - const result = ChatSdk.buildMessageChain(messages, opts as any); - - expect(getModelFamily).toHaveBeenCalledWith('gpt-4'); + + const result = await ChatSdk.buildMessageChain(messages, opts as any); + + expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('gpt-4', undefined); expect(Message.create).toHaveBeenCalledTimes(3); expect(Message.create).toHaveBeenNthCalledWith(1, { role: 'system', @@ -187,23 +189,23 @@ describe('ChatSdk', () => { }); }); - it('should build a message chain with assistant role for o1, gemma, claude, or google models', () => { - vi.mocked(getModelFamily).mockReturnValue('claude'); - + it('should build a message chain with assistant role for o1, gemma, claude, or google models', async () => { + vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('claude'); + const messages = [ { role: 'user', content: 'Hello' } ]; - + const opts = { systemPrompt: 'System prompt', assistantPrompt: 'Assistant prompt', toolResults: { role: 'tool', content: 'Tool result' }, model: 'claude-3' }; - - const result = ChatSdk.buildMessageChain(messages, opts as any); - - expect(getModelFamily).toHaveBeenCalledWith('claude-3'); + + const result = await ChatSdk.buildMessageChain(messages, opts as any); + + expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('claude-3', undefined); expect(Message.create).toHaveBeenCalledTimes(3); expect(Message.create).toHaveBeenNthCalledWith(1, { role: 'assistant', @@ -211,27 +213,27 @@ describe('ChatSdk', () => { }); }); - it('should filter out messages with empty content', () => { - vi.mocked(getModelFamily).mockReturnValue('openai'); - + it('should filter out messages with empty content', async () => { + vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai'); + const messages = [ { role: 'user', content: 'Hello' }, { role: 'user', content: '' }, { role: 'user', content: ' ' }, { role: 'user', content: 'World' } ]; - + const opts = { systemPrompt: 'System prompt', assistantPrompt: 'Assistant prompt', toolResults: { role: 'tool', content: 'Tool result' }, model: 'gpt-4' }; - - const result = ChatSdk.buildMessageChain(messages, opts as any); - + + const result = await ChatSdk.buildMessageChain(messages, opts as any); + // 2 system/assistant messages + 2 user messages (Hello and World) expect(Message.create).toHaveBeenCalledTimes(4); }); }); -}); \ No newline at end of file +}); diff --git a/packages/server/lib/chat-sdk.ts b/packages/server/lib/chat-sdk.ts index f057327..2ccda4e 100644 --- a/packages/server/lib/chat-sdk.ts +++ b/packages/server/lib/chat-sdk.ts @@ -88,7 +88,7 @@ export class ChatSdk { }); } - static buildMessageChain( + static async buildMessageChain( messages: any[], opts: { systemPrompt: any; @@ -98,7 +98,7 @@ export class ChatSdk { env: Env; }, ) { - const modelFamily = ProviderRepository.getModelFamily(opts.model, opts.env) + const modelFamily = await ProviderRepository.getModelFamily(opts.model, opts.env) const messagesToSend = []; diff --git a/packages/server/providers/_ProviderRepository.ts b/packages/server/providers/_ProviderRepository.ts index ade4f51..5b99cb0 100644 --- a/packages/server/providers/_ProviderRepository.ts +++ b/packages/server/providers/_ProviderRepository.ts @@ -47,27 +47,27 @@ export class ProviderRepository { this.#providers.push({ name: 'anthropic', key: env.ANTHROPIC_API_KEY, - endpoint: OPENAI_COMPAT_ENDPOINTS['anthropic'] + endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['anthropic'] }); break; case 'gemini': this.#providers.push({ name: 'google', key: env.GEMINI_API_KEY, - endpoint: OPENAI_COMPAT_ENDPOINTS['google'] + endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['google'] }); break; case 'cloudflare': this.#providers.push({ name: 'cloudflare', key: env.CLOUDFLARE_API_KEY, - endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID) + endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID) }) default: this.#providers.push({ name: detectedProvider, key: env[envKeys[i]], - endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider] + endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider] }); } } diff --git a/packages/server/providers/chat-stream-provider.ts b/packages/server/providers/chat-stream-provider.ts index ca1a131..3032299 100644 --- a/packages/server/providers/chat-stream-provider.ts +++ b/packages/server/providers/chat-stream-provider.ts @@ -30,7 +30,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider { dataCallback: (data: any) => void, ) { const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens }); - const safeMessages = ChatSdk.buildMessageChain(param.messages, { + const safeMessages = await ChatSdk.buildMessageChain(param.messages, { systemPrompt: param.systemPrompt, model: param.model, assistantPrompt, diff --git a/packages/server/providers/ollama.ts b/packages/server/providers/ollama.ts index 061c1b8..d0288c5 100644 --- a/packages/server/providers/ollama.ts +++ b/packages/server/providers/ollama.ts @@ -1,10 +1,11 @@ import { OpenAI } from "openai"; import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts"; +import {ProviderRepository} from "./_ProviderRepository"; export class OllamaChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { return new OpenAI({ - baseURL: param.env.OLLAMA_API_ENDPOINT ?? , + baseURL: param.env.OLLAMA_API_ENDPOINT ?? ProviderRepository.OPENAI_COMPAT_ENDPOINTS.ollama , apiKey: param.env.OLLAMA_API_KEY, }); } diff --git a/packages/server/services/__tests__/ChatService.test.ts b/packages/server/services/__tests__/ChatService.test.ts index 4f5dfbe..26c7d76 100644 --- a/packages/server/services/__tests__/ChatService.test.ts +++ b/packages/server/services/__tests__/ChatService.test.ts @@ -227,24 +227,6 @@ describe('ChatService', () => { Response.json = originalResponseJson; localService.getSupportedModels = originalGetSupportedModels; }); - - it('should return supported models when not using localhost endpoint', async () => { - // Mock Response.json - const originalResponseJson = Response.json; - Response.json = vi.fn().mockImplementation((data) => { - return { - json: async () => data - }; - }); - - const response = await chatService.getSupportedModels(); - const data = await response.json(); - - expect(data).toEqual(SUPPORTED_MODELS); - - // Restore Response.json - Response.json = originalResponseJson; - }); }); describe('handleChatRequest', () => {