use anyhow::{Error as E, Result}; use candle_core::{DType, Device, Tensor}; use candle_transformers::generation::LogitsProcessor; use tokenizers::Tokenizer; use std::collections::HashMap; use crate::model::Model; use crate::token_output_stream::TokenOutputStream; pub struct TextGeneration { model: Model, device: Device, // CPU device for fallback when operations are unsupported on primary device cpu_device: Option, // Flag to indicate if we should try to use the primary device first try_primary_device: bool, tokenizer: TokenOutputStream, logits_processor: LogitsProcessor, repeat_penalty: f32, repeat_last_n: usize, // Cache for repeat penalty computation to avoid redundant calculations penalty_cache: HashMap, // Context window size for sliding window context (default: 64 tokens) context_window_size: usize, } impl TextGeneration { #[allow(clippy::too_many_arguments)] pub fn new( model: Model, tokenizer: Tokenizer, seed: u64, temp: Option, top_p: Option, repeat_penalty: f32, repeat_last_n: usize, device: &Device, ) -> Self { let logits_processor = LogitsProcessor::new(seed, temp, top_p); // Initialize CPU device only if the primary device is not already CPU let (cpu_device, try_primary_device) = if device.is_cpu() { // If already on CPU, no need for a fallback device (None, false) } else { // Store CPU device for fallback and set flag to try primary device first (Some(Device::Cpu), true) }; Self { model, tokenizer: TokenOutputStream::new(tokenizer), logits_processor, repeat_penalty, repeat_last_n, device: device.clone(), cpu_device, try_primary_device, penalty_cache: HashMap::new(), context_window_size: 64, // Default sliding window size for better context preservation } } // Helper method for model execution with fallback to CPU for unsupported operations fn execute_with_fallback(&mut self, input: &Tensor, start_pos: usize) -> Result { // If we're not trying primary device anymore, go straight to CPU if available if !self.try_primary_device { if let Some(cpu_device) = &self.cpu_device { let cpu_input = input.to_device(cpu_device).map_err(E::msg)?; let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?; return cpu_result.to_device(&self.device).map_err(E::msg); } else { // No CPU fallback, use primary device return self.model.forward(input, start_pos).map_err(E::msg); } } // Try running on the primary device first match self.model.forward(input, start_pos) { Ok(result) => Ok(result), Err(err) => { // Convert to string to check for unsupported operation let err_string = err.to_string(); // Check if the error is about unsupported operations or shape mismatches if (err_string.contains("no metal implementation for") || err_string.contains("no cuda implementation for") || err_string.contains("shape mismatch") || err_string.contains("broadcast_add")) && self.cpu_device.is_some() { // Extract operation name for better logging let op_name = if let Some(idx) = err_string.find("for ") { &err_string[(idx + 4)..] } else if err_string.contains("shape mismatch") { "shape mismatch operation" } else { "an operation" }; // Log the fallback tracing::warn!("The primary device does not support {}. Falling back to CPU.", op_name); // Move input to CPU and try again let cpu_device = self.cpu_device.as_ref().unwrap(); let cpu_input = input.to_device(cpu_device).map_err(E::msg)?; let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?; // Don't try primary device for future operations self.try_primary_device = false; tracing::info!("Successfully executed on CPU. Will use CPU for subsequent operations."); // Move result back to original device cpu_result.to_device(&self.device).map_err(E::msg) } else { // Not an unsupported operation error or no CPU fallback Err(E::msg(err)) } } } } // Reset method to clear state between requests pub fn reset_state(&mut self) { // Reset the primary device flag so we try the primary device first for each new request if !self.device.is_cpu() { self.try_primary_device = true; } // Clear the penalty cache to avoid stale cached values from previous requests self.penalty_cache.clear(); } // Helper method to apply repeat penalty with caching for optimization pub fn apply_cached_repeat_penalty( &mut self, logits: Tensor, tokens: &[u32], ) -> Result<(Tensor, std::time::Duration)> { let repeat_start = std::time::Instant::now(); // If no penalty, return the original logits if self.repeat_penalty == 1.0 { return Ok((logits, repeat_start.elapsed())); } // Get the tokens to penalize (the last n tokens) let start_at = tokens.len().saturating_sub(self.repeat_last_n); let penalty_tokens = &tokens[start_at..]; // Extract logits to a vector for modification let mut logits_vec = logits.to_vec1::()?; let cache_hits = std::cell::Cell::new(0); // Apply penalties with caching for &token_id in penalty_tokens { let token_id = token_id as usize; if token_id < logits_vec.len() { // Check if we've already calculated this token's penalty if let Some(penalized_score) = self.penalty_cache.get(&token_id) { // Use cached value logits_vec[token_id] = *penalized_score; cache_hits.set(cache_hits.get() + 1); } else { // Calculate and cache new value let score = logits_vec[token_id]; let sign = if score < 0.0 { -1.0 } else { 1.0 }; let penalized_score = sign * score / self.repeat_penalty; logits_vec[token_id] = penalized_score; self.penalty_cache.insert(token_id, penalized_score); } } } // Log cache efficiency statistics if !penalty_tokens.is_empty() { let cache_efficiency = (cache_hits.get() as f32 / penalty_tokens.len() as f32) * 100.0; tracing::trace!("Repeat penalty cache hits: {}/{} ({:.1}%)", cache_hits.get(), penalty_tokens.len(), cache_efficiency); } // Create a new tensor with the modified logits (single tensor creation) let device = logits.device().clone(); let shape = logits.shape().clone(); let new_logits = Tensor::new(&logits_vec[..], &device)?; let result = new_logits.reshape(shape)?; let elapsed = repeat_start.elapsed(); Ok((result, elapsed)) } // Run text generation and print to stdout pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { use std::io::Write; // Track overall performance let start_time = std::time::Instant::now(); // Keep penalty cache across generation for better repetition prevention // Only clear cache if it becomes too large to prevent memory bloat if self.penalty_cache.len() > 10000 { self.penalty_cache.clear(); tracing::debug!("Cleared penalty cache due to size limit"); } else { tracing::debug!("Maintaining penalty cache across generation for better repetition prevention"); } // Phase 1: Tokenize input let tokenize_start = std::time::Instant::now(); self.tokenizer.clear(); let mut tokens = self .tokenizer .tokenizer() .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let tokenize_time = tokenize_start.elapsed(); tracing::debug!("Tokenization completed in {:.2?}", tokenize_time); tracing::debug!("Input tokens: {}", tokens.len()); // Print tokenized prompt for &t in tokens.iter() { if let Some(t) = self.tokenizer.next_token(t)? { print!("{t}") } } std::io::stdout().flush()?; let mut generated_tokens = 0usize; let eos_token = match self.tokenizer.get_token("") { Some(token) => token, None => anyhow::bail!("cannot find the token"), }; let eot_token = match self.tokenizer.get_token("") { Some(token) => token, None => { println!( "Warning: token not found in tokenizer, using as a backup" ); eos_token } }; // Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant // Both need special handling for shape compatibility let needs_special_handling = match &self.model { Model::V2(_) => true, Model::V3(_) => true, _ => false, }; // Phase 2: Text generation let start_gen = std::time::Instant::now(); // Track per-token generation timing for performance analysis let mut token_times = Vec::new(); let mut forward_times = Vec::new(); let mut repeat_penalty_times = Vec::new(); let mut sampling_times = Vec::new(); // For Model2 and Model3, we need to use a special approach for shape compatibility if needs_special_handling { // For gemma-2 and gemma-3 models, we'll generate one token at a time with the full context tracing::debug!("Using special generation approach for gemma-2/gemma-3 models"); // Initial generation with the full prompt let forward_start = std::time::Instant::now(); let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; // Use execute_with_fallback which handles both device compatibility and shape mismatches let mut logits = self.execute_with_fallback(&input, 0)?; logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); for _ in 0..sample_len { let token_start = std::time::Instant::now(); // Apply repeat penalty using optimized cached implementation let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?; repeat_penalty_times.push(repeat_time); // Track token sampling let sampling_start = std::time::Instant::now(); let next_token = self.logits_processor.sample(¤t_logits)?; let sampling_time = sampling_start.elapsed(); sampling_times.push(sampling_time); tokens.push(next_token); generated_tokens += 1; if next_token == eos_token || next_token == eot_token { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { print!("{t}"); std::io::stdout().flush()?; } // For the next iteration, just use the new token let forward_start = std::time::Instant::now(); let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?; // Use execute_with_fallback for both Gemma 3 and other models logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?; logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); let token_time = token_start.elapsed(); token_times.push(token_time); } } else { // Standard approach for other models tracing::debug!("Using standard generation approach"); for index in 0..sample_len { let token_start = std::time::Instant::now(); let context_size = if index > 0 { 1 } else { tokens.len() }; let start_pos = tokens.len().saturating_sub(context_size); let ctxt = &tokens[start_pos..]; // Track tensor operations and model forward pass let forward_start = std::time::Instant::now(); let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = self.execute_with_fallback(&input, start_pos)?; let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); // Apply repeat penalty using optimized cached implementation let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?; repeat_penalty_times.push(repeat_time); // Track token sampling let sampling_start = std::time::Instant::now(); let next_token = self.logits_processor.sample(&logits)?; let sampling_time = sampling_start.elapsed(); sampling_times.push(sampling_time); tokens.push(next_token); generated_tokens += 1; if next_token == eos_token || next_token == eot_token { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { print!("{t}"); std::io::stdout().flush()?; } let token_time = token_start.elapsed(); token_times.push(token_time); } } let dt = start_gen.elapsed(); // Phase 3: Final decoding and output let decode_start = std::time::Instant::now(); if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { print!("{rest}"); } let decode_time = decode_start.elapsed(); std::io::stdout().flush()?; // Calculate generation speed let tokens_per_second = generated_tokens as f64 / dt.as_secs_f64(); // Calculate average time per token and component breakdown let avg_token_time = if !token_times.is_empty() { token_times.iter().sum::() / token_times.len() as u32 } else { std::time::Duration::from_secs(0) }; let avg_forward_time = if !forward_times.is_empty() { forward_times.iter().sum::() / forward_times.len() as u32 } else { std::time::Duration::from_secs(0) }; let avg_repeat_time = if !repeat_penalty_times.is_empty() { repeat_penalty_times.iter().sum::() / repeat_penalty_times.len() as u32 } else { std::time::Duration::from_secs(0) }; let avg_sampling_time = if !sampling_times.is_empty() { sampling_times.iter().sum::() / sampling_times.len() as u32 } else { std::time::Duration::from_secs(0) }; // Log performance metrics println!( "\n{generated_tokens} tokens generated ({:.2} token/s)", tokens_per_second, ); // Record detailed performance metrics tracing::info!("Text generation completed in {:.2?}", dt); tracing::info!("Tokens generated: {}", generated_tokens); tracing::info!("Generation speed: {:.2} tokens/second", tokens_per_second); tracing::info!("Average time per token: {:.2?}", avg_token_time); tracing::debug!(" - Forward pass: {:.2?} ({:.1}%)", avg_forward_time, avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 ); tracing::debug!(" - Repeat penalty: {:.2?} ({:.1}%)", avg_repeat_time, avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 ); tracing::debug!(" - Sampling: {:.2?} ({:.1}%)", avg_sampling_time, avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 ); // Log total request time let total_time = start_time.elapsed(); tracing::info!("Total request time: {:.2?}", total_time); tracing::debug!(" - Tokenization: {:.2?} ({:.1}%)", tokenize_time, tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 ); tracing::debug!(" - Generation: {:.2?} ({:.1}%)", dt, dt.as_secs_f64() / total_time.as_secs_f64() * 100.0 ); tracing::debug!(" - Final decoding: {:.2?} ({:.1}%)", decode_time, decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 ); Ok(()) } // Run text generation and write to a buffer pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec) -> Result<()> { use std::io::Write; // Track overall performance let start_time = std::time::Instant::now(); // Keep penalty cache across generation for better repetition prevention // Only clear cache if it becomes too large to prevent memory bloat if self.penalty_cache.len() > 10000 { self.penalty_cache.clear(); tracing::debug!("Cleared penalty cache due to size limit (API mode)"); } else { tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (API mode)"); } // Phase 1: Tokenize input let tokenize_start = std::time::Instant::now(); self.tokenizer.clear(); let mut tokens = self .tokenizer .tokenizer() .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let tokenize_time = tokenize_start.elapsed(); tracing::debug!("API Tokenization completed in {:.2?}", tokenize_time); tracing::debug!("API Input tokens: {}", tokens.len()); // Write prompt tokens to output for &t in tokens.iter() { if let Some(t) = self.tokenizer.next_token(t)? { write!(output, "{}", t)?; } } let mut generated_tokens = 0usize; let eos_token = match self.tokenizer.get_token("") { Some(token) => token, None => anyhow::bail!("cannot find the token"), }; let eot_token = match self.tokenizer.get_token("") { Some(token) => token, None => { write!(output, "Warning: token not found in tokenizer, using as a backup")?; eos_token } }; // Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant // Both need special handling for shape compatibility let needs_special_handling = match &self.model { Model::V2(_) => true, Model::V3(_) => true, _ => false, }; // Check if we're specifically using a Model3 (gemma-3) for additional error handling // let is_model_v3 = matches!(&self.model, Model::V3(_)); // Track generation timing let start_gen = std::time::Instant::now(); // Track per-token generation timing for performance analysis let mut token_times = Vec::new(); let mut forward_times = Vec::new(); let mut repeat_penalty_times = Vec::new(); let mut sampling_times = Vec::new(); // For Model2 and Model3, we need to use a special approach for shape compatibility if needs_special_handling { // For gemma-2 and gemma-3 models, we'll generate one token at a time with the full context tracing::debug!("Using special generation approach for gemma-2/gemma-3 models"); // Initial generation with the full prompt let forward_start = std::time::Instant::now(); let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; // Use execute_with_fallback which handles both device compatibility and shape mismatches let mut logits = self.execute_with_fallback(&input, 0)?; logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); for _ in 0..sample_len { let token_start = std::time::Instant::now(); // Apply repeat penalty using optimized cached implementation let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?; repeat_penalty_times.push(repeat_time); // Track token sampling let sampling_start = std::time::Instant::now(); let next_token = self.logits_processor.sample(¤t_logits)?; let sampling_time = sampling_start.elapsed(); sampling_times.push(sampling_time); tokens.push(next_token); generated_tokens += 1; if next_token == eos_token || next_token == eot_token { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { write!(output, "{}", t)?; } // For the next iteration, just use the new token let forward_start = std::time::Instant::now(); let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?; // Use execute_with_fallback for both Gemma 3 and other models logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?; logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); let token_time = token_start.elapsed(); token_times.push(token_time); } let dt = start_gen.elapsed(); // Calculate and log performance metrics Self::log_performance_metrics( dt, generated_tokens, &token_times, &forward_times, &repeat_penalty_times, &sampling_times, tokenize_time, std::time::Duration::from_secs(0), start_time, "API" ); return Ok(()); } // Standard approach for other models tracing::debug!("Using standard generation approach"); for index in 0..sample_len { let token_start = std::time::Instant::now(); // Use sliding window context instead of single token to preserve context and reduce repetition let context_size = if index > 0 { std::cmp::min(self.context_window_size, tokens.len()) } else { tokens.len() }; let start_pos = tokens.len().saturating_sub(context_size); let ctxt = &tokens[start_pos..]; tracing::debug!("API standard model: Using sliding window context: {} tokens (from position {})", ctxt.len(), start_pos); // Track tensor operations and model forward pass let forward_start = std::time::Instant::now(); let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = self.execute_with_fallback(&input, start_pos)?; let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); // Apply repeat penalty using optimized cached implementation let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?; repeat_penalty_times.push(repeat_time); // Track token sampling let sampling_start = std::time::Instant::now(); let next_token = self.logits_processor.sample(&logits)?; let sampling_time = sampling_start.elapsed(); sampling_times.push(sampling_time); tokens.push(next_token); generated_tokens += 1; if next_token == eos_token || next_token == eot_token { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { write!(output, "{}", t)?; } let token_time = token_start.elapsed(); token_times.push(token_time); } let dt = start_gen.elapsed(); // Phase 3: Final decoding and output let decode_start = std::time::Instant::now(); // Write any remaining tokens if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { write!(output, "{}", rest)?; } let decode_time = decode_start.elapsed(); // Log performance metrics Self::log_performance_metrics( dt, generated_tokens, &token_times, &forward_times, &repeat_penalty_times, &sampling_times, tokenize_time, decode_time, start_time, "API" ); Ok(()) } // Run text generation with streaming callback for each token pub async fn run_with_streaming(&mut self, prompt: &str, sample_len: usize, mut token_callback: F) -> Result where F: FnMut(&str) -> Result<()>, { // Track overall performance let start_time = std::time::Instant::now(); // Keep penalty cache across generation for better repetition prevention // Only clear cache if it becomes too large to prevent memory bloat if self.penalty_cache.len() > 10000 { self.penalty_cache.clear(); tracing::debug!("Cleared penalty cache due to size limit (streaming mode)"); } else { tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (streaming mode)"); } // Phase 1: Tokenize input let tokenize_start = std::time::Instant::now(); self.tokenizer.clear(); let mut tokens = self .tokenizer .tokenizer() .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let tokenize_time = tokenize_start.elapsed(); tracing::debug!("Streaming Tokenization completed in {:.2?}", tokenize_time); tracing::debug!("Streaming Input tokens: {}", tokens.len()); // Collect all output for final return let mut full_output = String::new(); let mut generated_tokens = 0usize; let eos_token = match self.tokenizer.get_token("") { Some(token) => token, None => anyhow::bail!("cannot find the token"), }; let eot_token = match self.tokenizer.get_token("") { Some(token) => token, None => { tracing::warn!("Warning: token not found in tokenizer, using as a backup"); eos_token } }; // Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant let needs_special_handling = match &self.model { Model::V2(_) => true, Model::V3(_) => true, _ => false, }; // Track generation timing let start_gen = std::time::Instant::now(); // Track per-token generation timing for performance analysis let mut token_times = Vec::new(); let mut forward_times = Vec::new(); let mut repeat_penalty_times = Vec::new(); let mut sampling_times = Vec::new(); // For Model2 and Model3, we need to use a special approach for shape compatibility if needs_special_handling { tracing::debug!("Using special generation approach for gemma-2/gemma-3 models (streaming)"); tracing::debug!("Streaming: sample_len = {}", sample_len); // Initial generation with the full prompt let forward_start = std::time::Instant::now(); let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; let mut logits = self.execute_with_fallback(&input, 0)?; logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); tracing::debug!("Streaming: About to enter generation loop with sample_len = {}", sample_len); for gen_index in 0..sample_len { tracing::debug!("Streaming: Starting generation iteration {} / {}", gen_index + 1, sample_len); let token_start = std::time::Instant::now(); // Apply repeat penalty using optimized cached implementation let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?; repeat_penalty_times.push(repeat_time); // Track token sampling let sampling_start = std::time::Instant::now(); let next_token = self.logits_processor.sample(¤t_logits)?; let sampling_time = sampling_start.elapsed(); sampling_times.push(sampling_time); tokens.push(next_token); generated_tokens += 1; tracing::debug!("Streaming: Generated token {} (id: {}), eos: {}, eot: {}", next_token, next_token, eos_token, eot_token); if next_token == eos_token || next_token == eot_token { tracing::debug!("Streaming: Breaking due to end token"); break; } if let Some(token_text) = self.tokenizer.next_token(next_token)? { full_output.push_str(&token_text); // Call the streaming callback with this token token_callback(&token_text)?; } // For the next iteration, use single token to avoid shape mismatch let forward_start = std::time::Instant::now(); tracing::debug!("Streaming: Preparing next forward pass with {} tokens", tokens.len()); // Use just the last token for subsequent iterations to avoid shape mismatch // This is required for Gemma model's attention mechanism compatibility let context_tokens = &tokens[(tokens.len()-1)..]; let start_pos = tokens.len() - 1; tracing::debug!("Streaming: Using single token context for Gemma: {} tokens (from position {})", context_tokens.len(), start_pos); let new_input = match Tensor::new(context_tokens, &self.device) { Ok(tensor) => tensor, Err(e) => { tracing::error!("Streaming: Failed to create input tensor: {}", e); return Err(e.into()); } }; let new_input = match new_input.unsqueeze(0) { Ok(tensor) => tensor, Err(e) => { tracing::error!("Streaming: Failed to unsqueeze input tensor: {}", e); return Err(e.into()); } }; tracing::debug!("Streaming: About to call execute_with_fallback for iteration {} with start_pos {}", gen_index + 1, start_pos); logits = match self.execute_with_fallback(&new_input, start_pos) { Ok(result) => result, Err(e) => { tracing::error!("Streaming: Forward pass failed: {}", e); return Err(e); } }; logits = match logits.squeeze(0) { Ok(result) => result, Err(e) => { tracing::error!("Streaming: Failed to squeeze logits (dim 0): {}", e); return Err(e.into()); } }; logits = match logits.squeeze(0) { Ok(result) => result, Err(e) => { tracing::error!("Streaming: Failed to squeeze logits (dim 0 again): {}", e); return Err(e.into()); } }; logits = match logits.to_dtype(DType::F32) { Ok(result) => result, Err(e) => { tracing::error!("Streaming: Failed to convert logits to F32: {}", e); return Err(e.into()); } }; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); tracing::debug!("Streaming: Forward pass completed for iteration {}", gen_index + 1); let token_time = token_start.elapsed(); token_times.push(token_time); // Yield to allow other async tasks to run tokio::task::yield_now().await; } } else { // Standard approach for other models tracing::debug!("Using standard generation approach (streaming)"); for index in 0..sample_len { let token_start = std::time::Instant::now(); // Use sliding window context instead of single token to preserve context and reduce repetition let context_size = if index > 0 { std::cmp::min(self.context_window_size, tokens.len()) } else { tokens.len() }; let start_pos = tokens.len().saturating_sub(context_size); let ctxt = &tokens[start_pos..]; tracing::debug!("Standard model: Using sliding window context: {} tokens (from position {})", ctxt.len(), start_pos); // Track tensor operations and model forward pass let forward_start = std::time::Instant::now(); let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = self.execute_with_fallback(&input, start_pos)?; let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let forward_time = forward_start.elapsed(); forward_times.push(forward_time); // Apply repeat penalty using optimized cached implementation let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?; repeat_penalty_times.push(repeat_time); // Track token sampling let sampling_start = std::time::Instant::now(); let next_token = self.logits_processor.sample(&logits)?; let sampling_time = sampling_start.elapsed(); sampling_times.push(sampling_time); tokens.push(next_token); generated_tokens += 1; if next_token == eos_token || next_token == eot_token { break; } if let Some(token_text) = self.tokenizer.next_token(next_token)? { full_output.push_str(&token_text); // Call the streaming callback with this token token_callback(&token_text)?; } let token_time = token_start.elapsed(); token_times.push(token_time); } } let dt = start_gen.elapsed(); // Phase 3: Final decoding let decode_start = std::time::Instant::now(); // Decode any remaining tokens but don't send through callback to avoid repetition // The tokens were already streamed individually in the generation loop above if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { full_output.push_str(&rest); // Note: NOT calling token_callback(&rest) here to prevent token repetition // Individual tokens were already streamed via the callback in the generation loop } let decode_time = decode_start.elapsed(); // Log performance metrics Self::log_performance_metrics( dt, generated_tokens, &token_times, &forward_times, &repeat_penalty_times, &sampling_times, tokenize_time, decode_time, start_time, "Streaming" ); Ok(full_output) } // Helper function for logging performance metrics fn log_performance_metrics( generation_time: std::time::Duration, generated_tokens: usize, token_times: &[std::time::Duration], forward_times: &[std::time::Duration], repeat_penalty_times: &[std::time::Duration], sampling_times: &[std::time::Duration], tokenize_time: std::time::Duration, decode_time: std::time::Duration, start_time: std::time::Instant, prefix: &str, ) { // Calculate generation speed let tokens_per_second = if generation_time.as_secs_f64() > 0.0 { generated_tokens as f64 / generation_time.as_secs_f64() } else { 0.0 }; // Calculate average time per token and component breakdown let avg_token_time = if !token_times.is_empty() { token_times.iter().sum::() / token_times.len() as u32 } else { std::time::Duration::from_secs(0) }; let avg_forward_time = if !forward_times.is_empty() { forward_times.iter().sum::() / forward_times.len() as u32 } else { std::time::Duration::from_secs(0) }; let avg_repeat_time = if !repeat_penalty_times.is_empty() { repeat_penalty_times.iter().sum::() / repeat_penalty_times.len() as u32 } else { std::time::Duration::from_secs(0) }; let avg_sampling_time = if !sampling_times.is_empty() { sampling_times.iter().sum::() / sampling_times.len() as u32 } else { std::time::Duration::from_secs(0) }; // Record detailed performance metrics tracing::info!("{} Text generation completed in {:.2?}", prefix, generation_time); tracing::info!("{} Tokens generated: {}", prefix, generated_tokens); tracing::info!("{} Generation speed: {:.2} tokens/second", prefix, tokens_per_second); tracing::info!("{} Average time per token: {:.2?}", prefix, avg_token_time); if !avg_token_time.is_zero() { tracing::debug!("{} - Forward pass: {:.2?} ({:.1}%)", prefix, avg_forward_time, avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 ); tracing::debug!("{} - Repeat penalty: {:.2?} ({:.1}%)", prefix, avg_repeat_time, avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 ); tracing::debug!("{} - Sampling: {:.2?} ({:.1}%)", prefix, avg_sampling_time, avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 ); } // Log total request time let total_time = start_time.elapsed(); tracing::info!("{} Total request time: {:.2?}", prefix, total_time); if !total_time.is_zero() { tracing::debug!("{} - Tokenization: {:.2?} ({:.1}%)", prefix, tokenize_time, tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 ); tracing::debug!("{} - Generation: {:.2?} ({:.1}%)", prefix, generation_time, generation_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 ); tracing::debug!("{} - Final decoding: {:.2?} ({:.1}%)", prefix, decode_time, decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 ); } } }