mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
- Change default server host to localhost for improved security.
- Increase default maximum tokens in CLI configuration to 256. - Refactor and reorganize CLI
This commit is contained in:
@@ -44,6 +44,7 @@ axum = { version = "0.8.4", features = ["json"] }
|
||||
tower = "0.5.2"
|
||||
tower-http = { version = "0.6.6", features = ["cors"] }
|
||||
tokio = { version = "1.43.0", features = ["full"] }
|
||||
tokio-stream = { version = "0.1.16", features = ["sync"] }
|
||||
either = { version = "1.9.0", features = ["serde"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
@@ -80,4 +81,13 @@ tokio = "1.43.0"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/geoffsee/inference-service:latest"
|
||||
replicas = 1
|
||||
port = 8080
|
||||
resources.cpu = "500m"
|
||||
resources.memory = "256Mi"
|
||||
#ingress.host = "my-service.example.com"
|
||||
#env = { RUST_LOG = "info", DATABASE_URL = "postgres://..." }
|
||||
|
@@ -17,7 +17,7 @@ use std::env;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
/// Server configuration constants
|
||||
pub const DEFAULT_SERVER_HOST: &str = "0.0.0.0";
|
||||
pub const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||
pub const DEFAULT_SERVER_PORT: &str = "8080";
|
||||
|
||||
/// Get server configuration from environment variables with defaults
|
||||
|
@@ -5,14 +5,13 @@ use axum::{
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use std::convert::Infallible;
|
||||
use candle_core::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use futures_util::stream::{self, Stream};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use std::convert::Infallible;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -46,40 +45,26 @@ impl Default for AppState {
|
||||
let text_generation = build_pipeline(args.clone());
|
||||
Self {
|
||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||
model_id: String::new(),
|
||||
model_id: args.model_id.clone(),
|
||||
build_args: args,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Pipeline configuration
|
||||
// -------------------------
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PipelineArgs {
|
||||
/// HF model repo id, e.g. "google/gemma-2b"
|
||||
pub model_id: String,
|
||||
|
||||
/// Which internal model family to instantiate
|
||||
pub which: Which,
|
||||
|
||||
/// Optional HF revision/branch/tag; None => "main"
|
||||
pub revision: Option<String>,
|
||||
|
||||
/// Optional explicit tokenizer path
|
||||
pub tokenizer_path: Option<PathBuf>,
|
||||
|
||||
/// Optional explicit config path
|
||||
pub config_path: Option<PathBuf>,
|
||||
|
||||
/// Optional explicit weight paths. If empty, they will be resolved from the hub.
|
||||
pub weight_paths: Vec<PathBuf>,
|
||||
|
||||
/// Runtime toggles
|
||||
pub use_flash_attn: bool,
|
||||
pub force_cpu: bool,
|
||||
|
||||
/// Sampling / decoding params
|
||||
pub seed: u64,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
@@ -98,39 +83,43 @@ impl Default for PipelineArgs {
|
||||
weight_paths: Vec::new(),
|
||||
use_flash_attn: false,
|
||||
force_cpu: false,
|
||||
seed: 0,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
repeat_penalty: 0.0,
|
||||
repeat_last_n: 0,
|
||||
seed: 299792458, // Speed of light in vacuum (m/s)
|
||||
temperature: Some(0.8), // Good balance between creativity and coherence
|
||||
top_p: Some(0.9), // Keep diverse but reasonable options
|
||||
repeat_penalty: 1.2, // Stronger penalty for repetition to prevent looping
|
||||
repeat_last_n: 64, // Consider last 64 tokens for repetition
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no owner/org is present, prefix with a sensible default (tweak as you like).
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
if model_id.contains('/') { model_id.to_string() } else { format!("google/{}", model_id) }
|
||||
if model_id.contains('/') {
|
||||
model_id.to_string()
|
||||
} else {
|
||||
format!("google/{}", model_id)
|
||||
}
|
||||
}
|
||||
|
||||
// Quick existence check, mapping 404 into a helpful message.
|
||||
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
|
||||
let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string()));
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id.to_string(),
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
));
|
||||
match repo.get("config.json") {
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => match e {
|
||||
ApiError::RequestError(resp) => {
|
||||
// For HF API, RequestError with 404 status is returned when repo doesn't exist
|
||||
let error_str = resp.to_string();
|
||||
if error_str.contains("404") {
|
||||
anyhow::bail!(
|
||||
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'. \
|
||||
Please provide a fully-qualified repo id like 'google/gemma-2b-it'."
|
||||
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'."
|
||||
)
|
||||
}
|
||||
Err(anyhow::Error::new(ApiError::RequestError(resp)))
|
||||
}
|
||||
other => Err(anyhow::Error::new(other)),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,18 +140,13 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
let api = Api::new().unwrap();
|
||||
let revision = args.revision.as_deref().unwrap_or("main");
|
||||
|
||||
// Check if model_id is empty before normalizing it
|
||||
println!("Checking model_id: '{}'", args.model_id);
|
||||
|
||||
println!("Trimmed model_id length: {}", args.model_id.trim().len());
|
||||
if args.model_id.trim().is_empty() {
|
||||
panic!("No model ID specified. Please provide a valid model ID (e.g., 'gemma-2b-it' or 'google/gemma-2b-it').");
|
||||
panic!("No model ID specified.");
|
||||
}
|
||||
args.model_id = normalize_model_id(&args.model_id);
|
||||
|
||||
// Validate early (nice error if the repo/revision is wrong).
|
||||
match ensure_repo_exists(&api, &args.model_id, revision) {
|
||||
Ok(_) => {},
|
||||
Ok(_) => {}
|
||||
Err(e) => panic!("{}", e),
|
||||
};
|
||||
|
||||
@@ -172,105 +156,82 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
revision.to_string(),
|
||||
));
|
||||
|
||||
// Resolve files (prefer explicit paths; fallback to hub)
|
||||
let tokenizer_path = args
|
||||
.tokenizer_path
|
||||
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
|
||||
|
||||
let config_path = args
|
||||
.config_path
|
||||
.unwrap_or_else(|| repo.get("config.json").unwrap());
|
||||
|
||||
// Only use auto-detection if no specific model type was provided
|
||||
// This ensures that explicitly specified model types are respected
|
||||
if !matches!(args.which,
|
||||
Which::Base2B | Which::Base7B |
|
||||
Which::Instruct2B | Which::Instruct7B |
|
||||
Which::InstructV1_1_2B | Which::InstructV1_1_7B |
|
||||
Which::CodeBase2B | Which::CodeBase7B |
|
||||
Which::CodeInstruct2B | Which::CodeInstruct7B |
|
||||
Which::BaseV2_2B | Which::InstructV2_2B |
|
||||
Which::BaseV2_9B | Which::InstructV2_9B |
|
||||
Which::BaseV3_1B | Which::InstructV3_1B) {
|
||||
|
||||
// If model_id is a known value, map it directly
|
||||
if !matches!(
|
||||
args.which,
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B
|
||||
| Which::InstructV3_1B
|
||||
) {
|
||||
if args.model_id.contains("gemma-2-2b-it") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
println!("Setting model type to InstructV2_2B based on model_id: {}", args.model_id);
|
||||
} else if args.model_id.contains("gemma-3-1b-it") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
println!("Setting model type to InstructV3_1B based on model_id: {}", args.model_id);
|
||||
} else {
|
||||
// Fallback to auto-detection from config.json
|
||||
if let Ok(file) = std::fs::File::open(config_path.clone()) {
|
||||
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
|
||||
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
|
||||
println!("Auto-detecting model type from config.json: {}", model_type);
|
||||
// Map HF model_type to an internal Which variant
|
||||
if model_type.contains("gemma3") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
println!("Setting model type to InstructV3_1B based on config");
|
||||
} else if model_type.contains("gemma2") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
println!("Setting model type to InstructV2_2B based on config");
|
||||
} else {
|
||||
// default to Gemma v1
|
||||
args.which = Which::Instruct2B;
|
||||
println!("Setting model type to Instruct2B (v1) based on config");
|
||||
}
|
||||
} else if let Ok(file) = std::fs::File::open(config_path.clone()) {
|
||||
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
|
||||
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
|
||||
if model_type.contains("gemma3") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
} else if model_type.contains("gemma2") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
} else {
|
||||
args.which = Which::Instruct2B;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("Using explicitly specified model type: {:?}", args.which);
|
||||
}
|
||||
|
||||
// Resolve weight files: try a single-file first, then fall back to sharded index
|
||||
let weight_paths = if !args.weight_paths.is_empty() {
|
||||
args.weight_paths
|
||||
} else {
|
||||
match repo.get("model.safetensors") {
|
||||
Ok(single) => vec![single],
|
||||
Err(_) => {
|
||||
match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
|
||||
Ok(paths) => paths,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Unable to locate model weights for '{}'. Tried 'model.safetensors' and 'model.safetensors.index.json'. Underlying error: {}",
|
||||
args.model_id, e
|
||||
);
|
||||
}
|
||||
Err(_) => match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
|
||||
Ok(paths) => paths,
|
||||
Err(e) => {
|
||||
panic!("Unable to locate model weights: {}", e);
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path)
|
||||
.map_err(anyhow::Error::msg)
|
||||
.unwrap();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
||||
|
||||
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
|
||||
|
||||
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
|
||||
let is_v3_model = args.which.is_v3_model();
|
||||
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.force_cpu;
|
||||
|
||||
// Use CPU for V3 models on Metal due to missing implementations
|
||||
let is_metal = !initial_device.is_cpu()
|
||||
&& candle_core::utils::metal_is_available()
|
||||
&& !args.force_cpu;
|
||||
|
||||
let device = if is_v3_model && is_metal {
|
||||
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
|
||||
candle_core::Device::Cpu
|
||||
} else {
|
||||
initial_device
|
||||
};
|
||||
|
||||
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
|
||||
|
||||
// Keep original device + dtype
|
||||
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
|
||||
|
||||
let model = match args.which {
|
||||
@@ -285,23 +246,18 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
GemmaModel::V1(model)
|
||||
GemmaModel::V1(Model1::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
GemmaModel::V2(model)
|
||||
GemmaModel::V2(Model2::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
||||
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
GemmaModel::V3(model)
|
||||
GemmaModel::V3(Model3::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -314,6 +270,43 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
)
|
||||
}
|
||||
|
||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
let mut system_prompt: Option<String> = None;
|
||||
|
||||
for message in messages {
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(),
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
match message.role.as_str() {
|
||||
"system" => system_prompt = Some(content),
|
||||
"user" => {
|
||||
prompt.push_str("<start_of_turn>user\n");
|
||||
if let Some(sys_prompt) = system_prompt.take() {
|
||||
prompt.push_str(&sys_prompt);
|
||||
prompt.push_str("\n\n");
|
||||
}
|
||||
prompt.push_str(&content);
|
||||
prompt.push_str("<end_of_turn>\n");
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt.push_str(&content);
|
||||
prompt.push_str("<end_of_turn>\n");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// OpenAI-compatible handler
|
||||
// -------------------------
|
||||
@@ -322,72 +315,68 @@ pub async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
// If streaming was requested, this function shouldn't be called
|
||||
// A separate route handles streaming requests
|
||||
if !request.stream.unwrap_or(false) {
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response())
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
|
||||
}
|
||||
|
||||
Ok(chat_completions_stream(state, request).await.into_response())
|
||||
}
|
||||
|
||||
pub async fn chat_completions_non_streaming_proxy(state: AppState, request: ChatCompletionRequest) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
// Non-streaming response - original implementation
|
||||
let mut prompt = String::new();
|
||||
pub async fn chat_completions_non_streaming_proxy(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
let prompt = build_gemma_prompt(&request.messages);
|
||||
|
||||
// Convert messages to a prompt string
|
||||
for message in &request.messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
match role.as_str() {
|
||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||
}
|
||||
}
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
let model_id = state.model_id.clone();
|
||||
|
||||
// Generate
|
||||
let mut output = Vec::new();
|
||||
{
|
||||
// Recreate TextGeneration instance to ensure completely fresh state
|
||||
// This prevents KV cache persistence that causes tensor shape mismatches
|
||||
let fresh_text_gen = build_pipeline(state.build_args.clone());
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
*text_gen = fresh_text_gen;
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
||||
|
||||
if let Err(e) = result {
|
||||
// Enforce model selection behavior: reject if a different model is requested
|
||||
let configured_model = state.build_args.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
if normalized_requested != configured_model {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Error generating text: {}", e),
|
||||
"type": "text_generation_error"
|
||||
"message": format!(
|
||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
||||
requested_model, configured_model
|
||||
),
|
||||
"type": "model_mismatch"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(text) = String::from_utf8(buffer) {
|
||||
output.push(text);
|
||||
let model_id = state.model_id.clone();
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
{
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
// Reset per-request state without rebuilding the whole pipeline
|
||||
text_gen.reset_state();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error generating text: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let completion = output.join("");
|
||||
let completion = match String::from_utf8(buffer) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("UTF-8 conversion error: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
|
||||
@@ -407,7 +396,6 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
// still rough estimates
|
||||
prompt_tokens: prompt.len() / 4,
|
||||
completion_tokens: completion.len() / 4,
|
||||
total_tokens: (prompt.len() + completion.len()) / 4,
|
||||
@@ -415,185 +403,195 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
|
||||
};
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Streaming implementation
|
||||
// -------------------------
|
||||
|
||||
pub async fn chat_completions_stream(
|
||||
state: AppState,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Call the handler function
|
||||
handle_streaming_request(state, chat_completion_request).await
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||
handle_streaming_request(state, request).await
|
||||
}
|
||||
|
||||
/// Handle streaming requests with Server-Sent Events (SSE)
|
||||
async fn handle_streaming_request(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Generate a unique ID for this completion
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||
// Validate requested model vs configured model
|
||||
let configured_model = state.build_args.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
if normalized_requested != configured_model {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!(
|
||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
||||
requested_model, configured_model
|
||||
),
|
||||
"type": "model_mismatch"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a unique ID and metadata
|
||||
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
let created = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let model_id = state.model_id.clone();
|
||||
|
||||
// Convert messages to a prompt string (same as non-streaming)
|
||||
let mut prompt = String::new();
|
||||
for message in &request.messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
match role.as_str() {
|
||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||
}
|
||||
// Build prompt
|
||||
let prompt = build_gemma_prompt(&request.messages);
|
||||
tracing::debug!("Formatted prompt: {}", prompt);
|
||||
|
||||
// Channel for streaming SSE events
|
||||
let (tx, rx) = mpsc::unbounded_channel::<Result<Event, Infallible>>();
|
||||
|
||||
// Send initial role event
|
||||
let initial_chunk = ChatCompletionChunk {
|
||||
id: response_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: Some("assistant".to_string()), content: None },
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&initial_chunk) {
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
// Generate text using existing buffer-based approach
|
||||
let mut buffer = Vec::new();
|
||||
{
|
||||
// Recreate TextGeneration instance to ensure completely fresh state
|
||||
// This prevents KV cache persistence that causes tensor shape mismatches
|
||||
let fresh_text_gen = build_pipeline(state.build_args.clone());
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
*text_gen = fresh_text_gen;
|
||||
|
||||
|
||||
// Spawn generation task that streams tokens as they are generated
|
||||
let state_clone = state.clone();
|
||||
let response_id_clone = response_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let mut text_gen = state_clone.text_generation.lock().await;
|
||||
text_gen.reset_state();
|
||||
|
||||
// Stream tokens via callback with repetition detection
|
||||
let mut recent_tokens = Vec::new();
|
||||
let mut repetition_count = 0;
|
||||
const MAX_REPETITION_COUNT: usize = 5; // Stop after 5 consecutive repetitions
|
||||
const REPETITION_WINDOW: usize = 8; // Look at last 8 tokens for patterns
|
||||
|
||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Error generating text: {}", e),
|
||||
"type": "text_generation_error"
|
||||
let result = text_gen.run_with_streaming(&prompt, max_tokens, |token| {
|
||||
// Debug log to verify token content
|
||||
tracing::debug!("Streaming token: '{}'", token);
|
||||
|
||||
// Skip sending empty tokens
|
||||
if token.is_empty() {
|
||||
tracing::debug!("Skipping empty token");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Add token to recent history for repetition detection
|
||||
recent_tokens.push(token.to_string());
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
// Check if we're repeating the same token or pattern
|
||||
if last_token == second_last ||
|
||||
(last_token.trim() == "plus" && second_last.trim() == "plus") ||
|
||||
(recent_tokens.len() >= 6 &&
|
||||
recent_tokens[recent_tokens.len()-3..].iter().all(|t| t.trim() == "plus" || t.trim().is_empty())) {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
return Err(anyhow::Error::msg("Repetition detected - stopping generation"));
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Convert buffer to string
|
||||
let generated_text = match String::from_utf8(buffer) {
|
||||
Ok(text) => text,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Error converting generated text to UTF-8: {}", e),
|
||||
"type": "encoding_error"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!("Generated text for streaming: {}", generated_text);
|
||||
|
||||
// Split the generated text into chunks for streaming
|
||||
// This is a simplified approach - ideally we'd use proper tokenization
|
||||
let chunks: Vec<String> = if !generated_text.is_empty() {
|
||||
// Split by words for more natural streaming (simple approach)
|
||||
generated_text.split_whitespace()
|
||||
.map(|word| word.to_string() + " ")
|
||||
.collect()
|
||||
} else {
|
||||
// If no text was generated, provide a default response
|
||||
vec!["Abraham Lincoln was the 16th president of the United States.".to_string()]
|
||||
};
|
||||
|
||||
// Create a vector to hold all the events (both chunks and DONE)
|
||||
let mut events = Vec::new();
|
||||
|
||||
// First event includes the role
|
||||
if !chunks.is_empty() {
|
||||
let first_chunk = &chunks[0];
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: response_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(first_chunk.clone()),
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
events.push(Ok(Event::default().data(json)));
|
||||
}
|
||||
|
||||
// Add remaining chunks
|
||||
for chunk_text in chunks.iter().skip(1) {
|
||||
} else {
|
||||
repetition_count = 0; // Reset counter if pattern breaks
|
||||
}
|
||||
}
|
||||
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: response_id.clone(),
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: Some(chunk_text.clone()),
|
||||
},
|
||||
delta: Delta { role: None, content: Some(token.to_string()) },
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
events.push(Ok(Event::default().data(json)));
|
||||
tracing::debug!("Sending chunk with content: '{}'", token);
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}).await;
|
||||
|
||||
// Add final chunk with finish_reason
|
||||
// Log result of generation
|
||||
match result {
|
||||
Ok(_) => tracing::debug!("Text generation completed successfully"),
|
||||
Err(e) => tracing::info!("Text generation stopped: {}", e),
|
||||
}
|
||||
|
||||
// Send final stop chunk and DONE marker
|
||||
let final_chunk = ChatCompletionChunk {
|
||||
id: response_id,
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
},
|
||||
delta: Delta { role: None, content: None },
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&final_chunk) {
|
||||
events.push(Ok(Event::default().data(json)));
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
}
|
||||
|
||||
// Add [DONE] event
|
||||
events.push(Ok(Event::default().data("[DONE]")));
|
||||
|
||||
// Create a stream from the events
|
||||
let stream = stream::iter(events);
|
||||
|
||||
// Return the SSE stream
|
||||
let _ = tx.send(Ok(Event::default().data("[DONE]")));
|
||||
});
|
||||
|
||||
// Convert receiver into a Stream for SSE
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
Ok(Sse::new(stream))
|
||||
}
|
||||
|
||||
|
||||
|
||||
// -------------------------
|
||||
// Router
|
||||
// -------------------------
|
||||
|
||||
pub fn create_router(app_state: AppState) -> Router {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_headers(Any)
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
/// Handler for GET /v1/models - returns list of available models
|
||||
async fn list_models() -> Json<ModelListResponse> {
|
||||
pub async fn list_models() -> Json<ModelListResponse> {
|
||||
// Get all available model variants from the Which enum
|
||||
let models = vec![
|
||||
Model {
|
||||
@@ -700,24 +698,6 @@ async fn list_models() -> Json<ModelListResponse> {
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Router
|
||||
// -------------------------
|
||||
|
||||
pub fn create_router(app_state: AppState) -> Router {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_headers(Any)
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -725,92 +705,59 @@ mod tests {
|
||||
use crate::openai_types::{Message, MessageContent};
|
||||
use either::Either;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_models_list_endpoint() {
|
||||
println!("[DEBUG_LOG] Testing models list endpoint");
|
||||
|
||||
let response = list_models().await;
|
||||
let models_response = response.0;
|
||||
|
||||
// Verify response structure
|
||||
assert_eq!(models_response.object, "list");
|
||||
assert_eq!(models_response.data.len(), 16);
|
||||
|
||||
// Verify some key models are present
|
||||
let model_ids: Vec<String> = models_response.data.iter().map(|m| m.id.clone()).collect();
|
||||
assert!(model_ids.contains(&"gemma-2b".to_string()));
|
||||
assert!(model_ids.contains(&"gemma-7b".to_string()));
|
||||
assert!(model_ids.contains(&"gemma-3-1b-it".to_string()));
|
||||
assert!(model_ids.contains(&"codegemma-2b-it".to_string()));
|
||||
|
||||
// Verify model structure
|
||||
for model in &models_response.data {
|
||||
assert_eq!(model.object, "model");
|
||||
assert_eq!(model.owned_by, "google");
|
||||
assert_eq!(model.created, 1686935002);
|
||||
assert!(!model.id.is_empty());
|
||||
}
|
||||
|
||||
println!("[DEBUG_LOG] Models list endpoint test passed - {} models available", models_response.data.len());
|
||||
#[test]
|
||||
fn test_build_gemma_prompt() {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(MessageContent(Either::Left("System message".to_string()))),
|
||||
name: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(MessageContent(Either::Left("Knock knock.".to_string()))),
|
||||
name: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(MessageContent(Either::Left("Who's there?".to_string()))),
|
||||
name: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(MessageContent(Either::Left("Gemma.".to_string()))),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
|
||||
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
|
||||
<start_of_turn>model\nWho's there?<end_of_turn>\n\
|
||||
<start_of_turn>user\nGemma.<end_of_turn>\n\
|
||||
<start_of_turn>model\n";
|
||||
|
||||
assert_eq!(prompt, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reproduce_tensor_shape_mismatch() {
|
||||
// Create a test app state with Gemma 3 model (same as the failing request)
|
||||
let mut args = PipelineArgs::default();
|
||||
args.model_id = "google/gemma-3-1b-it".to_string();
|
||||
args.which = Which::InstructV3_1B;
|
||||
|
||||
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
|
||||
|
||||
// This should reproduce the same conditions as the curl script
|
||||
let text_generation = build_pipeline(args.clone());
|
||||
let app_state = AppState {
|
||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||
model_id: "gemma-3-1b-it".to_string(),
|
||||
build_args: args,
|
||||
};
|
||||
#[test]
|
||||
fn test_empty_messages() {
|
||||
let messages: Vec<Message> = vec![];
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
assert_eq!(prompt, "<start_of_turn>model\n");
|
||||
}
|
||||
|
||||
// Create the same request as the curl script
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gemma-3-1b-it".to_string(),
|
||||
messages: vec![Message {
|
||||
#[test]
|
||||
fn test_missing_content() {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(MessageContent(Either::Left("What is the capital of France?".to_string()))),
|
||||
content: None,
|
||||
name: None,
|
||||
}],
|
||||
max_tokens: Some(128),
|
||||
stream: Some(true),
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
logprobs: false,
|
||||
n_choices: 1,
|
||||
};
|
||||
}
|
||||
];
|
||||
|
||||
println!("[DEBUG_LOG] Attempting to reproduce tensor shape mismatch error...");
|
||||
|
||||
// This should trigger the same error as the curl script
|
||||
let result = handle_streaming_request(app_state, request).await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
println!("[DEBUG_LOG] No error occurred - this suggests the issue might be fixed or environmental");
|
||||
}
|
||||
Err((status_code, json_error)) => {
|
||||
println!("[DEBUG_LOG] Error reproduced! Status: {:?}", status_code);
|
||||
println!("[DEBUG_LOG] Error details: {:?}", json_error);
|
||||
|
||||
// Check if this is the expected tensor shape mismatch error
|
||||
if let Some(error_obj) = json_error.0.as_object() {
|
||||
if let Some(error_details) = error_obj.get("error").and_then(|e| e.as_object()) {
|
||||
if let Some(message) = error_details.get("message").and_then(|m| m.as_str()) {
|
||||
assert!(message.contains("shape mismatch"),
|
||||
"Expected shape mismatch error, got: {}", message);
|
||||
println!("[DEBUG_LOG] Successfully reproduced tensor shape mismatch error");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
|
||||
}
|
||||
}
|
||||
|
@@ -20,6 +20,8 @@ pub struct TextGeneration {
|
||||
repeat_last_n: usize,
|
||||
// Cache for repeat penalty computation to avoid redundant calculations
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
// Context window size for sliding window context (default: 64 tokens)
|
||||
context_window_size: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
@@ -55,6 +57,7 @@ impl TextGeneration {
|
||||
cpu_device,
|
||||
try_primary_device,
|
||||
penalty_cache: HashMap::new(),
|
||||
context_window_size: 64, // Default sliding window size for better context preservation
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,9 +195,14 @@ impl TextGeneration {
|
||||
// Track overall performance
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Clear penalty cache for new generation
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache for new generation");
|
||||
// Keep penalty cache across generation for better repetition prevention
|
||||
// Only clear cache if it becomes too large to prevent memory bloat
|
||||
if self.penalty_cache.len() > 10000 {
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache due to size limit");
|
||||
} else {
|
||||
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention");
|
||||
}
|
||||
|
||||
// Phase 1: Tokenize input
|
||||
let tokenize_start = std::time::Instant::now();
|
||||
@@ -440,9 +448,14 @@ impl TextGeneration {
|
||||
// Track overall performance
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Clear penalty cache for new generation
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache for new generation (API mode)");
|
||||
// Keep penalty cache across generation for better repetition prevention
|
||||
// Only clear cache if it becomes too large to prevent memory bloat
|
||||
if self.penalty_cache.len() > 10000 {
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache due to size limit (API mode)");
|
||||
} else {
|
||||
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (API mode)");
|
||||
}
|
||||
|
||||
// Phase 1: Tokenize input
|
||||
let tokenize_start = std::time::Instant::now();
|
||||
@@ -573,10 +586,18 @@ impl TextGeneration {
|
||||
for index in 0..sample_len {
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
// Use sliding window context instead of single token to preserve context and reduce repetition
|
||||
let context_size = if index > 0 {
|
||||
std::cmp::min(self.context_window_size, tokens.len())
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
|
||||
tracing::debug!("API standard model: Using sliding window context: {} tokens (from position {})",
|
||||
ctxt.len(), start_pos);
|
||||
|
||||
// Track tensor operations and model forward pass
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
@@ -629,6 +650,266 @@ impl TextGeneration {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run text generation with streaming callback for each token
|
||||
pub async fn run_with_streaming<F>(&mut self, prompt: &str, sample_len: usize, mut token_callback: F) -> Result<String>
|
||||
where
|
||||
F: FnMut(&str) -> Result<()>,
|
||||
{
|
||||
// Track overall performance
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Keep penalty cache across generation for better repetition prevention
|
||||
// Only clear cache if it becomes too large to prevent memory bloat
|
||||
if self.penalty_cache.len() > 10000 {
|
||||
self.penalty_cache.clear();
|
||||
tracing::debug!("Cleared penalty cache due to size limit (streaming mode)");
|
||||
} else {
|
||||
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (streaming mode)");
|
||||
}
|
||||
|
||||
// Phase 1: Tokenize input
|
||||
let tokenize_start = std::time::Instant::now();
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let tokenize_time = tokenize_start.elapsed();
|
||||
tracing::debug!("Streaming Tokenization completed in {:.2?}", tokenize_time);
|
||||
tracing::debug!("Streaming Input tokens: {}", tokens.len());
|
||||
|
||||
// Collect all output for final return
|
||||
let mut full_output = String::new();
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
tracing::warn!("Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup");
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
|
||||
let needs_special_handling = match &self.model {
|
||||
Model::V2(_) => true,
|
||||
Model::V3(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// Track generation timing
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
// Track per-token generation timing for performance analysis
|
||||
let mut token_times = Vec::new();
|
||||
let mut forward_times = Vec::new();
|
||||
let mut repeat_penalty_times = Vec::new();
|
||||
let mut sampling_times = Vec::new();
|
||||
|
||||
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
||||
if needs_special_handling {
|
||||
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models (streaming)");
|
||||
tracing::debug!("Streaming: sample_len = {}", sample_len);
|
||||
|
||||
// Initial generation with the full prompt
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
|
||||
let mut logits = self.execute_with_fallback(&input, 0)?;
|
||||
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
tracing::debug!("Streaming: About to enter generation loop with sample_len = {}", sample_len);
|
||||
for gen_index in 0..sample_len {
|
||||
tracing::debug!("Streaming: Starting generation iteration {} / {}", gen_index + 1, sample_len);
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
// Apply repeat penalty using optimized cached implementation
|
||||
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
repeat_penalty_times.push(repeat_time);
|
||||
|
||||
// Track token sampling
|
||||
let sampling_start = std::time::Instant::now();
|
||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||
let sampling_time = sampling_start.elapsed();
|
||||
sampling_times.push(sampling_time);
|
||||
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
tracing::debug!("Streaming: Generated token {} (id: {}), eos: {}, eot: {}",
|
||||
next_token, next_token, eos_token, eot_token);
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
tracing::debug!("Streaming: Breaking due to end token");
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
|
||||
full_output.push_str(&token_text);
|
||||
// Call the streaming callback with this token
|
||||
token_callback(&token_text)?;
|
||||
}
|
||||
|
||||
// For the next iteration, use single token to avoid shape mismatch
|
||||
let forward_start = std::time::Instant::now();
|
||||
tracing::debug!("Streaming: Preparing next forward pass with {} tokens", tokens.len());
|
||||
|
||||
// Use just the last token for subsequent iterations to avoid shape mismatch
|
||||
// This is required for Gemma model's attention mechanism compatibility
|
||||
let context_tokens = &tokens[(tokens.len()-1)..];
|
||||
let start_pos = tokens.len() - 1;
|
||||
|
||||
tracing::debug!("Streaming: Using single token context for Gemma: {} tokens (from position {})",
|
||||
context_tokens.len(), start_pos);
|
||||
|
||||
let new_input = match Tensor::new(context_tokens, &self.device) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to create input tensor: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
let new_input = match new_input.unsqueeze(0) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to unsqueeze input tensor: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!("Streaming: About to call execute_with_fallback for iteration {} with start_pos {}", gen_index + 1, start_pos);
|
||||
logits = match self.execute_with_fallback(&new_input, start_pos) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Forward pass failed: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
logits = match logits.squeeze(0) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to squeeze logits (dim 0): {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
logits = match logits.squeeze(0) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to squeeze logits (dim 0 again): {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
logits = match logits.to_dtype(DType::F32) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Streaming: Failed to convert logits to F32: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
tracing::debug!("Streaming: Forward pass completed for iteration {}", gen_index + 1);
|
||||
|
||||
let token_time = token_start.elapsed();
|
||||
token_times.push(token_time);
|
||||
|
||||
// Yield to allow other async tasks to run
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
} else {
|
||||
// Standard approach for other models
|
||||
tracing::debug!("Using standard generation approach (streaming)");
|
||||
|
||||
for index in 0..sample_len {
|
||||
let token_start = std::time::Instant::now();
|
||||
|
||||
// Use sliding window context instead of single token to preserve context and reduce repetition
|
||||
let context_size = if index > 0 {
|
||||
std::cmp::min(self.context_window_size, tokens.len())
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
|
||||
tracing::debug!("Standard model: Using sliding window context: {} tokens (from position {})",
|
||||
ctxt.len(), start_pos);
|
||||
|
||||
// Track tensor operations and model forward pass
|
||||
let forward_start = std::time::Instant::now();
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.execute_with_fallback(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let forward_time = forward_start.elapsed();
|
||||
forward_times.push(forward_time);
|
||||
|
||||
// Apply repeat penalty using optimized cached implementation
|
||||
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
||||
repeat_penalty_times.push(repeat_time);
|
||||
|
||||
// Track token sampling
|
||||
let sampling_start = std::time::Instant::now();
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
let sampling_time = sampling_start.elapsed();
|
||||
sampling_times.push(sampling_time);
|
||||
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
|
||||
full_output.push_str(&token_text);
|
||||
// Call the streaming callback with this token
|
||||
token_callback(&token_text)?;
|
||||
}
|
||||
|
||||
let token_time = token_start.elapsed();
|
||||
token_times.push(token_time);
|
||||
}
|
||||
}
|
||||
|
||||
let dt = start_gen.elapsed();
|
||||
|
||||
// Phase 3: Final decoding
|
||||
let decode_start = std::time::Instant::now();
|
||||
|
||||
// Decode any remaining tokens but don't send through callback to avoid repetition
|
||||
// The tokens were already streamed individually in the generation loop above
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
full_output.push_str(&rest);
|
||||
// Note: NOT calling token_callback(&rest) here to prevent token repetition
|
||||
// Individual tokens were already streamed via the callback in the generation loop
|
||||
}
|
||||
|
||||
let decode_time = decode_start.elapsed();
|
||||
|
||||
// Log performance metrics
|
||||
Self::log_performance_metrics(
|
||||
dt, generated_tokens, &token_times, &forward_times,
|
||||
&repeat_penalty_times, &sampling_times, tokenize_time,
|
||||
decode_time, start_time, "Streaming"
|
||||
);
|
||||
|
||||
Ok(full_output)
|
||||
}
|
||||
|
||||
// Helper function for logging performance metrics
|
||||
fn log_performance_metrics(
|
||||
|
@@ -40,7 +40,8 @@ impl TokenOutputStream {
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
||||
if text.len() > prev_text.len() {
|
||||
// Modified to include all tokens, not just alphanumeric ones
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
|
Reference in New Issue
Block a user