From 1e02b12cda8f415e3939ee8ee638f55cd9c1e45c Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Thu, 4 Sep 2025 13:42:30 -0400 Subject: [PATCH] fixes issue with model selection --- README.md | 2 +- bun.lock | 4 +- crates/chat-ui/src/app.rs | 3 +- crates/inference-engine/src/server.rs | 106 ++++++++++++++--- integration/gemma-runner/src/gemma_api.rs | 136 +++++++++++++++------- integration/llama-runner/src/llama_api.rs | 21 ++++ 6 files changed, 209 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 8d85b2d..0a0c3f0 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ AI inference Server with OpenAI-compatible API (Limited Features) > This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks. > By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems. -Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name. +Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name. A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces. diff --git a/bun.lock b/bun.lock index 365afba..cccf7ca 100644 --- a/bun.lock +++ b/bun.lock @@ -4,7 +4,7 @@ "": { "name": "predict-otron-9000", }, - "crates/cli/package": { + "integration/cli/package": { "name": "cli", "dependencies": { "install": "^0.13.0", @@ -13,7 +13,7 @@ }, }, "packages": { - "cli": ["cli@workspace:crates/cli/package"], + "cli": ["cli@workspace:integration/cli/package"], "install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="], diff --git a/crates/chat-ui/src/app.rs b/crates/chat-ui/src/app.rs index 3d45dea..0b1ad11 100644 --- a/crates/chat-ui/src/app.rs +++ b/crates/chat-ui/src/app.rs @@ -365,7 +365,7 @@ fn ChatPage() -> impl IntoView { // State for available models and selected model let available_models = RwSignal::new(Vec::::new()); - let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model + let selected_model = RwSignal::new(String::from("")); // Default model // State for streaming response let streaming_content = RwSignal::new(String::new()); @@ -382,6 +382,7 @@ fn ChatPage() -> impl IntoView { match fetch_models().await { Ok(models) => { available_models.set(models); + selected_model.set(String::from("gemma-3-1b-it")); } Err(error) => { console::log_1(&format!("Failed to fetch models: {}", error).into()); diff --git a/crates/inference-engine/src/server.rs b/crates/inference-engine/src/server.rs index 613a14e..bd2e91e 100644 --- a/crates/inference-engine/src/server.rs +++ b/crates/inference-engine/src/server.rs @@ -7,6 +7,7 @@ use axum::{ }; use futures_util::stream::{self, Stream}; use std::convert::Infallible; +use std::str::FromStr; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -20,7 +21,7 @@ use crate::openai_types::{ use crate::Which; use either::Either; use embeddings_engine::models_list; -use gemma_runner::{run_gemma_api, GemmaInferenceConfig}; +use gemma_runner::{run_gemma_api, GemmaInferenceConfig, WhichModel}; use llama_runner::{run_llama_inference, LlamaInferenceConfig}; use serde_json::Value; // ------------------------- @@ -35,12 +36,13 @@ pub enum ModelType { #[derive(Clone)] pub struct AppState { - pub model_type: ModelType, + pub model_type: Option, pub model_id: String, pub gemma_config: Option, pub llama_config: Option, } + impl Default for AppState { fn default() -> Self { // Configure a default model to prevent 503 errors from the chat-ui @@ -48,12 +50,12 @@ impl Default for AppState { let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string()); let gemma_config = GemmaInferenceConfig { - model: gemma_runner::WhichModel::InstructV3_1B, + model: None, ..Default::default() }; Self { - model_type: ModelType::Gemma, + model_type: None, model_id: default_model_id, gemma_config: Some(gemma_config), llama_config: None, @@ -84,7 +86,9 @@ fn model_id_to_which(model_id: &str) -> Option { "gemma-2-9b-it" => Some(Which::InstructV2_9B), "gemma-3-1b" => Some(Which::BaseV3_1B), "gemma-3-1b-it" => Some(Which::InstructV3_1B), + "llama-3.2-1b" => Some(Which::Llama32_1B), "llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct), + "llama-3.2-3b" => Some(Which::Llama32_3B), "llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct), _ => None, } @@ -190,7 +194,21 @@ pub async fn chat_completions_non_streaming_proxy( // Get streaming receiver based on model type let rx = if which_model.is_llama_model() { // Create Llama configuration dynamically - let mut config = LlamaInferenceConfig::default(); + let llama_model = match which_model { + Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B, + Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct, + Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B, + Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct, + _ => { + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": { "message": format!("Model {} is not a Llama model", model_id) } + })) + )); + } + }; + let mut config = LlamaInferenceConfig::new(llama_model); config.prompt = prompt.clone(); config.max_tokens = max_tokens; run_llama_inference(config).map_err(|e| ( @@ -201,14 +219,35 @@ pub async fn chat_completions_non_streaming_proxy( ))? } else { // Create Gemma configuration dynamically - let gemma_model = if which_model.is_v3_model() { - gemma_runner::WhichModel::InstructV3_1B - } else { - gemma_runner::WhichModel::InstructV3_1B // Default fallback + let gemma_model = match which_model { + Which::Base2B => gemma_runner::WhichModel::Base2B, + Which::Base7B => gemma_runner::WhichModel::Base7B, + Which::Instruct2B => gemma_runner::WhichModel::Instruct2B, + Which::Instruct7B => gemma_runner::WhichModel::Instruct7B, + Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B, + Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B, + Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B, + Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B, + Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B, + Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B, + Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B, + Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B, + Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B, + Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B, + Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B, + Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B, + _ => { + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": { "message": format!("Model {} is not a Gemma model", model_id) } + })) + )); + } }; let mut config = GemmaInferenceConfig { - model: gemma_model, + model: Some(gemma_model), ..Default::default() }; config.prompt = prompt.clone(); @@ -348,7 +387,21 @@ async fn handle_streaming_request( // Get streaming receiver based on model type let model_rx = if which_model.is_llama_model() { // Create Llama configuration dynamically - let mut config = LlamaInferenceConfig::default(); + let llama_model = match which_model { + Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B, + Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct, + Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B, + Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct, + _ => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { "message": format!("Model {} is not a Llama model", model_id) } + })) + )); + } + }; + let mut config = LlamaInferenceConfig::new(llama_model); config.prompt = prompt.clone(); config.max_tokens = max_tokens; match run_llama_inference(config) { @@ -364,14 +417,35 @@ async fn handle_streaming_request( } } else { // Create Gemma configuration dynamically - let gemma_model = if which_model.is_v3_model() { - gemma_runner::WhichModel::InstructV3_1B - } else { - gemma_runner::WhichModel::InstructV3_1B // Default fallback + let gemma_model = match which_model { + Which::Base2B => gemma_runner::WhichModel::Base2B, + Which::Base7B => gemma_runner::WhichModel::Base7B, + Which::Instruct2B => gemma_runner::WhichModel::Instruct2B, + Which::Instruct7B => gemma_runner::WhichModel::Instruct7B, + Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B, + Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B, + Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B, + Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B, + Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B, + Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B, + Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B, + Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B, + Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B, + Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B, + Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B, + Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B, + _ => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { "message": format!("Model {} is not a Gemma model", model_id) } + })) + )); + } }; let mut config = GemmaInferenceConfig { - model: gemma_model, + model: Some(gemma_model), ..Default::default() }; config.prompt = prompt.clone(); diff --git a/integration/gemma-runner/src/gemma_api.rs b/integration/gemma-runner/src/gemma_api.rs index 5f91ee5..4dfd7d9 100644 --- a/integration/gemma-runner/src/gemma_api.rs +++ b/integration/gemma-runner/src/gemma_api.rs @@ -1,13 +1,8 @@ -#[cfg(feature = "accelerate")] -extern crate accelerate_src; -#[cfg(feature = "mkl")] -extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; -use clap::ValueEnum; // Removed gemma_cli import as it's not needed for the API use candle_core::{DType, Device, Tensor}; @@ -21,8 +16,10 @@ use std::thread; use tokenizers::Tokenizer; use utils::hub_load_safetensors; use utils::token_output_stream::TokenOutputStream; +use std::str::FromStr; +use std::fmt; -#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] pub enum WhichModel { #[value(name = "gemma-2b")] Base2B, @@ -58,6 +55,56 @@ pub enum WhichModel { InstructV3_1B, } +impl FromStr for WhichModel { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "gemma-2b" => Ok(Self::Base2B), + "gemma-7b" => Ok(Self::Base7B), + "gemma-2b-it" => Ok(Self::Instruct2B), + "gemma-7b-it" => Ok(Self::Instruct7B), + "gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B), + "gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B), + "codegemma-2b" => Ok(Self::CodeBase2B), + "codegemma-7b" => Ok(Self::CodeBase7B), + "codegemma-2b-it" => Ok(Self::CodeInstruct2B), + "codegemma-7b-it" => Ok(Self::CodeInstruct7B), + "gemma-2-2b" => Ok(Self::BaseV2_2B), + "gemma-2-2b-it" => Ok(Self::InstructV2_2B), + "gemma-2-9b" => Ok(Self::BaseV2_9B), + "gemma-2-9b-it" => Ok(Self::InstructV2_9B), + "gemma-3-1b" => Ok(Self::BaseV3_1B), + "gemma-3-1b-it" => Ok(Self::InstructV3_1B), + _ => Err(format!("Unknown model: {}", s)), + } + } +} + +impl fmt::Display for WhichModel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let name = match self { + Self::Base2B => "gemma-2b", + Self::Base7B => "gemma-7b", + Self::Instruct2B => "gemma-2b-it", + Self::Instruct7B => "gemma-7b-it", + Self::InstructV1_1_2B => "gemma-1.1-2b-it", + Self::InstructV1_1_7B => "gemma-1.1-7b-it", + Self::CodeBase2B => "codegemma-2b", + Self::CodeBase7B => "codegemma-7b", + Self::CodeInstruct2B => "codegemma-2b-it", + Self::CodeInstruct7B => "codegemma-7b-it", + Self::BaseV2_2B => "gemma-2-2b", + Self::InstructV2_2B => "gemma-2-2b-it", + Self::BaseV2_9B => "gemma-2-9b", + Self::InstructV2_9B => "gemma-2-9b-it", + Self::BaseV3_1B => "gemma-3-1b", + Self::InstructV3_1B => "gemma-3-1b-it", + }; + write!(f, "{}", name) + } +} + enum Model { V1(Model1), V2(Model2), @@ -145,7 +192,7 @@ impl TextGeneration { // Make sure stdout isn't holding anything (if caller also prints). std::io::stdout().flush()?; - let mut generated_tokens = 0usize; + let mut _generated_tokens = 0usize; let eos_token = match self.tokenizer.get_token("") { Some(token) => token, @@ -183,7 +230,7 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); - generated_tokens += 1; + _generated_tokens += 1; if next_token == eos_token || next_token == eot_token { break; @@ -210,7 +257,7 @@ impl TextGeneration { pub struct GemmaInferenceConfig { pub tracing: bool, pub prompt: String, - pub model: WhichModel, + pub model: Option, pub cpu: bool, pub dtype: Option, pub model_id: Option, @@ -229,7 +276,7 @@ impl Default for GemmaInferenceConfig { Self { tracing: false, prompt: "Hello".to_string(), - model: WhichModel::InstructV2_2B, + model: Some(WhichModel::InstructV2_2B), cpu: false, dtype: None, model_id: None, @@ -286,28 +333,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result "google/gemma-2b", - WhichModel::Base7B => "google/gemma-7b", - WhichModel::Instruct2B => "google/gemma-2b-it", - WhichModel::Instruct7B => "google/gemma-7b-it", - WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it", - WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it", - WhichModel::CodeBase2B => "google/codegemma-2b", - WhichModel::CodeBase7B => "google/codegemma-7b", - WhichModel::CodeInstruct2B => "google/codegemma-2b-it", - WhichModel::CodeInstruct7B => "google/codegemma-7b-it", - WhichModel::BaseV2_2B => "google/gemma-2-2b", - WhichModel::InstructV2_2B => "google/gemma-2-2b-it", - WhichModel::BaseV2_9B => "google/gemma-2-9b", - WhichModel::InstructV2_9B => "google/gemma-2-9b-it", - WhichModel::BaseV3_1B => "google/gemma-3-1b-pt", - WhichModel::InstructV3_1B => "google/gemma-3-1b-it", + Some(WhichModel::Base2B) => "google/gemma-2b", + Some(WhichModel::Base7B) => "google/gemma-7b", + Some(WhichModel::Instruct2B) => "google/gemma-2b-it", + Some(WhichModel::Instruct7B) => "google/gemma-7b-it", + Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it", + Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it", + Some(WhichModel::CodeBase2B) => "google/codegemma-2b", + Some(WhichModel::CodeBase7B) => "google/codegemma-7b", + Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it", + Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it", + Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b", + Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it", + Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b", + Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it", + Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt", + Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it", + None => "google/gemma-2-2b-it", // default fallback } .to_string() }); @@ -318,7 +367,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result vec![repo.get("model.safetensors")?], + Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?], _ => hub_load_safetensors(&repo, "model.safetensors.index.json")?, }; println!("Retrieved files in {:?}", start.elapsed()); @@ -329,29 +378,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result { + Some(WhichModel::Base2B) + | Some(WhichModel::Base7B) + | Some(WhichModel::Instruct2B) + | Some(WhichModel::Instruct7B) + | Some(WhichModel::InstructV1_1_2B) + | Some(WhichModel::InstructV1_1_7B) + | Some(WhichModel::CodeBase2B) + | Some(WhichModel::CodeBase7B) + | Some(WhichModel::CodeInstruct2B) + | Some(WhichModel::CodeInstruct7B) => { let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let model = Model1::new(cfg.use_flash_attn, &config, vb)?; Model::V1(model) } - WhichModel::BaseV2_2B - | WhichModel::InstructV2_2B - | WhichModel::BaseV2_9B - | WhichModel::InstructV2_9B => { + Some(WhichModel::BaseV2_2B) + | Some(WhichModel::InstructV2_2B) + | Some(WhichModel::BaseV2_9B) + | Some(WhichModel::InstructV2_9B) + | None => { // default to V2 model let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let model = Model2::new(cfg.use_flash_attn, &config, vb)?; Model::V2(model) } - WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => { + Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => { let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let model = Model3::new(cfg.use_flash_attn, &config, vb)?; Model::V3(model) @@ -371,7 +421,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result { + Some(WhichModel::InstructV3_1B) => { format!( "user\n{}\nmodel\n", cfg.prompt diff --git a/integration/llama-runner/src/llama_api.rs b/integration/llama-runner/src/llama_api.rs index 59024c9..41aacd8 100644 --- a/integration/llama-runner/src/llama_api.rs +++ b/integration/llama-runner/src/llama_api.rs @@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig { pub repeat_last_n: usize, } +impl LlamaInferenceConfig { + pub fn new(model: WhichModel) -> Self { + Self { + prompt: String::new(), + model, + cpu: false, + temperature: 1.0, + top_p: None, + top_k: None, + seed: 42, + max_tokens: 512, + no_kv_cache: false, + dtype: None, + model_id: None, + revision: None, + use_flash_attn: true, + repeat_penalty: 1.1, + repeat_last_n: 64, + } + } +} impl Default for LlamaInferenceConfig { fn default() -> Self { Self {