diff --git a/README.md b/README.md index 0e7ea5f..56e25ad 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,22 @@ > Note: Subsequent deployments should omit `bun run deploy:secrets` +## Local Inference (Apple Silicon Only) +~~~bash +##### +# install mlx-omni-server (custom homebrew wrapper) +brew tap seemueller-io/tap +brew install seemueller-io/tap/mlx-omni-server +##### +# Run mlx-omni-server +bun run openai:local +#### +# Override OPENAI_* variables in .dev.vars +sed -i '' '/^OPENAI_API_KEY=/d' .dev.vars; echo 'OPENAI_API_KEY=not-needed' >> .dev.vars +sed -i '' '/^OPENAI_API_ENDPOINT=/d' .dev.vars; echo 'OPENAI_API_ENDPOINT=http://localhost:10240' >> .dev.vars +### Restart open-gsio server so it uses the new variables +bun run server:dev +~~~ History --- diff --git a/bun.lock b/bun.lock index 230a0f5..31ca23c 100644 --- a/bun.lock +++ b/bun.lock @@ -25,7 +25,7 @@ "mobx-react-lite": "^4.0.7", "mobx-state-tree": "^6.0.1", "moo": "^0.5.2", - "openai": "^4.76.0", + "openai": "^5.0.1", "qrcode.react": "^4.1.0", "react": "^18.3.1", "react-dom": "^18.3.1", @@ -676,7 +676,7 @@ "oniguruma-to-es": ["oniguruma-to-es@2.3.0", "", { "dependencies": { "emoji-regex-xs": "^1.0.0", "regex": "^5.1.1", "regex-recursion": "^5.1.1" } }, "sha512-bwALDxriqfKGfUufKGGepCzu9x7nJQuoRoAFp4AnwehhC2crqrDIAP/uN2qdlsAvSMpeRC3+Yzhqc7hLmle5+g=="], - "openai": ["openai@4.91.1", "", { "dependencies": { "@types/node": "^18.11.18", "@types/node-fetch": "^2.6.4", "abort-controller": "^3.0.0", "agentkeepalive": "^4.2.1", "form-data-encoder": "1.7.2", "formdata-node": "^4.3.2", "node-fetch": "^2.6.7" }, "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-DbjrR0hIMQFbxz8+3qBsfPJnh3+I/skPgoSlT7f9eiZuhGBUissPQULNgx6gHNkLoZ3uS0uYS6eXPUdtg4nHzw=="], + "openai": ["openai@5.0.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-Do6vxhbDv7cXhji/4ct1lrpZYMAOmjYbhyA9LJTuG7OfpbWMpuS+EIXkRT7R+XxpRB1OZhU/op4FU3p3uxU6gw=="], "parent-module": ["parent-module@1.0.1", "", { "dependencies": { "callsites": "^3.0.0" } }, "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g=="], @@ -872,7 +872,7 @@ "wrangler": ["wrangler@4.16.1", "", { "dependencies": { "@cloudflare/kv-asset-handler": "0.4.0", "@cloudflare/unenv-preset": "2.3.2", "blake3-wasm": "2.1.5", "esbuild": "0.25.4", "miniflare": "4.20250508.3", "path-to-regexp": "6.3.0", "unenv": "2.0.0-rc.17", "workerd": "1.20250508.0" }, "optionalDependencies": { "fsevents": "~2.3.2", "sharp": "^0.33.5" }, "peerDependencies": { "@cloudflare/workers-types": "^4.20250508.0" }, "optionalPeers": ["@cloudflare/workers-types"], "bin": { "wrangler": "bin/wrangler.js", "wrangler2": "bin/wrangler.js" } }, "sha512-YiLdWXcaQva2K/bqokpsZbySPmoT8TJFyJPsQPZumnkgimM9+/g/yoXArByA+pf+xU8jhw7ybQ8X1yBGXv731g=="], - "ws": ["ws@8.18.0", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw=="], + "ws": ["ws@8.18.1", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-RKW2aJZMXeMxVpnZ6bck+RswznaxmzdULiBr6KY7XkTnW8uvt0iT9H5DkHUChXrc+uurzwa0rVI16n/Xzjdz1w=="], "xml-name-validator": ["xml-name-validator@5.0.0", "", {}, "sha512-EvGK8EJ3DhaHfbRlETOWAS5pO9MZITeauHKJyb8wyajUfQUenkIg2MvLDTZ4T/TgIcm3HU0TFBgWWboAZ30UHg=="], @@ -908,10 +908,10 @@ "jsdom/whatwg-url": ["whatwg-url@14.2.0", "", { "dependencies": { "tr46": "^5.1.0", "webidl-conversions": "^7.0.0" } }, "sha512-De72GdQZzNTUBBChsXueQUnPKDkg/5A5zp7pFDuQAj5UFoENpiACU0wlCvzpAGnTkj++ihpKwKyYewn/XNUbKw=="], - "jsdom/ws": ["ws@8.18.1", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-RKW2aJZMXeMxVpnZ6bck+RswznaxmzdULiBr6KY7XkTnW8uvt0iT9H5DkHUChXrc+uurzwa0rVI16n/Xzjdz1w=="], - "miniflare/acorn": ["acorn@8.14.0", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA=="], + "miniflare/ws": ["ws@8.18.0", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw=="], + "miniflare/zod": ["zod@3.22.3", "", {}, "sha512-EjIevzuJRiRPbVH4mGc8nApb/lVLKVpmUhAaR5R5doKGfAnGJ6Gr3CViAVjP+4FWSxCsybeWQdcgCtbX+7oZug=="], "source-map-support/source-map": ["source-map@0.6.1", "", {}, "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g=="], diff --git a/package.json b/package.json index f23f530..43455b9 100644 --- a/package.json +++ b/package.json @@ -17,7 +17,8 @@ "tail": "wrangler tail", "tail:email-service": "wrangler tail -c workers/email/wrangler-email.toml", "tail:analytics-service": "wrangler tail -c workers/analytics/wrangler-analytics.toml", - "tail:session-proxy": "wrangler tail -c workers/session-proxy/wrangler-session-proxy.toml --env production" + "tail:session-proxy": "wrangler tail -c workers/session-proxy/wrangler-session-proxy.toml --env production", + "openai:local": "./scripts/start_inference_server.sh" }, "devDependencies": { "@anthropic-ai/sdk": "^0.32.1", @@ -42,7 +43,7 @@ "mobx-react-lite": "^4.0.7", "mobx-state-tree": "^6.0.1", "moo": "^0.5.2", - "openai": "^4.76.0", + "openai": "^5.0.1", "qrcode.react": "^4.1.0", "react": "^18.3.1", "react-dom": "^18.3.1", diff --git a/scripts/killport.js b/scripts/killport.js new file mode 100644 index 0000000..9f7a859 --- /dev/null +++ b/scripts/killport.js @@ -0,0 +1,51 @@ +#!/usr/bin/env node + +import * as child_process from "node:child_process"; + +const args = process.argv.slice(2); +const port = args.length > 0 ? parseInt(args[0], 10) : null; + +if (!port || isNaN(port)) { + console.error('Please provide a valid port number'); + process.exit(1); +} + + +export const killProcessOnPort = (port) => { + return new Promise((resolve, reject) => { + // Find the PID of the process using the specified port + child_process.exec(`lsof -t -i :${port}`.trim(), (err, stdout) => { + if (err) { + // Handle command error (such as permission denied) + if (err.code !== 1) { + console.error(`Error finding process on port ${port}:`, err); + return reject(err); + } else { + // If code is 1, it generally means no process is using the port + console.log(`No process found on port ${port}`); + return resolve(); + } + } + + // If stdout is empty, no process is using the port + const pid = stdout.trim(); + if (!pid) { + console.log(`No process is currently running on port ${port}`); + return resolve(); + } + + // Kill the process using the specified PID + child_process.exec(`kill -9 ${pid}`.trim(), (killErr) => { + if (killErr) { + console.error(`Failed to kill process ${pid} on port ${port}`, killErr); + return reject(killErr); + } + + console.log(`Successfully killed process ${pid} on port ${port}`); + resolve(); + }); + }); + }); +}; + +await killProcessOnPort(port); diff --git a/scripts/start_inference_server.sh b/scripts/start_inference_server.sh new file mode 100755 index 0000000..111e57c --- /dev/null +++ b/scripts/start_inference_server.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +SERVER_TYPE="mlx-omni-server" + +printf "Starting Inference Server: %s\n" ${SERVER_TYPE} + + +mlx-omni-server \ No newline at end of file diff --git a/src/components/chat/input-menu/InputMenu.tsx b/src/components/chat/input-menu/InputMenu.tsx index 7f473e9..c10d7f5 100644 --- a/src/components/chat/input-menu/InputMenu.tsx +++ b/src/components/chat/input-menu/InputMenu.tsx @@ -53,7 +53,17 @@ const InputMenu: React.FC<{ isDisabled?: boolean }> = observer( setControlledOpen(isOpen); }, [isOpen]); - const textModels = SUPPORTED_MODELS; + + const getSupportedModels = async () => { + return await (await fetch("/api/models")).json(); + } + + useEffect(() => { + getSupportedModels().then((supportedModels) => { + ClientChatStore.setSupportedModels(supportedModels); + }); + }, []); + const handleClose = useCallback(() => { onClose(); @@ -75,9 +85,7 @@ const InputMenu: React.FC<{ isDisabled?: boolean }> = observer( }, [onClose]); async function selectModelFn({ name, value }) { - if (getModelFamily(value)) { ClientChatStore.setModel(value); - } } function isSelectedModelFn({ name, value }) { @@ -144,7 +152,7 @@ const InputMenu: React.FC<{ isDisabled?: boolean }> = observer( > ({ name: m, value: m }))} + flyoutMenuOptions={ClientChatStore.supportedModels.map((m) => ({ name: m, value: m }))} onClose={onClose} parentIsOpen={isOpen} setMenuState={setMenuState} diff --git a/src/stores/ClientChatStore.ts b/src/stores/ClientChatStore.ts index 1f9617c..d6bd156 100644 --- a/src/stores/ClientChatStore.ts +++ b/src/stores/ClientChatStore.ts @@ -9,6 +9,7 @@ const ClientChatStore = types isLoading: types.optional(types.boolean, false), model: types.optional(types.string, "meta-llama/llama-4-scout-17b-16e-instruct"), imageModel: types.optional(types.string, "black-forest-labs/flux-1.1-pro"), + supportedModels: types.optional(types.array(types.string), []) }) .actions((self) => ({ cleanup() { @@ -17,6 +18,12 @@ const ClientChatStore = types self.eventSource = null; } }, + setSupportedModels(modelsList: string[]) { + self.supportedModels = modelsList; + if(!modelsList.includes(self.model)) { + self.model = modelsList.pop() + } + }, sendMessage: flow(function* () { if (!self.input.trim() || self.isLoading) return; diff --git a/workers/site/api-router.ts b/workers/site/api-router.ts index 3db389f..0133fa3 100644 --- a/workers/site/api-router.ts +++ b/workers/site/api-router.ts @@ -28,6 +28,13 @@ export function createRouter() { }, ) + .get("/api/models", + async (req, env, ctx) => { + const { chatService } = createRequestContext(env, ctx); + return chatService.getSupportedModels(); + }, + ) + .post("/api/feedback", async (r, e, c) => { const { feedbackService } = createRequestContext(e, c); return feedbackService.handleFeedback(r); diff --git a/workers/site/lib/utils.ts b/workers/site/lib/utils.ts index 1b03448..379fcfc 100644 --- a/workers/site/lib/utils.ts +++ b/workers/site/lib/utils.ts @@ -59,4 +59,38 @@ export class Utils { return result; } + + static normalizeWithBlanks(msgs: T[]): T[] { + const out: T[] = []; + + // In local mode first turn expected to be user. + let expected: Normalize.Role = "user"; + + for (const m of msgs) { + while (m.role !== expected) { + // Insert blanks to match expected sequence user/assistant/user... + out.push(Normalize.makeBlank(expected) as T); + expected = expected === "user" ? "assistant" : "user"; + } + + out.push(m); + expected = expected === "user" ? "assistant" : "user"; + } + + return out; + } + +} + +module Normalize { + export type Role = "user" | "assistant"; + + export interface ChatMessage extends Record { + role: Role; + } + + export const makeBlank = (role: Role): ChatMessage => ({ + role, + content: "" + }); } diff --git a/workers/site/providers/openai.ts b/workers/site/providers/openai.ts index 38a5342..42008f8 100644 --- a/workers/site/providers/openai.ts +++ b/workers/site/providers/openai.ts @@ -1,5 +1,7 @@ import { OpenAI } from "openai"; import ChatSdk from "../lib/chat-sdk"; +import { Utils } from "../lib/utils"; +import {ChatCompletionCreateParamsStreaming} from "openai/resources/chat/completions/completions"; export class OpenAiChatSdk { static async handleOpenAiStream( @@ -81,14 +83,42 @@ export class OpenAiChatSdk { return gpt4oTuningParams; }; - const openAIStream = await opts.openai.chat.completions.create({ + let completionRequest: ChatCompletionCreateParamsStreaming = { model: opts.model, - messages: messages, stream: true, - ...getTuningParams(), - }); + messages: messages + }; + + const isLocal = opts.openai.baseURL.includes("localhost"); + + + if(isLocal) { + completionRequest["messages"] = Utils.normalizeWithBlanks(messages) + completionRequest["stream_options"] = { + include_usage: true + } + } else { + completionRequest = {...completionRequest, ...getTuningParams()} + } + + const openAIStream = await opts.openai.chat.completions.create(completionRequest); for await (const chunk of openAIStream) { + if (isLocal && chunk.usage) { + dataCallback({ + type: "chat", + data: { + choices: [ + { + delta: { content: "" }, + logprobs: null, + finish_reason: "stop", + }, + ], + }, + }); + break; + } dataCallback({ type: "chat", data: chunk }); } } diff --git a/workers/site/services/ChatService.ts b/workers/site/services/ChatService.ts index 8bc7a8e..5e423e5 100644 --- a/workers/site/services/ChatService.ts +++ b/workers/site/services/ChatService.ts @@ -3,7 +3,7 @@ import OpenAI from 'openai'; import ChatSdk from '../lib/chat-sdk'; import Message from "../models/Message"; import O1Message from "../models/O1Message"; -import {getModelFamily, ModelFamily} from "../../../src/components/chat/lib/SupportedModels"; +import {getModelFamily, ModelFamily, SUPPORTED_MODELS} from "../../../src/components/chat/lib/SupportedModels"; import {OpenAiChatSdk} from "../providers/openai"; import {GroqChatSdk} from "../providers/groq"; import {ClaudeChatSdk} from "../providers/claude"; @@ -73,11 +73,21 @@ const ChatService = types throw new Error('Unsupported message format'); }; + const getSupportedModels = async () => { + if(self.env.OPENAI_API_ENDPOINT.includes("localhost")) { + const openaiClient = new OpenAI({baseURL: self.env.OPENAI_API_ENDPOINT}) + const models = await openaiClient.models.list(); + return Response.json(models.data.map(model => model.id)); + } + return Response.json(SUPPORTED_MODELS); + }; + const createStreamParams = async ( streamConfig: any, dynamicContext: any, durableObject: any ): Promise => { + return { env: self.env, openai: self.openai, @@ -112,6 +122,7 @@ const ChatService = types }; return { + getSupportedModels, setActiveStream(streamId: string, stream: any) { const validStream = { name: stream?.name || "Unnamed Stream", @@ -129,10 +140,18 @@ const ChatService = types }, setEnv(env: Env) { self.env = env; - self.openai = new OpenAI({ - apiKey: self.openAIApiKey, - baseURL: self.openAIBaseURL, - }); + + if(env.OPENAI_API_ENDPOINT.includes("localhost")) { + self.openai = new OpenAI({ + apiKey: self.env.OPENAI_API_KEY, + baseURL: self.env.OPENAI_API_ENDPOINT, + }); + } else{ + self.openai = new OpenAI({ + apiKey: self.openAIApiKey, + baseURL: self.openAIBaseURL, + }); + } }, handleChatRequest: async (request: Request) => { @@ -154,12 +173,12 @@ const ChatService = types }) { const {streamConfig, streamParams, controller, encoder, streamId} = params; - const modelFamily = getModelFamily(streamConfig.model); + const modelFamily = !self.env.OPENAI_API_ENDPOINT.includes("localhost") ? getModelFamily(streamConfig.model) : "openai"; + + const handler = !self.env.OPENAI_API_ENDPOINT.includes("localhost") ? modelHandlers[modelFamily as ModelFamily] : modelHandlers.openai; - const handler = modelHandlers[modelFamily as ModelFamily]; if (handler) { try { - await handler(streamParams, handleStreamData(controller, encoder)); } catch (error) {