From 766d41af78cbca932c91e172521da6d65b9bd4bf Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Wed, 27 Aug 2025 17:53:50 -0400 Subject: [PATCH] - Refactored `build_pipeline` usage to ensure pipeline arguments are cloned. - Introduced `reset_state` for clearing cached state between requests. - Enhanced chat UI with model selector and dynamic model fetching. - Improved error logging and detailed debug messages for chat request flows. - Added fresh instantiation of `TextGeneration` to prevent tensor shape mismatches. --- crates/inference-engine/src/server.rs | 19 +- .../inference-engine/src/text_generation.rs | 10 + crates/leptos-chat/Cargo.toml | 1 + crates/leptos-chat/src/lib.rs | 361 ++++++++---------- crates/predict-otron-9000/src/main.rs | 3 +- 5 files changed, 185 insertions(+), 209 deletions(-) diff --git a/crates/inference-engine/src/server.rs b/crates/inference-engine/src/server.rs index 7ee76ae..6e26f59 100644 --- a/crates/inference-engine/src/server.rs +++ b/crates/inference-engine/src/server.rs @@ -36,15 +36,18 @@ use serde_json::Value; pub struct AppState { pub text_generation: Arc>, pub model_id: String, + // Store build args to recreate TextGeneration when needed + pub build_args: PipelineArgs, } impl Default for AppState { fn default() -> Self { let args = PipelineArgs::default(); - let text_generation = build_pipeline(args); + let text_generation = build_pipeline(args.clone()); Self { text_generation: Arc::new(Mutex::new(text_generation)), model_id: String::new(), + build_args: args, } } } @@ -318,7 +321,7 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration { pub async fn chat_completions( State(state): State, Json(request): Json, -) -> Result)> { +) -> Result { // If streaming was requested, this function shouldn't be called // A separate route handles streaming requests if !request.stream.unwrap_or(false) { @@ -357,7 +360,11 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat // Generate let mut output = Vec::new(); { + // Recreate TextGeneration instance to ensure completely fresh state + // This prevents KV cache persistence that causes tensor shape mismatches + let fresh_text_gen = build_pipeline(state.build_args.clone()); let mut text_gen = state.text_generation.lock().await; + *text_gen = fresh_text_gen; let mut buffer = Vec::new(); let max_tokens = request.max_tokens.unwrap_or(1000); @@ -456,7 +463,12 @@ async fn handle_streaming_request( // Generate text using existing buffer-based approach let mut buffer = Vec::new(); { + // Recreate TextGeneration instance to ensure completely fresh state + // This prevents KV cache persistence that causes tensor shape mismatches + let fresh_text_gen = build_pipeline(state.build_args.clone()); let mut text_gen = state.text_generation.lock().await; + *text_gen = fresh_text_gen; + let max_tokens = request.max_tokens.unwrap_or(1000); if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) { @@ -752,10 +764,11 @@ mod tests { println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id); // This should reproduce the same conditions as the curl script - let text_generation = build_pipeline(args); + let text_generation = build_pipeline(args.clone()); let app_state = AppState { text_generation: Arc::new(Mutex::new(text_generation)), model_id: "gemma-3-1b-it".to_string(), + build_args: args, }; // Create the same request as the curl script diff --git a/crates/inference-engine/src/text_generation.rs b/crates/inference-engine/src/text_generation.rs index fd7ef40..94984c4 100644 --- a/crates/inference-engine/src/text_generation.rs +++ b/crates/inference-engine/src/text_generation.rs @@ -117,6 +117,16 @@ impl TextGeneration { } } + // Reset method to clear state between requests + pub fn reset_state(&mut self) { + // Reset the primary device flag so we try the primary device first for each new request + if !self.device.is_cpu() { + self.try_primary_device = true; + } + // Clear the penalty cache to avoid stale cached values from previous requests + self.penalty_cache.clear(); + } + // Helper method to apply repeat penalty with caching for optimization pub fn apply_cached_repeat_penalty( &mut self, diff --git a/crates/leptos-chat/Cargo.toml b/crates/leptos-chat/Cargo.toml index 130b541..ebc1b86 100644 --- a/crates/leptos-chat/Cargo.toml +++ b/crates/leptos-chat/Cargo.toml @@ -34,6 +34,7 @@ web-sys = { version = "0.3", features = [ "Element", "HtmlElement", "HtmlInputElement", + "HtmlSelectElement", "HtmlTextAreaElement", "Event", "EventTarget", diff --git a/crates/leptos-chat/src/lib.rs b/crates/leptos-chat/src/lib.rs index ffc2783..dfe5136 100644 --- a/crates/leptos-chat/src/lib.rs +++ b/crates/leptos-chat/src/lib.rs @@ -10,12 +10,12 @@ use futures_util::StreamExt; use async_openai_wasm::{ types::{ ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs, - ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, + ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Model as OpenAIModel, }, Client, }; use async_openai_wasm::config::OpenAIConfig; -use async_openai_wasm::types::ChatCompletionResponseStream; +use async_openai_wasm::types::{ChatCompletionResponseStream, Model}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { @@ -89,6 +89,43 @@ pub fn App() -> impl IntoView { } } +async fn fetch_available_models() -> Result, String> { + log::info!("[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"); + + let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string()); + let client = Client::with_config(config); + + match client.models().list().await { + Ok(response) => { + let model_count = response.data.len(); + log::info!("[DEBUG_LOG] fetch_available_models: Successfully fetched {} models", model_count); + + if model_count > 0 { + let model_names: Vec = response.data.iter().map(|m| m.id.clone()).collect(); + log::debug!("[DEBUG_LOG] fetch_available_models: Available models: {:?}", model_names); + } else { + log::warn!("[DEBUG_LOG] fetch_available_models: No models returned by server"); + } + + Ok(response.data) + }, + Err(e) => { + log::error!("[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}", e); + + let error_details = format!("{:?}", e); + if error_details.contains("400") || error_details.contains("Bad Request") { + log::error!("[DEBUG_LOG] fetch_available_models: HTTP 400 - Server rejected models request"); + } else if error_details.contains("404") || error_details.contains("Not Found") { + log::error!("[DEBUG_LOG] fetch_available_models: HTTP 404 - Models endpoint not found"); + } else if error_details.contains("Connection") || error_details.contains("connection") { + log::error!("[DEBUG_LOG] fetch_available_models: Connection error - server may be down"); + } + + Err(format!("Failed to fetch models: {}", e)) + } + } +} + async fn send_chat_request(chat_request: ChatRequest) -> ChatCompletionResponseStream { let config = OpenAIConfig::new().with_api_base("http://localhost:8080".to_string()); let client = Client::with_config(config); @@ -168,19 +205,47 @@ async fn send_chat_request(chat_request: ChatRequest) -> ChatCompletionResponseS // Err("leptos-chat chat request only supported on wasm32 target".to_string()) // } +const DEFAULT_MODEL: &str = "gemma-2b-it"; + #[component] fn ChatInterface() -> impl IntoView { let (messages, set_messages) = create_signal::>(VecDeque::new()); let (input_value, set_input_value) = create_signal(String::new()); let (is_loading, set_is_loading) = create_signal(false); + let (available_models, set_available_models) = create_signal::>(Vec::new()); + let (selected_model, set_selected_model) = create_signal(DEFAULT_MODEL.to_string()); + let (models_loading, set_models_loading) = create_signal(false); + + // Fetch models on component initialization + create_effect(move |_| { + spawn_local(async move { + set_models_loading.set(true); + match fetch_available_models().await { + Ok(models) => { + set_available_models.set(models); + set_models_loading.set(false); + } + Err(e) => { + log::error!("Failed to fetch models: {}", e); + // Set a default model if fetching fails + set_available_models.set(vec![]); + set_models_loading.set(false); + } + } + }); + }); let send_message = create_action(move |content: &String| { let content = content.clone(); async move { if content.trim().is_empty() { + log::debug!("[DEBUG_LOG] send_message: Empty content, skipping"); return; } + log::info!("[DEBUG_LOG] send_message: Starting message send process"); + log::debug!("[DEBUG_LOG] send_message: User message content length: {}", content.len()); + set_is_loading.set(true); // Add user message to chat @@ -204,7 +269,8 @@ fn ChatInterface() -> impl IntoView { chat_messages.push(system_message.into()); // Add history messages - messages.with(|msgs| { + let history_count = messages.with_untracked(|msgs| { + let count = msgs.len(); for msg in msgs.iter() { let message = ChatCompletionRequestUserMessageArgs::default() .content(msg.content.clone()) @@ -212,6 +278,7 @@ fn ChatInterface() -> impl IntoView { .expect("failed to build message"); chat_messages.push(message.into()); } + count }); // Add current user message @@ -221,20 +288,37 @@ fn ChatInterface() -> impl IntoView { .expect("failed to build user message"); chat_messages.push(message.into()); + let current_model = selected_model.get_untracked(); + let total_messages = chat_messages.len(); + + log::info!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}", + current_model, history_count, total_messages); + let request = CreateChatCompletionRequestArgs::default() - .model("gemma-2b-it") + .model(current_model.as_str()) .max_tokens(512u32) .messages(chat_messages) .stream(true) // ensure server streams .build() .expect("failed to build request"); + // Log request details for debugging server issues + log::info!("[DEBUG_LOG] send_message: Request configuration - model: '{}', max_tokens: 512, stream: true, messages_count: {}", + current_model, total_messages); + log::debug!("[DEBUG_LOG] send_message: Request details - history_messages: {}, system_messages: 1, user_messages: {}", + history_count, total_messages - history_count - 1); + // Send request let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string()); let client = Client::with_config(config); + + log::info!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model); + match client.chat().create_stream(request).await { Ok(mut stream) => { + log::info!("[DEBUG_LOG] send_message: Successfully created stream, starting to receive response"); + // Insert a placeholder assistant message to append into let assistant_id = Uuid::new_v4().to_string(); set_messages.update(|msgs| { @@ -246,10 +330,12 @@ fn ChatInterface() -> impl IntoView { }); }); + let mut chunks_received = 0; // Stream loop: append deltas to the last message while let Some(next) = stream.next().await { match next { Ok(chunk) => { + chunks_received += 1; // Try to pull out the content delta in a tolerant way. // async-openai 0.28.x stream chunk usually looks like: // choices[0].delta.content: Option @@ -281,12 +367,13 @@ fn ChatInterface() -> impl IntoView { } } Err(e) => { - log::error!("Stream error: {:?}", e); + log::error!("[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}", chunks_received, e); + log::error!("[DEBUG_LOG] send_message: Stream error details - model: '{}', chunks_received: {}", current_model, chunks_received); set_messages.update(|msgs| { msgs.push_back(Message { id: Uuid::new_v4().to_string(), role: "system".to_string(), - content: format!("Stream error: {e}"), + content: format!("Stream error after {} chunks: {}", chunks_received, e), timestamp: Date::now(), }); }); @@ -294,13 +381,39 @@ fn ChatInterface() -> impl IntoView { } } } + log::info!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received); } Err(e) => { - log::error!("Failed to send request: {:?}", e); + // Detailed error logging for different types of errors + log::error!("[DEBUG_LOG] send_message: Request failed with error: {:?}", e); + log::error!("[DEBUG_LOG] send_message: Request context - model: '{}', total_messages: {}, endpoint: http://localhost:8080/v1", + current_model, total_messages); + + // Try to extract more specific error information + let error_details = format!("{:?}", e); + let user_message = if error_details.contains("400") || error_details.contains("Bad Request") { + log::error!("[DEBUG_LOG] send_message: HTTP 400 Bad Request detected - possible issues:"); + log::error!("[DEBUG_LOG] send_message: - Invalid model name: '{}'", current_model); + log::error!("[DEBUG_LOG] send_message: - Invalid message format or content"); + log::error!("[DEBUG_LOG] send_message: - Server configuration issue"); + format!("Error: HTTP 400 Bad Request - Check model '{}' and message format. See console for details.", current_model) + } else if error_details.contains("404") || error_details.contains("Not Found") { + log::error!("[DEBUG_LOG] send_message: HTTP 404 Not Found - server endpoint may be incorrect"); + "Error: HTTP 404 Not Found - Server endpoint not found".to_string() + } else if error_details.contains("500") || error_details.contains("Internal Server Error") { + log::error!("[DEBUG_LOG] send_message: HTTP 500 Internal Server Error - server-side issue"); + "Error: HTTP 500 Internal Server Error - Server problem".to_string() + } else if error_details.contains("Connection") || error_details.contains("connection") { + log::error!("[DEBUG_LOG] send_message: Connection error - server may be down"); + "Error: Cannot connect to server at http://localhost:8080".to_string() + } else { + format!("Error: Request failed - {}", e) + }; + let error_message = Message { id: Uuid::new_v4().to_string(), role: "system".to_string(), - content: "Error: Failed to connect to server".to_string(), + content: user_message, timestamp: Date::now(), }; set_messages.update(|msgs| msgs.push_back(error_message)); @@ -330,6 +443,11 @@ fn ChatInterface() -> impl IntoView { } }; + let on_model_change = move |ev| { + let select = event_target::(&ev); + set_selected_model.set(select.value()); + }; + let messages_list = move || { messages.get() .into_iter() @@ -364,6 +482,36 @@ fn ChatInterface() -> impl IntoView { view! {

"Chat Interface"

+
+ + +
{messages_list} {loading_indicator} @@ -390,203 +538,6 @@ fn ChatInterface() -> impl IntoView { } } -// -// #[component] -// fn ChatInterface() -> impl IntoView { -// let (messages, set_messages) = create_signal::>(VecDeque::new()); -// let (input_value, set_input_value) = create_signal(String::new()); -// let (is_loading, set_is_loading) = create_signal(false); -// -// let send_message = create_action(move |content: &String| { -// let content = content.clone(); -// async move { -// if content.trim().is_empty() { -// return; -// } -// -// set_is_loading.set(true); -// -// // Add user message to chat -// let user_message = Message { -// id: Uuid::new_v4().to_string(), -// role: "user".to_string(), -// content: content.clone(), -// timestamp: Date::now(), -// }; -// -// set_messages.update(|msgs| msgs.push_back(user_message.clone())); -// set_input_value.set(String::new()); -// -// let mut chat_messages = Vec::new(); -// -// // Add system message -// let system_message = ChatCompletionRequestSystemMessageArgs::default() -// .content("You are a helpful assistant.") -// .build() -// .expect("failed to build system message"); -// chat_messages.push(system_message.into()); -// -// // Add history messages -// messages.with(|msgs| { -// for msg in msgs.iter() { -// let message = ChatCompletionRequestUserMessageArgs::default() -// .content(msg.content.clone().into()) -// .build() -// .expect("failed to build message"); -// chat_messages.push(message.into()); -// } -// }); -// -// // Add current user message -// let message = ChatCompletionRequestUserMessageArgs::default() -// .content(user_message.content.clone().into()) -// .build() -// .expect("failed to build user message"); -// chat_messages.push(message.into()); -// -// let request = CreateChatCompletionRequestArgs::default() -// .model("gemma-2b-it") -// .max_tokens(512u32) -// .messages(chat_messages) -// .build() -// .expect("failed to build request"); -// -// // Send request -// let config = OpenAIConfig::new().with_api_base("http://localhost:8080".to_string()); -// let client = Client::with_config(config); -// -// match client -// .chat() -// .create_stream(request) -// .await -// { -// Ok(chat_response) => { -// -// -// // if let Some(choice) = chat_response { -// // // Extract content from the message -// // let content_text = match &choice.message.content { -// // Some(message_content) => { -// // match &message_content.0 { -// // either::Either::Left(text) => text.clone(), -// // either::Either::Right(_) => "Complex content not supported".to_string(), -// // } -// // } -// // None => "No content provided".to_string(), -// // }; -// // -// // let assistant_message = Message { -// // id: Uuid::new_v4().to_string(), -// // role: "assistant".to_string(), -// // content: content_text, -// // timestamp: Date::now(), -// // }; -// // set_messages.update(|msgs| msgs.push_back(assistant_message)); -// // -// // -// // -// // // Log token usage information -// // log::debug!("Token usage - Prompt: {}, Completion: {}, Total: {}", -// // chat_response.usage.prompt_tokens, -// // chat_response.usage.completion_tokens, -// // chat_response.usage.total_tokens); -// // } -// } -// Err(e) => { -// log::error!("Failed to send request: {:?}", e); -// let error_message = Message { -// id: Uuid::new_v4().to_string(), -// role: "system".to_string(), -// content: "Error: Failed to connect to server".to_string(), -// timestamp: Date::now(), -// }; -// set_messages.update(|msgs| msgs.push_back(error_message)); -// } -// } -// -// set_is_loading.set(false); -// } -// }); -// -// let on_input = move |ev| { -// let input = event_target::(&ev); -// set_input_value.set(input.value()); -// }; -// -// let on_submit = move |ev: SubmitEvent| { -// ev.prevent_default(); -// let content = input_value.get(); -// send_message.dispatch(content); -// }; -// -// let on_keypress = move |ev: KeyboardEvent| { -// if ev.key() == "Enter" && !ev.shift_key() { -// ev.prevent_default(); -// let content = input_value.get(); -// send_message.dispatch(content); -// } -// }; -// -// let messages_list = move || { -// messages.get() -// .into_iter() -// .map(|message| { -// let role_class = match message.role.as_str() { -// "user" => "user-message", -// "assistant" => "assistant-message", -// _ => "system-message", -// }; -// -// view! { -//
-//
{message.role}
-//
{message.content}
-//
-// } -// }) -// .collect_view() -// }; -// -// let loading_indicator = move || { -// is_loading.get().then(|| { -// view! { -//
-//
"assistant"
-//
"Thinking..."
-//
-// } -// }) -// }; -// -// view! { -//
-//

"Chat Interface"

-//
-// {messages_list} -// {loading_indicator} -//
-//
-// -// -//
-//
-// } -// } - #[wasm_bindgen::prelude::wasm_bindgen(start)] pub fn main() { // Set up error handling and logging for WebAssembly diff --git a/crates/predict-otron-9000/src/main.rs b/crates/predict-otron-9000/src/main.rs index 7a11dbd..5868298 100644 --- a/crates/predict-otron-9000/src/main.rs +++ b/crates/predict-otron-9000/src/main.rs @@ -53,10 +53,11 @@ async fn main() { pipeline_args.model_id = "google/gemma-3-1b-it".to_string(); pipeline_args.which = Which::InstructV3_1B; - let text_generation = build_pipeline(pipeline_args); + let text_generation = build_pipeline(pipeline_args.clone()); let app_state = AppState { text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)), model_id: "google/gemma-3-1b-it".to_string(), + build_args: pipeline_args, }; // Get the inference router directly from the inference engine