fixes tests

This commit is contained in:
geoffsee
2025-06-09 23:17:00 -04:00
committed by Geoff Seemueller
parent 362f50bf85
commit 0c999e0400
7 changed files with 52 additions and 67 deletions

View File

@@ -10,7 +10,7 @@
], ],
"scripts": { "scripts": {
"clean": "packages/scripts/cleanup.sh", "clean": "packages/scripts/cleanup.sh",
"test:all": "bunx vitest", "test:all": "bun run --filter='*' tests",
"client:dev": "(cd packages/client && bun run dev)", "client:dev": "(cd packages/client && bun run dev)",
"server:dev": "bun build:client && (cd packages/server && bun run dev)", "server:dev": "bun build:client && (cd packages/server && bun run dev)",
"build": "(cd packages/cloudflare-workers && bun run deploy:dry-run)", "build": "(cd packages/cloudflare-workers && bun run deploy:dry-run)",

View File

@@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ChatSdk } from '../chat-sdk.ts'; import { ChatSdk } from '../chat-sdk.ts';
import { AssistantSdk } from '../assistant-sdk.ts'; import { AssistantSdk } from '../assistant-sdk.ts';
import Message from '../../models/Message.ts'; import Message from '../../models/Message.ts';
import { getModelFamily } from '@open-gsio/ai/supported-models'; import { ProviderRepository } from '../../providers/_ProviderRepository';
// Mock dependencies // Mock dependencies
vi.mock('../assistant-sdk', () => ({ vi.mock('../assistant-sdk', () => ({
@@ -17,8 +17,10 @@ vi.mock('../../models/Message', () => ({
} }
})); }));
vi.mock('@open-gsio/ai/supported-models', () => ({ vi.mock('../../providers/_ProviderRepository', () => ({
ProviderRepository: {
getModelFamily: vi.fn() getModelFamily: vi.fn()
}
})); }));
describe('ChatSdk', () => { describe('ChatSdk', () => {
@@ -155,23 +157,23 @@ describe('ChatSdk', () => {
}); });
describe('buildMessageChain', () => { describe('buildMessageChain', () => {
it('should build a message chain with system role for most models', () => { it('should build a message chain with system role for most models', async () => {
vi.mocked(getModelFamily).mockReturnValue('openai'); vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai');
const messages = [ const messages = [
{ role: 'user', content: 'Hello' } {role: 'user', content: 'Hello'}
]; ];
const opts = { const opts = {
systemPrompt: 'System prompt', systemPrompt: 'System prompt',
assistantPrompt: 'Assistant prompt', assistantPrompt: 'Assistant prompt',
toolResults: { role: 'tool', content: 'Tool result' }, toolResults: {role: 'tool', content: 'Tool result'},
model: 'gpt-4' model: 'gpt-4'
}; };
const result = ChatSdk.buildMessageChain(messages, opts as any); const result = await ChatSdk.buildMessageChain(messages, opts as any);
expect(getModelFamily).toHaveBeenCalledWith('gpt-4'); expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('gpt-4', undefined);
expect(Message.create).toHaveBeenCalledTimes(3); expect(Message.create).toHaveBeenCalledTimes(3);
expect(Message.create).toHaveBeenNthCalledWith(1, { expect(Message.create).toHaveBeenNthCalledWith(1, {
role: 'system', role: 'system',
@@ -187,8 +189,8 @@ describe('ChatSdk', () => {
}); });
}); });
it('should build a message chain with assistant role for o1, gemma, claude, or google models', () => { it('should build a message chain with assistant role for o1, gemma, claude, or google models', async () => {
vi.mocked(getModelFamily).mockReturnValue('claude'); vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('claude');
const messages = [ const messages = [
{ role: 'user', content: 'Hello' } { role: 'user', content: 'Hello' }
@@ -201,9 +203,9 @@ describe('ChatSdk', () => {
model: 'claude-3' model: 'claude-3'
}; };
const result = ChatSdk.buildMessageChain(messages, opts as any); const result = await ChatSdk.buildMessageChain(messages, opts as any);
expect(getModelFamily).toHaveBeenCalledWith('claude-3'); expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('claude-3', undefined);
expect(Message.create).toHaveBeenCalledTimes(3); expect(Message.create).toHaveBeenCalledTimes(3);
expect(Message.create).toHaveBeenNthCalledWith(1, { expect(Message.create).toHaveBeenNthCalledWith(1, {
role: 'assistant', role: 'assistant',
@@ -211,8 +213,8 @@ describe('ChatSdk', () => {
}); });
}); });
it('should filter out messages with empty content', () => { it('should filter out messages with empty content', async () => {
vi.mocked(getModelFamily).mockReturnValue('openai'); vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai');
const messages = [ const messages = [
{ role: 'user', content: 'Hello' }, { role: 'user', content: 'Hello' },
@@ -228,7 +230,7 @@ describe('ChatSdk', () => {
model: 'gpt-4' model: 'gpt-4'
}; };
const result = ChatSdk.buildMessageChain(messages, opts as any); const result = await ChatSdk.buildMessageChain(messages, opts as any);
// 2 system/assistant messages + 2 user messages (Hello and World) // 2 system/assistant messages + 2 user messages (Hello and World)
expect(Message.create).toHaveBeenCalledTimes(4); expect(Message.create).toHaveBeenCalledTimes(4);

View File

@@ -88,7 +88,7 @@ export class ChatSdk {
}); });
} }
static buildMessageChain( static async buildMessageChain(
messages: any[], messages: any[],
opts: { opts: {
systemPrompt: any; systemPrompt: any;
@@ -98,7 +98,7 @@ export class ChatSdk {
env: Env; env: Env;
}, },
) { ) {
const modelFamily = ProviderRepository.getModelFamily(opts.model, opts.env) const modelFamily = await ProviderRepository.getModelFamily(opts.model, opts.env)
const messagesToSend = []; const messagesToSend = [];

View File

@@ -47,27 +47,27 @@ export class ProviderRepository {
this.#providers.push({ this.#providers.push({
name: 'anthropic', name: 'anthropic',
key: env.ANTHROPIC_API_KEY, key: env.ANTHROPIC_API_KEY,
endpoint: OPENAI_COMPAT_ENDPOINTS['anthropic'] endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['anthropic']
}); });
break; break;
case 'gemini': case 'gemini':
this.#providers.push({ this.#providers.push({
name: 'google', name: 'google',
key: env.GEMINI_API_KEY, key: env.GEMINI_API_KEY,
endpoint: OPENAI_COMPAT_ENDPOINTS['google'] endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['google']
}); });
break; break;
case 'cloudflare': case 'cloudflare':
this.#providers.push({ this.#providers.push({
name: 'cloudflare', name: 'cloudflare',
key: env.CLOUDFLARE_API_KEY, key: env.CLOUDFLARE_API_KEY,
endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID) endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID)
}) })
default: default:
this.#providers.push({ this.#providers.push({
name: detectedProvider, name: detectedProvider,
key: env[envKeys[i]], key: env[envKeys[i]],
endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider] endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider]
}); });
} }
} }

