fixes issue with model selection

This commit is contained in:
geoffsee
2025-09-04 13:42:30 -04:00
parent ff55d882c7
commit 1e02b12cda
6 changed files with 209 additions and 63 deletions

View File

@@ -12,7 +12,7 @@ AI inference Server with OpenAI-compatible API (Limited Features)
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.

View File

@@ -4,7 +4,7 @@
"": {
"name": "predict-otron-9000",
},
"crates/cli/package": {
"integration/cli/package": {
"name": "cli",
"dependencies": {
"install": "^0.13.0",
@@ -13,7 +13,7 @@
},
},
"packages": {
"cli": ["cli@workspace:crates/cli/package"],
"cli": ["cli@workspace:integration/cli/package"],
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],

View File

@@ -365,7 +365,7 @@ fn ChatPage() -> impl IntoView {
// State for available models and selected model
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model
let selected_model = RwSignal::new(String::from("")); // Default model
// State for streaming response
let streaming_content = RwSignal::new(String::new());
@@ -382,6 +382,7 @@ fn ChatPage() -> impl IntoView {
match fetch_models().await {
Ok(models) => {
available_models.set(models);
selected_model.set(String::from("gemma-3-1b-it"));
}
Err(error) => {
console::log_1(&format!("Failed to fetch models: {}", error).into());

View File

@@ -7,6 +7,7 @@ use axum::{
};
use futures_util::stream::{self, Stream};
use std::convert::Infallible;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
@@ -20,7 +21,7 @@ use crate::openai_types::{
use crate::Which;
use either::Either;
use embeddings_engine::models_list;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use gemma_runner::{run_gemma_api, GemmaInferenceConfig, WhichModel};
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value;
// -------------------------
@@ -35,12 +36,13 @@ pub enum ModelType {
#[derive(Clone)]
pub struct AppState {
pub model_type: ModelType,
pub model_type: Option<ModelType>,
pub model_id: String,
pub gemma_config: Option<GemmaInferenceConfig>,
pub llama_config: Option<LlamaInferenceConfig>,
}
impl Default for AppState {
fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui
@@ -48,12 +50,12 @@ impl Default for AppState {
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig {
model: gemma_runner::WhichModel::InstructV3_1B,
model: None,
..Default::default()
};
Self {
model_type: ModelType::Gemma,
model_type: None,
model_id: default_model_id,
gemma_config: Some(gemma_config),
llama_config: None,
@@ -84,7 +86,9 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
"gemma-3-1b" => Some(Which::BaseV3_1B),
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
"llama-3.2-1b" => Some(Which::Llama32_1B),
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
"llama-3.2-3b" => Some(Which::Llama32_3B),
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
_ => None,
}
@@ -190,7 +194,21 @@ pub async fn chat_completions_non_streaming_proxy(
// Get streaming receiver based on model type
let rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default();
let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
}))
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
@@ -201,14 +219,35 @@ pub async fn chat_completions_non_streaming_proxy(
))?
} else {
// Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() {
gemma_runner::WhichModel::InstructV3_1B
} else {
gemma_runner::WhichModel::InstructV3_1B // Default fallback
let gemma_model = match which_model {
Which::Base2B => gemma_runner::WhichModel::Base2B,
Which::Base7B => gemma_runner::WhichModel::Base7B,
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
}))
));
}
};
let mut config = GemmaInferenceConfig {
model: gemma_model,
model: Some(gemma_model),
..Default::default()
};
config.prompt = prompt.clone();
@@ -348,7 +387,21 @@ async fn handle_streaming_request(
// Get streaming receiver based on model type
let model_rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default();
let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
}))
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_llama_inference(config) {
@@ -364,14 +417,35 @@ async fn handle_streaming_request(
}
} else {
// Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() {
gemma_runner::WhichModel::InstructV3_1B
} else {
gemma_runner::WhichModel::InstructV3_1B // Default fallback
let gemma_model = match which_model {
Which::Base2B => gemma_runner::WhichModel::Base2B,
Which::Base7B => gemma_runner::WhichModel::Base7B,
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
}))
));
}
};
let mut config = GemmaInferenceConfig {
model: gemma_model,
model: Some(gemma_model),
..Default::default()
};
config.prompt = prompt.clone();

View File

