creates a common abstraction for providers

This commit is contained in:
geoffsee
2025-05-31 20:21:22 -04:00
committed by Geoff Seemueller
parent 5a7691a9af
commit 87e083682c
9 changed files with 569 additions and 618 deletions

View File

@@ -10,33 +10,44 @@ import {
} from "mobx-state-tree";
import Message from "../models/Message";
import ChatSdk from "../lib/chat-sdk";
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider";
export class FireworksAiChatSdk {
private static async streamFireworksResponse(
messages: any[],
opts: {
model: string;
maxTokens: number | unknown | undefined;
openai: OpenAI;
},
dataCallback: (data: any) => void,
) {
export class FireworksAiChatProvider extends BaseChatProvider {
getOpenAIClient(param: CommonProviderParams): OpenAI {
return new OpenAI({
apiKey: param.env.FIREWORKS_API_KEY,
baseURL: "https://api.fireworks.ai/inference/v1",
});
}
getStreamParams(param: CommonProviderParams, safeMessages: any[]): any {
let modelPrefix = "accounts/fireworks/models/";
if (opts.model.toLowerCase().includes("yi-")) {
if (param.model.toLowerCase().includes("yi-")) {
modelPrefix = "accounts/yi-01-ai/models/";
}
const fireworksStream = await opts.openai.chat.completions.create({
model: `${modelPrefix}${opts.model}`,
messages: messages,
return {
model: `${modelPrefix}${param.model}`,
messages: safeMessages,
stream: true,
});
for await (const chunk of fireworksStream) {
dataCallback({ type: "chat", data: chunk });
}
};
}
async processChunk(chunk: any, dataCallback: (data: any) => void): Promise<boolean> {
// Check if this is the final chunk
if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") {
dataCallback({ type: "chat", data: chunk });
return true; // Break the stream
}
dataCallback({ type: "chat", data: chunk });
return false; // Continue the stream
}
}
export class FireworksAiChatSdk {
private static provider = new FireworksAiChatProvider();
static async handleFireworksStream(
param: {
openai: OpenAI;
@@ -55,36 +66,14 @@ export class FireworksAiChatSdk {
},
dataCallback: (data) => void,
) {
const {
preprocessedContext,
messages,
env,
maxTokens,
systemPrompt,
model,
} = param;
const assistantPrompt = ChatSdk.buildAssistantPrompt({
maxTokens: maxTokens,
});
const safeMessages = ChatSdk.buildMessageChain(messages, {
systemPrompt: systemPrompt,
model,
assistantPrompt,
toolResults: preprocessedContext,
});
const fireworksOpenAIClient = new OpenAI({
apiKey: param.env.FIREWORKS_API_KEY,
baseURL: "https://api.fireworks.ai/inference/v1",
});
return FireworksAiChatSdk.streamFireworksResponse(
safeMessages,
return this.provider.handleStream(
{
model: param.model,
systemPrompt: param.systemPrompt,
preprocessedContext: param.preprocessedContext,
maxTokens: param.maxTokens,
openai: fireworksOpenAIClient,
messages: param.messages,
model: param.model,
env: param.env,
},
dataCallback,
);