creates a common abstraction for providers

This commit is contained in:
geoffsee
2025-05-31 20:21:22 -04:00
committed by Geoff Seemueller
parent 5a7691a9af
commit 87e083682c
9 changed files with 569 additions and 618 deletions

View File

@@ -6,9 +6,49 @@ import {
ModelSnapshotType2,
UnionStringArray,
} from "mobx-state-tree";
import ChatSdk from "../lib/chat-sdk";
import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider";
export class GroqChatProvider extends BaseChatProvider {
getOpenAIClient(param: CommonProviderParams): OpenAI {
return new OpenAI({
baseURL: "https://api.groq.com/openai/v1",
apiKey: param.env.GROQ_API_KEY,
});
}
getStreamParams(param: CommonProviderParams, safeMessages: any[]): any {
const llamaTuningParams = {
temperature: 0.86,
top_p: 0.98,
presence_penalty: 0.1,
frequency_penalty: 0.3,
max_tokens: param.maxTokens as number,
};
return {
model: param.model,
messages: safeMessages,
stream: true,
...llamaTuningParams
};
}
async processChunk(chunk: any, dataCallback: (data: any) => void): Promise<boolean> {
// Check if this is the final chunk
if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") {
dataCallback({ type: "chat", data: chunk });
return true; // Break the stream
}
dataCallback({ type: "chat", data: chunk });
return false; // Continue the stream
}
}
// Legacy class for backward compatibility
export class GroqChatSdk {
private static provider = new GroqChatProvider();
static async handleGroqStream(
param: {
openai: OpenAI;
@@ -27,73 +67,16 @@ export class GroqChatSdk {
},
dataCallback: (data) => void,
) {
const {
preprocessedContext,
messages,
env,
maxTokens,
systemPrompt,
model,
} = param;
const assistantPrompt = ChatSdk.buildAssistantPrompt({
maxTokens: maxTokens,
});
const safeMessages = ChatSdk.buildMessageChain(messages, {
systemPrompt: systemPrompt,
model,
assistantPrompt,
toolResults: preprocessedContext,
});
const openai = new OpenAI({
baseURL: "https://api.groq.com/openai/v1",
apiKey: param.env.GROQ_API_KEY,
});
return GroqChatSdk.streamGroqResponse(
safeMessages,
return this.provider.handleStream(
{
model: param.model,
systemPrompt: param.systemPrompt,
preprocessedContext: param.preprocessedContext,
maxTokens: param.maxTokens,
openai: openai,
messages: param.messages,
model: param.model,
env: param.env,
},
dataCallback,
);
}
private static async streamGroqResponse(
messages: any[],
opts: {
model: string;
maxTokens: number | unknown | undefined;
openai: OpenAI;
},
dataCallback: (data: any) => void,
) {
const tuningParams: Record<string, any> = {};
const llamaTuningParams = {
temperature: 0.86,
top_p: 0.98,
presence_penalty: 0.1,
frequency_penalty: 0.3,
max_tokens: opts.maxTokens,
};
const getLlamaTuningParams = () => {
return llamaTuningParams;
};
const groqStream = await opts.openai.chat.completions.create({
model: opts.model,
messages: messages,
frequency_penalty: 2,
stream: true,
temperature: 0.78,
});
for await (const chunk of groqStream) {
dataCallback({ type: "chat", data: chunk });
}
}
}