run cargo fmt

This commit is contained in:
geoffsee
2025-09-04 13:45:25 -04:00
parent 1e02b12cda
commit c1c583faab
11 changed files with 241 additions and 170 deletions

View File

@@ -44,7 +44,7 @@ jobs:
- name: Clippy - name: Clippy
shell: bash shell: bash
run: cargo clippy --all-targets run: cargo clippy --all
- name: Tests - name: Tests
shell: bash shell: bash

View File

@@ -194,57 +194,57 @@ pub fn send_chat_completion_stream(
) { ) {
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_bindgen::JsCast; use wasm_bindgen::JsCast;
let request = ChatRequest { let request = ChatRequest {
model, model,
messages, messages,
max_tokens: Some(1024), max_tokens: Some(1024),
stream: Some(true), stream: Some(true),
}; };
// We need to send a POST request but EventSource only supports GET // We need to send a POST request but EventSource only supports GET
// So we'll use fetch with a readable stream instead // So we'll use fetch with a readable stream instead
let window = web_sys::window().unwrap(); let window = web_sys::window().unwrap();
let request_json = serde_json::to_string(&request).unwrap(); let request_json = serde_json::to_string(&request).unwrap();
let opts = web_sys::RequestInit::new(); let opts = web_sys::RequestInit::new();
opts.set_method("POST"); opts.set_method("POST");
opts.set_body(&JsValue::from_str(&request_json)); opts.set_body(&JsValue::from_str(&request_json));
let headers = web_sys::Headers::new().unwrap(); let headers = web_sys::Headers::new().unwrap();
headers.set("Content-Type", "application/json").unwrap(); headers.set("Content-Type", "application/json").unwrap();
headers.set("Accept", "text/event-stream").unwrap(); headers.set("Accept", "text/event-stream").unwrap();
opts.set_headers(&headers); opts.set_headers(&headers);
let request = web_sys::Request::new_with_str_and_init("/v1/chat/completions", &opts).unwrap(); let request = web_sys::Request::new_with_str_and_init("/v1/chat/completions", &opts).unwrap();
let promise = window.fetch_with_request(&request); let promise = window.fetch_with_request(&request);
wasm_bindgen_futures::spawn_local(async move { wasm_bindgen_futures::spawn_local(async move {
match wasm_bindgen_futures::JsFuture::from(promise).await { match wasm_bindgen_futures::JsFuture::from(promise).await {
Ok(resp_value) => { Ok(resp_value) => {
let resp: web_sys::Response = resp_value.dyn_into().unwrap(); let resp: web_sys::Response = resp_value.dyn_into().unwrap();
if !resp.ok() { if !resp.ok() {
on_error(format!("Server error: {}", resp.status())); on_error(format!("Server error: {}", resp.status()));
return; return;
} }
let body = resp.body(); let body = resp.body();
if body.is_none() { if body.is_none() {
on_error("No response body".to_string()); on_error("No response body".to_string());
return; return;
} }
let reader = body let reader = body
.unwrap() .unwrap()
.get_reader() .get_reader()
.dyn_into::<web_sys::ReadableStreamDefaultReader>() .dyn_into::<web_sys::ReadableStreamDefaultReader>()
.unwrap(); .unwrap();
let decoder = web_sys::TextDecoder::new().unwrap(); let decoder = web_sys::TextDecoder::new().unwrap();
let mut buffer = String::new(); let mut buffer = String::new();
loop { loop {
match wasm_bindgen_futures::JsFuture::from(reader.read()).await { match wasm_bindgen_futures::JsFuture::from(reader.read()).await {
Ok(result) => { Ok(result) => {
@@ -252,24 +252,25 @@ pub fn send_chat_completion_stream(
.unwrap() .unwrap()
.as_bool() .as_bool()
.unwrap_or(false); .unwrap_or(false);
if done { if done {
break; break;
} }
let value = js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap(); let value =
js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
let array = js_sys::Uint8Array::new(&value); let array = js_sys::Uint8Array::new(&value);
let mut bytes = vec![0; array.length() as usize]; let mut bytes = vec![0; array.length() as usize];
array.copy_to(&mut bytes); array.copy_to(&mut bytes);
let text = decoder.decode_with_u8_array(&bytes).unwrap(); let text = decoder.decode_with_u8_array(&bytes).unwrap();
buffer.push_str(&text); buffer.push_str(&text);
// Process complete SSE events from buffer // Process complete SSE events from buffer
while let Some(event_end) = buffer.find("\n\n") { while let Some(event_end) = buffer.find("\n\n") {
let event = buffer[..event_end].to_string(); let event = buffer[..event_end].to_string();
buffer = buffer[event_end + 2..].to_string(); buffer = buffer[event_end + 2..].to_string();
// Parse SSE event // Parse SSE event
for line in event.lines() { for line in event.lines() {
if let Some(data) = line.strip_prefix("data: ") { if let Some(data) = line.strip_prefix("data: ") {
@@ -277,9 +278,11 @@ pub fn send_chat_completion_stream(
on_complete(); on_complete();
return; return;
} }
// Parse JSON chunk // Parse JSON chunk
if let Ok(chunk) = serde_json::from_str::<StreamChatResponse>(data) { if let Ok(chunk) =
serde_json::from_str::<StreamChatResponse>(data)
{
if let Some(choice) = chunk.choices.first() { if let Some(choice) = chunk.choices.first() {
if let Some(content) = &choice.delta.content { if let Some(content) = &choice.delta.content {
on_chunk(content.clone()); on_chunk(content.clone());
@@ -296,7 +299,7 @@ pub fn send_chat_completion_stream(
} }
} }
} }
on_complete(); on_complete();
} }
Err(e) => { Err(e) => {
@@ -366,11 +369,11 @@ fn ChatPage() -> impl IntoView {
// State for available models and selected model // State for available models and selected model
let available_models = RwSignal::new(Vec::<ModelInfo>::new()); let available_models = RwSignal::new(Vec::<ModelInfo>::new());
let selected_model = RwSignal::new(String::from("")); // Default model let selected_model = RwSignal::new(String::from("")); // Default model
// State for streaming response // State for streaming response
let streaming_content = RwSignal::new(String::new()); let streaming_content = RwSignal::new(String::new());
let is_streaming = RwSignal::new(false); let is_streaming = RwSignal::new(false);
// State for streaming mode toggle // State for streaming mode toggle
let use_streaming = RwSignal::new(true); // Default to streaming let use_streaming = RwSignal::new(true); // Default to streaming
@@ -424,7 +427,7 @@ fn ChatPage() -> impl IntoView {
// Clear streaming content and set streaming flag // Clear streaming content and set streaming flag
streaming_content.set(String::new()); streaming_content.set(String::new());
is_streaming.set(true); is_streaming.set(true);
// Use streaming API // Use streaming API
send_chat_completion_stream( send_chat_completion_stream(
current_messages, current_messages,

View File

@@ -1,5 +1,10 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput}; use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{Json, Router, response::Json as ResponseJson, routing::{get, post}, http::StatusCode}; use axum::{
Json, Router,
http::StatusCode,
response::Json as ResponseJson,
routing::{get, post},
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::Serialize; use serde::Serialize;
@@ -9,9 +14,8 @@ use tower_http::trace::TraceLayer;
use tracing; use tracing;
// Cache for multiple embedding models // Cache for multiple embedding models
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> = Lazy::new(|| { static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
RwLock::new(HashMap::new()) Lazy::new(|| RwLock::new(HashMap::new()));
});
#[derive(Serialize)] #[derive(Serialize)]
pub struct ModelInfo { pub struct ModelInfo {
@@ -32,11 +36,19 @@ pub struct ModelsResponse {
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> { fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
match model_name { match model_name {
// Sentence Transformers models // Sentence Transformers models
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLML6V2), "sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => Ok(EmbeddingModel::AllMiniLML6V2Q), Ok(EmbeddingModel::AllMiniLML6V2)
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => Ok(EmbeddingModel::AllMiniLML12V2), }
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => Ok(EmbeddingModel::AllMiniLML12V2Q), "sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
Ok(EmbeddingModel::AllMiniLML6V2Q)
}
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
Ok(EmbeddingModel::AllMiniLML12V2)
}
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
Ok(EmbeddingModel::AllMiniLML12V2Q)
}
// BGE models // BGE models
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15), "BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q), "BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
@@ -46,41 +58,68 @@ fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q), "BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15), "BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15), "BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
// Nomic models // Nomic models
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => Ok(EmbeddingModel::NomicEmbedTextV1), "nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => Ok(EmbeddingModel::NomicEmbedTextV15), Ok(EmbeddingModel::NomicEmbedTextV1)
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => Ok(EmbeddingModel::NomicEmbedTextV15Q), }
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
Ok(EmbeddingModel::NomicEmbedTextV15)
}
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
Ok(EmbeddingModel::NomicEmbedTextV15Q)
}
// Paraphrase models // Paraphrase models
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2), "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q" | "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q), | "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2), "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
// ModernBert // ModernBert
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => Ok(EmbeddingModel::ModernBertEmbedLarge), "lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
Ok(EmbeddingModel::ModernBertEmbedLarge)
}
// Multilingual E5 models // Multilingual E5 models
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => Ok(EmbeddingModel::MultilingualE5Small), "intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => Ok(EmbeddingModel::MultilingualE5Base), Ok(EmbeddingModel::MultilingualE5Small)
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => Ok(EmbeddingModel::MultilingualE5Large), }
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
Ok(EmbeddingModel::MultilingualE5Base)
}
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
Ok(EmbeddingModel::MultilingualE5Large)
}
// Mixedbread models // Mixedbread models
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => Ok(EmbeddingModel::MxbaiEmbedLargeV1), "mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => Ok(EmbeddingModel::MxbaiEmbedLargeV1Q), Ok(EmbeddingModel::MxbaiEmbedLargeV1)
}
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
}
// GTE models // GTE models
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15), "Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => Ok(EmbeddingModel::GTEBaseENV15Q), "Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
Ok(EmbeddingModel::GTEBaseENV15Q)
}
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15), "Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => Ok(EmbeddingModel::GTELargeENV15Q), "Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
Ok(EmbeddingModel::GTELargeENV15Q)
}
// CLIP model // CLIP model
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32), "Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
// Jina model // Jina model
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode), "jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
}
_ => Err(format!("Unsupported embedding model: {}", model_name)), _ => Err(format!("Unsupported embedding model: {}", model_name)),
} }
} }
@@ -95,7 +134,9 @@ fn get_model_dimensions(model: &EmbeddingModel) -> usize {
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384, EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
EmbeddingModel::BGESmallZHV15 => 512, EmbeddingModel::BGESmallZHV15 => 512,
EmbeddingModel::BGELargeZHV15 => 1024, EmbeddingModel::BGELargeZHV15 => 1024,
EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 | EmbeddingModel::NomicEmbedTextV15Q => 768, EmbeddingModel::NomicEmbedTextV1
| EmbeddingModel::NomicEmbedTextV15
| EmbeddingModel::NomicEmbedTextV15Q => 768,
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384, EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768, EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
EmbeddingModel::ModernBertEmbedLarge => 1024, EmbeddingModel::ModernBertEmbedLarge => 1024,
@@ -114,37 +155,41 @@ fn get_model_dimensions(model: &EmbeddingModel) -> usize {
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> { fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
// First try to get from cache (read lock) // First try to get from cache (read lock)
{ {
let cache = MODEL_CACHE.read().map_err(|e| format!("Failed to acquire read lock: {}", e))?; let cache = MODEL_CACHE
.read()
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
if let Some(model) = cache.get(&embedding_model) { if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model: {:?}", embedding_model); tracing::debug!("Using cached model: {:?}", embedding_model);
return Ok(Arc::clone(model)); return Ok(Arc::clone(model));
} }
} }
// Model not in cache, create it (write lock) // Model not in cache, create it (write lock)
let mut cache = MODEL_CACHE.write().map_err(|e| format!("Failed to acquire write lock: {}", e))?; let mut cache = MODEL_CACHE
.write()
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
// Double-check after acquiring write lock // Double-check after acquiring write lock
if let Some(model) = cache.get(&embedding_model) { if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model (double-check): {:?}", embedding_model); tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
return Ok(Arc::clone(model)); return Ok(Arc::clone(model));
} }
tracing::info!("Initializing new embedding model: {:?}", embedding_model); tracing::info!("Initializing new embedding model: {:?}", embedding_model);
let model_start_time = std::time::Instant::now(); let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new( let model = TextEmbedding::try_new(
InitOptions::new(embedding_model.clone()).with_show_download_progress(true), InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
) )
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?; .map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
let model_init_time = model_start_time.elapsed(); let model_init_time = model_start_time.elapsed();
tracing::info!( tracing::info!(
"Embedding model {:?} initialized in {:.2?}", "Embedding model {:?} initialized in {:.2?}",
embedding_model, embedding_model,
model_init_time model_init_time
); );
let model_arc = Arc::new(model); let model_arc = Arc::new(model);
cache.insert(embedding_model.clone(), Arc::clone(&model_arc)); cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
Ok(model_arc) Ok(model_arc)
@@ -158,7 +203,7 @@ pub async fn embeddings_create(
// Phase 1: Parse and get the embedding model // Phase 1: Parse and get the embedding model
let model_start_time = std::time::Instant::now(); let model_start_time = std::time::Instant::now();
let embedding_model = match parse_embedding_model(&payload.model) { let embedding_model = match parse_embedding_model(&payload.model) {
Ok(model) => model, Ok(model) => model,
Err(e) => { Err(e) => {
@@ -166,15 +211,18 @@ pub async fn embeddings_create(
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e))); return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
} }
}; };
let model = match get_or_create_model(embedding_model.clone()) { let model = match get_or_create_model(embedding_model.clone()) {
Ok(model) => model, Ok(model) => model,
Err(e) => { Err(e) => {
tracing::error!("Failed to get/create model: {}", e); tracing::error!("Failed to get/create model: {}", e);
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("Model initialization failed: {}", e))); return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Model initialization failed: {}", e),
));
} }
}; };
let model_access_time = model_start_time.elapsed(); let model_access_time = model_start_time.elapsed();
tracing::debug!( tracing::debug!(
"Model access/creation completed in {:.2?}", "Model access/creation completed in {:.2?}",
@@ -205,12 +253,13 @@ pub async fn embeddings_create(
// Phase 3: Generate embeddings // Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now(); let embedding_start_time = std::time::Instant::now();
let embeddings = model let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
.embed(texts_from_embedding_input, None) tracing::error!("Failed to generate embeddings: {}", e);
.map_err(|e| { (
tracing::error!("Failed to generate embeddings: {}", e); StatusCode::INTERNAL_SERVER_ERROR,
(StatusCode::INTERNAL_SERVER_ERROR, format!("Embedding generation failed: {}", e)) format!("Embedding generation failed: {}", e),
})?; )
})?;
let embedding_generation_time = embedding_start_time.elapsed(); let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!( tracing::info!(
@@ -287,7 +336,7 @@ pub async fn embeddings_create(
// Use the actual model dimensions instead of hardcoded 768 // Use the actual model dimensions instead of hardcoded 768
let actual_dimensions = padded_embedding.len(); let actual_dimensions = padded_embedding.len();
let expected_dimensions = get_model_dimensions(&embedding_model); let expected_dimensions = get_model_dimensions(&embedding_model);
if actual_dimensions != expected_dimensions { if actual_dimensions != expected_dimensions {
tracing::warn!( tracing::warn!(
"Model {:?} produced {} dimensions but expected {}", "Model {:?} produced {} dimensions but expected {}",
@@ -455,7 +504,8 @@ pub async fn models_list() -> ResponseJson<ModelsResponse> {
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(), id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
object: "model".to_string(), object: "model".to_string(),
owned_by: "nomic-ai".to_string(), owned_by: "nomic-ai".to_string(),
description: "Quantized v1.5 release of the 8192 context length english model".to_string(), description: "Quantized v1.5 release of the 8192 context length english model"
.to_string(),
dimensions: 768, dimensions: 768,
}, },
ModelInfo { ModelInfo {
@@ -476,7 +526,8 @@ pub async fn models_list() -> ResponseJson<ModelsResponse> {
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(), id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
object: "model".to_string(), object: "model".to_string(),
owned_by: "sentence-transformers".to_string(), owned_by: "sentence-transformers".to_string(),
description: "Sentence-transformers model for tasks like clustering or semantic search".to_string(), description: "Sentence-transformers model for tasks like clustering or semantic search"
.to_string(),
dimensions: 768, dimensions: 768,
}, },
ModelInfo { ModelInfo {

View File

@@ -18,12 +18,10 @@ async fn embeddings_create(
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> { ) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
match embeddings_engine::embeddings_create(Json(payload)).await { match embeddings_engine::embeddings_create(Json(payload)).await {
Ok(response) => Ok(response), Ok(response) => Ok(response),
Err((status_code, message)) => { Err((status_code, message)) => Err(axum::response::Response::builder()
Err(axum::response::Response::builder() .status(status_code)
.status(status_code) .body(axum::body::Body::from(message))
.body(axum::body::Body::from(message)) .unwrap()),
.unwrap())
}
} }
} }

View File

@@ -42,7 +42,11 @@ pub struct ModelMeta {
} }
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta { const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
ModelMeta { id, family, instruct } ModelMeta {
id,
family,
instruct,
}
} }
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]

View File

@@ -42,13 +42,13 @@ pub struct AppState {
pub llama_config: Option<LlamaInferenceConfig>, pub llama_config: Option<LlamaInferenceConfig>,
} }
impl Default for AppState { impl Default for AppState {
fn default() -> Self { fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui // Configure a default model to prevent 503 errors from the chat-ui
// This can be overridden by environment variables if needed // This can be overridden by environment variables if needed
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string()); let default_model_id =
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig { let gemma_config = GemmaInferenceConfig {
model: None, model: None,
..Default::default() ..Default::default()
@@ -94,9 +94,6 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
} }
} }
fn normalize_model_id(model_id: &str) -> String { fn normalize_model_id(model_id: &str) -> String {
model_id.to_lowercase().replace("_", "-") model_id.to_lowercase().replace("_", "-")
} }
@@ -157,7 +154,7 @@ pub async fn chat_completions_non_streaming_proxy(
// Use the model specified in the request // Use the model specified in the request
let model_id = request.model.clone(); let model_id = request.model.clone();
let which_model = model_id_to_which(&model_id); let which_model = model_id_to_which(&model_id);
// Validate that the requested model is supported // Validate that the requested model is supported
let which_model = match which_model { let which_model = match which_model {
Some(model) => model, Some(model) => model,
@@ -204,19 +201,21 @@ pub async fn chat_completions_non_streaming_proxy(
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) } "error": { "message": format!("Model {} is not a Llama model", model_id) }
})) })),
)); ));
} }
}; };
let mut config = LlamaInferenceConfig::new(llama_model); let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone(); config.prompt = prompt.clone();
config.max_tokens = max_tokens; config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| ( run_llama_inference(config).map_err(|e| {
StatusCode::INTERNAL_SERVER_ERROR, (
Json(serde_json::json!({ StatusCode::INTERNAL_SERVER_ERROR,
"error": { "message": format!("Error initializing Llama model: {}", e) } Json(serde_json::json!({
})) "error": { "message": format!("Error initializing Llama model: {}", e) }
))? })),
)
})?
} else { } else {
// Create Gemma configuration dynamically // Create Gemma configuration dynamically
let gemma_model = match which_model { let gemma_model = match which_model {
@@ -241,23 +240,25 @@ pub async fn chat_completions_non_streaming_proxy(
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) } "error": { "message": format!("Model {} is not a Gemma model", model_id) }
})) })),
)); ));
} }
}; };
let mut config = GemmaInferenceConfig { let mut config = GemmaInferenceConfig {
model: Some(gemma_model), model: Some(gemma_model),
..Default::default() ..Default::default()
}; };
config.prompt = prompt.clone(); config.prompt = prompt.clone();
config.max_tokens = max_tokens; config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| ( run_gemma_api(config).map_err(|e| {
StatusCode::INTERNAL_SERVER_ERROR, (
Json(serde_json::json!({ StatusCode::INTERNAL_SERVER_ERROR,
"error": { "message": format!("Error initializing Gemma model: {}", e) } Json(serde_json::json!({
})) "error": { "message": format!("Error initializing Gemma model: {}", e) }
))? })),
)
})?
}; };
// Collect all tokens from the stream // Collect all tokens from the stream
@@ -320,7 +321,7 @@ async fn handle_streaming_request(
// Use the model specified in the request // Use the model specified in the request
let model_id = request.model.clone(); let model_id = request.model.clone();
let which_model = model_id_to_which(&model_id); let which_model = model_id_to_which(&model_id);
// Validate that the requested model is supported // Validate that the requested model is supported
let which_model = match which_model { let which_model = match which_model {
Some(model) => model, Some(model) => model,
@@ -397,7 +398,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) } "error": { "message": format!("Model {} is not a Llama model", model_id) }
})) })),
)); ));
} }
}; };
@@ -439,11 +440,11 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) } "error": { "message": format!("Model {} is not a Gemma model", model_id) }
})) })),
)); ));
} }
}; };
let mut config = GemmaInferenceConfig { let mut config = GemmaInferenceConfig {
model: Some(gemma_model), model: Some(gemma_model),
..Default::default() ..Default::default()
@@ -605,59 +606,66 @@ pub async fn list_models() -> Json<ModelListResponse> {
Which::Llama32_3BInstruct, Which::Llama32_3BInstruct,
]; ];
let mut models: Vec<Model> = which_variants
.into_iter()
.map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
Which::Base7B => "gemma-7b",
Which::Instruct2B => "gemma-2b-it",
Which::Instruct7B => "gemma-7b-it",
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeBase2B => "codegemma-2b",
Which::CodeBase7B => "codegemma-7b",
Which::CodeInstruct2B => "codegemma-2b-it",
Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV2_9B => "gemma-2-9b-it",
Which::BaseV3_1B => "gemma-3-1b",
Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let owned_by = if meta.id.starts_with("google/") {
"google"
} else if meta.id.starts_with("meta-llama/") {
"meta"
} else {
"unknown"
};
let mut models: Vec<Model> = which_variants.into_iter().map(|which| { Model {
let meta = which.meta(); id: model_id.to_string(),
let model_id = match which { object: "model".to_string(),
Which::Base2B => "gemma-2b", created: 1686935002,
Which::Base7B => "gemma-7b", owned_by: owned_by.to_string(),
Which::Instruct2B => "gemma-2b-it", }
Which::Instruct7B => "gemma-7b-it", })
Which::InstructV1_1_2B => "gemma-1.1-2b-it", .collect();
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeBase2B => "codegemma-2b",
Which::CodeBase7B => "codegemma-7b",
Which::CodeInstruct2B => "codegemma-2b-it",
Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV2_9B => "gemma-2-9b-it",
Which::BaseV3_1B => "gemma-3-1b",
Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let owned_by = if meta.id.starts_with("google/") {
"google"
} else if meta.id.starts_with("meta-llama/") {
"meta"
} else {
"unknown"
};
Model {
id: model_id.to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: owned_by.to_string(),
}
}).collect();
// Get embeddings models and convert them to inference Model format // Get embeddings models and convert them to inference Model format
let embeddings_response = models_list().await; let embeddings_response = models_list().await;
let embeddings_models: Vec<Model> = embeddings_response.0.data.into_iter().map(|embedding_model| { let embeddings_models: Vec<Model> = embeddings_response
Model { .0
.data
.into_iter()
.map(|embedding_model| Model {
id: embedding_model.id, id: embedding_model.id,
object: embedding_model.object, object: embedding_model.object,
created: 1686935002, created: 1686935002,
owned_by: format!("{} - {}", embedding_model.owned_by, embedding_model.description), owned_by: format!(
} "{} - {}",
}).collect(); embedding_model.owned_by, embedding_model.description
),
})
.collect();
// Add embeddings models to the main models list // Add embeddings models to the main models list
models.extend(embeddings_models); models.extend(embeddings_models);

