creates a common abstraction for providers

This commit is contained in:
geoffsee
2025-05-31 20:21:22 -04:00
committed by Geoff Seemueller
parent 5a7691a9af
commit 87e083682c
9 changed files with 569 additions and 618 deletions

View File

@@ -1,65 +1,16 @@
import { OpenAI } from "openai";
import ChatSdk from "../lib/chat-sdk";
import { Utils } from "../lib/utils";
import {ChatCompletionCreateParamsStreaming} from "openai/resources/chat/completions/completions";
import { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions/completions";
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider";
export class OpenAiChatSdk {
static async handleOpenAiStream(
ctx: {
openai: OpenAI;
systemPrompt: any;
preprocessedContext: any;
maxTokens: unknown | number | undefined;
messages: any;
model: any;
},
dataCallback: (data: any) => any,
) {
const {
openai,
systemPrompt,
maxTokens,
messages,
model,
preprocessedContext,
} = ctx;
if (!messages?.length) {
return new Response("No messages provided", { status: 400 });
}
const assistantPrompt = ChatSdk.buildAssistantPrompt({
maxTokens: maxTokens,
});
const safeMessages = ChatSdk.buildMessageChain(messages, {
systemPrompt: systemPrompt,
model,
assistantPrompt,
toolResults: preprocessedContext,
});
return OpenAiChatSdk.streamOpenAiResponse(
safeMessages,
{
model,
maxTokens: maxTokens as number,
openai: openai,
},
dataCallback,
);
export class OpenAiChatProvider extends BaseChatProvider {
getOpenAIClient(param: CommonProviderParams): OpenAI {
return param.openai as OpenAI;
}
private static async streamOpenAiResponse(
messages: any[],
opts: {
model: string;
maxTokens: number | undefined;
openai: OpenAI;
},
dataCallback: (data: any) => any,
) {
getStreamParams(param: CommonProviderParams, safeMessages: any[]): ChatCompletionCreateParamsStreaming {
const isO1 = () => {
if (opts.model === "o1-preview" || opts.model === "o1-mini") {
if (param.model === "o1-preview" || param.model === "o1-mini") {
return true;
}
};
@@ -71,55 +22,93 @@ export class OpenAiChatSdk {
top_p: 0.98,
presence_penalty: 0.1,
frequency_penalty: 0.3,
max_tokens: opts.maxTokens,
max_tokens: param.maxTokens as number,
};
const getTuningParams = () => {
if (isO1()) {
tuningParams["temperature"] = 1;
tuningParams["max_completion_tokens"] = opts.maxTokens + 10000;
tuningParams["max_completion_tokens"] = (param.maxTokens as number) + 10000;
return tuningParams;
}
return gpt4oTuningParams;
};
let completionRequest: ChatCompletionCreateParamsStreaming = {
model: opts.model,
model: param.model,
stream: true,
messages: messages
messages: safeMessages
};
const isLocal = opts.openai.baseURL.includes("localhost");
const client = this.getOpenAIClient(param);
const isLocal = client.baseURL.includes("localhost");
if(isLocal) {
completionRequest["messages"] = Utils.normalizeWithBlanks(messages)
completionRequest["stream_options"] = {
completionRequest["messages"] = Utils.normalizeWithBlanks(safeMessages);
completionRequest["stream_options"] = {
include_usage: true
}
};
} else {
completionRequest = {...completionRequest, ...getTuningParams()}
completionRequest = {...completionRequest, ...getTuningParams()};
}
const openAIStream = await opts.openai.chat.completions.create(completionRequest);
return completionRequest;
}
for await (const chunk of openAIStream) {
if (isLocal && chunk.usage) {
dataCallback({
type: "chat",
data: {
choices: [
{
delta: { content: "" },
logprobs: null,
finish_reason: "stop",
},
],
},
});
break;
}
dataCallback({ type: "chat", data: chunk });
async processChunk(chunk: any, dataCallback: (data: any) => void): Promise<boolean> {
const isLocal = chunk.usage !== undefined;
if (isLocal && chunk.usage) {
dataCallback({
type: "chat",
data: {
choices: [
{
delta: { content: "" },
logprobs: null,
finish_reason: "stop",
},
],
},
});
return true; // Break the stream
}
dataCallback({ type: "chat", data: chunk });
return false; // Continue the stream
}
}
// Legacy class for backward compatibility
export class OpenAiChatSdk {
private static provider = new OpenAiChatProvider();
static async handleOpenAiStream(
ctx: {
openai: OpenAI;
systemPrompt: any;
preprocessedContext: any;
maxTokens: unknown | number | undefined;
messages: any;
model: any;
},
dataCallback: (data: any) => any,
) {
if (!ctx.messages?.length) {
return new Response("No messages provided", { status: 400 });
}
return this.provider.handleStream(
{
openai: ctx.openai,
systemPrompt: ctx.systemPrompt,
preprocessedContext: ctx.preprocessedContext,
maxTokens: ctx.maxTokens,
messages: ctx.messages,
model: ctx.model,
env: {} as Env, // This is not used in OpenAI provider
},
dataCallback,
);
}
}