import { OpenAI } from "openai"; import ChatSdk from "../lib/chat-sdk"; import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; export class XaiChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { return new OpenAI({ baseURL: "https://api.x.ai/v1", apiKey: param.env.XAI_API_KEY, }); } getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { const isO1 = () => { if (param.model === "o1-preview" || param.model === "o1-mini") { return true; } }; const tuningParams: Record = {}; const gpt4oTuningParams = { temperature: 0.75, }; const getTuningParams = () => { if (isO1()) { tuningParams["temperature"] = 1; tuningParams["max_completion_tokens"] = (param.maxTokens as number) + 10000; return tuningParams; } return gpt4oTuningParams; }; return { model: param.model, messages: safeMessages, stream: true, ...getTuningParams(), }; } async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { // Check if this is the final chunk if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { dataCallback({ type: "chat", data: chunk }); return true; // Break the stream } dataCallback({ type: "chat", data: chunk }); return false; // Continue the stream } } export class XaiChatSdk { private static provider = new XaiChatProvider(); static async handleXaiStream( ctx: { openai: OpenAI; systemPrompt: any; preprocessedContext: any; maxTokens: unknown | number | undefined; messages: any; disableWebhookGeneration: boolean; model: any; env: Env; }, dataCallback: (data: any) => any, ) { if (!ctx.messages?.length) { return new Response("No messages provided", { status: 400 }); } return this.provider.handleStream( { systemPrompt: ctx.systemPrompt, preprocessedContext: ctx.preprocessedContext, maxTokens: ctx.maxTokens, messages: ctx.messages, model: ctx.model, env: ctx.env, disableWebhookGeneration: ctx.disableWebhookGeneration, }, dataCallback, ); } }