View File

@@ -1,4 +1,3 @@
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
@@ -11,13 +10,13 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write; use std::io::Write;
use std::fmt;
use std::str::FromStr;
use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::mpsc::{self, Receiver, Sender};
use std::thread; use std::thread;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use utils::hub_load_safetensors; use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream; use utils::token_output_stream::TokenOutputStream;
use std::str::FromStr;
use std::fmt;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum WhichModel { pub enum WhichModel {
@@ -367,7 +366,9 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let tokenizer_filename = repo.get("tokenizer.json")?; let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?; let config_filename = repo.get("config.json")?;
let filenames = match cfg.model { let filenames = match cfg.model {
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?], Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
vec![repo.get("model.safetensors")?]
}
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?, _ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
}; };
println!("Retrieved files in {:?}", start.elapsed()); println!("Retrieved files in {:?}", start.elapsed());
@@ -396,7 +397,8 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
| Some(WhichModel::InstructV2_2B) | Some(WhichModel::InstructV2_2B)
| Some(WhichModel::BaseV2_9B) | Some(WhichModel::BaseV2_9B)
| Some(WhichModel::InstructV2_9B) | Some(WhichModel::InstructV2_9B)
| None => { // default to V2 model | None => {
// default to V2 model
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?; let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model) Model::V2(model)

