Files
open-gsio/workers/site/providers/xai.ts
2025-06-01 08:12:40 -04:00

89 lines
2.3 KiB
TypeScript

import { OpenAI } from "openai";
import ChatSdk from "../lib/chat-sdk";
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider";
export class XaiChatProvider extends BaseChatProvider {
getOpenAIClient(param: CommonProviderParams): OpenAI {
return new OpenAI({
baseURL: "https://api.x.ai/v1",
apiKey: param.env.XAI_API_KEY,
});
}
getStreamParams(param: CommonProviderParams, safeMessages: any[]): any {
const isO1 = () => {
if (param.model === "o1-preview" || param.model === "o1-mini") {
return true;
}
};
const tuningParams: Record<string, any> = {};
const gpt4oTuningParams = {
temperature: 0.75,
};
const getTuningParams = () => {
if (isO1()) {
tuningParams["temperature"] = 1;
tuningParams["max_completion_tokens"] = (param.maxTokens as number) + 10000;
return tuningParams;
}
return gpt4oTuningParams;
};
return {
model: param.model,
messages: safeMessages,
stream: true,
...getTuningParams(),
};
}
async processChunk(chunk: any, dataCallback: (data: any) => void): Promise<boolean> {
// Check if this is the final chunk
if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") {
dataCallback({ type: "chat", data: chunk });
return true; // Break the stream
}
dataCallback({ type: "chat", data: chunk });
return false; // Continue the stream
}
}
export class XaiChatSdk {
private static provider = new XaiChatProvider();
static async handleXaiStream(
ctx: {
openai: OpenAI;
systemPrompt: any;
preprocessedContext: any;
maxTokens: unknown | number | undefined;
messages: any;
disableWebhookGeneration: boolean;
model: any;
env: Env;
},
dataCallback: (data: any) => any,
) {
if (!ctx.messages?.length) {
return new Response("No messages provided", { status: 400 });
}
return this.provider.handleStream(
{
systemPrompt: ctx.systemPrompt,
preprocessedContext: ctx.preprocessedContext,
maxTokens: ctx.maxTokens,
messages: ctx.messages,
model: ctx.model,
env: ctx.env,
disableWebhookGeneration: ctx.disableWebhookGeneration,
},
dataCallback,
);
}
}