creates a common abstraction for providers

This commit is contained in:
geoffsee
2025-05-31 20:21:22 -04:00
committed by Geoff Seemueller
parent 5a7691a9af
commit 87e083682c
9 changed files with 569 additions and 618 deletions

View File

@@ -1,86 +1,18 @@
import { OpenAI } from "openai";
import ChatSdk from "../lib/chat-sdk";
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider";
export class XaiChatSdk {
static async handleXaiStream(
ctx: {
openai: OpenAI;
systemPrompt: any;
preprocessedContext: any;
maxTokens: unknown | number | undefined;
messages: any;
disableWebhookGeneration: boolean;
model: any;
env: Env;
},
dataCallback: (data: any) => any,
) {
const {
openai,
systemPrompt,
maxTokens,
messages,
env,
model,
preprocessedContext,
} = ctx;
if (!messages?.length) {
return new Response("No messages provided", { status: 400 });
}
const getMaxTokens = async (mt) => {
if (mt) {
return await ChatSdk.calculateMaxTokens(
JSON.parse(JSON.stringify(messages)),
{
env,
maxTokens: mt,
},
);
} else {
return undefined;
}
};
const assistantPrompt = ChatSdk.buildAssistantPrompt({
maxTokens: maxTokens,
});
const safeMessages = ChatSdk.buildMessageChain(messages, {
systemPrompt: systemPrompt,
model,
assistantPrompt,
toolResults: preprocessedContext,
});
const xAiClient = new OpenAI({
export class XaiChatProvider extends BaseChatProvider {
getOpenAIClient(param: CommonProviderParams): OpenAI {
return new OpenAI({
baseURL: "https://api.x.ai/v1",
apiKey: env.XAI_API_KEY,
apiKey: param.env.XAI_API_KEY,
});
return XaiChatSdk.streamOpenAiResponse(
safeMessages,
{
model,
maxTokens: maxTokens as number,
openai: xAiClient,
},
dataCallback,
);
}
private static async streamOpenAiResponse(
messages: any[],
opts: {
model: string;
maxTokens: number | undefined;
openai: OpenAI;
},
dataCallback: (data: any) => any,
) {
getStreamParams(param: CommonProviderParams, safeMessages: any[]): any {
const isO1 = () => {
if (opts.model === "o1-preview" || opts.model === "o1-mini") {
if (param.model === "o1-preview" || param.model === "o1-mini") {
return true;
}
};
@@ -94,21 +26,63 @@ export class XaiChatSdk {
const getTuningParams = () => {
if (isO1()) {
tuningParams["temperature"] = 1;
tuningParams["max_completion_tokens"] = opts.maxTokens + 10000;
tuningParams["max_completion_tokens"] = (param.maxTokens as number) + 10000;
return tuningParams;
}
return gpt4oTuningParams;
};
const xAIStream = await opts.openai.chat.completions.create({
model: opts.model,
messages: messages,
return {
model: param.model,
messages: safeMessages,
stream: true,
...getTuningParams(),
});
};
}
for await (const chunk of xAIStream) {
async processChunk(chunk: any, dataCallback: (data: any) => void): Promise<boolean> {
// Check if this is the final chunk
if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") {
dataCallback({ type: "chat", data: chunk });
return true; // Break the stream
}
dataCallback({ type: "chat", data: chunk });
return false; // Continue the stream
}
}
export class XaiChatSdk {
private static provider = new XaiChatProvider();
static async handleXaiStream(
ctx: {
openai: OpenAI;
systemPrompt: any;
preprocessedContext: any;
maxTokens: unknown | number | undefined;
messages: any;
disableWebhookGeneration: boolean;
model: any;
env: Env;
},
dataCallback: (data: any) => any,
) {
if (!ctx.messages?.length) {
return new Response("No messages provided", { status: 400 });
}
return this.provider.handleStream(
{
systemPrompt: ctx.systemPrompt,
preprocessedContext: ctx.preprocessedContext,
maxTokens: ctx.maxTokens,
messages: ctx.messages,
model: ctx.model,
env: ctx.env,
disableWebhookGeneration: ctx.disableWebhookGeneration,
},
dataCallback,
);
}
}