@@ -1,13 +1,8 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API
use candle_core::{DType, Device, Tensor};
@@ -21,8 +16,10 @@ use std::thread;
use tokenizers::Tokenizer;
use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream;
use std::str::FromStr;
use std::fmt;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum WhichModel {
#[value(name = "gemma-2b")]
Base2B,
@@ -58,6 +55,56 @@ pub enum WhichModel {
InstructV3_1B,
}
impl FromStr for WhichModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gemma-2b" => Ok(Self::Base2B),
"gemma-7b" => Ok(Self::Base7B),
"gemma-2b-it" => Ok(Self::Instruct2B),
"gemma-7b-it" => Ok(Self::Instruct7B),
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
"codegemma-2b" => Ok(Self::CodeBase2B),
"codegemma-7b" => Ok(Self::CodeBase7B),
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
"gemma-2-2b" => Ok(Self::BaseV2_2B),
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
"gemma-2-9b" => Ok(Self::BaseV2_9B),
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
"gemma-3-1b" => Ok(Self::BaseV3_1B),
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
_ => Err(format!("Unknown model: {}", s)),
}
}
}
impl fmt::Display for WhichModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Base2B => "gemma-2b",
Self::Base7B => "gemma-7b",
Self::Instruct2B => "gemma-2b-it",
Self::Instruct7B => "gemma-7b-it",
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
Self::CodeBase2B => "codegemma-2b",
Self::CodeBase7B => "codegemma-7b",
Self::CodeInstruct2B => "codegemma-2b-it",
Self::CodeInstruct7B => "codegemma-7b-it",
Self::BaseV2_2B => "gemma-2-2b",
Self::InstructV2_2B => "gemma-2-2b-it",
Self::BaseV2_9B => "gemma-2-9b",
Self::InstructV2_9B => "gemma-2-9b-it",
Self::BaseV3_1B => "gemma-3-1b",
Self::InstructV3_1B => "gemma-3-1b-it",
};
write!(f, "{}", name)
}
}
enum Model {
V1(Model1),
V2(Model2),
@@ -145,7 +192,7 @@ impl TextGeneration {
// Make sure stdout isn't holding anything (if caller also prints).
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let mut _generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
@@ -183,7 +230,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
_generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
@@ -210,7 +257,7 @@ impl TextGeneration {
pub struct GemmaInferenceConfig {
pub tracing: bool,
pub prompt: String,
pub model: WhichModel,
pub model: Option<WhichModel>,
pub cpu: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
@@ -229,7 +276,7 @@ impl Default for GemmaInferenceConfig {
Self {
tracing: false,
prompt: "Hello".to_string(),
model: WhichModel::InstructV2_2B,
model: Some(WhichModel::InstructV2_2B),
cpu: false,
dtype: None,
model_id: None,
@@ -286,28 +333,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
}
};
println!("Using dtype: {:?}", dtype);
println!("Raw model string: {:?}", cfg.model_id);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| {
match cfg.model {
WhichModel::Base2B => "google/gemma-2b",
WhichModel::Base7B => "google/gemma-7b",
WhichModel::Instruct2B => "google/gemma-2b-it",
WhichModel::Instruct7B => "google/gemma-7b-it",
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
WhichModel::CodeBase2B => "google/codegemma-2b",
WhichModel::CodeBase7B => "google/codegemma-7b",
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
WhichModel::BaseV2_2B => "google/gemma-2-2b",
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
WhichModel::BaseV2_9B => "google/gemma-2-9b",
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
Some(WhichModel::Base2B) => "google/gemma-2b",
Some(WhichModel::Base7B) => "google/gemma-7b",
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
None => "google/gemma-2-2b-it", // default fallback
}
.to_string()
});
@@ -318,7 +367,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?;
let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?],
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("Retrieved files in {:?}", start.elapsed());
@@ -329,29 +378,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model: Model = match cfg.model {
WhichModel::Base2B
| WhichModel::Base7B
| WhichModel::Instruct2B
| WhichModel::Instruct7B
| WhichModel::InstructV1_1_2B
| WhichModel::InstructV1_1_7B
| WhichModel::CodeBase2B
| WhichModel::CodeBase7B
| WhichModel::CodeInstruct2B
| WhichModel::CodeInstruct7B => {
Some(WhichModel::Base2B)
| Some(WhichModel::Base7B)
| Some(WhichModel::Instruct2B)
| Some(WhichModel::Instruct7B)
| Some(WhichModel::InstructV1_1_2B)
| Some(WhichModel::InstructV1_1_7B)
| Some(WhichModel::CodeBase2B)
| Some(WhichModel::CodeBase7B)
| Some(WhichModel::CodeInstruct2B)
| Some(WhichModel::CodeInstruct7B) => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model)
}
WhichModel::BaseV2_2B
| WhichModel::InstructV2_2B
| WhichModel::BaseV2_9B
| WhichModel::InstructV2_9B => {
Some(WhichModel::BaseV2_2B)
| Some(WhichModel::InstructV2_2B)
| Some(WhichModel::BaseV2_9B)
| Some(WhichModel::InstructV2_9B)
| None => { // default to V2 model
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model)
}
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
Model::V3(model)
@@ -371,7 +421,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
);
let prompt = match cfg.model {
WhichModel::InstructV3_1B => {
Some(WhichModel::InstructV3_1B) => {
format!(
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
cfg.prompt

View File

@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
pub repeat_last_n: usize,
}
impl LlamaInferenceConfig {
pub fn new(model: WhichModel) -> Self {
Self {
prompt: String::new(),
model,
cpu: false,
temperature: 1.0,
top_p: None,
top_k: None,
seed: 42,
max_tokens: 512,
no_kv_cache: false,
dtype: None,
model_id: None,
revision: None,
use_flash_attn: true,
repeat_penalty: 1.1,
repeat_last_n: 64,
}
}
}
impl Default for LlamaInferenceConfig {
fn default() -> Self {
Self {