mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
- Change default server host to localhost for improved security.
- Increase default maximum tokens in CLI configuration to 256. - Refactor and reorganize CLI
This commit is contained in:
@@ -20,6 +20,8 @@ pub struct TextGeneration {
|
||||
repeat_last_n: usize,
|
||||
// Cache for repeat penalty computation to avoid redundant calculations
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
// Context window size for sliding window context (default: 64 tokens)
|
||||
context_window_size: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
@@ -55,6 +57,7 @@ impl TextGeneration {
|
||||
cpu_device,
|
||||
try_primary_device,
|
||||
penalty_cache: HashMap::new(),
|
||||
context_window_size: 64, // Default sliding window size for better context preservation
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,9 +195,14 @@ impl TextGeneration {
|
||||
// Track overall performance
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Clear penalty cache for new generation
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache for new generation");
|
||||
// Keep penalty cache across generation for better repetition prevention
|
||||
// Only clear cache if it becomes too large to prevent memory bloat
|
||||
if self.penalty_cache.len() > 10000 {
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache due to size limit");
|
||||
} else {
|
||||
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention");
|
||||
}
|
||||
|
||||
// Phase 1: Tokenize input
|
||||
let tokenize_start = std::time::Instant::now();
|
||||
@@ -440,9 +448,14 @@ impl TextGeneration {
|
||||
// Track overall performance
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Clear penalty cache for new generation
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache for new generation (API mode)");
|
||||
// Keep penalty cache across generation for better repetition prevention
|
||||
// Only clear cache if it becomes too large to prevent memory bloat
|
||||
if self.penalty_cache.len() > 10000 {
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache due to size limit (API mode)");
|
||||
} else {
|
||||
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (API mode)");
|
||||
}
|
||||
|
||||
// Phase 1: Tokenize input
|
||||
let tokenize_start = std::time::Instant::now();
|
||||
@@ -573,10 +586,18 @@ impl TextGeneration {
|
||||
for index in 0..sample_len {
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
// Use sliding window context instead of single token to preserve context and reduce repetition
|
||||
let context_size = if index > 0 {
|
||||
std::cmp::min(self.context_window_size, tokens.len())
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
|
||||
tracing::debug!("API standard model: Using sliding window context: {} tokens (from position {})",
|
||||
ctxt.len(), start_pos);
|
||||
|
||||
// Track tensor operations and model forward pass
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
@@ -629,6 +650,266 @@ impl TextGeneration {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run text generation with streaming callback for each token
|
||||
pub async fn run_with_streaming<F>(&mut self, prompt: &str, sample_len: usize, mut token_callback: F) -> Result<String>
|
||||
where
|
||||
F: FnMut(&str) -> Result<()>,
|
||||
{
|
||||
// Track overall performance
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Keep penalty cache across generation for better repetition prevention
|
||||
// Only clear cache if it becomes too large to prevent memory bloat
|
||||
if self.penalty_cache.len() > 10000 {
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache due to size limit (streaming mode)");
|
||||
} else {
|
||||
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (streaming mode)");
|
||||
}
|
||||
|
||||
// Phase 1: Tokenize input
|
||||
let tokenize_start = std::time::Instant::now();
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let tokenize_time = tokenize_start.elapsed();
|
||||
tracing::debug!("Streaming Tokenization completed in {:.2?}", tokenize_time);
|
||||
tracing::debug!("Streaming Input tokens: {}", tokens.len());
|
||||
|
||||
// Collect all output for final return
|
||||
let mut full_output = String::new();
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
tracing::warn!("Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup");
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
|
||||
let needs_special_handling = match &self.model {
|
||||
Model::V2(_) => true,
|
||||
Model::V3(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// Track generation timing
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
// Track per-token generation timing for performance analysis
|
||||
let mut token_times = Vec::new();
|
||||
let mut forward_times = Vec::new();
|
||||
let mut repeat_penalty_times = Vec::new();
|
||||
let mut sampling_times = Vec::new();
|
||||
|
||||
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
||||
if needs_special_handling {
|
||||
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models (streaming)");
|
||||
tracing::debug!("Streaming: sample_len = {}", sample_len);
|
||||
|
||||
// Initial generation with the full prompt
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
|
||||
let mut logits = self.execute_with_fallback(&input, 0)?;
|
||||
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
tracing::debug!("Streaming: About to enter generation loop with sample_len = {}", sample_len);
|
||||
for gen_index in 0..sample_len {
|
||||
tracing::debug!("Streaming: Starting generation iteration {} / {}", gen_index + 1, sample_len);
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
// Apply repeat penalty using optimized cached implementation
|
||||
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
repeat_penalty_times.push(repeat_time);
|
||||
|
||||
// Track token sampling
|
||||
let sampling_start = std::time::Instant::now();
|
||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||
let sampling_time = sampling_start.elapsed();
|
||||
sampling_times.push(sampling_time);
|
||||
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
tracing::debug!("Streaming: Generated token {} (id: {}), eos: {}, eot: {}",
|
||||
next_token, next_token, eos_token, eot_token);
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
tracing::debug!("Streaming: Breaking due to end token");
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
|
||||
full_output.push_str(&token_text);
|
||||
// Call the streaming callback with this token
|
||||
token_callback(&token_text)?;
|
||||
}
|
||||
|
||||
// For the next iteration, use single token to avoid shape mismatch
|
||||
let forward_start = std::time::Instant::now();
|
||||
tracing::debug!("Streaming: Preparing next forward pass with {} tokens", tokens.len());
|
||||
|
||||
// Use just the last token for subsequent iterations to avoid shape mismatch
|
||||
// This is required for Gemma model's attention mechanism compatibility
|
||||
let context_tokens = &tokens[(tokens.len()-1)..];
|
||||
let start_pos = tokens.len() - 1;
|
||||
|
||||
tracing::debug!("Streaming: Using single token context for Gemma: {} tokens (from position {})",
|
||||
context_tokens.len(), start_pos);
|
||||
|
||||
let new_input = match Tensor::new(context_tokens, &self.device) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to create input tensor: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
let new_input = match new_input.unsqueeze(0) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to unsqueeze input tensor: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!("Streaming: About to call execute_with_fallback for iteration {} with start_pos {}", gen_index + 1, start_pos);
|
||||
logits = match self.execute_with_fallback(&new_input, start_pos) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Forward pass failed: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
logits = match logits.squeeze(0) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to squeeze logits (dim 0): {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
logits = match logits.squeeze(0) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to squeeze logits (dim 0 again): {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
logits = match logits.to_dtype(DType::F32) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to convert logits to F32: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
tracing::debug!("Streaming: Forward pass completed for iteration {}", gen_index + 1);
|
||||
|
||||
let token_time = token_start.elapsed();
|
||||
token_times.push(token_time);
|
||||
|
||||
// Yield to allow other async tasks to run
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
} else {
|
||||
// Standard approach for other models
|
||||
tracing::debug!("Using standard generation approach (streaming)");
|
||||
|
||||
for index in 0..sample_len {
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
// Use sliding window context instead of single token to preserve context and reduce repetition
|
||||
let context_size = if index > 0 {
|
||||
std::cmp::min(self.context_window_size, tokens.len())
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
|
||||
tracing::debug!("Standard model: Using sliding window context: {} tokens (from position {})",
|
||||
ctxt.len(), start_pos);
|
||||
|
||||
// Track tensor operations and model forward pass
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.execute_with_fallback(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
// Apply repeat penalty using optimized cached implementation
|
||||
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
||||
repeat_penalty_times.push(repeat_time);
|
||||
|
||||
// Track token sampling
|
||||
let sampling_start = std::time::Instant::now();
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
let sampling_time = sampling_start.elapsed();
|
||||
sampling_times.push(sampling_time);
|
||||
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
|
||||
full_output.push_str(&token_text);
|
||||
// Call the streaming callback with this token
|
||||
token_callback(&token_text)?;
|
||||
}
|
||||
|
||||
let token_time = token_start.elapsed();
|
||||
token_times.push(token_time);
|
||||
}
|
||||
}
|
||||
|
||||
let dt = start_gen.elapsed();
|
||||
|
||||
// Phase 3: Final decoding
|
||||
let decode_start = std::time::Instant::now();
|
||||
|
||||
// Decode any remaining tokens but don't send through callback to avoid repetition
|
||||
// The tokens were already streamed individually in the generation loop above
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
full_output.push_str(&rest);
|
||||
// Note: NOT calling token_callback(&rest) here to prevent token repetition
|
||||
// Individual tokens were already streamed via the callback in the generation loop
|
||||
}
|
||||
|
||||
let decode_time = decode_start.elapsed();
|
||||
|
||||
// Log performance metrics
|
||||
Self::log_performance_metrics(
|
||||
dt, generated_tokens, &token_times, &forward_times,
|
||||
&repeat_penalty_times, &sampling_times, tokenize_time,
|
||||
decode_time, start_time, "Streaming"
|
||||
);
|
||||
|
||||
Ok(full_output)
|
||||
}
|
||||
|
||||
// Helper function for logging performance metrics
|
||||
fn log_performance_metrics(
|
||||
|
Reference in New Issue
Block a user