- Change default server host to localhost for improved security.

- Increase default maximum tokens in CLI configuration to 256.
- Refactor and reorganize CLI
This commit is contained in:
geoffsee
2025-08-27 21:47:24 -04:00
parent 766d41af78
commit 719beb3791
20 changed files with 1703 additions and 490 deletions

View File

@@ -20,6 +20,8 @@ pub struct TextGeneration {
repeat_last_n: usize,
// Cache for repeat penalty computation to avoid redundant calculations
penalty_cache: HashMap<usize, f32>,
// Context window size for sliding window context (default: 64 tokens)
context_window_size: usize,
}
impl TextGeneration {
@@ -55,6 +57,7 @@ impl TextGeneration {
cpu_device,
try_primary_device,
penalty_cache: HashMap::new(),
context_window_size: 64, // Default sliding window size for better context preservation
}
}
@@ -192,9 +195,14 @@ impl TextGeneration {
// Track overall performance
let start_time = std::time::Instant::now();
// Clear penalty cache for new generation
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache for new generation");
// Keep penalty cache across generation for better repetition prevention
// Only clear cache if it becomes too large to prevent memory bloat
if self.penalty_cache.len() > 10000 {
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache due to size limit");
} else {
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention");
}
// Phase 1: Tokenize input
let tokenize_start = std::time::Instant::now();
@@ -440,9 +448,14 @@ impl TextGeneration {
// Track overall performance
let start_time = std::time::Instant::now();
// Clear penalty cache for new generation
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache for new generation (API mode)");
// Keep penalty cache across generation for better repetition prevention
// Only clear cache if it becomes too large to prevent memory bloat
if self.penalty_cache.len() > 10000 {
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache due to size limit (API mode)");
} else {
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (API mode)");
}
// Phase 1: Tokenize input
let tokenize_start = std::time::Instant::now();
@@ -573,10 +586,18 @@ impl TextGeneration {
for index in 0..sample_len {
let token_start = std::time::Instant::now();
let context_size = if index > 0 { 1 } else { tokens.len() };
// Use sliding window context instead of single token to preserve context and reduce repetition
let context_size = if index > 0 {
std::cmp::min(self.context_window_size, tokens.len())
} else {
tokens.len()
};
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
tracing::debug!("API standard model: Using sliding window context: {} tokens (from position {})",
ctxt.len(), start_pos);
// Track tensor operations and model forward pass
let forward_start = std::time::Instant::now();
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
@@ -629,6 +650,266 @@ impl TextGeneration {
Ok(())
}
// Run text generation with streaming callback for each token
pub async fn run_with_streaming<F>(&mut self, prompt: &str, sample_len: usize, mut token_callback: F) -> Result<String>
where
F: FnMut(&str) -> Result<()>,
{
// Track overall performance
let start_time = std::time::Instant::now();
// Keep penalty cache across generation for better repetition prevention
// Only clear cache if it becomes too large to prevent memory bloat
if self.penalty_cache.len() > 10000 {
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache due to size limit (streaming mode)");
} else {
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (streaming mode)");
}
// Phase 1: Tokenize input
let tokenize_start = std::time::Instant::now();
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokenize_time = tokenize_start.elapsed();
tracing::debug!("Streaming Tokenization completed in {:.2?}", tokenize_time);
tracing::debug!("Streaming Input tokens: {}", tokens.len());
// Collect all output for final return
let mut full_output = String::new();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
tracing::warn!("Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup");
eos_token
}
};
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
let needs_special_handling = match &self.model {
Model::V2(_) => true,
Model::V3(_) => true,
_ => false,
};
// Track generation timing
let start_gen = std::time::Instant::now();
// Track per-token generation timing for performance analysis
let mut token_times = Vec::new();
let mut forward_times = Vec::new();
let mut repeat_penalty_times = Vec::new();
let mut sampling_times = Vec::new();
// For Model2 and Model3, we need to use a special approach for shape compatibility
if needs_special_handling {
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models (streaming)");
tracing::debug!("Streaming: sample_len = {}", sample_len);
// Initial generation with the full prompt
let forward_start = std::time::Instant::now();
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let mut logits = self.execute_with_fallback(&input, 0)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let forward_time = forward_start.elapsed();
forward_times.push(forward_time);
tracing::debug!("Streaming: About to enter generation loop with sample_len = {}", sample_len);
for gen_index in 0..sample_len {
tracing::debug!("Streaming: Starting generation iteration {} / {}", gen_index + 1, sample_len);
let token_start = std::time::Instant::now();
// Apply repeat penalty using optimized cached implementation
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
repeat_penalty_times.push(repeat_time);
// Track token sampling
let sampling_start = std::time::Instant::now();
let next_token = self.logits_processor.sample(&current_logits)?;
let sampling_time = sampling_start.elapsed();
sampling_times.push(sampling_time);
tokens.push(next_token);
generated_tokens += 1;
tracing::debug!("Streaming: Generated token {} (id: {}), eos: {}, eot: {}",
next_token, next_token, eos_token, eot_token);
if next_token == eos_token || next_token == eot_token {
tracing::debug!("Streaming: Breaking due to end token");
break;
}
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
full_output.push_str(&token_text);
// Call the streaming callback with this token
token_callback(&token_text)?;
}
// For the next iteration, use single token to avoid shape mismatch
let forward_start = std::time::Instant::now();
tracing::debug!("Streaming: Preparing next forward pass with {} tokens", tokens.len());
// Use just the last token for subsequent iterations to avoid shape mismatch
// This is required for Gemma model's attention mechanism compatibility
let context_tokens = &tokens[(tokens.len()-1)..];
let start_pos = tokens.len() - 1;
tracing::debug!("Streaming: Using single token context for Gemma: {} tokens (from position {})",
context_tokens.len(), start_pos);
let new_input = match Tensor::new(context_tokens, &self.device) {
Ok(tensor) => tensor,
Err(e) => {
tracing::error!("Streaming: Failed to create input tensor: {}", e);
return Err(e.into());
}
};
let new_input = match new_input.unsqueeze(0) {
Ok(tensor) => tensor,
Err(e) => {
tracing::error!("Streaming: Failed to unsqueeze input tensor: {}", e);
return Err(e.into());
}
};
tracing::debug!("Streaming: About to call execute_with_fallback for iteration {} with start_pos {}", gen_index + 1, start_pos);
logits = match self.execute_with_fallback(&new_input, start_pos) {
Ok(result) => result,
Err(e) => {
tracing::error!("Streaming: Forward pass failed: {}", e);
return Err(e);
}
};
logits = match logits.squeeze(0) {
Ok(result) => result,
Err(e) => {
tracing::error!("Streaming: Failed to squeeze logits (dim 0): {}", e);
return Err(e.into());
}
};
logits = match logits.squeeze(0) {
Ok(result) => result,
Err(e) => {
tracing::error!("Streaming: Failed to squeeze logits (dim 0 again): {}", e);
return Err(e.into());
}
};
logits = match logits.to_dtype(DType::F32) {
Ok(result) => result,
Err(e) => {
tracing::error!("Streaming: Failed to convert logits to F32: {}", e);
return Err(e.into());
}
};
let forward_time = forward_start.elapsed();
forward_times.push(forward_time);
tracing::debug!("Streaming: Forward pass completed for iteration {}", gen_index + 1);
let token_time = token_start.elapsed();
token_times.push(token_time);
// Yield to allow other async tasks to run
tokio::task::yield_now().await;
}
} else {
// Standard approach for other models
tracing::debug!("Using standard generation approach (streaming)");
for index in 0..sample_len {
let token_start = std::time::Instant::now();
// Use sliding window context instead of single token to preserve context and reduce repetition
let context_size = if index > 0 {
std::cmp::min(self.context_window_size, tokens.len())
} else {
tokens.len()
};
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
tracing::debug!("Standard model: Using sliding window context: {} tokens (from position {})",
ctxt.len(), start_pos);
// Track tensor operations and model forward pass
let forward_start = std::time::Instant::now();
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.execute_with_fallback(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let forward_time = forward_start.elapsed();
forward_times.push(forward_time);
// Apply repeat penalty using optimized cached implementation
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
repeat_penalty_times.push(repeat_time);
// Track token sampling
let sampling_start = std::time::Instant::now();
let next_token = self.logits_processor.sample(&logits)?;
let sampling_time = sampling_start.elapsed();
sampling_times.push(sampling_time);
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
full_output.push_str(&token_text);
// Call the streaming callback with this token
token_callback(&token_text)?;
}
let token_time = token_start.elapsed();
token_times.push(token_time);
}
}
let dt = start_gen.elapsed();
// Phase 3: Final decoding
let decode_start = std::time::Instant::now();
// Decode any remaining tokens but don't send through callback to avoid repetition
// The tokens were already streamed individually in the generation loop above
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
full_output.push_str(&rest);
// Note: NOT calling token_callback(&rest) here to prevent token repetition
// Individual tokens were already streamed via the callback in the generation loop
}
let decode_time = decode_start.elapsed();
// Log performance metrics
Self::log_performance_metrics(
dt, generated_tokens, &token_times, &forward_times,
&repeat_penalty_times, &sampling_times, tokenize_time,
decode_time, start_time, "Streaming"
);
Ok(full_output)
}
// Helper function for logging performance metrics
fn log_performance_metrics(