diff --git a/workers/site/providers/cerebras.ts b/workers/site/providers/cerebras.ts index 782c95a..2938f19 100644 --- a/workers/site/providers/cerebras.ts +++ b/workers/site/providers/cerebras.ts @@ -7,8 +7,47 @@ import { UnionStringArray, } from "mobx-state-tree"; import ChatSdk from "../lib/chat-sdk"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; + +export class CerebrasChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + return new OpenAI({ + baseURL: "https://api.cerebras.ai/v1", + apiKey: param.env.CEREBRAS_API_KEY, + }); + } + + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { + const llamaTuningParams = { + temperature: 0.86, + top_p: 0.98, + presence_penalty: 0.1, + frequency_penalty: 0.3, + max_tokens: param.maxTokens as number, + }; + + return { + model: param.model, + messages: safeMessages, + stream: true, + }; + } + + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + // Check if this is the final chunk + if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { + dataCallback({ type: "chat", data: chunk }); + return true; // Break the stream + } + + dataCallback({ type: "chat", data: chunk }); + return false; // Continue the stream + } +} export class CerebrasSdk { + private static provider = new CerebrasChatProvider(); + static async handleCerebrasStream( param: { openai: OpenAI; @@ -28,73 +67,17 @@ export class CerebrasSdk { }, dataCallback: (data) => void, ) { - const { - preprocessedContext, - messages, - env, - maxTokens, - systemPrompt, - model, - } = param; - - const assistantPrompt = ChatSdk.buildAssistantPrompt({ - maxTokens: maxTokens, - }); - - const safeMessages = ChatSdk.buildMessageChain(messages, { - systemPrompt: systemPrompt, - model, - assistantPrompt, - toolResults: preprocessedContext, - }); - - const openai = new OpenAI({ - baseURL: "https://api.cerebras.ai/v1", - apiKey: param.env.CEREBRAS_API_KEY, - }); - - return CerebrasSdk.streamCerebrasResponse( - safeMessages, + return this.provider.handleStream( { - model: param.model, + systemPrompt: param.systemPrompt, + preprocessedContext: param.preprocessedContext, maxTokens: param.maxTokens, - openai: openai, + messages: param.messages, + model: param.model, + env: param.env, + disableWebhookGeneration: param.disableWebhookGeneration, }, dataCallback, ); } - private static async streamCerebrasResponse( - messages: any[], - opts: { - model: string; - maxTokens: number | unknown | undefined; - openai: OpenAI; - }, - dataCallback: (data: any) => void, - ) { - const tuningParams: Record = {}; - - const llamaTuningParams = { - temperature: 0.86, - top_p: 0.98, - presence_penalty: 0.1, - frequency_penalty: 0.3, - max_tokens: opts.maxTokens, - }; - - const getLlamaTuningParams = () => { - return llamaTuningParams; - }; - - const groqStream = await opts.openai.chat.completions.create({ - model: opts.model, - messages: messages, - - stream: true, - }); - - for await (const chunk of groqStream) { - dataCallback({ type: "chat", data: chunk }); - } - } } diff --git a/workers/site/providers/chat-stream-provider.ts b/workers/site/providers/chat-stream-provider.ts new file mode 100644 index 0000000..1d5b741 --- /dev/null +++ b/workers/site/providers/chat-stream-provider.ts @@ -0,0 +1,49 @@ +import { OpenAI } from "openai"; +import ChatSdk from "../lib/chat-sdk"; + +export interface CommonProviderParams { + openai?: OpenAI; // Optional for providers that use a custom client. + systemPrompt: any; + preprocessedContext: any; + maxTokens: number | unknown | undefined; + messages: any; + model: string; + env: Env; + disableWebhookGeneration?: boolean; + // Additional fields can be added as needed +} + +export interface ChatStreamProvider { + handleStream( + param: CommonProviderParams, + dataCallback: (data: any) => void, + ): Promise; +} + +export abstract class BaseChatProvider implements ChatStreamProvider { + abstract getOpenAIClient(param: CommonProviderParams): OpenAI; + abstract getStreamParams(param: CommonProviderParams, safeMessages: any[]): any; + abstract async processChunk(chunk: any, dataCallback: (data: any) => void): Promise; + + async handleStream( + param: CommonProviderParams, + dataCallback: (data: any) => void, + ) { + const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens }); + const safeMessages = ChatSdk.buildMessageChain(param.messages, { + systemPrompt: param.systemPrompt, + model: param.model, + assistantPrompt, + toolResults: param.preprocessedContext, + }); + + const client = this.getOpenAIClient(param); + const streamParams = this.getStreamParams(param, safeMessages); + const stream = await client.chat.completions.create(streamParams); + + for await (const chunk of stream) { + const shouldBreak = await this.processChunk(chunk, dataCallback); + if (shouldBreak) break; + } + } +} \ No newline at end of file diff --git a/workers/site/providers/claude.ts b/workers/site/providers/claude.ts index 1af06e0..7063723 100644 --- a/workers/site/providers/claude.ts +++ b/workers/site/providers/claude.ts @@ -8,43 +8,88 @@ import { UnionStringArray, } from "mobx-state-tree"; import ChatSdk from "../lib/chat-sdk"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; -export class ClaudeChatSdk { - private static async streamClaudeResponse( - messages: any[], - param: { - model: string; - maxTokens: number | unknown | undefined; - anthropic: Anthropic; - }, - dataCallback: (data: any) => void, - ) { - const claudeStream = await param.anthropic.messages.create({ - stream: true, - model: param.model, - max_tokens: param.maxTokens, - messages: messages, +export class ClaudeChatProvider extends BaseChatProvider { + private anthropic: Anthropic | null = null; + + getOpenAIClient(param: CommonProviderParams): OpenAI { + // Claude doesn't use OpenAI client directly, but we need to return something + // to satisfy the interface. The actual Anthropic client is created in getStreamParams. + return param.openai as OpenAI; + } + + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { + this.anthropic = new Anthropic({ + apiKey: param.env.ANTHROPIC_API_KEY, }); - for await (const chunk of claudeStream) { - if (chunk.type === "message_stop") { - dataCallback({ - type: "chat", - data: { - choices: [ - { - delta: { content: "" }, - logprobs: null, - finish_reason: "stop", - }, - ], - }, - }); - break; - } - dataCallback({ type: "chat", data: chunk }); + const claudeTuningParams = { + temperature: 0.7, + max_tokens: param.maxTokens as number, + }; + + return { + stream: true, + model: param.model, + messages: safeMessages, + ...claudeTuningParams + }; + } + + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + if (chunk.type === "message_stop") { + dataCallback({ + type: "chat", + data: { + choices: [ + { + delta: { content: "" }, + logprobs: null, + finish_reason: "stop", + }, + ], + }, + }); + return true; // Break the stream + } + + dataCallback({ type: "chat", data: chunk }); + return false; // Continue the stream + } + + // Override the base handleStream method to use Anthropic client instead of OpenAI + async handleStream( + param: CommonProviderParams, + dataCallback: (data: any) => void, + ) { + const assistantPrompt = ChatSdk.buildAssistantPrompt({ maxTokens: param.maxTokens }); + const safeMessages = ChatSdk.buildMessageChain(param.messages, { + systemPrompt: param.systemPrompt, + model: param.model, + assistantPrompt, + toolResults: param.preprocessedContext, + }); + + const streamParams = this.getStreamParams(param, safeMessages); + + if (!this.anthropic) { + throw new Error("Anthropic client not initialized"); + } + + const stream = await this.anthropic.messages.create(streamParams); + + for await (const chunk of stream) { + const shouldBreak = await this.processChunk(chunk, dataCallback); + if (shouldBreak) break; } } +} + +// Legacy class for backward compatibility +export class ClaudeChatSdk { + private static provider = new ClaudeChatProvider(); + static async handleClaudeStream( param: { openai: OpenAI; @@ -63,36 +108,15 @@ export class ClaudeChatSdk { }, dataCallback: (data) => void, ) { - const { - preprocessedContext, - messages, - env, - maxTokens, - systemPrompt, - model, - } = param; - - const assistantPrompt = ChatSdk.buildAssistantPrompt({ - maxTokens: maxTokens, - }); - - const safeMessages = ChatSdk.buildMessageChain(messages, { - systemPrompt: systemPrompt, - model, - assistantPrompt, - toolResults: preprocessedContext, - }); - - const anthropic = new Anthropic({ - apiKey: env.ANTHROPIC_API_KEY, - }); - - return ClaudeChatSdk.streamClaudeResponse( - safeMessages, + return this.provider.handleStream( { - model: param.model, + openai: param.openai, + systemPrompt: param.systemPrompt, + preprocessedContext: param.preprocessedContext, maxTokens: param.maxTokens, - anthropic: anthropic, + messages: param.messages, + model: param.model, + env: param.env, }, dataCallback, ); diff --git a/workers/site/providers/cloudflareAi.ts b/workers/site/providers/cloudflareAi.ts index b19c8e2..a02e0b0 100644 --- a/workers/site/providers/cloudflareAi.ts +++ b/workers/site/providers/cloudflareAi.ts @@ -7,8 +7,122 @@ import { UnionStringArray, } from "mobx-state-tree"; import ChatSdk from "../lib/chat-sdk"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; + +export class CloudflareAiChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + const cfAiURL = `https://api.cloudflare.com/client/v4/accounts/${param.env.CLOUDFLARE_ACCOUNT_ID}/ai/v1`; + + return new OpenAI({ + apiKey: param.env.CLOUDFLARE_API_KEY, + baseURL: cfAiURL, + }); + } + + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { + const generationParams: Record = { + model: this.getModelWithPrefix(param.model), + messages: safeMessages, + stream: true, + }; + + // Set max_tokens based on model + if (this.getModelPrefix(param.model) === "@cf/meta") { + generationParams["max_tokens"] = 4096; + } + + if (this.getModelPrefix(param.model) === "@hf/mistral") { + generationParams["max_tokens"] = 4096; + } + + if (param.model.toLowerCase().includes("hermes-2-pro-mistral-7b")) { + generationParams["max_tokens"] = 1000; + } + + if (param.model.toLowerCase().includes("openhermes-2.5-mistral-7b-awq")) { + generationParams["max_tokens"] = 1000; + } + + if (param.model.toLowerCase().includes("deepseek-coder-6.7b-instruct-awq")) { + generationParams["max_tokens"] = 590; + } + + if (param.model.toLowerCase().includes("deepseek-math-7b-instruct")) { + generationParams["max_tokens"] = 512; + } + + if (param.model.toLowerCase().includes("neural-chat-7b-v3-1-awq")) { + generationParams["max_tokens"] = 590; + } + + if (param.model.toLowerCase().includes("openchat-3.5-0106")) { + generationParams["max_tokens"] = 2000; + } + + return generationParams; + } + + private getModelPrefix(model: string): string { + let modelPrefix = `@cf/meta`; + + if (model.toLowerCase().includes("llama")) { + modelPrefix = `@cf/meta`; + } + + if (model.toLowerCase().includes("hermes-2-pro-mistral-7b")) { + modelPrefix = `@hf/nousresearch`; + } + + if (model.toLowerCase().includes("mistral-7b-instruct")) { + modelPrefix = `@hf/mistral`; + } + + if (model.toLowerCase().includes("gemma")) { + modelPrefix = `@cf/google`; + } + + if (model.toLowerCase().includes("deepseek")) { + modelPrefix = `@cf/deepseek-ai`; + } + + if (model.toLowerCase().includes("openchat-3.5-0106")) { + modelPrefix = `@cf/openchat`; + } + + const isNueralChat = model + .toLowerCase() + .includes("neural-chat-7b-v3-1-awq"); + if ( + isNueralChat || + model.toLowerCase().includes("openhermes-2.5-mistral-7b-awq") || + model.toLowerCase().includes("zephyr-7b-beta-awq") || + model.toLowerCase().includes("deepseek-coder-6.7b-instruct-awq") + ) { + modelPrefix = `@hf/thebloke`; + } + + return modelPrefix; + } + + private getModelWithPrefix(model: string): string { + return `${this.getModelPrefix(model)}/${model}`; + } + + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + // Check if this is the final chunk + if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { + dataCallback({ type: "chat", data: chunk }); + return true; // Break the stream + } + + dataCallback({ type: "chat", data: chunk }); + return false; // Continue the stream + } +} export class CloudflareAISdk { + private static provider = new CloudflareAiChatProvider(); + static async handleCloudflareAIStream( param: { openai: OpenAI; @@ -27,148 +141,16 @@ export class CloudflareAISdk { }, dataCallback: (data) => void, ) { - const { - preprocessedContext, - messages, - env, - maxTokens, - systemPrompt, - model, - } = param; - - const assistantPrompt = ChatSdk.buildAssistantPrompt({ - maxTokens: maxTokens, - }); - const safeMessages = ChatSdk.buildMessageChain(messages, { - systemPrompt: systemPrompt, - model, - assistantPrompt, - toolResults: preprocessedContext, - }); - - const cfAiURL = `https://api.cloudflare.com/client/v4/accounts/${env.CLOUDFLARE_ACCOUNT_ID}/ai/v1`; - - console.log({ cfAiURL }); - const openai = new OpenAI({ - apiKey: env.CLOUDFLARE_API_KEY, - baseURL: cfAiURL, - }); - - return CloudflareAISdk.streamCloudflareAIResponse( - safeMessages, + return this.provider.handleStream( { - model: param.model, + systemPrompt: param.systemPrompt, + preprocessedContext: param.preprocessedContext, maxTokens: param.maxTokens, - openai: openai, + messages: param.messages, + model: param.model, + env: param.env, }, dataCallback, ); } - private static async streamCloudflareAIResponse( - messages: any[], - opts: { - model: string; - maxTokens: number | unknown | undefined; - openai: OpenAI; - }, - dataCallback: (data: any) => void, - ) { - const tuningParams: Record = {}; - - const llamaTuningParams = { - temperature: 0.86, - top_p: 0.98, - presence_penalty: 0.1, - frequency_penalty: 0.3, - max_tokens: opts.maxTokens, - }; - - const getLlamaTuningParams = () => { - return llamaTuningParams; - }; - - let modelPrefix = `@cf/meta`; - - if (opts.model.toLowerCase().includes("llama")) { - modelPrefix = `@cf/meta`; - } - - if (opts.model.toLowerCase().includes("hermes-2-pro-mistral-7b")) { - modelPrefix = `@hf/nousresearch`; - } - - if (opts.model.toLowerCase().includes("mistral-7b-instruct")) { - modelPrefix = `@hf/mistral`; - } - - if (opts.model.toLowerCase().includes("gemma")) { - modelPrefix = `@cf/google`; - } - - if (opts.model.toLowerCase().includes("deepseek")) { - modelPrefix = `@cf/deepseek-ai`; - } - - if (opts.model.toLowerCase().includes("openchat-3.5-0106")) { - modelPrefix = `@cf/openchat`; - } - - const isNueralChat = opts.model - .toLowerCase() - .includes("neural-chat-7b-v3-1-awq"); - if ( - isNueralChat || - opts.model.toLowerCase().includes("openhermes-2.5-mistral-7b-awq") || - opts.model.toLowerCase().includes("zephyr-7b-beta-awq") || - opts.model.toLowerCase().includes("deepseek-coder-6.7b-instruct-awq") - ) { - modelPrefix = `@hf/thebloke`; - } - - const generationParams: Record = { - model: `${modelPrefix}/${opts.model}`, - messages: messages, - stream: true, - }; - - if (modelPrefix === "@cf/meta") { - generationParams["max_tokens"] = 4096; - } - - if (modelPrefix === "@hf/mistral") { - generationParams["max_tokens"] = 4096; - } - - if (opts.model.toLowerCase().includes("hermes-2-pro-mistral-7b")) { - generationParams["max_tokens"] = 1000; - } - - if (opts.model.toLowerCase().includes("openhermes-2.5-mistral-7b-awq")) { - generationParams["max_tokens"] = 1000; - } - - if (opts.model.toLowerCase().includes("deepseek-coder-6.7b-instruct-awq")) { - generationParams["max_tokens"] = 590; - } - - if (opts.model.toLowerCase().includes("deepseek-math-7b-instruct")) { - generationParams["max_tokens"] = 512; - } - - if (opts.model.toLowerCase().includes("neural-chat-7b-v3-1-awq")) { - generationParams["max_tokens"] = 590; - } - - if (opts.model.toLowerCase().includes("openchat-3.5-0106")) { - generationParams["max_tokens"] = 2000; - } - - const cloudflareAiStream = await opts.openai.chat.completions.create({ - ...generationParams, - }); - - for await (const chunk of cloudflareAiStream) { - dataCallback({ type: "chat", data: chunk }); - } - } } diff --git a/workers/site/providers/fireworks.ts b/workers/site/providers/fireworks.ts index 76a186f..a5a21a8 100644 --- a/workers/site/providers/fireworks.ts +++ b/workers/site/providers/fireworks.ts @@ -10,33 +10,44 @@ import { } from "mobx-state-tree"; import Message from "../models/Message"; import ChatSdk from "../lib/chat-sdk"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; -export class FireworksAiChatSdk { - private static async streamFireworksResponse( - messages: any[], - opts: { - model: string; - maxTokens: number | unknown | undefined; - openai: OpenAI; - }, - dataCallback: (data: any) => void, - ) { +export class FireworksAiChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + return new OpenAI({ + apiKey: param.env.FIREWORKS_API_KEY, + baseURL: "https://api.fireworks.ai/inference/v1", + }); + } + + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { let modelPrefix = "accounts/fireworks/models/"; - if (opts.model.toLowerCase().includes("yi-")) { + if (param.model.toLowerCase().includes("yi-")) { modelPrefix = "accounts/yi-01-ai/models/"; } - const fireworksStream = await opts.openai.chat.completions.create({ - model: `${modelPrefix}${opts.model}`, - messages: messages, + return { + model: `${modelPrefix}${param.model}`, + messages: safeMessages, stream: true, - }); - - for await (const chunk of fireworksStream) { - dataCallback({ type: "chat", data: chunk }); - } + }; } + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + // Check if this is the final chunk + if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { + dataCallback({ type: "chat", data: chunk }); + return true; // Break the stream + } + + dataCallback({ type: "chat", data: chunk }); + return false; // Continue the stream + } +} + +export class FireworksAiChatSdk { + private static provider = new FireworksAiChatProvider(); + static async handleFireworksStream( param: { openai: OpenAI; @@ -55,36 +66,14 @@ export class FireworksAiChatSdk { }, dataCallback: (data) => void, ) { - const { - preprocessedContext, - messages, - env, - maxTokens, - systemPrompt, - model, - } = param; - - const assistantPrompt = ChatSdk.buildAssistantPrompt({ - maxTokens: maxTokens, - }); - - const safeMessages = ChatSdk.buildMessageChain(messages, { - systemPrompt: systemPrompt, - model, - assistantPrompt, - toolResults: preprocessedContext, - }); - - const fireworksOpenAIClient = new OpenAI({ - apiKey: param.env.FIREWORKS_API_KEY, - baseURL: "https://api.fireworks.ai/inference/v1", - }); - return FireworksAiChatSdk.streamFireworksResponse( - safeMessages, + return this.provider.handleStream( { - model: param.model, + systemPrompt: param.systemPrompt, + preprocessedContext: param.preprocessedContext, maxTokens: param.maxTokens, - openai: fireworksOpenAIClient, + messages: param.messages, + model: param.model, + env: param.env, }, dataCallback, ); diff --git a/workers/site/providers/google.ts b/workers/site/providers/google.ts index b7c3c98..f306bd8 100644 --- a/workers/site/providers/google.ts +++ b/workers/site/providers/google.ts @@ -1,97 +1,75 @@ import { OpenAI } from "openai"; import ChatSdk from "../lib/chat-sdk"; import { StreamParams } from "../services/ChatService"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; + +export class GoogleChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + return new OpenAI({ + baseURL: "https://generativelanguage.googleapis.com/v1beta/openai", + apiKey: param.env.GEMINI_API_KEY, + }); + } + + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { + return { + model: param.model, + messages: safeMessages, + stream: true, + }; + } + + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + if (chunk.choices?.[0]?.finish_reason === "stop") { + dataCallback({ + type: "chat", + data: { + choices: [ + { + delta: { content: chunk.choices[0].delta.content || "" }, + finish_reason: "stop", + index: chunk.choices[0].index, + }, + ], + }, + }); + return true; // Break the stream + } else { + dataCallback({ + type: "chat", + data: { + choices: [ + { + delta: { content: chunk.choices?.[0]?.delta?.content || "" }, + finish_reason: null, + index: chunk.choices?.[0]?.index || 0, + }, + ], + }, + }); + return false; // Continue the stream + } + } +} export class GoogleChatSdk { + private static provider = new GoogleChatProvider(); + static async handleGoogleStream( param: StreamParams, dataCallback: (data) => void, ) { - const { - preprocessedContext, - messages, - env, - maxTokens, - systemPrompt, - model, - } = param; - - const assistantPrompt = ChatSdk.buildAssistantPrompt({ - maxTokens: maxTokens, - }); - - const safeMessages = ChatSdk.buildMessageChain(messages, { - systemPrompt: systemPrompt, - model, - assistantPrompt, - toolResults: preprocessedContext, - }); - - const openai = new OpenAI({ - baseURL: "https://generativelanguage.googleapis.com/v1beta/openai", - apiKey: param.env.GEMINI_API_KEY, - }); - - return GoogleChatSdk.streamGoogleResponse( - safeMessages, + return this.provider.handleStream( { - model: param.model, + systemPrompt: param.systemPrompt, + preprocessedContext: param.preprocessedContext, maxTokens: param.maxTokens, - openai: openai, + messages: param.messages, + model: param.model, + env: param.env, + disableWebhookGeneration: param.disableWebhookGeneration, }, dataCallback, ); } - private static async streamGoogleResponse( - messages: any[], - opts: { - model: string; - maxTokens: number | unknown | undefined; - openai: OpenAI; - }, - dataCallback: (data: any) => void, - ) { - const chatReq = JSON.stringify({ - model: opts.model, - messages: messages, - stream: true, - }); - - const googleStream = await opts.openai.chat.completions.create( - JSON.parse(chatReq), - ); - - for await (const chunk of googleStream) { - console.log(JSON.stringify(chunk)); - - if (chunk.choices?.[0]?.finishReason === "stop") { - dataCallback({ - type: "chat", - data: { - choices: [ - { - delta: { content: chunk.choices[0].delta.content || "" }, - finish_reason: "stop", - index: chunk.choices[0].index, - }, - ], - }, - }); - break; - } else { - dataCallback({ - type: "chat", - data: { - choices: [ - { - delta: { content: chunk.choices?.[0]?.delta?.content || "" }, - finish_reason: null, - index: chunk.choices?.[0]?.index || 0, - }, - ], - }, - }); - } - } - } } diff --git a/workers/site/providers/groq.ts b/workers/site/providers/groq.ts index 28b2daa..3d9b73f 100644 --- a/workers/site/providers/groq.ts +++ b/workers/site/providers/groq.ts @@ -6,9 +6,49 @@ import { ModelSnapshotType2, UnionStringArray, } from "mobx-state-tree"; -import ChatSdk from "../lib/chat-sdk"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; +export class GroqChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + return new OpenAI({ + baseURL: "https://api.groq.com/openai/v1", + apiKey: param.env.GROQ_API_KEY, + }); + } + + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { + const llamaTuningParams = { + temperature: 0.86, + top_p: 0.98, + presence_penalty: 0.1, + frequency_penalty: 0.3, + max_tokens: param.maxTokens as number, + }; + + return { + model: param.model, + messages: safeMessages, + stream: true, + ...llamaTuningParams + }; + } + + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + // Check if this is the final chunk + if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { + dataCallback({ type: "chat", data: chunk }); + return true; // Break the stream + } + + dataCallback({ type: "chat", data: chunk }); + return false; // Continue the stream + } +} + +// Legacy class for backward compatibility export class GroqChatSdk { + private static provider = new GroqChatProvider(); + static async handleGroqStream( param: { openai: OpenAI; @@ -27,73 +67,16 @@ export class GroqChatSdk { }, dataCallback: (data) => void, ) { - const { - preprocessedContext, - messages, - env, - maxTokens, - systemPrompt, - model, - } = param; - - const assistantPrompt = ChatSdk.buildAssistantPrompt({ - maxTokens: maxTokens, - }); - const safeMessages = ChatSdk.buildMessageChain(messages, { - systemPrompt: systemPrompt, - model, - assistantPrompt, - toolResults: preprocessedContext, - }); - - const openai = new OpenAI({ - baseURL: "https://api.groq.com/openai/v1", - apiKey: param.env.GROQ_API_KEY, - }); - - return GroqChatSdk.streamGroqResponse( - safeMessages, + return this.provider.handleStream( { - model: param.model, + systemPrompt: param.systemPrompt, + preprocessedContext: param.preprocessedContext, maxTokens: param.maxTokens, - openai: openai, + messages: param.messages, + model: param.model, + env: param.env, }, dataCallback, ); } - private static async streamGroqResponse( - messages: any[], - opts: { - model: string; - maxTokens: number | unknown | undefined; - openai: OpenAI; - }, - dataCallback: (data: any) => void, - ) { - const tuningParams: Record = {}; - - const llamaTuningParams = { - temperature: 0.86, - top_p: 0.98, - presence_penalty: 0.1, - frequency_penalty: 0.3, - max_tokens: opts.maxTokens, - }; - - const getLlamaTuningParams = () => { - return llamaTuningParams; - }; - - const groqStream = await opts.openai.chat.completions.create({ - model: opts.model, - messages: messages, - frequency_penalty: 2, - stream: true, - temperature: 0.78, - }); - - for await (const chunk of groqStream) { - dataCallback({ type: "chat", data: chunk }); - } - } } diff --git a/workers/site/providers/openai.ts b/workers/site/providers/openai.ts index 42008f8..eb1fdc0 100644 --- a/workers/site/providers/openai.ts +++ b/workers/site/providers/openai.ts @@ -1,65 +1,16 @@ import { OpenAI } from "openai"; -import ChatSdk from "../lib/chat-sdk"; import { Utils } from "../lib/utils"; -import {ChatCompletionCreateParamsStreaming} from "openai/resources/chat/completions/completions"; +import { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions/completions"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; -export class OpenAiChatSdk { - static async handleOpenAiStream( - ctx: { - openai: OpenAI; - systemPrompt: any; - preprocessedContext: any; - maxTokens: unknown | number | undefined; - messages: any; - model: any; - }, - dataCallback: (data: any) => any, - ) { - const { - openai, - systemPrompt, - maxTokens, - messages, - model, - preprocessedContext, - } = ctx; - - if (!messages?.length) { - return new Response("No messages provided", { status: 400 }); - } - - const assistantPrompt = ChatSdk.buildAssistantPrompt({ - maxTokens: maxTokens, - }); - const safeMessages = ChatSdk.buildMessageChain(messages, { - systemPrompt: systemPrompt, - model, - assistantPrompt, - toolResults: preprocessedContext, - }); - - return OpenAiChatSdk.streamOpenAiResponse( - safeMessages, - { - model, - maxTokens: maxTokens as number, - openai: openai, - }, - dataCallback, - ); +export class OpenAiChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + return param.openai as OpenAI; } - private static async streamOpenAiResponse( - messages: any[], - opts: { - model: string; - maxTokens: number | undefined; - openai: OpenAI; - }, - dataCallback: (data: any) => any, - ) { + getStreamParams(param: CommonProviderParams, safeMessages: any[]): ChatCompletionCreateParamsStreaming { const isO1 = () => { - if (opts.model === "o1-preview" || opts.model === "o1-mini") { + if (param.model === "o1-preview" || param.model === "o1-mini") { return true; } }; @@ -71,55 +22,93 @@ export class OpenAiChatSdk { top_p: 0.98, presence_penalty: 0.1, frequency_penalty: 0.3, - max_tokens: opts.maxTokens, + max_tokens: param.maxTokens as number, }; const getTuningParams = () => { if (isO1()) { tuningParams["temperature"] = 1; - tuningParams["max_completion_tokens"] = opts.maxTokens + 10000; + tuningParams["max_completion_tokens"] = (param.maxTokens as number) + 10000; return tuningParams; } return gpt4oTuningParams; }; let completionRequest: ChatCompletionCreateParamsStreaming = { - model: opts.model, + model: param.model, stream: true, - messages: messages + messages: safeMessages }; - const isLocal = opts.openai.baseURL.includes("localhost"); - + const client = this.getOpenAIClient(param); + const isLocal = client.baseURL.includes("localhost"); if(isLocal) { - completionRequest["messages"] = Utils.normalizeWithBlanks(messages) - completionRequest["stream_options"] = { + completionRequest["messages"] = Utils.normalizeWithBlanks(safeMessages); + completionRequest["stream_options"] = { include_usage: true - } + }; } else { - completionRequest = {...completionRequest, ...getTuningParams()} + completionRequest = {...completionRequest, ...getTuningParams()}; } - const openAIStream = await opts.openai.chat.completions.create(completionRequest); + return completionRequest; + } - for await (const chunk of openAIStream) { - if (isLocal && chunk.usage) { - dataCallback({ - type: "chat", - data: { - choices: [ - { - delta: { content: "" }, - logprobs: null, - finish_reason: "stop", - }, - ], - }, - }); - break; - } - dataCallback({ type: "chat", data: chunk }); + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + const isLocal = chunk.usage !== undefined; + + if (isLocal && chunk.usage) { + dataCallback({ + type: "chat", + data: { + choices: [ + { + delta: { content: "" }, + logprobs: null, + finish_reason: "stop", + }, + ], + }, + }); + return true; // Break the stream } + + dataCallback({ type: "chat", data: chunk }); + return false; // Continue the stream + } +} + +// Legacy class for backward compatibility +export class OpenAiChatSdk { + private static provider = new OpenAiChatProvider(); + + static async handleOpenAiStream( + ctx: { + openai: OpenAI; + systemPrompt: any; + preprocessedContext: any; + maxTokens: unknown | number | undefined; + messages: any; + model: any; + }, + dataCallback: (data: any) => any, + ) { + if (!ctx.messages?.length) { + return new Response("No messages provided", { status: 400 }); + } + + return this.provider.handleStream( + { + openai: ctx.openai, + systemPrompt: ctx.systemPrompt, + preprocessedContext: ctx.preprocessedContext, + maxTokens: ctx.maxTokens, + messages: ctx.messages, + model: ctx.model, + env: {} as Env, // This is not used in OpenAI provider + }, + dataCallback, + ); } } diff --git a/workers/site/providers/xai.ts b/workers/site/providers/xai.ts index 8029377..73825cd 100644 --- a/workers/site/providers/xai.ts +++ b/workers/site/providers/xai.ts @@ -1,86 +1,18 @@ import { OpenAI } from "openai"; import ChatSdk from "../lib/chat-sdk"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; -export class XaiChatSdk { - static async handleXaiStream( - ctx: { - openai: OpenAI; - systemPrompt: any; - preprocessedContext: any; - maxTokens: unknown | number | undefined; - messages: any; - disableWebhookGeneration: boolean; - model: any; - env: Env; - }, - dataCallback: (data: any) => any, - ) { - const { - openai, - systemPrompt, - maxTokens, - messages, - env, - model, - preprocessedContext, - } = ctx; - - if (!messages?.length) { - return new Response("No messages provided", { status: 400 }); - } - - const getMaxTokens = async (mt) => { - if (mt) { - return await ChatSdk.calculateMaxTokens( - JSON.parse(JSON.stringify(messages)), - { - env, - maxTokens: mt, - }, - ); - } else { - return undefined; - } - }; - - const assistantPrompt = ChatSdk.buildAssistantPrompt({ - maxTokens: maxTokens, - }); - - const safeMessages = ChatSdk.buildMessageChain(messages, { - systemPrompt: systemPrompt, - model, - assistantPrompt, - toolResults: preprocessedContext, - }); - - const xAiClient = new OpenAI({ +export class XaiChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + return new OpenAI({ baseURL: "https://api.x.ai/v1", - apiKey: env.XAI_API_KEY, + apiKey: param.env.XAI_API_KEY, }); - - return XaiChatSdk.streamOpenAiResponse( - safeMessages, - { - model, - maxTokens: maxTokens as number, - openai: xAiClient, - }, - dataCallback, - ); } - private static async streamOpenAiResponse( - messages: any[], - opts: { - model: string; - maxTokens: number | undefined; - openai: OpenAI; - }, - dataCallback: (data: any) => any, - ) { + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { const isO1 = () => { - if (opts.model === "o1-preview" || opts.model === "o1-mini") { + if (param.model === "o1-preview" || param.model === "o1-mini") { return true; } }; @@ -94,21 +26,63 @@ export class XaiChatSdk { const getTuningParams = () => { if (isO1()) { tuningParams["temperature"] = 1; - tuningParams["max_completion_tokens"] = opts.maxTokens + 10000; + tuningParams["max_completion_tokens"] = (param.maxTokens as number) + 10000; return tuningParams; } return gpt4oTuningParams; }; - const xAIStream = await opts.openai.chat.completions.create({ - model: opts.model, - messages: messages, + return { + model: param.model, + messages: safeMessages, stream: true, ...getTuningParams(), - }); + }; + } - for await (const chunk of xAIStream) { + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + // Check if this is the final chunk + if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { dataCallback({ type: "chat", data: chunk }); + return true; // Break the stream } + + dataCallback({ type: "chat", data: chunk }); + return false; // Continue the stream + } +} + +export class XaiChatSdk { + private static provider = new XaiChatProvider(); + + static async handleXaiStream( + ctx: { + openai: OpenAI; + systemPrompt: any; + preprocessedContext: any; + maxTokens: unknown | number | undefined; + messages: any; + disableWebhookGeneration: boolean; + model: any; + env: Env; + }, + dataCallback: (data: any) => any, + ) { + if (!ctx.messages?.length) { + return new Response("No messages provided", { status: 400 }); + } + + return this.provider.handleStream( + { + systemPrompt: ctx.systemPrompt, + preprocessedContext: ctx.preprocessedContext, + maxTokens: ctx.maxTokens, + messages: ctx.messages, + model: ctx.model, + env: ctx.env, + disableWebhookGeneration: ctx.disableWebhookGeneration, + }, + dataCallback, + ); } }