View File

@@ -105,7 +105,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
.into_iter() .into_iter()
.filter_map(|e| e.ok()) .filter_map(|e| e.ok())
{ {
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("../../../Cargo.toml") { if entry.file_name() == "Cargo.toml"
&& entry.path() != workspace_root.join("../../../Cargo.toml")
{
if let Ok(service_info) = parse_cargo_toml(entry.path()) { if let Ok(service_info) = parse_cargo_toml(entry.path()) {
services.push(service_info); services.push(service_info);
} }

View File

@@ -102,7 +102,7 @@ impl Default for LlamaInferenceConfig {
max_tokens: 512, max_tokens: 512,
// Performance flags // Performance flags
no_kv_cache: false, // keep cache ON for speed no_kv_cache: false, // keep cache ON for speed
use_flash_attn: false, // great speed boost if supported use_flash_attn: false, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed. // Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.

View File

@@ -1,5 +1,5 @@
use candle_transformers::models::mimi::candle;
use candle_core::{Device, Result, Tensor}; use candle_core::{Device, Result, Tensor};
use candle_transformers::models::mimi::candle;
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406]; pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225]; pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];

View File

@@ -8,8 +8,10 @@ pub mod coco_classes;
pub mod imagenet; pub mod imagenet;
pub mod token_output_stream; pub mod token_output_stream;
pub mod wav; pub mod wav;
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}}; use candle_core::{
utils::{cuda_is_available, metal_is_available},
Device, Tensor,
};
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> { pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
if cpu { if cpu {
@@ -126,7 +128,7 @@ pub fn hub_load_safetensors(
repo.get(v) repo.get(v)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
}) })
.collect::<Result<Vec<_>, std::io::Error, >>()?; .collect::<Result<Vec<_>, std::io::Error>>()?;
Ok(safetensors_files) Ok(safetensors_files)
} }
@@ -136,7 +138,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> { ) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let path = path.as_ref(); let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?; let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?; let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") { let weight_map = match json.get("weight_map") {
None => anyhow::bail!("no weight map in {json_file:?}"), None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map, Some(serde_json::Value::Object(map)) => map,