use axum::{ extract::State, http::StatusCode, response::{sse::Event, sse::Sse, IntoResponse}, routing::{get, post}, Json, Router, }; use futures_util::stream::{self, Stream}; use std::convert::Infallible; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; use tokio_stream::wrappers::UnboundedReceiverStream; use tower_http::cors::{Any, CorsLayer}; use uuid::Uuid; use crate::openai_types::{ ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage, }; use crate::Which; use either::Either; use embeddings_engine::models_list; use gemma_runner::{run_gemma_api, GemmaInferenceConfig}; use llama_runner::{run_llama_inference, LlamaInferenceConfig}; use serde_json::Value; // ------------------------- // Shared app state // ------------------------- #[derive(Clone, Debug)] pub enum ModelType { Gemma, Llama, } #[derive(Clone)] pub struct AppState { pub model_type: ModelType, pub model_id: String, pub gemma_config: Option, pub llama_config: Option, } impl Default for AppState { fn default() -> Self { // Configure a default model to prevent 503 errors from the chat-ui // This can be overridden by environment variables if needed let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string()); let gemma_config = GemmaInferenceConfig { model: gemma_runner::WhichModel::InstructV3_1B, ..Default::default() }; Self { model_type: ModelType::Gemma, model_id: default_model_id, gemma_config: Some(gemma_config), llama_config: None, } } } // ------------------------- // Helper functions // ------------------------- fn model_id_to_which(model_id: &str) -> Option { let normalized = normalize_model_id(model_id); match normalized.as_str() { "gemma-2b" => Some(Which::Base2B), "gemma-7b" => Some(Which::Base7B), "gemma-2b-it" => Some(Which::Instruct2B), "gemma-7b-it" => Some(Which::Instruct7B), "gemma-1.1-2b-it" => Some(Which::InstructV1_1_2B), "gemma-1.1-7b-it" => Some(Which::InstructV1_1_7B), "codegemma-2b" => Some(Which::CodeBase2B), "codegemma-7b" => Some(Which::CodeBase7B), "codegemma-2b-it" => Some(Which::CodeInstruct2B), "codegemma-7b-it" => Some(Which::CodeInstruct7B), "gemma-2-2b" => Some(Which::BaseV2_2B), "gemma-2-2b-it" => Some(Which::InstructV2_2B), "gemma-2-9b" => Some(Which::BaseV2_9B), "gemma-2-9b-it" => Some(Which::InstructV2_9B), "gemma-3-1b" => Some(Which::BaseV3_1B), "gemma-3-1b-it" => Some(Which::InstructV3_1B), "llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct), "llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct), _ => None, } } fn normalize_model_id(model_id: &str) -> String { model_id.to_lowercase().replace("_", "-") } fn build_gemma_prompt(messages: &[Message]) -> String { let mut prompt = String::new(); for message in messages { match message.role.as_str() { "system" => { if let Some(MessageContent(Either::Left(content))) = &message.content { prompt.push_str(&format!( "system\n{}\n", content )); } } "user" => { if let Some(MessageContent(Either::Left(content))) = &message.content { prompt.push_str(&format!("user\n{}\n", content)); } } "assistant" => { if let Some(MessageContent(Either::Left(content))) = &message.content { prompt.push_str(&format!("model\n{}\n", content)); } } _ => {} } } prompt.push_str("model\n"); prompt } // ------------------------- // OpenAI-compatible handler // ------------------------- pub async fn chat_completions( State(state): State, Json(request): Json, ) -> Result { if !request.stream.unwrap_or(false) { return Ok(chat_completions_non_streaming_proxy(state, request) .await .into_response()); } Ok(chat_completions_stream(state, request) .await .into_response()) } pub async fn chat_completions_non_streaming_proxy( state: AppState, request: ChatCompletionRequest, ) -> Result)> { // Use the model specified in the request let model_id = request.model.clone(); let which_model = model_id_to_which(&model_id); // Validate that the requested model is supported let which_model = match which_model { Some(model) => model, None => { return Err(( StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": { "message": format!("Unsupported model: {}", model_id), "type": "model_not_supported" } })), )); } }; let max_tokens = request.max_tokens.unwrap_or(1000); // Build prompt based on model type let prompt = if which_model.is_llama_model() { // For Llama, just use the last user message for now request .messages .last() .and_then(|m| m.content.as_ref()) .and_then(|c| match c { MessageContent(Either::Left(text)) => Some(text.clone()), _ => None, }) .unwrap_or_default() } else { build_gemma_prompt(&request.messages) }; // Get streaming receiver based on model type let rx = if which_model.is_llama_model() { // Create Llama configuration dynamically let mut config = LlamaInferenceConfig::default(); config.prompt = prompt.clone(); config.max_tokens = max_tokens; run_llama_inference(config).map_err(|e| ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": { "message": format!("Error initializing Llama model: {}", e) } })) ))? } else { // Create Gemma configuration dynamically let gemma_model = if which_model.is_v3_model() { gemma_runner::WhichModel::InstructV3_1B } else { gemma_runner::WhichModel::InstructV3_1B // Default fallback }; let mut config = GemmaInferenceConfig { model: gemma_model, ..Default::default() }; config.prompt = prompt.clone(); config.max_tokens = max_tokens; run_gemma_api(config).map_err(|e| ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": { "message": format!("Error initializing Gemma model: {}", e) } })) ))? }; // Collect all tokens from the stream let mut completion = String::new(); while let Ok(token_result) = rx.recv() { match token_result { Ok(token) => completion.push_str(&token), Err(e) => { return Err(( StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": { "message": format!("Error generating text: {}", e) } })), )); } } } let response = ChatCompletionResponse { id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")), object: "chat.completion".to_string(), created: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), model: model_id, choices: vec![ChatCompletionChoice { index: 0, message: Message { role: "assistant".to_string(), content: Some(MessageContent(Either::Left(completion.clone()))), name: None, }, finish_reason: "stop".to_string(), }], usage: Usage { prompt_tokens: prompt.len() / 4, completion_tokens: completion.len() / 4, total_tokens: (prompt.len() + completion.len()) / 4, }, }; Ok(Json(response).into_response()) } // ------------------------- // Streaming implementation // ------------------------- pub async fn chat_completions_stream( state: AppState, request: ChatCompletionRequest, ) -> Result>>, (StatusCode, Json)> { handle_streaming_request(state, request).await } async fn handle_streaming_request( state: AppState, request: ChatCompletionRequest, ) -> Result>>, (StatusCode, Json)> { // Use the model specified in the request let model_id = request.model.clone(); let which_model = model_id_to_which(&model_id); // Validate that the requested model is supported let which_model = match which_model { Some(model) => model, None => { return Err(( StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": { "message": format!("Unsupported model: {}", model_id), "type": "model_not_supported" } })), )); } }; // Generate a unique ID and metadata let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")); let created = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(); let max_tokens = request.max_tokens.unwrap_or(1000); // Build prompt based on model type let prompt = if which_model.is_llama_model() { // For Llama, just use the last user message for now request .messages .last() .and_then(|m| m.content.as_ref()) .and_then(|c| match c { MessageContent(Either::Left(text)) => Some(text.clone()), _ => None, }) .unwrap_or_default() } else { build_gemma_prompt(&request.messages) }; tracing::debug!("Formatted prompt: {}", prompt); // Channel for streaming SSE events let (tx, rx) = mpsc::unbounded_channel::>(); // Send initial role event let initial_chunk = ChatCompletionChunk { id: response_id.clone(), object: "chat.completion.chunk".to_string(), created, model: model_id.clone(), choices: vec![ChatCompletionChunkChoice { index: 0, delta: Delta { role: Some("assistant".to_string()), content: None, }, finish_reason: None, }], }; if let Ok(json) = serde_json::to_string(&initial_chunk) { let _ = tx.send(Ok(Event::default().data(json))); } // Get streaming receiver based on model type let model_rx = if which_model.is_llama_model() { // Create Llama configuration dynamically let mut config = LlamaInferenceConfig::default(); config.prompt = prompt.clone(); config.max_tokens = max_tokens; match run_llama_inference(config) { Ok(rx) => rx, Err(e) => { return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": { "message": format!("Error initializing Llama model: {}", e) } })), )); } } } else { // Create Gemma configuration dynamically let gemma_model = if which_model.is_v3_model() { gemma_runner::WhichModel::InstructV3_1B } else { gemma_runner::WhichModel::InstructV3_1B // Default fallback }; let mut config = GemmaInferenceConfig { model: gemma_model, ..Default::default() }; config.prompt = prompt.clone(); config.max_tokens = max_tokens; match run_gemma_api(config) { Ok(rx) => rx, Err(e) => { return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": { "message": format!("Error initializing Gemma model: {}", e) } })), )); } } }; // Spawn task to receive tokens from model and forward as SSE events let response_id_clone = response_id.clone(); let model_id_clone = model_id.clone(); tokio::spawn(async move { // Stream tokens with repetition detection let mut recent_tokens = Vec::new(); let mut repetition_count = 0; const MAX_REPETITION_COUNT: usize = 5; const REPETITION_WINDOW: usize = 8; while let Ok(token_result) = model_rx.recv() { match token_result { Ok(token) => { // Skip sending empty tokens if token.is_empty() { continue; } // Add token to recent history for repetition detection recent_tokens.push(token.clone()); if recent_tokens.len() > REPETITION_WINDOW { recent_tokens.remove(0); } // Check for repetitive patterns if recent_tokens.len() >= 4 { let last_token = &recent_tokens[recent_tokens.len() - 1]; let second_last = &recent_tokens[recent_tokens.len() - 2]; if last_token == second_last { repetition_count += 1; tracing::warn!( "Detected repetition pattern: '{}' (count: {})", last_token, repetition_count ); if repetition_count >= MAX_REPETITION_COUNT { tracing::info!("Stopping generation due to excessive repetition"); break; } } else { repetition_count = 0; } } let chunk = ChatCompletionChunk { id: response_id_clone.clone(), object: "chat.completion.chunk".to_string(), created, model: model_id_clone.clone(), choices: vec![ChatCompletionChunkChoice { index: 0, delta: Delta { role: None, content: Some(token), }, finish_reason: None, }], }; if let Ok(json) = serde_json::to_string(&chunk) { let _ = tx.send(Ok(Event::default().data(json))); } } Err(e) => { tracing::info!("Text generation stopped: {}", e); break; } } } // Send final stop chunk and DONE marker let final_chunk = ChatCompletionChunk { id: response_id_clone.clone(), object: "chat.completion.chunk".to_string(), created, model: model_id_clone.clone(), choices: vec![ChatCompletionChunkChoice { index: 0, delta: Delta { role: None, content: None, }, finish_reason: Some("stop".to_string()), }], }; if let Ok(json) = serde_json::to_string(&final_chunk) { let _ = tx.send(Ok(Event::default().data(json))); } let _ = tx.send(Ok(Event::default().data("[DONE]"))); }); // Convert receiver into a Stream for SSE let stream = UnboundedReceiverStream::new(rx); Ok(Sse::new(stream)) } // ------------------------- // Router // ------------------------- pub fn create_router(app_state: AppState) -> Router { let cors = CorsLayer::new() .allow_headers(Any) .allow_origin(Any) .allow_methods(Any) .allow_headers(Any); Router::new() .route("/v1/chat/completions", post(chat_completions)) .route("/v1/models", get(list_models)) .layer(cors) .with_state(app_state) } /// Handler for GET /v1/models - returns list of available models pub async fn list_models() -> Json { // Get all available model variants from the Which enum let which_variants = vec![ Which::Base2B, Which::Base7B, Which::Instruct2B, Which::Instruct7B, Which::InstructV1_1_2B, Which::InstructV1_1_7B, Which::CodeBase2B, Which::CodeBase7B, Which::CodeInstruct2B, Which::CodeInstruct7B, Which::BaseV2_2B, Which::InstructV2_2B, Which::BaseV2_9B, Which::InstructV2_9B, Which::BaseV3_1B, Which::InstructV3_1B, Which::Llama32_1B, Which::Llama32_1BInstruct, Which::Llama32_3B, Which::Llama32_3BInstruct, ]; let mut models: Vec = which_variants.into_iter().map(|which| { let meta = which.meta(); let model_id = match which { Which::Base2B => "gemma-2b", Which::Base7B => "gemma-7b", Which::Instruct2B => "gemma-2b-it", Which::Instruct7B => "gemma-7b-it", Which::InstructV1_1_2B => "gemma-1.1-2b-it", Which::InstructV1_1_7B => "gemma-1.1-7b-it", Which::CodeBase2B => "codegemma-2b", Which::CodeBase7B => "codegemma-7b", Which::CodeInstruct2B => "codegemma-2b-it", Which::CodeInstruct7B => "codegemma-7b-it", Which::BaseV2_2B => "gemma-2-2b", Which::InstructV2_2B => "gemma-2-2b-it", Which::BaseV2_9B => "gemma-2-9b", Which::InstructV2_9B => "gemma-2-9b-it", Which::BaseV3_1B => "gemma-3-1b", Which::InstructV3_1B => "gemma-3-1b-it", Which::Llama32_1B => "llama-3.2-1b", Which::Llama32_1BInstruct => "llama-3.2-1b-instruct", Which::Llama32_3B => "llama-3.2-3b", Which::Llama32_3BInstruct => "llama-3.2-3b-instruct", }; let owned_by = if meta.id.starts_with("google/") { "google" } else if meta.id.starts_with("meta-llama/") { "meta" } else { "unknown" }; Model { id: model_id.to_string(), object: "model".to_string(), created: 1686935002, owned_by: owned_by.to_string(), } }).collect(); // Get embeddings models and convert them to inference Model format let embeddings_response = models_list().await; let embeddings_models: Vec = embeddings_response.0.data.into_iter().map(|embedding_model| { Model { id: embedding_model.id, object: embedding_model.object, created: 1686935002, owned_by: format!("{} - {}", embedding_model.owned_by, embedding_model.description), } }).collect(); // Add embeddings models to the main models list models.extend(embeddings_models); Json(ModelListResponse { object: "list".to_string(), data: models, }) } #[cfg(test)] mod tests { use super::*; use crate::openai_types::{Message, MessageContent}; use either::Either; #[test] fn test_build_gemma_prompt() { let messages = vec![ Message { role: "system".to_string(), content: Some(MessageContent(Either::Left("System message".to_string()))), name: None, }, Message { role: "user".to_string(), content: Some(MessageContent(Either::Left("Knock knock.".to_string()))), name: None, }, Message { role: "assistant".to_string(), content: Some(MessageContent(Either::Left("Who's there?".to_string()))), name: None, }, Message { role: "user".to_string(), content: Some(MessageContent(Either::Left("Gemma.".to_string()))), name: None, }, ]; let prompt = build_gemma_prompt(&messages); let expected = "system\nSystem message\nuser\nKnock knock.\nmodel\nWho's there?\nuser\nGemma.\nmodel\n"; assert_eq!(prompt, expected); } #[test] fn test_empty_messages() { let messages: Vec = vec![]; let prompt = build_gemma_prompt(&messages); assert_eq!(prompt, "model\n"); } #[test] fn test_missing_content() { let messages = vec![Message { role: "user".to_string(), content: None, name: None, }]; let prompt = build_gemma_prompt(&messages); assert_eq!(prompt, "model\n"); } }