fixes tests

This commit is contained in:
geoffsee
2025-06-09 23:17:00 -04:00
committed by Geoff Seemueller
parent 362f50bf85
commit 0c999e0400
7 changed files with 52 additions and 67 deletions

View File

@@ -10,7 +10,7 @@
],
"scripts": {
"clean": "packages/scripts/cleanup.sh",
"test:all": "bunx vitest",
"test:all": "bun run --filter='*' tests",
"client:dev": "(cd packages/client && bun run dev)",
"server:dev": "bun build:client && (cd packages/server && bun run dev)",
"build": "(cd packages/cloudflare-workers && bun run deploy:dry-run)",

View File

@@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ChatSdk } from '../chat-sdk.ts';
import { AssistantSdk } from '../assistant-sdk.ts';
import Message from '../../models/Message.ts';
import { getModelFamily } from '@open-gsio/ai/supported-models';
import { ProviderRepository } from '../../providers/_ProviderRepository';
// Mock dependencies
vi.mock('../assistant-sdk', () => ({
@@ -17,8 +17,10 @@ vi.mock('../../models/Message', () => ({
}
}));
vi.mock('@open-gsio/ai/supported-models', () => ({
vi.mock('../../providers/_ProviderRepository', () => ({
ProviderRepository: {
getModelFamily: vi.fn()
}
}));
describe('ChatSdk', () => {
@@ -155,8 +157,8 @@ describe('ChatSdk', () => {
});
describe('buildMessageChain', () => {
it('should build a message chain with system role for most models', () => {
vi.mocked(getModelFamily).mockReturnValue('openai');
it('should build a message chain with system role for most models', async () => {
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai');
const messages = [
{role: 'user', content: 'Hello'}
@@ -169,9 +171,9 @@ describe('ChatSdk', () => {
model: 'gpt-4'
};
const result = ChatSdk.buildMessageChain(messages, opts as any);
const result = await ChatSdk.buildMessageChain(messages, opts as any);
expect(getModelFamily).toHaveBeenCalledWith('gpt-4');
expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('gpt-4', undefined);
expect(Message.create).toHaveBeenCalledTimes(3);
expect(Message.create).toHaveBeenNthCalledWith(1, {
role: 'system',
@@ -187,8 +189,8 @@ describe('ChatSdk', () => {
});
});
it('should build a message chain with assistant role for o1, gemma, claude, or google models', () => {
vi.mocked(getModelFamily).mockReturnValue('claude');
it('should build a message chain with assistant role for o1, gemma, claude, or google models', async () => {
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('claude');
const messages = [
{ role: 'user', content: 'Hello' }
@@ -201,9 +203,9 @@ describe('ChatSdk', () => {
model: 'claude-3'
};
const result = ChatSdk.buildMessageChain(messages, opts as any);
const result = await ChatSdk.buildMessageChain(messages, opts as any);
expect(getModelFamily).toHaveBeenCalledWith('claude-3');
expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('claude-3', undefined);
expect(Message.create).toHaveBeenCalledTimes(3);
expect(Message.create).toHaveBeenNthCalledWith(1, {
role: 'assistant',
@@ -211,8 +213,8 @@ describe('ChatSdk', () => {
});
});
it('should filter out messages with empty content', () => {
vi.mocked(getModelFamily).mockReturnValue('openai');
it('should filter out messages with empty content', async () => {
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai');
const messages = [
{ role: 'user', content: 'Hello' },
@@ -228,7 +230,7 @@ describe('ChatSdk', () => {
model: 'gpt-4'
};
const result = ChatSdk.buildMessageChain(messages, opts as any);
const result = await ChatSdk.buildMessageChain(messages, opts as any);
// 2 system/assistant messages + 2 user messages (Hello and World)
expect(Message.create).toHaveBeenCalledTimes(4);

View File

@@ -88,7 +88,7 @@ export class ChatSdk {
});
}
static buildMessageChain(
static async buildMessageChain(
messages: any[],
opts: {
systemPrompt: any;
@@ -98,7 +98,7 @@ export class ChatSdk {
env: Env;
},
) {
const modelFamily = ProviderRepository.getModelFamily(opts.model, opts.env)
const modelFamily = await ProviderRepository.getModelFamily(opts.model, opts.env)
const messagesToSend = [];

View File

@@ -47,27 +47,27 @@ export class ProviderRepository {
this.#providers.push({
name: 'anthropic',
key: env.ANTHROPIC_API_KEY,
endpoint: OPENAI_COMPAT_ENDPOINTS['anthropic']
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['anthropic']
});
break;
case 'gemini':
this.#providers.push({
name: 'google',
key: env.GEMINI_API_KEY,
endpoint: OPENAI_COMPAT_ENDPOINTS['google']
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['google']
});
break;
case 'cloudflare':
this.#providers.push({
name: 'cloudflare',
key: env.CLOUDFLARE_API_KEY,
endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID)
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID)
})
default:
this.#providers.push({
name: detectedProvider,
key: env[envKeys[i]],
endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider]
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider]
});
}
}

View File

@@ -30,7 +30,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
dataCallback: (data: any) => void,
) {
const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens });
const safeMessages = ChatSdk.buildMessageChain(param.messages, {
const safeMessages = await ChatSdk.buildMessageChain(param.messages, {
systemPrompt: param.systemPrompt,
model: param.model,
assistantPrompt,

View File

@@ -1,10 +1,11 @@
import { OpenAI } from "openai";
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts";
import {ProviderRepository} from "./_ProviderRepository";
export class OllamaChatProvider extends BaseChatProvider {
getOpenAIClient(param: CommonProviderParams): OpenAI {
return new OpenAI({
baseURL: param.env.OLLAMA_API_ENDPOINT ?? ,
baseURL: param.env.OLLAMA_API_ENDPOINT ?? ProviderRepository.OPENAI_COMPAT_ENDPOINTS.ollama ,
apiKey: param.env.OLLAMA_API_KEY,
});
}

View File

@@ -227,24 +227,6 @@ describe('ChatService', () => {
Response.json = originalResponseJson;
localService.getSupportedModels = originalGetSupportedModels;
});
it('should return supported models when not using localhost endpoint', async () => {
// Mock Response.json
const originalResponseJson = Response.json;
Response.json = vi.fn().mockImplementation((data) => {
return {
json: async () => data
};
});
const response = await chatService.getSupportedModels();
const data = await response.json();
expect(data).toEqual(SUPPORTED_MODELS);
// Restore Response.json
Response.json = originalResponseJson;
});
});
describe('handleChatRequest', () => {