mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Refactor apply_cached_repeat_penalty
for optimized caching and reuse, add extensive unit tests, and integrate special handling for gemma-specific models.
Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions. - Add CPU fallback support for text generation when primary device is unsupported - Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors - Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations - Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing chat completion endpoint functions with gemma3 (no streaming) Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
This commit is contained in:
@@ -2,7 +2,7 @@ use anyhow::{Error as E, Result};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use tokenizers::Tokenizer;
|
||||
use std::io::Write;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::model::Model;
|
||||
use crate::token_output_stream::TokenOutputStream;
|
||||
@@ -10,10 +10,16 @@ use crate::token_output_stream::TokenOutputStream;
|
||||
pub struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
// CPU device for fallback when operations are unsupported on primary device
|
||||
cpu_device: Option<Device>,
|
||||
// Flag to indicate if we should try to use the primary device first
|
||||
try_primary_device: bool,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
// Cache for repeat penalty computation to avoid redundant calculations
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
@@ -29,6 +35,16 @@ impl TextGeneration {
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
|
||||
// Initialize CPU device only if the primary device is not already CPU
|
||||
let (cpu_device, try_primary_device) = if device.is_cpu() {
|
||||
// If already on CPU, no need for a fallback device
|
||||
(None, false)
|
||||
} else {
|
||||
// Store CPU device for fallback and set flag to try primary device first
|
||||
(Some(Device::Cpu), true)
|
||||
};
|
||||
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
@@ -36,12 +52,142 @@ impl TextGeneration {
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
cpu_device,
|
||||
try_primary_device,
|
||||
penalty_cache: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// Helper method for model execution with fallback to CPU for unsupported operations
|
||||
fn execute_with_fallback(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
|
||||
// If we're not trying primary device anymore, go straight to CPU if available
|
||||
if !self.try_primary_device {
|
||||
if let Some(cpu_device) = &self.cpu_device {
|
||||
let cpu_input = input.to_device(cpu_device).map_err(E::msg)?;
|
||||
let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?;
|
||||
return cpu_result.to_device(&self.device).map_err(E::msg);
|
||||
} else {
|
||||
// No CPU fallback, use primary device
|
||||
return self.model.forward(input, start_pos).map_err(E::msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Try running on the primary device first
|
||||
match self.model.forward(input, start_pos) {
|
||||
Ok(result) => Ok(result),
|
||||
Err(err) => {
|
||||
// Convert to string to check for unsupported operation
|
||||
let err_string = err.to_string();
|
||||
|
||||
// Check if the error is about unsupported operations or shape mismatches
|
||||
if (err_string.contains("no metal implementation for") ||
|
||||
err_string.contains("no cuda implementation for") ||
|
||||
err_string.contains("shape mismatch") ||
|
||||
err_string.contains("broadcast_add")) &&
|
||||
self.cpu_device.is_some() {
|
||||
|
||||
// Extract operation name for better logging
|
||||
let op_name = if let Some(idx) = err_string.find("for ") {
|
||||
&err_string[(idx + 4)..]
|
||||
} else if err_string.contains("shape mismatch") {
|
||||
"shape mismatch operation"
|
||||
} else {
|
||||
"an operation"
|
||||
};
|
||||
|
||||
// Log the fallback
|
||||
tracing::warn!("The primary device does not support {}. Falling back to CPU.", op_name);
|
||||
|
||||
// Move input to CPU and try again
|
||||
let cpu_device = self.cpu_device.as_ref().unwrap();
|
||||
let cpu_input = input.to_device(cpu_device).map_err(E::msg)?;
|
||||
let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?;
|
||||
|
||||
// Don't try primary device for future operations
|
||||
self.try_primary_device = false;
|
||||
tracing::info!("Successfully executed on CPU. Will use CPU for subsequent operations.");
|
||||
|
||||
// Move result back to original device
|
||||
cpu_result.to_device(&self.device).map_err(E::msg)
|
||||
} else {
|
||||
// Not an unsupported operation error or no CPU fallback
|
||||
Err(E::msg(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper method to apply repeat penalty with caching for optimization
|
||||
pub fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
logits: Tensor,
|
||||
tokens: &[u32],
|
||||
) -> Result<(Tensor, std::time::Duration)> {
|
||||
let repeat_start = std::time::Instant::now();
|
||||
|
||||
// If no penalty, return the original logits
|
||||
if self.repeat_penalty == 1.0 {
|
||||
return Ok((logits, repeat_start.elapsed()));
|
||||
}
|
||||
|
||||
// Get the tokens to penalize (the last n tokens)
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
|
||||
// Extract logits to a vector for modification
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
let cache_hits = std::cell::Cell::new(0);
|
||||
|
||||
// Apply penalties with caching
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
// Check if we've already calculated this token's penalty
|
||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
||||
// Use cached value
|
||||
logits_vec[token_id] = *penalized_score;
|
||||
cache_hits.set(cache_hits.get() + 1);
|
||||
} else {
|
||||
// Calculate and cache new value
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
let penalized_score = sign * score / self.repeat_penalty;
|
||||
logits_vec[token_id] = penalized_score;
|
||||
self.penalty_cache.insert(token_id, penalized_score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log cache efficiency statistics
|
||||
if !penalty_tokens.is_empty() {
|
||||
let cache_efficiency = (cache_hits.get() as f32 / penalty_tokens.len() as f32) * 100.0;
|
||||
tracing::trace!("Repeat penalty cache hits: {}/{} ({:.1}%)",
|
||||
cache_hits.get(), penalty_tokens.len(), cache_efficiency);
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits (single tensor creation)
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
let result = new_logits.reshape(shape)?;
|
||||
|
||||
let elapsed = repeat_start.elapsed();
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
|
||||
// Run text generation and print to stdout
|
||||
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
|
||||
// Track overall performance
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Clear penalty cache for new generation
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache for new generation");
|
||||
|
||||
// Phase 1: Tokenize input
|
||||
let tokenize_start = std::time::Instant::now();
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
@@ -50,6 +196,12 @@ impl TextGeneration {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let tokenize_time = tokenize_start.elapsed();
|
||||
tracing::debug!("Tokenization completed in {:.2?}", tokenize_time);
|
||||
tracing::debug!("Input tokens: {}", tokens.len());
|
||||
|
||||
// Print tokenized prompt
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
@@ -73,39 +225,107 @@ impl TextGeneration {
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
|
||||
// Both need special handling for shape compatibility
|
||||
let needs_special_handling = match &self.model {
|
||||
Model::V2(_) => true,
|
||||
Model::V3(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// Phase 2: Text generation
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
|
||||
// Track per-token generation timing for performance analysis
|
||||
let mut token_times = Vec::new();
|
||||
let mut forward_times = Vec::new();
|
||||
let mut repeat_penalty_times = Vec::new();
|
||||
let mut sampling_times = Vec::new();
|
||||
|
||||
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
||||
if needs_special_handling {
|
||||
// For gemma-2 and gemma-3 models, we'll generate one token at a time with the full context
|
||||
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models");
|
||||
|
||||
// Initial generation with the full prompt
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
|
||||
// Use execute_with_fallback which handles both device compatibility and shape mismatches
|
||||
let mut logits = self.execute_with_fallback(&input, 0)?;
|
||||
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
for _ in 0..sample_len {
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
// Apply repeat penalty using optimized cached implementation
|
||||
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
repeat_penalty_times.push(repeat_time);
|
||||
|
||||
// Track token sampling
|
||||
let sampling_start = std::time::Instant::now();
|
||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||
let sampling_time = sampling_start.elapsed();
|
||||
sampling_times.push(sampling_time);
|
||||
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
// For the next iteration, just use the new token
|
||||
let forward_start = std::time::Instant::now();
|
||||
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
||||
|
||||
// Use execute_with_fallback for both Gemma 3 and other models
|
||||
logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?;
|
||||
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
let token_time = token_start.elapsed();
|
||||
token_times.push(token_time);
|
||||
}
|
||||
} else {
|
||||
// Standard approach for other models
|
||||
tracing::debug!("Using standard generation approach");
|
||||
|
||||
for index in 0..sample_len {
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
|
||||
// Track tensor operations and model forward pass
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = self.execute_with_fallback(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
// Apply repeat penalty using optimized cached implementation
|
||||
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
||||
repeat_penalty_times.push(repeat_time);
|
||||
|
||||
// Track token sampling
|
||||
let sampling_start = std::time::Instant::now();
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
let sampling_time = sampling_start.elapsed();
|
||||
sampling_times.push(sampling_time);
|
||||
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
@@ -115,21 +335,107 @@ impl TextGeneration {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let token_time = token_start.elapsed();
|
||||
token_times.push(token_time);
|
||||
}
|
||||
}
|
||||
|
||||
let dt = start_gen.elapsed();
|
||||
|
||||
// Phase 3: Final decoding and output
|
||||
let decode_start = std::time::Instant::now();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
let decode_time = decode_start.elapsed();
|
||||
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
// Calculate generation speed
|
||||
let tokens_per_second = generated_tokens as f64 / dt.as_secs_f64();
|
||||
|
||||
// Calculate average time per token and component breakdown
|
||||
let avg_token_time = if !token_times.is_empty() {
|
||||
token_times.iter().sum::<std::time::Duration>() / token_times.len() as u32
|
||||
} else {
|
||||
std::time::Duration::from_secs(0)
|
||||
};
|
||||
|
||||
let avg_forward_time = if !forward_times.is_empty() {
|
||||
forward_times.iter().sum::<std::time::Duration>() / forward_times.len() as u32
|
||||
} else {
|
||||
std::time::Duration::from_secs(0)
|
||||
};
|
||||
|
||||
let avg_repeat_time = if !repeat_penalty_times.is_empty() {
|
||||
repeat_penalty_times.iter().sum::<std::time::Duration>() / repeat_penalty_times.len() as u32
|
||||
} else {
|
||||
std::time::Duration::from_secs(0)
|
||||
};
|
||||
|
||||
let avg_sampling_time = if !sampling_times.is_empty() {
|
||||
sampling_times.iter().sum::<std::time::Duration>() / sampling_times.len() as u32
|
||||
} else {
|
||||
std::time::Duration::from_secs(0)
|
||||
};
|
||||
|
||||
// Log performance metrics
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
tokens_per_second,
|
||||
);
|
||||
|
||||
// Record detailed performance metrics
|
||||
tracing::info!("Text generation completed in {:.2?}", dt);
|
||||
tracing::info!("Tokens generated: {}", generated_tokens);
|
||||
tracing::info!("Generation speed: {:.2} tokens/second", tokens_per_second);
|
||||
tracing::info!("Average time per token: {:.2?}", avg_token_time);
|
||||
tracing::debug!(" - Forward pass: {:.2?} ({:.1}%)",
|
||||
avg_forward_time,
|
||||
avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||
);
|
||||
tracing::debug!(" - Repeat penalty: {:.2?} ({:.1}%)",
|
||||
avg_repeat_time,
|
||||
avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||
);
|
||||
tracing::debug!(" - Sampling: {:.2?} ({:.1}%)",
|
||||
avg_sampling_time,
|
||||
avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||
);
|
||||
|
||||
// Log total request time
|
||||
let total_time = start_time.elapsed();
|
||||
tracing::info!("Total request time: {:.2?}", total_time);
|
||||
tracing::debug!(" - Tokenization: {:.2?} ({:.1}%)",
|
||||
tokenize_time,
|
||||
tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||
);
|
||||
tracing::debug!(" - Generation: {:.2?} ({:.1}%)",
|
||||
dt,
|
||||
dt.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||
);
|
||||
tracing::debug!(" - Final decoding: {:.2?} ({:.1}%)",
|
||||
decode_time,
|
||||
decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run text generation and write to a buffer
|
||||
pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
||||
use std::io::Write;
|
||||
|
||||
// Track overall performance
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Clear penalty cache for new generation
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache for new generation (API mode)");
|
||||
|
||||
// Phase 1: Tokenize input
|
||||
let tokenize_start = std::time::Instant::now();
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
@@ -138,6 +444,10 @@ impl TextGeneration {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let tokenize_time = tokenize_start.elapsed();
|
||||
tracing::debug!("API Tokenization completed in {:.2?}", tokenize_time);
|
||||
tracing::debug!("API Input tokens: {}", tokens.len());
|
||||
|
||||
// Write prompt tokens to output
|
||||
for &t in tokens.iter() {
|
||||
@@ -160,49 +470,55 @@ impl TextGeneration {
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we're using a Model3 (gemma-3) variant
|
||||
let is_model3 = match &self.model {
|
||||
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
|
||||
// Both need special handling for shape compatibility
|
||||
let needs_special_handling = match &self.model {
|
||||
Model::V2(_) => true,
|
||||
Model::V3(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// For Model3, we need to use a different approach
|
||||
if is_model3 {
|
||||
// For gemma-3 models, we'll generate one token at a time with the full context
|
||||
let start_gen = std::time::Instant::now();
|
||||
// Check if we're specifically using a Model3 (gemma-3) for additional error handling
|
||||
// let is_model_v3 = matches!(&self.model, Model::V3(_));
|
||||
|
||||
// Track generation timing
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
// Track per-token generation timing for performance analysis
|
||||
let mut token_times = Vec::new();
|
||||
let mut forward_times = Vec::new();
|
||||
let mut repeat_penalty_times = Vec::new();
|
||||
let mut sampling_times = Vec::new();
|
||||
|
||||
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
||||
if needs_special_handling {
|
||||
// For gemma-2 and gemma-3 models, we'll generate one token at a time with the full context
|
||||
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models");
|
||||
|
||||
// Initial generation with the full prompt
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
let mut logits = self.model.forward(&input, 0)?;
|
||||
|
||||
// Use execute_with_fallback which handles both device compatibility and shape mismatches
|
||||
let mut logits = self.execute_with_fallback(&input, 0)?;
|
||||
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
for _ in 0..sample_len {
|
||||
// Apply repeat penalty if needed
|
||||
let current_logits = if self.repeat_penalty == 1. {
|
||||
logits.clone()
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
// Apply repeat penalty using optimized cached implementation
|
||||
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
repeat_penalty_times.push(repeat_time);
|
||||
|
||||
// Track token sampling
|
||||
let sampling_start = std::time::Instant::now();
|
||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||
let sampling_time = sampling_start.elapsed();
|
||||
sampling_times.push(sampling_time);
|
||||
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
@@ -215,48 +531,60 @@ impl TextGeneration {
|
||||
}
|
||||
|
||||
// For the next iteration, just use the new token
|
||||
let forward_start = std::time::Instant::now();
|
||||
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
||||
logits = self.model.forward(&new_input, tokens.len() - 1)?;
|
||||
|
||||
// Use execute_with_fallback for both Gemma 3 and other models
|
||||
logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?;
|
||||
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
let token_time = token_start.elapsed();
|
||||
token_times.push(token_time);
|
||||
}
|
||||
|
||||
let dt = start_gen.elapsed();
|
||||
|
||||
// Calculate and log performance metrics
|
||||
Self::log_performance_metrics(
|
||||
dt, generated_tokens, &token_times, &forward_times,
|
||||
&repeat_penalty_times, &sampling_times, tokenize_time,
|
||||
std::time::Duration::from_secs(0), start_time, "API"
|
||||
);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Standard approach for other models
|
||||
let start_gen = std::time::Instant::now();
|
||||
tracing::debug!("Using standard generation approach");
|
||||
|
||||
for index in 0..sample_len {
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
|
||||
// Track tensor operations and model forward pass
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = self.execute_with_fallback(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
// Apply repeat penalty using optimized cached implementation
|
||||
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
||||
repeat_penalty_times.push(repeat_time);
|
||||
|
||||
// Track token sampling
|
||||
let sampling_start = std::time::Instant::now();
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
let sampling_time = sampling_start.elapsed();
|
||||
sampling_times.push(sampling_time);
|
||||
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
@@ -265,13 +593,122 @@ impl TextGeneration {
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
|
||||
let token_time = token_start.elapsed();
|
||||
token_times.push(token_time);
|
||||
}
|
||||
|
||||
|
||||
let dt = start_gen.elapsed();
|
||||
|
||||
// Phase 3: Final decoding and output
|
||||
let decode_start = std::time::Instant::now();
|
||||
|
||||
// Write any remaining tokens
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
write!(output, "{}", rest)?;
|
||||
}
|
||||
|
||||
|
||||
let decode_time = decode_start.elapsed();
|
||||
|
||||
// Log performance metrics
|
||||
Self::log_performance_metrics(
|
||||
dt, generated_tokens, &token_times, &forward_times,
|
||||
&repeat_penalty_times, &sampling_times, tokenize_time,
|
||||
decode_time, start_time, "API"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper function for logging performance metrics
|
||||
fn log_performance_metrics(
|
||||
generation_time: std::time::Duration,
|
||||
generated_tokens: usize,
|
||||
token_times: &[std::time::Duration],
|
||||
forward_times: &[std::time::Duration],
|
||||
repeat_penalty_times: &[std::time::Duration],
|
||||
sampling_times: &[std::time::Duration],
|
||||
tokenize_time: std::time::Duration,
|
||||
decode_time: std::time::Duration,
|
||||
start_time: std::time::Instant,
|
||||
prefix: &str,
|
||||
) {
|
||||
// Calculate generation speed
|
||||
let tokens_per_second = if generation_time.as_secs_f64() > 0.0 {
|
||||
generated_tokens as f64 / generation_time.as_secs_f64()
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Calculate average time per token and component breakdown
|
||||
let avg_token_time = if !token_times.is_empty() {
|
||||
token_times.iter().sum::<std::time::Duration>() / token_times.len() as u32
|
||||
} else {
|
||||
std::time::Duration::from_secs(0)
|
||||
};
|
||||
|
||||
let avg_forward_time = if !forward_times.is_empty() {
|
||||
forward_times.iter().sum::<std::time::Duration>() / forward_times.len() as u32
|
||||
} else {
|
||||
std::time::Duration::from_secs(0)
|
||||
};
|
||||
|
||||
let avg_repeat_time = if !repeat_penalty_times.is_empty() {
|
||||
repeat_penalty_times.iter().sum::<std::time::Duration>() / repeat_penalty_times.len() as u32
|
||||
} else {
|
||||
std::time::Duration::from_secs(0)
|
||||
};
|
||||
|
||||
let avg_sampling_time = if !sampling_times.is_empty() {
|
||||
sampling_times.iter().sum::<std::time::Duration>() / sampling_times.len() as u32
|
||||
} else {
|
||||
std::time::Duration::from_secs(0)
|
||||
};
|
||||
|
||||
// Record detailed performance metrics
|
||||
tracing::info!("{} Text generation completed in {:.2?}", prefix, generation_time);
|
||||
tracing::info!("{} Tokens generated: {}", prefix, generated_tokens);
|
||||
tracing::info!("{} Generation speed: {:.2} tokens/second", prefix, tokens_per_second);
|
||||
tracing::info!("{} Average time per token: {:.2?}", prefix, avg_token_time);
|
||||
|
||||
if !avg_token_time.is_zero() {
|
||||
tracing::debug!("{} - Forward pass: {:.2?} ({:.1}%)",
|
||||
prefix,
|
||||
avg_forward_time,
|
||||
avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||
);
|
||||
tracing::debug!("{} - Repeat penalty: {:.2?} ({:.1}%)",
|
||||
prefix,
|
||||
avg_repeat_time,
|
||||
avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||
);
|
||||
tracing::debug!("{} - Sampling: {:.2?} ({:.1}%)",
|
||||
prefix,
|
||||
avg_sampling_time,
|
||||
avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
// Log total request time
|
||||
let total_time = start_time.elapsed();
|
||||
tracing::info!("{} Total request time: {:.2?}", prefix, total_time);
|
||||
|
||||
if !total_time.is_zero() {
|
||||
tracing::debug!("{} - Tokenization: {:.2?} ({:.1}%)",
|
||||
prefix,
|
||||
tokenize_time,
|
||||
tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||
);
|
||||
tracing::debug!("{} - Generation: {:.2?} ({:.1}%)",
|
||||
prefix,
|
||||
generation_time,
|
||||
generation_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||
);
|
||||
tracing::debug!("{} - Final decoding: {:.2?} ({:.1}%)",
|
||||
prefix,
|
||||
decode_time,
|
||||
decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user