import { OpenAI } from "openai"; import { _NotCustomized, ISimpleType, ModelPropertiesDeclarationToProperties, ModelSnapshotType2, UnionStringArray, } from "mobx-state-tree"; import ChatSdk from "../lib/chat-sdk"; import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; export class CerebrasChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { return new OpenAI({ baseURL: "https://api.cerebras.ai/v1", apiKey: param.env.CEREBRAS_API_KEY, }); } getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { const llamaTuningParams = { temperature: 0.86, top_p: 0.98, presence_penalty: 0.1, frequency_penalty: 0.3, max_tokens: param.maxTokens as number, }; return { model: param.model, messages: safeMessages, stream: true, }; } async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { // Check if this is the final chunk if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { dataCallback({ type: "chat", data: chunk }); return true; // Break the stream } dataCallback({ type: "chat", data: chunk }); return false; // Continue the stream } } export class CerebrasSdk { private static provider = new CerebrasChatProvider(); static async handleCerebrasStream( param: { openai: OpenAI; systemPrompt: any; disableWebhookGeneration: boolean; preprocessedContext: ModelSnapshotType2< ModelPropertiesDeclarationToProperties<{ role: ISimpleType>; content: ISimpleType; }>, _NotCustomized >; maxTokens: unknown | number | undefined; messages: any; model: string; env: Env; }, dataCallback: (data) => void, ) { return this.provider.handleStream( { systemPrompt: param.systemPrompt, preprocessedContext: param.preprocessedContext, maxTokens: param.maxTokens, messages: param.messages, model: param.model, env: param.env, disableWebhookGeneration: param.disableWebhookGeneration, }, dataCallback, ); } }