mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00

- Increase default maximum tokens in CLI configuration to 256. - Refactor and reorganize CLI
1005 lines
42 KiB
Rust
1005 lines
42 KiB
Rust
use anyhow::{Error as E, Result};
|
|
use candle_core::{DType, Device, Tensor};
|
|
use candle_transformers::generation::LogitsProcessor;
|
|
use tokenizers::Tokenizer;
|
|
use std::collections::HashMap;
|
|
|
|
use crate::model::Model;
|
|
use crate::token_output_stream::TokenOutputStream;
|
|
|
|
pub struct TextGeneration {
|
|
model: Model,
|
|
device: Device,
|
|
// CPU device for fallback when operations are unsupported on primary device
|
|
cpu_device: Option<Device>,
|
|
// Flag to indicate if we should try to use the primary device first
|
|
try_primary_device: bool,
|
|
tokenizer: TokenOutputStream,
|
|
logits_processor: LogitsProcessor,
|
|
repeat_penalty: f32,
|
|
repeat_last_n: usize,
|
|
// Cache for repeat penalty computation to avoid redundant calculations
|
|
penalty_cache: HashMap<usize, f32>,
|
|
// Context window size for sliding window context (default: 64 tokens)
|
|
context_window_size: usize,
|
|
}
|
|
|
|
impl TextGeneration {
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn new(
|
|
model: Model,
|
|
tokenizer: Tokenizer,
|
|
seed: u64,
|
|
temp: Option<f64>,
|
|
top_p: Option<f64>,
|
|
repeat_penalty: f32,
|
|
repeat_last_n: usize,
|
|
device: &Device,
|
|
) -> Self {
|
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
|
|
// Initialize CPU device only if the primary device is not already CPU
|
|
let (cpu_device, try_primary_device) = if device.is_cpu() {
|
|
// If already on CPU, no need for a fallback device
|
|
(None, false)
|
|
} else {
|
|
// Store CPU device for fallback and set flag to try primary device first
|
|
(Some(Device::Cpu), true)
|
|
};
|
|
|
|
Self {
|
|
model,
|
|
tokenizer: TokenOutputStream::new(tokenizer),
|
|
logits_processor,
|
|
repeat_penalty,
|
|
repeat_last_n,
|
|
device: device.clone(),
|
|
cpu_device,
|
|
try_primary_device,
|
|
penalty_cache: HashMap::new(),
|
|
context_window_size: 64, // Default sliding window size for better context preservation
|
|
}
|
|
}
|
|
|
|
// Helper method for model execution with fallback to CPU for unsupported operations
|
|
fn execute_with_fallback(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
|
|
// If we're not trying primary device anymore, go straight to CPU if available
|
|
if !self.try_primary_device {
|
|
if let Some(cpu_device) = &self.cpu_device {
|
|
let cpu_input = input.to_device(cpu_device).map_err(E::msg)?;
|
|
let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?;
|
|
return cpu_result.to_device(&self.device).map_err(E::msg);
|
|
} else {
|
|
// No CPU fallback, use primary device
|
|
return self.model.forward(input, start_pos).map_err(E::msg);
|
|
}
|
|
}
|
|
|
|
// Try running on the primary device first
|
|
match self.model.forward(input, start_pos) {
|
|
Ok(result) => Ok(result),
|
|
Err(err) => {
|
|
// Convert to string to check for unsupported operation
|
|
let err_string = err.to_string();
|
|
|
|
// Check if the error is about unsupported operations or shape mismatches
|
|
if (err_string.contains("no metal implementation for") ||
|
|
err_string.contains("no cuda implementation for") ||
|
|
err_string.contains("shape mismatch") ||
|
|
err_string.contains("broadcast_add")) &&
|
|
self.cpu_device.is_some() {
|
|
|
|
// Extract operation name for better logging
|
|
let op_name = if let Some(idx) = err_string.find("for ") {
|
|
&err_string[(idx + 4)..]
|
|
} else if err_string.contains("shape mismatch") {
|
|
"shape mismatch operation"
|
|
} else {
|
|
"an operation"
|
|
};
|
|
|
|
// Log the fallback
|
|
tracing::warn!("The primary device does not support {}. Falling back to CPU.", op_name);
|
|
|
|
// Move input to CPU and try again
|
|
let cpu_device = self.cpu_device.as_ref().unwrap();
|
|
let cpu_input = input.to_device(cpu_device).map_err(E::msg)?;
|
|
let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?;
|
|
|
|
// Don't try primary device for future operations
|
|
self.try_primary_device = false;
|
|
tracing::info!("Successfully executed on CPU. Will use CPU for subsequent operations.");
|
|
|
|
// Move result back to original device
|
|
cpu_result.to_device(&self.device).map_err(E::msg)
|
|
} else {
|
|
// Not an unsupported operation error or no CPU fallback
|
|
Err(E::msg(err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reset method to clear state between requests
|
|
pub fn reset_state(&mut self) {
|
|
// Reset the primary device flag so we try the primary device first for each new request
|
|
if !self.device.is_cpu() {
|
|
self.try_primary_device = true;
|
|
}
|
|
// Clear the penalty cache to avoid stale cached values from previous requests
|
|
self.penalty_cache.clear();
|
|
}
|
|
|
|
// Helper method to apply repeat penalty with caching for optimization
|
|
pub fn apply_cached_repeat_penalty(
|
|
&mut self,
|
|
logits: Tensor,
|
|
tokens: &[u32],
|
|
) -> Result<(Tensor, std::time::Duration)> {
|
|
let repeat_start = std::time::Instant::now();
|
|
|
|
// If no penalty, return the original logits
|
|
if self.repeat_penalty == 1.0 {
|
|
return Ok((logits, repeat_start.elapsed()));
|
|
}
|
|
|
|
// Get the tokens to penalize (the last n tokens)
|
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
let penalty_tokens = &tokens[start_at..];
|
|
|
|
// Extract logits to a vector for modification
|
|
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
let cache_hits = std::cell::Cell::new(0);
|
|
|
|
// Apply penalties with caching
|
|
for &token_id in penalty_tokens {
|
|
let token_id = token_id as usize;
|
|
if token_id < logits_vec.len() {
|
|
// Check if we've already calculated this token's penalty
|
|
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
|
// Use cached value
|
|
logits_vec[token_id] = *penalized_score;
|
|
cache_hits.set(cache_hits.get() + 1);
|
|
} else {
|
|
// Calculate and cache new value
|
|
let score = logits_vec[token_id];
|
|
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
let penalized_score = sign * score / self.repeat_penalty;
|
|
logits_vec[token_id] = penalized_score;
|
|
self.penalty_cache.insert(token_id, penalized_score);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Log cache efficiency statistics
|
|
if !penalty_tokens.is_empty() {
|
|
let cache_efficiency = (cache_hits.get() as f32 / penalty_tokens.len() as f32) * 100.0;
|
|
tracing::trace!("Repeat penalty cache hits: {}/{} ({:.1}%)",
|
|
cache_hits.get(), penalty_tokens.len(), cache_efficiency);
|
|
}
|
|
|
|
// Create a new tensor with the modified logits (single tensor creation)
|
|
let device = logits.device().clone();
|
|
let shape = logits.shape().clone();
|
|
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
let result = new_logits.reshape(shape)?;
|
|
|
|
let elapsed = repeat_start.elapsed();
|
|
Ok((result, elapsed))
|
|
}
|
|
|
|
// Run text generation and print to stdout
|
|
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
use std::io::Write;
|
|
|
|
// Track overall performance
|
|
let start_time = std::time::Instant::now();
|
|
|
|
// Keep penalty cache across generation for better repetition prevention
|
|
// Only clear cache if it becomes too large to prevent memory bloat
|
|
if self.penalty_cache.len() > 10000 {
|
|
self.penalty_cache.clear();
|
|
tracing::debug!("Cleared penalty cache due to size limit");
|
|
} else {
|
|
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention");
|
|
}
|
|
|
|
// Phase 1: Tokenize input
|
|
let tokenize_start = std::time::Instant::now();
|
|
self.tokenizer.clear();
|
|
let mut tokens = self
|
|
.tokenizer
|
|
.tokenizer()
|
|
.encode(prompt, true)
|
|
.map_err(E::msg)?
|
|
.get_ids()
|
|
.to_vec();
|
|
|
|
let tokenize_time = tokenize_start.elapsed();
|
|
tracing::debug!("Tokenization completed in {:.2?}", tokenize_time);
|
|
tracing::debug!("Input tokens: {}", tokens.len());
|
|
|
|
// Print tokenized prompt
|
|
for &t in tokens.iter() {
|
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
print!("{t}")
|
|
}
|
|
}
|
|
std::io::stdout().flush()?;
|
|
|
|
let mut generated_tokens = 0usize;
|
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
Some(token) => token,
|
|
None => anyhow::bail!("cannot find the <eos> token"),
|
|
};
|
|
|
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
Some(token) => token,
|
|
None => {
|
|
println!(
|
|
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
|
);
|
|
eos_token
|
|
}
|
|
};
|
|
|
|
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
|
|
// Both need special handling for shape compatibility
|
|
let needs_special_handling = match &self.model {
|
|
Model::V2(_) => true,
|
|
Model::V3(_) => true,
|
|
_ => false,
|
|
};
|
|
|
|
// Phase 2: Text generation
|
|
let start_gen = std::time::Instant::now();
|
|
|
|
// Track per-token generation timing for performance analysis
|
|
let mut token_times = Vec::new();
|
|
let mut forward_times = Vec::new();
|
|
let mut repeat_penalty_times = Vec::new();
|
|
let mut sampling_times = Vec::new();
|
|
|
|
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
|
if needs_special_handling {
|
|
// For gemma-2 and gemma-3 models, we'll generate one token at a time with the full context
|
|
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models");
|
|
|
|
// Initial generation with the full prompt
|
|
let forward_start = std::time::Instant::now();
|
|
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
|
|
|
// Use execute_with_fallback which handles both device compatibility and shape mismatches
|
|
let mut logits = self.execute_with_fallback(&input, 0)?;
|
|
|
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
|
|
for _ in 0..sample_len {
|
|
let token_start = std::time::Instant::now();
|
|
|
|
// Apply repeat penalty using optimized cached implementation
|
|
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
|
repeat_penalty_times.push(repeat_time);
|
|
|
|
// Track token sampling
|
|
let sampling_start = std::time::Instant::now();
|
|
let next_token = self.logits_processor.sample(¤t_logits)?;
|
|
let sampling_time = sampling_start.elapsed();
|
|
sampling_times.push(sampling_time);
|
|
|
|
tokens.push(next_token);
|
|
generated_tokens += 1;
|
|
|
|
if next_token == eos_token || next_token == eot_token {
|
|
break;
|
|
}
|
|
|
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
print!("{t}");
|
|
std::io::stdout().flush()?;
|
|
}
|
|
|
|
// For the next iteration, just use the new token
|
|
let forward_start = std::time::Instant::now();
|
|
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
|
|
|
// Use execute_with_fallback for both Gemma 3 and other models
|
|
logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?;
|
|
|
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
|
|
let token_time = token_start.elapsed();
|
|
token_times.push(token_time);
|
|
}
|
|
} else {
|
|
// Standard approach for other models
|
|
tracing::debug!("Using standard generation approach");
|
|
|
|
for index in 0..sample_len {
|
|
let token_start = std::time::Instant::now();
|
|
|
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
let start_pos = tokens.len().saturating_sub(context_size);
|
|
let ctxt = &tokens[start_pos..];
|
|
|
|
// Track tensor operations and model forward pass
|
|
let forward_start = std::time::Instant::now();
|
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
let logits = self.execute_with_fallback(&input, start_pos)?;
|
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
|
|
// Apply repeat penalty using optimized cached implementation
|
|
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
|
repeat_penalty_times.push(repeat_time);
|
|
|
|
// Track token sampling
|
|
let sampling_start = std::time::Instant::now();
|
|
let next_token = self.logits_processor.sample(&logits)?;
|
|
let sampling_time = sampling_start.elapsed();
|
|
sampling_times.push(sampling_time);
|
|
|
|
tokens.push(next_token);
|
|
generated_tokens += 1;
|
|
if next_token == eos_token || next_token == eot_token {
|
|
break;
|
|
}
|
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
print!("{t}");
|
|
std::io::stdout().flush()?;
|
|
}
|
|
|
|
let token_time = token_start.elapsed();
|
|
token_times.push(token_time);
|
|
}
|
|
}
|
|
|
|
let dt = start_gen.elapsed();
|
|
|
|
// Phase 3: Final decoding and output
|
|
let decode_start = std::time::Instant::now();
|
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
print!("{rest}");
|
|
}
|
|
let decode_time = decode_start.elapsed();
|
|
|
|
std::io::stdout().flush()?;
|
|
|
|
// Calculate generation speed
|
|
let tokens_per_second = generated_tokens as f64 / dt.as_secs_f64();
|
|
|
|
// Calculate average time per token and component breakdown
|
|
let avg_token_time = if !token_times.is_empty() {
|
|
token_times.iter().sum::<std::time::Duration>() / token_times.len() as u32
|
|
} else {
|
|
std::time::Duration::from_secs(0)
|
|
};
|
|
|
|
let avg_forward_time = if !forward_times.is_empty() {
|
|
forward_times.iter().sum::<std::time::Duration>() / forward_times.len() as u32
|
|
} else {
|
|
std::time::Duration::from_secs(0)
|
|
};
|
|
|
|
let avg_repeat_time = if !repeat_penalty_times.is_empty() {
|
|
repeat_penalty_times.iter().sum::<std::time::Duration>() / repeat_penalty_times.len() as u32
|
|
} else {
|
|
std::time::Duration::from_secs(0)
|
|
};
|
|
|
|
let avg_sampling_time = if !sampling_times.is_empty() {
|
|
sampling_times.iter().sum::<std::time::Duration>() / sampling_times.len() as u32
|
|
} else {
|
|
std::time::Duration::from_secs(0)
|
|
};
|
|
|
|
// Log performance metrics
|
|
println!(
|
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
tokens_per_second,
|
|
);
|
|
|
|
// Record detailed performance metrics
|
|
tracing::info!("Text generation completed in {:.2?}", dt);
|
|
tracing::info!("Tokens generated: {}", generated_tokens);
|
|
tracing::info!("Generation speed: {:.2} tokens/second", tokens_per_second);
|
|
tracing::info!("Average time per token: {:.2?}", avg_token_time);
|
|
tracing::debug!(" - Forward pass: {:.2?} ({:.1}%)",
|
|
avg_forward_time,
|
|
avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
|
);
|
|
tracing::debug!(" - Repeat penalty: {:.2?} ({:.1}%)",
|
|
avg_repeat_time,
|
|
avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
|
);
|
|
tracing::debug!(" - Sampling: {:.2?} ({:.1}%)",
|
|
avg_sampling_time,
|
|
avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
|
);
|
|
|
|
// Log total request time
|
|
let total_time = start_time.elapsed();
|
|
tracing::info!("Total request time: {:.2?}", total_time);
|
|
tracing::debug!(" - Tokenization: {:.2?} ({:.1}%)",
|
|
tokenize_time,
|
|
tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
|
);
|
|
tracing::debug!(" - Generation: {:.2?} ({:.1}%)",
|
|
dt,
|
|
dt.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
|
);
|
|
tracing::debug!(" - Final decoding: {:.2?} ({:.1}%)",
|
|
decode_time,
|
|
decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// Run text generation and write to a buffer
|
|
pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
|
use std::io::Write;
|
|
|
|
// Track overall performance
|
|
let start_time = std::time::Instant::now();
|
|
|
|
// Keep penalty cache across generation for better repetition prevention
|
|
// Only clear cache if it becomes too large to prevent memory bloat
|
|
if self.penalty_cache.len() > 10000 {
|
|
self.penalty_cache.clear();
|
|
tracing::debug!("Cleared penalty cache due to size limit (API mode)");
|
|
} else {
|
|
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (API mode)");
|
|
}
|
|
|
|
// Phase 1: Tokenize input
|
|
let tokenize_start = std::time::Instant::now();
|
|
self.tokenizer.clear();
|
|
let mut tokens = self
|
|
.tokenizer
|
|
.tokenizer()
|
|
.encode(prompt, true)
|
|
.map_err(E::msg)?
|
|
.get_ids()
|
|
.to_vec();
|
|
|
|
let tokenize_time = tokenize_start.elapsed();
|
|
tracing::debug!("API Tokenization completed in {:.2?}", tokenize_time);
|
|
tracing::debug!("API Input tokens: {}", tokens.len());
|
|
|
|
// Write prompt tokens to output
|
|
for &t in tokens.iter() {
|
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
write!(output, "{}", t)?;
|
|
}
|
|
}
|
|
|
|
let mut generated_tokens = 0usize;
|
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
Some(token) => token,
|
|
None => anyhow::bail!("cannot find the <eos> token"),
|
|
};
|
|
|
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
Some(token) => token,
|
|
None => {
|
|
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
|
eos_token
|
|
}
|
|
};
|
|
|
|
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
|
|
// Both need special handling for shape compatibility
|
|
let needs_special_handling = match &self.model {
|
|
Model::V2(_) => true,
|
|
Model::V3(_) => true,
|
|
_ => false,
|
|
};
|
|
|
|
// Check if we're specifically using a Model3 (gemma-3) for additional error handling
|
|
// let is_model_v3 = matches!(&self.model, Model::V3(_));
|
|
|
|
// Track generation timing
|
|
let start_gen = std::time::Instant::now();
|
|
|
|
// Track per-token generation timing for performance analysis
|
|
let mut token_times = Vec::new();
|
|
let mut forward_times = Vec::new();
|
|
let mut repeat_penalty_times = Vec::new();
|
|
let mut sampling_times = Vec::new();
|
|
|
|
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
|
if needs_special_handling {
|
|
// For gemma-2 and gemma-3 models, we'll generate one token at a time with the full context
|
|
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models");
|
|
|
|
// Initial generation with the full prompt
|
|
let forward_start = std::time::Instant::now();
|
|
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
|
|
|
// Use execute_with_fallback which handles both device compatibility and shape mismatches
|
|
let mut logits = self.execute_with_fallback(&input, 0)?;
|
|
|
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
|
|
for _ in 0..sample_len {
|
|
let token_start = std::time::Instant::now();
|
|
|
|
// Apply repeat penalty using optimized cached implementation
|
|
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
|
repeat_penalty_times.push(repeat_time);
|
|
|
|
// Track token sampling
|
|
let sampling_start = std::time::Instant::now();
|
|
let next_token = self.logits_processor.sample(¤t_logits)?;
|
|
let sampling_time = sampling_start.elapsed();
|
|
sampling_times.push(sampling_time);
|
|
|
|
tokens.push(next_token);
|
|
generated_tokens += 1;
|
|
|
|
if next_token == eos_token || next_token == eot_token {
|
|
break;
|
|
}
|
|
|
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
write!(output, "{}", t)?;
|
|
}
|
|
|
|
// For the next iteration, just use the new token
|
|
let forward_start = std::time::Instant::now();
|
|
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
|
|
|
// Use execute_with_fallback for both Gemma 3 and other models
|
|
logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?;
|
|
|
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
|
|
let token_time = token_start.elapsed();
|
|
token_times.push(token_time);
|
|
}
|
|
|
|
let dt = start_gen.elapsed();
|
|
|
|
// Calculate and log performance metrics
|
|
Self::log_performance_metrics(
|
|
dt, generated_tokens, &token_times, &forward_times,
|
|
&repeat_penalty_times, &sampling_times, tokenize_time,
|
|
std::time::Duration::from_secs(0), start_time, "API"
|
|
);
|
|
|
|
return Ok(());
|
|
}
|
|
|
|
// Standard approach for other models
|
|
tracing::debug!("Using standard generation approach");
|
|
|
|
for index in 0..sample_len {
|
|
let token_start = std::time::Instant::now();
|
|
|
|
// Use sliding window context instead of single token to preserve context and reduce repetition
|
|
let context_size = if index > 0 {
|
|
std::cmp::min(self.context_window_size, tokens.len())
|
|
} else {
|
|
tokens.len()
|
|
};
|
|
let start_pos = tokens.len().saturating_sub(context_size);
|
|
let ctxt = &tokens[start_pos..];
|
|
|
|
tracing::debug!("API standard model: Using sliding window context: {} tokens (from position {})",
|
|
ctxt.len(), start_pos);
|
|
|
|
// Track tensor operations and model forward pass
|
|
let forward_start = std::time::Instant::now();
|
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
let logits = self.execute_with_fallback(&input, start_pos)?;
|
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
|
|
// Apply repeat penalty using optimized cached implementation
|
|
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
|
repeat_penalty_times.push(repeat_time);
|
|
|
|
// Track token sampling
|
|
let sampling_start = std::time::Instant::now();
|
|
let next_token = self.logits_processor.sample(&logits)?;
|
|
let sampling_time = sampling_start.elapsed();
|
|
sampling_times.push(sampling_time);
|
|
|
|
tokens.push(next_token);
|
|
generated_tokens += 1;
|
|
if next_token == eos_token || next_token == eot_token {
|
|
break;
|
|
}
|
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
write!(output, "{}", t)?;
|
|
}
|
|
|
|
let token_time = token_start.elapsed();
|
|
token_times.push(token_time);
|
|
}
|
|
|
|
let dt = start_gen.elapsed();
|
|
|
|
// Phase 3: Final decoding and output
|
|
let decode_start = std::time::Instant::now();
|
|
|
|
// Write any remaining tokens
|
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
write!(output, "{}", rest)?;
|
|
}
|
|
|
|
let decode_time = decode_start.elapsed();
|
|
|
|
// Log performance metrics
|
|
Self::log_performance_metrics(
|
|
dt, generated_tokens, &token_times, &forward_times,
|
|
&repeat_penalty_times, &sampling_times, tokenize_time,
|
|
decode_time, start_time, "API"
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// Run text generation with streaming callback for each token
|
|
pub async fn run_with_streaming<F>(&mut self, prompt: &str, sample_len: usize, mut token_callback: F) -> Result<String>
|
|
where
|
|
F: FnMut(&str) -> Result<()>,
|
|
{
|
|
// Track overall performance
|
|
let start_time = std::time::Instant::now();
|
|
|
|
// Keep penalty cache across generation for better repetition prevention
|
|
// Only clear cache if it becomes too large to prevent memory bloat
|
|
if self.penalty_cache.len() > 10000 {
|
|
self.penalty_cache.clear();
|
|
tracing::debug!("Cleared penalty cache due to size limit (streaming mode)");
|
|
} else {
|
|
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (streaming mode)");
|
|
}
|
|
|
|
// Phase 1: Tokenize input
|
|
let tokenize_start = std::time::Instant::now();
|
|
self.tokenizer.clear();
|
|
let mut tokens = self
|
|
.tokenizer
|
|
.tokenizer()
|
|
.encode(prompt, true)
|
|
.map_err(E::msg)?
|
|
.get_ids()
|
|
.to_vec();
|
|
|
|
let tokenize_time = tokenize_start.elapsed();
|
|
tracing::debug!("Streaming Tokenization completed in {:.2?}", tokenize_time);
|
|
tracing::debug!("Streaming Input tokens: {}", tokens.len());
|
|
|
|
// Collect all output for final return
|
|
let mut full_output = String::new();
|
|
|
|
let mut generated_tokens = 0usize;
|
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
Some(token) => token,
|
|
None => anyhow::bail!("cannot find the <eos> token"),
|
|
};
|
|
|
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
Some(token) => token,
|
|
None => {
|
|
tracing::warn!("Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup");
|
|
eos_token
|
|
}
|
|
};
|
|
|
|
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
|
|
let needs_special_handling = match &self.model {
|
|
Model::V2(_) => true,
|
|
Model::V3(_) => true,
|
|
_ => false,
|
|
};
|
|
|
|
// Track generation timing
|
|
let start_gen = std::time::Instant::now();
|
|
|
|
// Track per-token generation timing for performance analysis
|
|
let mut token_times = Vec::new();
|
|
let mut forward_times = Vec::new();
|
|
let mut repeat_penalty_times = Vec::new();
|
|
let mut sampling_times = Vec::new();
|
|
|
|
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
|
if needs_special_handling {
|
|
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models (streaming)");
|
|
tracing::debug!("Streaming: sample_len = {}", sample_len);
|
|
|
|
// Initial generation with the full prompt
|
|
let forward_start = std::time::Instant::now();
|
|
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
|
|
|
let mut logits = self.execute_with_fallback(&input, 0)?;
|
|
|
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
|
|
tracing::debug!("Streaming: About to enter generation loop with sample_len = {}", sample_len);
|
|
for gen_index in 0..sample_len {
|
|
tracing::debug!("Streaming: Starting generation iteration {} / {}", gen_index + 1, sample_len);
|
|
let token_start = std::time::Instant::now();
|
|
|
|
// Apply repeat penalty using optimized cached implementation
|
|
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
|
repeat_penalty_times.push(repeat_time);
|
|
|
|
// Track token sampling
|
|
let sampling_start = std::time::Instant::now();
|
|
let next_token = self.logits_processor.sample(¤t_logits)?;
|
|
let sampling_time = sampling_start.elapsed();
|
|
sampling_times.push(sampling_time);
|
|
|
|
tokens.push(next_token);
|
|
generated_tokens += 1;
|
|
|
|
tracing::debug!("Streaming: Generated token {} (id: {}), eos: {}, eot: {}",
|
|
next_token, next_token, eos_token, eot_token);
|
|
if next_token == eos_token || next_token == eot_token {
|
|
tracing::debug!("Streaming: Breaking due to end token");
|
|
break;
|
|
}
|
|
|
|
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
|
|
full_output.push_str(&token_text);
|
|
// Call the streaming callback with this token
|
|
token_callback(&token_text)?;
|
|
}
|
|
|
|
// For the next iteration, use single token to avoid shape mismatch
|
|
let forward_start = std::time::Instant::now();
|
|
tracing::debug!("Streaming: Preparing next forward pass with {} tokens", tokens.len());
|
|
|
|
// Use just the last token for subsequent iterations to avoid shape mismatch
|
|
// This is required for Gemma model's attention mechanism compatibility
|
|
let context_tokens = &tokens[(tokens.len()-1)..];
|
|
let start_pos = tokens.len() - 1;
|
|
|
|
tracing::debug!("Streaming: Using single token context for Gemma: {} tokens (from position {})",
|
|
context_tokens.len(), start_pos);
|
|
|
|
let new_input = match Tensor::new(context_tokens, &self.device) {
|
|
Ok(tensor) => tensor,
|
|
Err(e) => {
|
|
tracing::error!("Streaming: Failed to create input tensor: {}", e);
|
|
return Err(e.into());
|
|
}
|
|
};
|
|
|
|
let new_input = match new_input.unsqueeze(0) {
|
|
Ok(tensor) => tensor,
|
|
Err(e) => {
|
|
tracing::error!("Streaming: Failed to unsqueeze input tensor: {}", e);
|
|
return Err(e.into());
|
|
}
|
|
};
|
|
|
|
tracing::debug!("Streaming: About to call execute_with_fallback for iteration {} with start_pos {}", gen_index + 1, start_pos);
|
|
logits = match self.execute_with_fallback(&new_input, start_pos) {
|
|
Ok(result) => result,
|
|
Err(e) => {
|
|
tracing::error!("Streaming: Forward pass failed: {}", e);
|
|
return Err(e);
|
|
}
|
|
};
|
|
|
|
logits = match logits.squeeze(0) {
|
|
Ok(result) => result,
|
|
Err(e) => {
|
|
tracing::error!("Streaming: Failed to squeeze logits (dim 0): {}", e);
|
|
return Err(e.into());
|
|
}
|
|
};
|
|
|
|
logits = match logits.squeeze(0) {
|
|
Ok(result) => result,
|
|
Err(e) => {
|
|
tracing::error!("Streaming: Failed to squeeze logits (dim 0 again): {}", e);
|
|
return Err(e.into());
|
|
}
|
|
};
|
|
|
|
logits = match logits.to_dtype(DType::F32) {
|
|
Ok(result) => result,
|
|
Err(e) => {
|
|
tracing::error!("Streaming: Failed to convert logits to F32: {}", e);
|
|
return Err(e.into());
|
|
}
|
|
};
|
|
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
tracing::debug!("Streaming: Forward pass completed for iteration {}", gen_index + 1);
|
|
|
|
let token_time = token_start.elapsed();
|
|
token_times.push(token_time);
|
|
|
|
// Yield to allow other async tasks to run
|
|
tokio::task::yield_now().await;
|
|
}
|
|
} else {
|
|
// Standard approach for other models
|
|
tracing::debug!("Using standard generation approach (streaming)");
|
|
|
|
for index in 0..sample_len {
|
|
let token_start = std::time::Instant::now();
|
|
|
|
// Use sliding window context instead of single token to preserve context and reduce repetition
|
|
let context_size = if index > 0 {
|
|
std::cmp::min(self.context_window_size, tokens.len())
|
|
} else {
|
|
tokens.len()
|
|
};
|
|
let start_pos = tokens.len().saturating_sub(context_size);
|
|
let ctxt = &tokens[start_pos..];
|
|
|
|
tracing::debug!("Standard model: Using sliding window context: {} tokens (from position {})",
|
|
ctxt.len(), start_pos);
|
|
|
|
// Track tensor operations and model forward pass
|
|
let forward_start = std::time::Instant::now();
|
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
let logits = self.execute_with_fallback(&input, start_pos)?;
|
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let forward_time = forward_start.elapsed();
|
|
forward_times.push(forward_time);
|
|
|
|
// Apply repeat penalty using optimized cached implementation
|
|
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
|
repeat_penalty_times.push(repeat_time);
|
|
|
|
// Track token sampling
|
|
let sampling_start = std::time::Instant::now();
|
|
let next_token = self.logits_processor.sample(&logits)?;
|
|
let sampling_time = sampling_start.elapsed();
|
|
sampling_times.push(sampling_time);
|
|
|
|
tokens.push(next_token);
|
|
generated_tokens += 1;
|
|
if next_token == eos_token || next_token == eot_token {
|
|
break;
|
|
}
|
|
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
|
|
full_output.push_str(&token_text);
|
|
// Call the streaming callback with this token
|
|
token_callback(&token_text)?;
|
|
}
|
|
|
|
let token_time = token_start.elapsed();
|
|
token_times.push(token_time);
|
|
}
|
|
}
|
|
|
|
let dt = start_gen.elapsed();
|
|
|
|
// Phase 3: Final decoding
|
|
let decode_start = std::time::Instant::now();
|
|
|
|
// Decode any remaining tokens but don't send through callback to avoid repetition
|
|
// The tokens were already streamed individually in the generation loop above
|
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
full_output.push_str(&rest);
|
|
// Note: NOT calling token_callback(&rest) here to prevent token repetition
|
|
// Individual tokens were already streamed via the callback in the generation loop
|
|
}
|
|
|
|
let decode_time = decode_start.elapsed();
|
|
|
|
// Log performance metrics
|
|
Self::log_performance_metrics(
|
|
dt, generated_tokens, &token_times, &forward_times,
|
|
&repeat_penalty_times, &sampling_times, tokenize_time,
|
|
decode_time, start_time, "Streaming"
|
|
);
|
|
|
|
Ok(full_output)
|
|
}
|
|
|
|
// Helper function for logging performance metrics
|
|
fn log_performance_metrics(
|
|
generation_time: std::time::Duration,
|
|
generated_tokens: usize,
|
|
token_times: &[std::time::Duration],
|
|
forward_times: &[std::time::Duration],
|
|
repeat_penalty_times: &[std::time::Duration],
|
|
sampling_times: &[std::time::Duration],
|
|
tokenize_time: std::time::Duration,
|
|
decode_time: std::time::Duration,
|
|
start_time: std::time::Instant,
|
|
prefix: &str,
|
|
) {
|
|
// Calculate generation speed
|
|
let tokens_per_second = if generation_time.as_secs_f64() > 0.0 {
|
|
generated_tokens as f64 / generation_time.as_secs_f64()
|
|
} else {
|
|
0.0
|
|
};
|
|
|
|
// Calculate average time per token and component breakdown
|
|
let avg_token_time = if !token_times.is_empty() {
|
|
token_times.iter().sum::<std::time::Duration>() / token_times.len() as u32
|
|
} else {
|
|
std::time::Duration::from_secs(0)
|
|
};
|
|
|
|
let avg_forward_time = if !forward_times.is_empty() {
|
|
forward_times.iter().sum::<std::time::Duration>() / forward_times.len() as u32
|
|
} else {
|
|
std::time::Duration::from_secs(0)
|
|
};
|
|
|
|
let avg_repeat_time = if !repeat_penalty_times.is_empty() {
|
|
repeat_penalty_times.iter().sum::<std::time::Duration>() / repeat_penalty_times.len() as u32
|
|
} else {
|
|
std::time::Duration::from_secs(0)
|
|
};
|
|
|
|
let avg_sampling_time = if !sampling_times.is_empty() {
|
|
sampling_times.iter().sum::<std::time::Duration>() / sampling_times.len() as u32
|
|
} else {
|
|
std::time::Duration::from_secs(0)
|
|
};
|
|
|
|
// Record detailed performance metrics
|
|
tracing::info!("{} Text generation completed in {:.2?}", prefix, generation_time);
|
|
tracing::info!("{} Tokens generated: {}", prefix, generated_tokens);
|
|
tracing::info!("{} Generation speed: {:.2} tokens/second", prefix, tokens_per_second);
|
|
tracing::info!("{} Average time per token: {:.2?}", prefix, avg_token_time);
|
|
|
|
if !avg_token_time.is_zero() {
|
|
tracing::debug!("{} - Forward pass: {:.2?} ({:.1}%)",
|
|
prefix,
|
|
avg_forward_time,
|
|
avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
|
);
|
|
tracing::debug!("{} - Repeat penalty: {:.2?} ({:.1}%)",
|
|
prefix,
|
|
avg_repeat_time,
|
|
avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
|
);
|
|
tracing::debug!("{} - Sampling: {:.2?} ({:.1}%)",
|
|
prefix,
|
|
avg_sampling_time,
|
|
avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
|
);
|
|
}
|
|
|
|
// Log total request time
|
|
let total_time = start_time.elapsed();
|
|
tracing::info!("{} Total request time: {:.2?}", prefix, total_time);
|
|
|
|
if !total_time.is_zero() {
|
|
tracing::debug!("{} - Tokenization: {:.2?} ({:.1}%)",
|
|
prefix,
|
|
tokenize_time,
|
|
tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
|
);
|
|
tracing::debug!("{} - Generation: {:.2?} ({:.1}%)",
|
|
prefix,
|
|
generation_time,
|
|
generation_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
|
);
|
|
tracing::debug!("{} - Final decoding: {:.2?} ({:.1}%)",
|
|
prefix,
|
|
decode_time,
|
|
decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
|
);
|
|
}
|
|
}
|
|
} |