View File

@@ -30,7 +30,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
dataCallback: (data: any) => void, dataCallback: (data: any) => void,
) { ) {
const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens }); const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens });
const safeMessages = ChatSdk.buildMessageChain(param.messages, { const safeMessages = await ChatSdk.buildMessageChain(param.messages, {
systemPrompt: param.systemPrompt, systemPrompt: param.systemPrompt,
model: param.model, model: param.model,
assistantPrompt, assistantPrompt,

View File

@@ -1,10 +1,11 @@
import { OpenAI } from "openai"; import { OpenAI } from "openai";
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts"; import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts";
import {ProviderRepository} from "./_ProviderRepository";
export class OllamaChatProvider extends BaseChatProvider { export class OllamaChatProvider extends BaseChatProvider {
getOpenAIClient(param: CommonProviderParams): OpenAI { getOpenAIClient(param: CommonProviderParams): OpenAI {
return new OpenAI({ return new OpenAI({
baseURL: param.env.OLLAMA_API_ENDPOINT ?? , baseURL: param.env.OLLAMA_API_ENDPOINT ?? ProviderRepository.OPENAI_COMPAT_ENDPOINTS.ollama ,
apiKey: param.env.OLLAMA_API_KEY, apiKey: param.env.OLLAMA_API_KEY,
}); });
} }

View File

@@ -227,24 +227,6 @@ describe('ChatService', () => {
Response.json = originalResponseJson; Response.json = originalResponseJson;
localService.getSupportedModels = originalGetSupportedModels; localService.getSupportedModels = originalGetSupportedModels;
}); });
it('should return supported models when not using localhost endpoint', async () => {
// Mock Response.json
const originalResponseJson = Response.json;
Response.json = vi.fn().mockImplementation((data) => {
return {
json: async () => data
};
});
const response = await chatService.getSupportedModels();
const data = await response.json();
expect(data).toEqual(SUPPORTED_MODELS);
// Restore Response.json
Response.json = originalResponseJson;
});
}); });
describe('handleChatRequest', () => { describe('handleChatRequest', () => {