diff --git a/packages/client/src/components/chat/input-menu/InputMenu.tsx b/packages/client/src/components/chat/input-menu/InputMenu.tsx index 535b31a..f87b8b8 100644 --- a/packages/client/src/components/chat/input-menu/InputMenu.tsx +++ b/packages/client/src/components/chat/input-menu/InputMenu.tsx @@ -1,215 +1,197 @@ -import React, { useCallback, useEffect, useRef, useState } from "react"; +import React, {useCallback, useEffect, useRef, useState} from "react"; import { - Box, - Button, - Divider, - Flex, - IconButton, - Menu, - MenuButton, - MenuItem, - MenuList, - Text, - useDisclosure, - useOutsideClick, + Box, + Button, + Divider, + Flex, + IconButton, + Menu, + MenuButton, + MenuItem, + MenuList, + Text, + useDisclosure, + useOutsideClick, } from "@chakra-ui/react"; -import { observer } from "mobx-react-lite"; -import { ChevronDown, Copy, RefreshCcw, Settings } from "lucide-react"; -import ClientChatStore from "../../../stores/ClientChatStore"; +import {observer} from "mobx-react-lite"; +import {ChevronDown, Copy, RefreshCcw, Settings} from "lucide-react"; import clientChatStore from "../../../stores/ClientChatStore"; import FlyoutSubMenu from "./FlyoutSubMenu"; -import { useIsMobile } from "../../contexts/MobileContext"; -import { useIsMobile as useIsMobileUserAgent } from "../../../hooks/_IsMobileHook"; -import { getModelFamily, SUPPORTED_MODELS } from "../lib/SupportedModels"; -import { formatConversationMarkdown } from "../lib/exportConversationAsMarkdown"; +import {useIsMobile} from "../../contexts/MobileContext"; +import {useIsMobile as useIsMobileUserAgent} from "../../../hooks/_IsMobileHook"; +import {formatConversationMarkdown} from "../lib/exportConversationAsMarkdown"; export const MsM_commonButtonStyles = { - bg: "transparent", - color: "text.primary", - borderRadius: "full", - padding: 2, - border: "none", - _hover: { bg: "rgba(255, 255, 255, 0.2)" }, - _active: { bg: "rgba(255, 255, 255, 0.3)" }, - _focus: { boxShadow: "none" }, + bg: "transparent", + color: "text.primary", + borderRadius: "full", + padding: 2, + border: "none", + _hover: {bg: "rgba(255, 255, 255, 0.2)"}, + _active: {bg: "rgba(255, 255, 255, 0.3)"}, + _focus: {boxShadow: "none"}, }; const InputMenu: React.FC<{ isDisabled?: boolean }> = observer( - ({ isDisabled }) => { - const isMobile = useIsMobile(); - const isMobileUserAgent = useIsMobileUserAgent(); - const { - isOpen, - onOpen, - onClose, - onToggle, - getDisclosureProps, - getButtonProps, - } = useDisclosure(); + ({isDisabled}) => { + const isMobile = useIsMobile(); + const isMobileUserAgent = useIsMobileUserAgent(); + const { + isOpen, + onOpen, + onClose, + onToggle, + getDisclosureProps, + getButtonProps, + } = useDisclosure(); - const [controlledOpen, setControlledOpen] = useState(false); + const [controlledOpen, setControlledOpen] = useState(false); + const [supportedModels, setSupportedModels] = useState([]); - useEffect(() => { - setControlledOpen(isOpen); - }, [isOpen]); + useEffect(() => { + setControlledOpen(isOpen); + }, [isOpen]); + + useEffect(() => { + fetch("/api/models").then(response => response.json()).then((models) => { + setSupportedModels(models); + }).catch((err) => { + console.error("Could not fetch models: ", err); + }); + }, []); - const getSupportedModels = async () => { - // Check if fetch is available (browser environment) - if (typeof fetch !== 'undefined') { - try { - return await (await fetch("/api/models")).json(); - } catch (error) { - console.error("Error fetching models:", error); - return []; - } - } else { - // In test environment or where fetch is not available - console.log("Fetch not available, using default models"); - return []; + const handleClose = useCallback(() => { + onClose(); + }, [isOpen]); + + const handleCopyConversation = useCallback(() => { + navigator.clipboard + .writeText(formatConversationMarkdown(clientChatStore.items)) + .then(() => { + window.alert( + "Conversation copied to clipboard. \n\nPaste it somewhere safe!", + ); + onClose(); + }) + .catch((err) => { + console.error("Could not copy text to clipboard: ", err); + window.alert("Failed to copy conversation. Please try again."); + }); + }, [onClose]); + + async function selectModelFn({name, value}) { + clientChatStore.setModel(value); } - } - useEffect(() => { - getSupportedModels().then((supportedModels) => { - // Check if setSupportedModels method exists before calling it - if (clientChatStore.setSupportedModels) { - clientChatStore.setSupportedModels(supportedModels); - } else { - console.log("setSupportedModels method not available in this environment"); - } - }); - }, []); + function isSelectedModelFn({name, value}) { + return clientChatStore.model === value; + } + const menuRef = useRef(); + const [menuState, setMenuState] = useState(); - const handleClose = useCallback(() => { - onClose(); - }, [isOpen]); - - const handleCopyConversation = useCallback(() => { - navigator.clipboard - .writeText(formatConversationMarkdown(clientChatStore.items)) - .then(() => { - window.alert( - "Conversation copied to clipboard. \n\nPaste it somewhere safe!", - ); - onClose(); - }) - .catch((err) => { - console.error("Could not copy text to clipboard: ", err); - window.alert("Failed to copy conversation. Please try again."); + useOutsideClick({ + enabled: !isMobile && isOpen, + ref: menuRef, + handler: () => { + handleClose(); + }, }); - }, [onClose]); - async function selectModelFn({ name, value }) { - clientChatStore.setModel(value); - } - - function isSelectedModelFn({ name, value }) { - return clientChatStore.model === value; - } - - const menuRef = useRef(); - const [menuState, setMenuState] = useState(); - - useOutsideClick({ - enabled: !isMobile && isOpen, - ref: menuRef, - handler: () => { - handleClose(); - }, - }); - - return ( - - {isMobile ? ( - } - isDisabled={isDisabled} - aria-label="Settings" - _hover={{ bg: "rgba(255, 255, 255, 0.2)" }} - _focus={{ boxShadow: "none" }} - {...MsM_commonButtonStyles} - /> - ) : ( - } - isDisabled={isDisabled} - variant="ghost" - display="flex" - justifyContent="space-between" - alignItems="center" - minW="auto" - {...MsM_commonButtonStyles} - > - - {clientChatStore.model} - - - )} - - ({ name: m, value: m }))} - onClose={onClose} - parentIsOpen={isOpen} - setMenuState={setMenuState} - handleSelect={selectModelFn} - isSelected={isSelectedModelFn} - /> - - {/*Export conversation button*/} - - - - Export - - - {/*New conversation button*/} - { - clientChatStore.setActiveConversation("conversation:new"); - onClose(); - }} - _hover={{ bg: "rgba(0, 0, 0, 0.05)" }} - _focus={{ bg: "rgba(0, 0, 0, 0.1)" }} - > - - - New - - - - - ); - }, + return ( + + {isMobile ? ( + } + isDisabled={isDisabled} + aria-label="Settings" + _hover={{bg: "rgba(255, 255, 255, 0.2)"}} + _focus={{boxShadow: "none"}} + {...MsM_commonButtonStyles} + /> + ) : ( + } + isDisabled={isDisabled} + variant="ghost" + display="flex" + justifyContent="space-between" + alignItems="center" + minW="auto" + {...MsM_commonButtonStyles} + > + + {clientChatStore.model} + + + )} + + ({ + name: modelData.id.split('/').pop() || modelData.id, + value: modelData.id + }))} + onClose={onClose} + parentIsOpen={isOpen} + setMenuState={setMenuState} + handleSelect={selectModelFn} + isSelected={isSelectedModelFn} + /> + + {/*Export conversation button*/} + + + + Export + + + {/*New conversation button*/} + { + clientChatStore.setActiveConversation("conversation:new"); + onClose(); + }} + _hover={{bg: "rgba(0, 0, 0, 0.05)"}} + _focus={{bg: "rgba(0, 0, 0, 0.1)"}} + > + + + New + + + + + ); + }, ); export default InputMenu; diff --git a/packages/server/ServerCoordinator.ts b/packages/server/ServerCoordinator.ts index 7874aa5..cd6e770 100644 --- a/packages/server/ServerCoordinator.ts +++ b/packages/server/ServerCoordinator.ts @@ -1,6 +1,9 @@ import { DurableObject } from "cloudflare:workers"; +import {ProviderRepository} from "./providers/_ProviderRepository"; export default class ServerCoordinator extends DurableObject { + env; + state; constructor(state, env) { super(state, env); this.state = state; @@ -8,20 +11,24 @@ export default class ServerCoordinator extends DurableObject { } // Public method to calculate dynamic max tokens - async dynamicMaxTokens(input, maxOuputTokens) { - return 2000; - // const baseTokenLimit = 1024; - // - // - // const { encode } = await import("gpt-tokenizer/esm/model/gpt-4o"); - // - // const inputTokens = Array.isArray(input) - // ? encode(input.map(i => i.content).join(' ')) - // : encode(input); - // - // const scalingFactor = inputTokens.length > 300 ? 1.5 : 1; - // - // return Math.min(baseTokenLimit + Math.floor(inputTokens.length * scalingFactor^2), maxOuputTokens); + async dynamicMaxTokens(model, input, maxOuputTokens) { + + const modelMeta = ProviderRepository.getModelMeta(model, this.env); + + // The token‑limit information is stored in three different keys: + // max_completion_tokens + // context_window + // context_length + + if('max_completion_tokens' in modelMeta) { + return modelMeta.max_completion_tokens; + } else if('context_window' in modelMeta) { + return modelMeta.context_window; + } else if('context_length' in modelMeta) { + return modelMeta.context_length; + } else { + return 8096; + } } // Public method to retrieve conversation history diff --git a/packages/server/lib/chat-sdk.ts b/packages/server/lib/chat-sdk.ts index 00a8b1b..f057327 100644 --- a/packages/server/lib/chat-sdk.ts +++ b/packages/server/lib/chat-sdk.ts @@ -1,8 +1,8 @@ import {OpenAI} from "openai"; import Message from "../models/Message.ts"; import {AssistantSdk} from "./assistant-sdk.ts"; -import {getModelFamily} from "@open-gsio/ai/supported-models.ts"; import type {Instance} from "mobx-state-tree"; +import {ProviderRepository} from "../providers/_ProviderRepository"; export class ChatSdk { static async preprocess({ @@ -95,9 +95,10 @@ export class ChatSdk { assistantPrompt: string; toolResults: Instance; model: any; + env: Env; }, ) { - const modelFamily = getModelFamily(opts.model); + const modelFamily = ProviderRepository.getModelFamily(opts.model, opts.env) const messagesToSend = []; diff --git a/packages/server/providers/_ProviderRepository.ts b/packages/server/providers/_ProviderRepository.ts new file mode 100644 index 0000000..ade4f51 --- /dev/null +++ b/packages/server/providers/_ProviderRepository.ts @@ -0,0 +1,76 @@ + +export class ProviderRepository { + #providers: {name: string, key: string, endpoint: string}[] = []; + constructor(env: Record) { + this.setProviders(env); + } + + static OPENAI_COMPAT_ENDPOINTS = { + xai: 'https://api.x.ai/v1', + groq: 'https://api.groq.com/openai/v1', + google: 'https://generativelanguage.googleapis.com/v1beta/openai', + fireworks: 'https://api.fireworks.ai/inference/v1', + cohere: 'https://api.cohere.ai/compatibility/v1', + cloudflare: 'https://api.cloudflare.com/client/v4/accounts/{CLOUDFLARE_ACCOUNT_ID}/ai/v1', + anthropic: 'https://api.anthropic.com/v1/', + openai: 'https://api.openai.com/v1/', + cerebras: 'https://api.cerebras.com/v1/', + ollama: "http://localhost:11434", + mlx: "http://localhost:10240", + } + + static async getModelFamily(model, env: Env) { + const allModels = await env.KV_STORAGE.get("supportedModels"); + const models = JSON.parse(allModels); + const modelData = models.filter(m => m.id === model) + console.log({modelData}) + return modelData[0].provider; + } + + static async getModelMeta(meta, env) { + const allModels = await env.KV_STORAGE.get("supportedModels"); + const models = JSON.parse(allModels); + return models.filter(m => m.id === meta.model).pop() + } + + getProviders(): {name: string, key: string, endpoint: string}[] { + return this.#providers; + } + + setProviders(env: Record) { + let envKeys = Object.keys(env); + for (let i = 0; i < envKeys.length; i++) { + if (envKeys[i].endsWith('KEY')) { + const detectedProvider = envKeys[i].split('_')[0].toLowerCase(); + switch (detectedProvider) { + case 'anthropic': + this.#providers.push({ + name: 'anthropic', + key: env.ANTHROPIC_API_KEY, + endpoint: OPENAI_COMPAT_ENDPOINTS['anthropic'] + }); + break; + case 'gemini': + this.#providers.push({ + name: 'google', + key: env.GEMINI_API_KEY, + endpoint: OPENAI_COMPAT_ENDPOINTS['google'] + }); + break; + case 'cloudflare': + this.#providers.push({ + name: 'cloudflare', + key: env.CLOUDFLARE_API_KEY, + endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider].replace("{CLOUDFLARE_ACCOUNT_ID}", env.CLOUDFLARE_ACCOUNT_ID) + }) + default: + this.#providers.push({ + name: detectedProvider, + key: env[envKeys[i]], + endpoint: OPENAI_COMPAT_ENDPOINTS[detectedProvider] + }); + } + } + } + } +} \ No newline at end of file diff --git a/packages/server/providers/cerebras.ts b/packages/server/providers/cerebras.ts index a77e980..f6a6350 100644 --- a/packages/server/providers/cerebras.ts +++ b/packages/server/providers/cerebras.ts @@ -1,10 +1,11 @@ import {OpenAI} from "openai"; import {BaseChatProvider, CommonProviderParams} from "./chat-stream-provider.ts"; +import {ProviderRepository} from "./_ProviderRepository"; export class CerebrasChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { return new OpenAI({ - baseURL: "https://api.cerebras.ai/v1", + baseURL: ProviderRepository.OPENAI_COMPAT_ENDPOINTS.cerebras, apiKey: param.env.CEREBRAS_API_KEY, }); } diff --git a/packages/server/providers/chat-stream-provider.ts b/packages/server/providers/chat-stream-provider.ts index f03da17..ca1a131 100644 --- a/packages/server/providers/chat-stream-provider.ts +++ b/packages/server/providers/chat-stream-provider.ts @@ -35,6 +35,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider { model: param.model, assistantPrompt, toolResults: param.preprocessedContext, + env: param.env }); const client = this.getOpenAIClient(param); diff --git a/packages/server/providers/claude.ts b/packages/server/providers/claude.ts index 96487d4..9546fed 100644 --- a/packages/server/providers/claude.ts +++ b/packages/server/providers/claude.ts @@ -69,6 +69,7 @@ export class ClaudeChatProvider extends BaseChatProvider { model: param.model, assistantPrompt, toolResults: param.preprocessedContext, + env: param.env, }); const streamParams = this.getStreamParams(param, safeMessages); diff --git a/packages/server/providers/cloudflareAi.ts b/packages/server/providers/cloudflareAi.ts index d2c2c96..031d30a 100644 --- a/packages/server/providers/cloudflareAi.ts +++ b/packages/server/providers/cloudflareAi.ts @@ -1,13 +1,12 @@ import {OpenAI} from "openai"; import {BaseChatProvider, CommonProviderParams} from "./chat-stream-provider.ts"; +import {ProviderRepository} from "./_ProviderRepository"; export class CloudflareAiChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { - const cfAiURL = `https://api.cloudflare.com/client/v4/accounts/${param.env.CLOUDFLARE_ACCOUNT_ID}/ai/v1`; - return new OpenAI({ apiKey: param.env.CLOUDFLARE_API_KEY, - baseURL: cfAiURL, + baseURL: ProviderRepository.OPENAI_COMPAT_ENDPOINTS.cloudflare.replace("{CLOUDFLARE_ACCOUNT_ID}", param.env.CLOUDFLARE_ACCOUNT_ID), }); } diff --git a/packages/server/providers/fireworks.ts b/packages/server/providers/fireworks.ts index f250838..c2624e8 100644 --- a/packages/server/providers/fireworks.ts +++ b/packages/server/providers/fireworks.ts @@ -11,12 +11,13 @@ import { import Message from "../models/Message.ts"; import ChatSdk from "../lib/chat-sdk.ts"; import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts"; +import {ProviderRepository} from "./_ProviderRepository"; export class FireworksAiChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { return new OpenAI({ apiKey: param.env.FIREWORKS_API_KEY, - baseURL: "https://api.fireworks.ai/inference/v1", + baseURL: ProviderRepository.OPENAI_COMPAT_ENDPOINTS.fireworks, }); } diff --git a/packages/server/providers/google.ts b/packages/server/providers/google.ts index 7d73288..6cfead9 100644 --- a/packages/server/providers/google.ts +++ b/packages/server/providers/google.ts @@ -2,11 +2,12 @@ import { OpenAI } from "openai"; import ChatSdk from "../lib/chat-sdk.ts"; import { StreamParams } from "../services/ChatService.ts"; import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts"; +import {ProviderRepository} from "./_ProviderRepository"; export class GoogleChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { return new OpenAI({ - baseURL: "https://generativelanguage.googleapis.com/v1beta/openai", + baseURL: ProviderRepository.OPENAI_COMPAT_ENDPOINTS.google, apiKey: param.env.GEMINI_API_KEY, }); } diff --git a/packages/server/providers/groq.ts b/packages/server/providers/groq.ts index 83c3c6a..7e88164 100644 --- a/packages/server/providers/groq.ts +++ b/packages/server/providers/groq.ts @@ -7,11 +7,12 @@ import { UnionStringArray, } from "mobx-state-tree"; import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts"; +import {ProviderRepository} from "./_ProviderRepository"; export class GroqChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { return new OpenAI({ - baseURL: "https://api.groq.com/openai/v1", + baseURL: ProviderRepository.OPENAI_COMPAT_ENDPOINTS.groq, apiKey: param.env.GROQ_API_KEY, }); } diff --git a/packages/server/providers/mlx-omni.ts b/packages/server/providers/mlx-omni.ts new file mode 100644 index 0000000..f3abdc8 --- /dev/null +++ b/packages/server/providers/mlx-omni.ts @@ -0,0 +1,73 @@ +import { OpenAI } from "openai"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts"; + +export class MlxOmniChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + return new OpenAI({ + baseURL: param.env.MLX_API_ENDPOINT ?? "http://localhost:10240", + apiKey: param.env.MLX_API_KEY, + }); + } + + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { + const tuningParams = { + temperature: 0.75, + }; + + const getTuningParams = () => { + return tuningParams; + }; + + return { + model: param.model, + messages: safeMessages, + stream: true, + ...getTuningParams(), + }; + } + + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { + dataCallback({ type: "chat", data: chunk }); + return true; + } + + dataCallback({ type: "chat", data: chunk }); + return false; + } +} + +export class MlxOmniChatSdk { + private static provider = new MlxOmniChatProvider(); + + static async handleMlxOmniStream( + ctx: { + openai: OpenAI; + systemPrompt: any; + preprocessedContext: any; + maxTokens: unknown | number | undefined; + messages: any; + disableWebhookGeneration: boolean; + model: any; + env: Env; + }, + dataCallback: (data: any) => any, + ) { + if (!ctx.messages?.length) { + return new Response("No messages provided", { status: 400 }); + } + + return this.provider.handleStream( + { + systemPrompt: ctx.systemPrompt, + preprocessedContext: ctx.preprocessedContext, + maxTokens: ctx.maxTokens, + messages: ctx.messages, + model: ctx.model, + env: ctx.env, + disableWebhookGeneration: ctx.disableWebhookGeneration, + }, + dataCallback, + ); + } +} diff --git a/packages/server/providers/ollama.ts b/packages/server/providers/ollama.ts new file mode 100644 index 0000000..061c1b8 --- /dev/null +++ b/packages/server/providers/ollama.ts @@ -0,0 +1,73 @@ +import { OpenAI } from "openai"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts"; + +export class OllamaChatProvider extends BaseChatProvider { + getOpenAIClient(param: CommonProviderParams): OpenAI { + return new OpenAI({ + baseURL: param.env.OLLAMA_API_ENDPOINT ?? , + apiKey: param.env.OLLAMA_API_KEY, + }); + } + + getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { + const tuningParams = { + temperature: 0.75, + }; + + const getTuningParams = () => { + return tuningParams; + }; + + return { + model: param.model, + messages: safeMessages, + stream: true, + ...getTuningParams(), + }; + } + + async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { + if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { + dataCallback({ type: "chat", data: chunk }); + return true; + } + + dataCallback({ type: "chat", data: chunk }); + return false; + } +} + +export class OllamaChatSdk { + private static provider = new OllamaChatProvider(); + + static async handleOllamaStream( + ctx: { + openai: OpenAI; + systemPrompt: any; + preprocessedContext: any; + maxTokens: unknown | number | undefined; + messages: any; + disableWebhookGeneration: boolean; + model: any; + env: Env; + }, + dataCallback: (data: any) => any, + ) { + if (!ctx.messages?.length) { + return new Response("No messages provided", { status: 400 }); + } + + return this.provider.handleStream( + { + systemPrompt: ctx.systemPrompt, + preprocessedContext: ctx.preprocessedContext, + maxTokens: ctx.maxTokens, + messages: ctx.messages, + model: ctx.model, + env: ctx.env, + disableWebhookGeneration: ctx.disableWebhookGeneration, + }, + dataCallback, + ); + } +} diff --git a/packages/server/services/ChatService.ts b/packages/server/services/ChatService.ts index a9eed4c..9c0221c 100644 --- a/packages/server/services/ChatService.ts +++ b/packages/server/services/ChatService.ts @@ -3,7 +3,6 @@ import OpenAI from 'openai'; import ChatSdk from '../lib/chat-sdk.ts'; import Message from "../models/Message.ts"; import O1Message from "../models/O1Message.ts"; -import {getModelFamily, ModelFamily, SUPPORTED_MODELS} from "@open-gsio/ai/supported-models"; import {OpenAiChatSdk} from "../providers/openai.ts"; import {GroqChatSdk} from "../providers/groq.ts"; import {ClaudeChatSdk} from "../providers/claude.ts"; @@ -13,6 +12,9 @@ import {GoogleChatSdk} from "../providers/google.ts"; import {XaiChatSdk} from "../providers/xai.ts"; import {CerebrasSdk} from "../providers/cerebras.ts"; import {CloudflareAISdk} from "../providers/cloudflareAi.ts"; +import {OllamaChatSdk} from "../providers/ollama"; +import {MlxOmniChatSdk} from "../providers/mlx-omni"; +import {ProviderRepository} from "../providers/_ProviderRepository"; export interface StreamParams { env: Env; @@ -110,24 +112,97 @@ const ChatService = types cerebras: (params: StreamParams, dataHandler: Function) => CerebrasSdk.handleCerebrasStream(params, dataHandler), cloudflareAI: (params: StreamParams, dataHandler: Function) => - CloudflareAISdk.handleCloudflareAIStream(params, dataHandler) + CloudflareAISdk.handleCloudflareAIStream(params, dataHandler), + ollama: (params: StreamParams, dataHandler: Function) => + OllamaChatSdk.handleOllamaStream(params, dataHandler), + mlx: (params: StreamParams, dataHandler: Function) => + MlxOmniChatSdk.handleMlxOmniStream(params, dataHandler), }; return { - async getSupportedModels() { - const isLocal = self.env.OPENAI_API_ENDPOINT && self.env.OPENAI_API_ENDPOINT.includes("localhost"); - console.log({isLocal}) - if(isLocal) { - console.log("getting local models") - const openaiClient = new OpenAI({baseURL: self.env.OPENAI_API_ENDPOINT}) - const models = await openaiClient.models.list(); - return Response.json( - models.data - .filter(model => model.id.includes("mlx")) - .map(model => model.id)); + getSupportedModels: flow(function* (): + Generator, Response, unknown> { + + // ----- Helpers ---------------------------------------------------------- + const logger = console; + + // ----- 1. Try cached value --------------------------------------------- + try { + const cached = yield self.env.KV_STORAGE.get('supportedModels'); + if (cached) { + const parsed = JSON.parse(cached as string); + if (Array.isArray(parsed)) { + logger.info('Cache hit – returning supportedModels from KV'); + return new Response(JSON.stringify(parsed), { status: 200 }); + } + logger.warn('Cache entry malformed – refreshing'); + } + } catch (err) { + logger.error('Error reading/parsing supportedModels cache', err); } - return Response.json(SUPPORTED_MODELS); - }, + + // ----- 2. Build fresh list --------------------------------------------- + const providerRepo = new ProviderRepository(self.env); + const providers = providerRepo.getProviders(); + const providerModels = new Map(); + const modelMeta = new Map(); + + for (const provider of providers) { + if (!provider.key) continue; + + logger.info(`Fetching models for provider «${provider.name}»`); + + const openai = new OpenAI({ apiKey: provider.key, baseURL: provider.endpoint }); + + // 2‑a. List models + try { + const listResp = yield openai.models.list(); // <‑‑ async + const models = ('data' in listResp) ? listResp.data : listResp; + providerModels.set(provider.name, models); + + // 2‑b. Retrieve metadata + for (const mdl of models) { + try { + const meta = yield openai.models.retrieve(mdl.id); // <‑‑ async + modelMeta.set(mdl.id, { ...mdl, ...meta }); + } catch (err) { + // logger.error(`Metadata fetch failed for ${mdl.id}`, err); + modelMeta.set(mdl.id, {provider: provider.name, mdl}); + } + } + } catch (err) { + logger.error(`Model list failed for provider «${provider.name}»`, err); + } + } + + // ----- 3. Merge results ------------------------------------------------- + const resultMap = new Map(); + for (const [provName, models] of providerModels) { + for (const mdl of models) { + resultMap.set(mdl.id, { + id: mdl.id, + provider: provName, + ...(modelMeta.get(mdl.id) ?? mdl), + }); + } + } + const resultArr = Array.from(resultMap.values()); + + // ----- 4. Cache fresh list --------------------------------------------- + try { + yield self.env.KV_STORAGE.put( + 'supportedModels', + JSON.stringify(resultArr), + { expirationTtl: 60 * 60 * 24 }, // 24 h + ); + logger.info('supportedModels cache refreshed'); + } catch (err) { + logger.error('KV put failed for supportedModels', err); + } + + // ----- 5. Return -------------------------------------------------------- + return new Response(JSON.stringify(resultArr), { status: 200 }); + }), setActiveStream(streamId: string, stream: any) { const validStream = { name: stream?.name || "Unnamed Stream", @@ -179,13 +254,13 @@ const ChatService = types const {streamConfig, streamParams, controller, encoder, streamId} = params; const useModelFamily = () => { - return !self.env.OPENAI_API_ENDPOINT || !self.env.OPENAI_API_ENDPOINT.includes("localhost") ? getModelFamily(streamConfig.model) : "openai"; + return ProviderRepository.getModelFamily(streamConfig.model, self.env) } - const modelFamily = useModelFamily(); + const modelFamily = await useModelFamily(); const useModelHandler = () => { - return !self.env.OPENAI_API_ENDPOINT || !self.env.OPENAI_API_ENDPOINT.includes("localhost") ? modelHandlers[modelFamily as ModelFamily] : modelHandlers.openai; + return modelHandlers[modelFamily] } const handler = useModelHandler();