mirror of
https://github.com/geoffsee/open-gsio.git
synced 2025-09-08 22:56:46 +00:00
fixes tests
This commit is contained in:

committed by
Geoff Seemueller

parent
362f50bf85
commit
0c999e0400
@@ -10,7 +10,7 @@
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "packages/scripts/cleanup.sh",
|
||||
"test:all": "bunx vitest",
|
||||
"test:all": "bun run --filter='*' tests",
|
||||
"client:dev": "(cd packages/client && bun run dev)",
|
||||
"server:dev": "bun build:client && (cd packages/server && bun run dev)",
|
||||
"build": "(cd packages/cloudflare-workers && bun run deploy:dry-run)",
|
||||
|
@@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ChatSdk } from '../chat-sdk.ts';
|
||||
import { AssistantSdk } from '../assistant-sdk.ts';
|
||||
import Message from '../../models/Message.ts';
|
||||
import { getModelFamily } from '@open-gsio/ai/supported-models';
|
||||
import { ProviderRepository } from '../../providers/_ProviderRepository';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../assistant-sdk', () => ({
|
||||
@@ -17,8 +17,10 @@ vi.mock('../../models/Message', () => ({
|
||||
}
|
||||
}));
|
||||
|
||||
vi.mock('@open-gsio/ai/supported-models', () => ({
|
||||
vi.mock('../../providers/_ProviderRepository', () => ({
|
||||
ProviderRepository: {
|
||||
getModelFamily: vi.fn()
|
||||
}
|
||||
}));
|
||||
|
||||
describe('ChatSdk', () => {
|
||||
@@ -155,8 +157,8 @@ describe('ChatSdk', () => {
|
||||
});
|
||||
|
||||
describe('buildMessageChain', () => {
|
||||
it('should build a message chain with system role for most models', () => {
|
||||
vi.mocked(getModelFamily).mockReturnValue('openai');
|
||||
it('should build a message chain with system role for most models', async () => {
|
||||
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai');
|
||||
|
||||
const messages = [
|
||||
{role: 'user', content: 'Hello'}
|
||||
@@ -169,9 +171,9 @@ describe('ChatSdk', () => {
|
||||
model: 'gpt-4'
|
||||
};
|
||||
|
||||
const result = ChatSdk.buildMessageChain(messages, opts as any);
|
||||
const result = await ChatSdk.buildMessageChain(messages, opts as any);
|
||||
|
||||
expect(getModelFamily).toHaveBeenCalledWith('gpt-4');
|
||||
expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('gpt-4', undefined);
|
||||
expect(Message.create).toHaveBeenCalledTimes(3);
|
||||
expect(Message.create).toHaveBeenNthCalledWith(1, {
|
||||
role: 'system',
|
||||
@@ -187,8 +189,8 @@ describe('ChatSdk', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should build a message chain with assistant role for o1, gemma, claude, or google models', () => {
|
||||
vi.mocked(getModelFamily).mockReturnValue('claude');
|
||||
it('should build a message chain with assistant role for o1, gemma, claude, or google models', async () => {
|
||||
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('claude');
|
||||
|
||||
const messages = [
|
||||
{ role: 'user', content: 'Hello' }
|
||||
@@ -201,9 +203,9 @@ describe('ChatSdk', () => {
|
||||
model: 'claude-3'
|
||||
};
|
||||
|
||||
const result = ChatSdk.buildMessageChain(messages, opts as any);
|
||||
const result = await ChatSdk.buildMessageChain(messages, opts as any);
|
||||
|
||||
expect(getModelFamily).toHaveBeenCalledWith('claude-3');
|
||||
expect(ProviderRepository.getModelFamily).toHaveBeenCalledWith('claude-3', undefined);
|
||||
expect(Message.create).toHaveBeenCalledTimes(3);
|
||||
expect(Message.create).toHaveBeenNthCalledWith(1, {
|
||||
role: 'assistant',
|
||||
@@ -211,8 +213,8 @@ describe('ChatSdk', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should filter out messages with empty content', () => {
|
||||
vi.mocked(getModelFamily).mockReturnValue('openai');
|
||||
it('should filter out messages with empty content', async () => {
|
||||
vi.mocked(ProviderRepository.getModelFamily).mockResolvedValue('openai');
|
||||
|
||||
const messages = [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
@@ -228,7 +230,7 @@ describe('ChatSdk', () => {
|
||||
model: 'gpt-4'
|
||||
};
|
||||
|
||||
const result = ChatSdk.buildMessageChain(messages, opts as any);
|
||||
const result = await ChatSdk.buildMessageChain(messages, opts as any);
|
||||
|
||||
// 2 system/assistant messages + 2 user messages (Hello and World)
|
||||
expect(Message.create).toHaveBeenCalledTimes(4);
|
||||
|
@@ -88,7 +88,7 @@ export class ChatSdk {
|
||||
});
|
||||
}
|
||||
|
||||
static buildMessageChain(
|
||||
static async buildMessageChain(
|
||||
messages: any[],
|
||||
opts: {
|
||||
systemPrompt: any;
|
||||
@@ -98,7 +98,7 @@ export class ChatSdk {
|
||||
env: Env;
|
||||
},
|
||||
) {
|
||||
const modelFamily = ProviderRepository.getModelFamily(opts.model, opts.env)
|
||||
const modelFamily = await ProviderRepository.getModelFamily(opts.model, opts.env)
|
||||
|
||||
const messagesToSend = [];
|
||||
|
||||
|
@@ -47,27 +47,27 @@ export class ProviderRepository {
|
||||
this.#providers.push({
|
||||
name: 'anthropic',
|
||||
key: env.ANTHROPIC_API_KEY,
|
||||
endpoint: OPENAI_COMPAT_ENDPOINTS['anthropic']
|
||||
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['anthropic']
|
||||
});
|
||||
break;
|
||||
case 'gemini':
|
||||
this.#providers.push({
|
||||
name: 'google',
|
||||
key: env.GEMINI_API_KEY,
|
||||
endpoint: OPENAI_COMPAT_ENDPOINTS['google']
|
||||
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS['google']
|
||||
});
|
||||
break;
|
||||
case 'cloudflare':
|
||||
this.#providers.push({
|
||||
name: 'cloudflare',
|
||||
key: env.CLOUDFLARE_API_KEY,
|
||||
endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID)
|
||||
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID)
|
||||
})
|
||||
default:
|
||||
this.#providers.push({
|
||||
name: detectedProvider,
|
||||
key: env[envKeys[i]],
|
||||
endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider]
|
||||
endpoint: ProviderRepository.OPENAI_COMPAT_ENDPOINTS[detectedProvider]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@@ -30,7 +30,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
||||
dataCallback: (data: any) => void,
|
||||
) {
|
||||
const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens });
|
||||
const safeMessages = ChatSdk.buildMessageChain(param.messages, {
|
||||
const safeMessages = await ChatSdk.buildMessageChain(param.messages, {
|
||||
systemPrompt: param.systemPrompt,
|
||||
model: param.model,
|
||||
assistantPrompt,
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import { OpenAI } from "openai";
|
||||
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts";
|
||||
import {ProviderRepository} from "./_ProviderRepository";
|
||||
|
||||
export class OllamaChatProvider extends BaseChatProvider {
|
||||
getOpenAIClient(param: CommonProviderParams): OpenAI {
|
||||
return new OpenAI({
|
||||
baseURL: param.env.OLLAMA_API_ENDPOINT ?? ,
|
||||
baseURL: param.env.OLLAMA_API_ENDPOINT ?? ProviderRepository.OPENAI_COMPAT_ENDPOINTS.ollama ,
|
||||
apiKey: param.env.OLLAMA_API_KEY,
|
||||
});
|
||||
}
|
||||
|
@@ -227,24 +227,6 @@ describe('ChatService', () => {
|
||||
Response.json = originalResponseJson;
|
||||
localService.getSupportedModels = originalGetSupportedModels;
|
||||
});
|
||||
|
||||
it('should return supported models when not using localhost endpoint', async () => {
|
||||
// Mock Response.json
|
||||
const originalResponseJson = Response.json;
|
||||
Response.json = vi.fn().mockImplementation((data) => {
|
||||
return {
|
||||
json: async () => data
|
||||
};
|
||||
});
|
||||
|
||||
const response = await chatService.getSupportedModels();
|
||||
const data = await response.json();
|
||||
|
||||
expect(data).toEqual(SUPPORTED_MODELS);
|
||||
|
||||
// Restore Response.json
|
||||
Response.json = originalResponseJson;
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleChatRequest', () => {
|
||||
|
Reference in New Issue
Block a user