mirror of
https://github.com/geoffsee/open-gsio.git
synced 2025-09-08 22:56:46 +00:00

Update README deployment steps and add deploy:secrets script to package.json update local inference script and README update lockfile reconfigure package scripts for development update test execution pass server tests Update README with revised Bun commands and workspace details remove pnpm package manager designator create bun server
142 lines
3.9 KiB
TypeScript
142 lines
3.9 KiB
TypeScript
import {OpenAI} from "openai";
|
|
import {BaseChatProvider, CommonProviderParams} from "./chat-stream-provider.ts";
|
|
|
|
export class CloudflareAiChatProvider extends BaseChatProvider {
|
|
getOpenAIClient(param: CommonProviderParams): OpenAI {
|
|
const cfAiURL = `https://api.cloudflare.com/client/v4/accounts/${param.env.CLOUDFLARE_ACCOUNT_ID}/ai/v1`;
|
|
|
|
return new OpenAI({
|
|
apiKey: param.env.CLOUDFLARE_API_KEY,
|
|
baseURL: cfAiURL,
|
|
});
|
|
}
|
|
|
|
getStreamParams(param: CommonProviderParams, safeMessages: any[]): any {
|
|
const generationParams: Record<string, any> = {
|
|
model: this.getModelWithPrefix(param.model),
|
|
messages: safeMessages,
|
|
stream: true,
|
|
};
|
|
|
|
// Set max_tokens based on model
|
|
if (this.getModelPrefix(param.model) === "@cf/meta") {
|
|
generationParams["max_tokens"] = 4096;
|
|
}
|
|
|
|
if (this.getModelPrefix(param.model) === "@hf/mistral") {
|
|
generationParams["max_tokens"] = 4096;
|
|
}
|
|
|
|
if (param.model.toLowerCase().includes("hermes-2-pro-mistral-7b")) {
|
|
generationParams["max_tokens"] = 1000;
|
|
}
|
|
|
|
if (param.model.toLowerCase().includes("openhermes-2.5-mistral-7b-awq")) {
|
|
generationParams["max_tokens"] = 1000;
|
|
}
|
|
|
|
if (param.model.toLowerCase().includes("deepseek-coder-6.7b-instruct-awq")) {
|
|
generationParams["max_tokens"] = 590;
|
|
}
|
|
|
|
if (param.model.toLowerCase().includes("deepseek-math-7b-instruct")) {
|
|
generationParams["max_tokens"] = 512;
|
|
}
|
|
|
|
if (param.model.toLowerCase().includes("neural-chat-7b-v3-1-awq")) {
|
|
generationParams["max_tokens"] = 590;
|
|
}
|
|
|
|
if (param.model.toLowerCase().includes("openchat-3.5-0106")) {
|
|
generationParams["max_tokens"] = 2000;
|
|
}
|
|
|
|
return generationParams;
|
|
}
|
|
|
|
private getModelPrefix(model: string): string {
|
|
let modelPrefix = `@cf/meta`;
|
|
|
|
if (model.toLowerCase().includes("llama")) {
|
|
modelPrefix = `@cf/meta`;
|
|
}
|
|
|
|
if (model.toLowerCase().includes("hermes-2-pro-mistral-7b")) {
|
|
modelPrefix = `@hf/nousresearch`;
|
|
}
|
|
|
|
if (model.toLowerCase().includes("mistral-7b-instruct")) {
|
|
modelPrefix = `@hf/mistral`;
|
|
}
|
|
|
|
if (model.toLowerCase().includes("gemma")) {
|
|
modelPrefix = `@cf/google`;
|
|
}
|
|
|
|
if (model.toLowerCase().includes("deepseek")) {
|
|
modelPrefix = `@cf/deepseek-ai`;
|
|
}
|
|
|
|
if (model.toLowerCase().includes("openchat-3.5-0106")) {
|
|
modelPrefix = `@cf/openchat`;
|
|
}
|
|
|
|
const isNueralChat = model
|
|
.toLowerCase()
|
|
.includes("neural-chat-7b-v3-1-awq");
|
|
if (
|
|
isNueralChat ||
|
|
model.toLowerCase().includes("openhermes-2.5-mistral-7b-awq") ||
|
|
model.toLowerCase().includes("zephyr-7b-beta-awq") ||
|
|
model.toLowerCase().includes("deepseek-coder-6.7b-instruct-awq")
|
|
) {
|
|
modelPrefix = `@hf/thebloke`;
|
|
}
|
|
|
|
return modelPrefix;
|
|
}
|
|
|
|
private getModelWithPrefix(model: string): string {
|
|
return `${this.getModelPrefix(model)}/${model}`;
|
|
}
|
|
|
|
async processChunk(chunk: any, dataCallback: (data: any) => void): Promise<boolean> {
|
|
if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") {
|
|
dataCallback({ type: "chat", data: chunk });
|
|
return true;
|
|
}
|
|
|
|
dataCallback({ type: "chat", data: chunk });
|
|
return false;
|
|
}
|
|
}
|
|
|
|
export class CloudflareAISdk {
|
|
private static provider = new CloudflareAiChatProvider();
|
|
|
|
static async handleCloudflareAIStream(
|
|
param: {
|
|
openai: OpenAI;
|
|
systemPrompt: any;
|
|
preprocessedContext: any;
|
|
maxTokens: unknown | number | undefined;
|
|
messages: any;
|
|
model: string;
|
|
env: Env;
|
|
},
|
|
dataCallback: (data) => void,
|
|
) {
|
|
return this.provider.handleStream(
|
|
{
|
|
systemPrompt: param.systemPrompt,
|
|
preprocessedContext: param.preprocessedContext,
|
|
maxTokens: param.maxTokens,
|
|
messages: param.messages,
|
|
model: param.model,
|
|
env: param.env,
|
|
},
|
|
dataCallback,
|
|
);
|
|
}
|
|
}
|