mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
- Change default server host to localhost for improved security.
- Increase default maximum tokens in CLI configuration to 256. - Refactor and reorganize CLI
This commit is contained in:
@@ -5,14 +5,13 @@ use axum::{
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use std::convert::Infallible;
|
||||
use candle_core::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use futures_util::stream::{self, Stream};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use std::convert::Infallible;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -46,40 +45,26 @@ impl Default for AppState {
|
||||
let text_generation = build_pipeline(args.clone());
|
||||
Self {
|
||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||
model_id: String::new(),
|
||||
model_id: args.model_id.clone(),
|
||||
build_args: args,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Pipeline configuration
|
||||
// -------------------------
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PipelineArgs {
|
||||
/// HF model repo id, e.g. "google/gemma-2b"
|
||||
pub model_id: String,
|
||||
|
||||
/// Which internal model family to instantiate
|
||||
pub which: Which,
|
||||
|
||||
/// Optional HF revision/branch/tag; None => "main"
|
||||
pub revision: Option<String>,
|
||||
|
||||
/// Optional explicit tokenizer path
|
||||
pub tokenizer_path: Option<PathBuf>,
|
||||
|
||||
/// Optional explicit config path
|
||||
pub config_path: Option<PathBuf>,
|
||||
|
||||
/// Optional explicit weight paths. If empty, they will be resolved from the hub.
|
||||
pub weight_paths: Vec<PathBuf>,
|
||||
|
||||
/// Runtime toggles
|
||||
pub use_flash_attn: bool,
|
||||
pub force_cpu: bool,
|
||||
|
||||
/// Sampling / decoding params
|
||||
pub seed: u64,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
@@ -98,39 +83,43 @@ impl Default for PipelineArgs {
|
||||
weight_paths: Vec::new(),
|
||||
use_flash_attn: false,
|
||||
force_cpu: false,
|
||||
seed: 0,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
repeat_penalty: 0.0,
|
||||
repeat_last_n: 0,
|
||||
seed: 299792458, // Speed of light in vacuum (m/s)
|
||||
temperature: Some(0.8), // Good balance between creativity and coherence
|
||||
top_p: Some(0.9), // Keep diverse but reasonable options
|
||||
repeat_penalty: 1.2, // Stronger penalty for repetition to prevent looping
|
||||
repeat_last_n: 64, // Consider last 64 tokens for repetition
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no owner/org is present, prefix with a sensible default (tweak as you like).
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
if model_id.contains('/') { model_id.to_string() } else { format!("google/{}", model_id) }
|
||||
if model_id.contains('/') {
|
||||
model_id.to_string()
|
||||
} else {
|
||||
format!("google/{}", model_id)
|
||||
}
|
||||
}
|
||||
|
||||
// Quick existence check, mapping 404 into a helpful message.
|
||||
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
|
||||
let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string()));
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id.to_string(),
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
));
|
||||
match repo.get("config.json") {
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => match e {
|
||||
ApiError::RequestError(resp) => {
|
||||
// For HF API, RequestError with 404 status is returned when repo doesn't exist
|
||||
let error_str = resp.to_string();
|
||||
if error_str.contains("404") {
|
||||
anyhow::bail!(
|
||||
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'. \
|
||||
Please provide a fully-qualified repo id like 'google/gemma-2b-it'."
|
||||
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'."
|
||||
)
|
||||
}
|
||||
Err(anyhow::Error::new(ApiError::RequestError(resp)))
|
||||
}
|
||||
other => Err(anyhow::Error::new(other)),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,18 +140,13 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
let api = Api::new().unwrap();
|
||||
let revision = args.revision.as_deref().unwrap_or("main");
|
||||
|
||||
// Check if model_id is empty before normalizing it
|
||||
println!("Checking model_id: '{}'", args.model_id);
|
||||
|
||||
println!("Trimmed model_id length: {}", args.model_id.trim().len());
|
||||
if args.model_id.trim().is_empty() {
|
||||
panic!("No model ID specified. Please provide a valid model ID (e.g., 'gemma-2b-it' or 'google/gemma-2b-it').");
|
||||
panic!("No model ID specified.");
|
||||
}
|
||||
args.model_id = normalize_model_id(&args.model_id);
|
||||
|
||||
// Validate early (nice error if the repo/revision is wrong).
|
||||
match ensure_repo_exists(&api, &args.model_id, revision) {
|
||||
Ok(_) => {},
|
||||
Ok(_) => {}
|
||||
Err(e) => panic!("{}", e),
|
||||
};
|
||||
|
||||
@@ -172,105 +156,82 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
revision.to_string(),
|
||||
));
|
||||
|
||||
// Resolve files (prefer explicit paths; fallback to hub)
|
||||
let tokenizer_path = args
|
||||
.tokenizer_path
|
||||
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
|
||||
|
||||
let config_path = args
|
||||
.config_path
|
||||
.unwrap_or_else(|| repo.get("config.json").unwrap());
|
||||
|
||||
// Only use auto-detection if no specific model type was provided
|
||||
// This ensures that explicitly specified model types are respected
|
||||
if !matches!(args.which,
|
||||
Which::Base2B | Which::Base7B |
|
||||
Which::Instruct2B | Which::Instruct7B |
|
||||
Which::InstructV1_1_2B | Which::InstructV1_1_7B |
|
||||
Which::CodeBase2B | Which::CodeBase7B |
|
||||
Which::CodeInstruct2B | Which::CodeInstruct7B |
|
||||
Which::BaseV2_2B | Which::InstructV2_2B |
|
||||
Which::BaseV2_9B | Which::InstructV2_9B |
|
||||
Which::BaseV3_1B | Which::InstructV3_1B) {
|
||||
|
||||
// If model_id is a known value, map it directly
|
||||
if !matches!(
|
||||
args.which,
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B
|
||||
| Which::InstructV3_1B
|
||||
) {
|
||||
if args.model_id.contains("gemma-2-2b-it") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
println!("Setting model type to InstructV2_2B based on model_id: {}", args.model_id);
|
||||
} else if args.model_id.contains("gemma-3-1b-it") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
println!("Setting model type to InstructV3_1B based on model_id: {}", args.model_id);
|
||||
} else {
|
||||
// Fallback to auto-detection from config.json
|
||||
if let Ok(file) = std::fs::File::open(config_path.clone()) {
|
||||
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
|
||||
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
|
||||
println!("Auto-detecting model type from config.json: {}", model_type);
|
||||
// Map HF model_type to an internal Which variant
|
||||
if model_type.contains("gemma3") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
println!("Setting model type to InstructV3_1B based on config");
|
||||
} else if model_type.contains("gemma2") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
println!("Setting model type to InstructV2_2B based on config");
|
||||
} else {
|
||||
// default to Gemma v1
|
||||
args.which = Which::Instruct2B;
|
||||
println!("Setting model type to Instruct2B (v1) based on config");
|
||||
}
|
||||
} else if let Ok(file) = std::fs::File::open(config_path.clone()) {
|
||||
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
|
||||
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
|
||||
if model_type.contains("gemma3") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
} else if model_type.contains("gemma2") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
} else {
|
||||
args.which = Which::Instruct2B;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("Using explicitly specified model type: {:?}", args.which);
|
||||
}
|
||||
|
||||
// Resolve weight files: try a single-file first, then fall back to sharded index
|
||||
let weight_paths = if !args.weight_paths.is_empty() {
|
||||
args.weight_paths
|
||||
} else {
|
||||
match repo.get("model.safetensors") {
|
||||
Ok(single) => vec![single],
|
||||
Err(_) => {
|
||||
match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
|
||||
Ok(paths) => paths,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Unable to locate model weights for '{}'. Tried 'model.safetensors' and 'model.safetensors.index.json'. Underlying error: {}",
|
||||
args.model_id, e
|
||||
);
|
||||
}
|
||||
Err(_) => match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
|
||||
Ok(paths) => paths,
|
||||
Err(e) => {
|
||||
panic!("Unable to locate model weights: {}", e);
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path)
|
||||
.map_err(anyhow::Error::msg)
|
||||
.unwrap();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
||||
|
||||
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
|
||||
|
||||
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
|
||||
let is_v3_model = args.which.is_v3_model();
|
||||
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.force_cpu;
|
||||
|
||||
// Use CPU for V3 models on Metal due to missing implementations
|
||||
let is_metal = !initial_device.is_cpu()
|
||||
&& candle_core::utils::metal_is_available()
|
||||
&& !args.force_cpu;
|
||||
|
||||
let device = if is_v3_model && is_metal {
|
||||
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
|
||||
candle_core::Device::Cpu
|
||||
} else {
|
||||
initial_device
|
||||
};
|
||||
|
||||
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
|
||||
|
||||
// Keep original device + dtype
|
||||
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
|
||||
|
||||
let model = match args.which {
|
||||
@@ -285,23 +246,18 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
GemmaModel::V1(model)
|
||||
GemmaModel::V1(Model1::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
GemmaModel::V2(model)
|
||||
GemmaModel::V2(Model2::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
||||
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
GemmaModel::V3(model)
|
||||
GemmaModel::V3(Model3::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -314,6 +270,43 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
)
|
||||
}
|
||||
|
||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
let mut system_prompt: Option<String> = None;
|
||||
|
||||
for message in messages {
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(),
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
match message.role.as_str() {
|
||||
"system" => system_prompt = Some(content),
|
||||
"user" => {
|
||||
prompt.push_str("<start_of_turn>user\n");
|
||||
if let Some(sys_prompt) = system_prompt.take() {
|
||||
prompt.push_str(&sys_prompt);
|
||||
prompt.push_str("\n\n");
|
||||
}
|
||||
prompt.push_str(&content);
|
||||
prompt.push_str("<end_of_turn>\n");
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt.push_str(&content);
|
||||
prompt.push_str("<end_of_turn>\n");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// OpenAI-compatible handler
|
||||
// -------------------------
|
||||
@@ -322,72 +315,68 @@ pub async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
// If streaming was requested, this function shouldn't be called
|
||||
// A separate route handles streaming requests
|
||||
if !request.stream.unwrap_or(false) {
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response())
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
|
||||
}
|
||||
|
||||
Ok(chat_completions_stream(state, request).await.into_response())
|
||||
}
|
||||
|
||||
pub async fn chat_completions_non_streaming_proxy(state: AppState, request: ChatCompletionRequest) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
// Non-streaming response - original implementation
|
||||
let mut prompt = String::new();
|
||||
pub async fn chat_completions_non_streaming_proxy(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
let prompt = build_gemma_prompt(&request.messages);
|
||||
|
||||
// Convert messages to a prompt string
|
||||
for message in &request.messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
match role.as_str() {
|
||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||
}
|
||||
}
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
let model_id = state.model_id.clone();
|
||||
|
||||
// Generate
|
||||
let mut output = Vec::new();
|
||||
{
|
||||
// Recreate TextGeneration instance to ensure completely fresh state
|
||||
// This prevents KV cache persistence that causes tensor shape mismatches
|
||||
let fresh_text_gen = build_pipeline(state.build_args.clone());
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
*text_gen = fresh_text_gen;
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
||||
|
||||
if let Err(e) = result {
|
||||
// Enforce model selection behavior: reject if a different model is requested
|
||||
let configured_model = state.build_args.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
if normalized_requested != configured_model {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Error generating text: {}", e),
|
||||
"type": "text_generation_error"
|
||||
"message": format!(
|
||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
||||
requested_model, configured_model
|
||||
),
|
||||
"type": "model_mismatch"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(text) = String::from_utf8(buffer) {
|
||||
output.push(text);
|
||||
let model_id = state.model_id.clone();
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
{
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
// Reset per-request state without rebuilding the whole pipeline
|
||||
text_gen.reset_state();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error generating text: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let completion = output.join("");
|
||||
let completion = match String::from_utf8(buffer) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("UTF-8 conversion error: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
|
||||
@@ -407,7 +396,6 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
// still rough estimates
|
||||
prompt_tokens: prompt.len() / 4,
|
||||
completion_tokens: completion.len() / 4,
|
||||
total_tokens: (prompt.len() + completion.len()) / 4,
|
||||
@@ -415,185 +403,195 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
|
||||
};
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Streaming implementation
|
||||
// -------------------------
|
||||
|
||||
pub async fn chat_completions_stream(
|
||||
state: AppState,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Call the handler function
|
||||
handle_streaming_request(state, chat_completion_request).await
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||
handle_streaming_request(state, request).await
|
||||
}
|
||||
|
||||
/// Handle streaming requests with Server-Sent Events (SSE)
|
||||
async fn handle_streaming_request(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Generate a unique ID for this completion
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||
// Validate requested model vs configured model
|
||||
let configured_model = state.build_args.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
if normalized_requested != configured_model {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!(
|
||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
||||
requested_model, configured_model
|
||||
),
|
||||
"type": "model_mismatch"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a unique ID and metadata
|
||||
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
let created = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let model_id = state.model_id.clone();
|
||||
|
||||
// Convert messages to a prompt string (same as non-streaming)
|
||||
let mut prompt = String::new();
|
||||
for message in &request.messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
match role.as_str() {
|
||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||
}
|
||||
// Build prompt
|
||||
let prompt = build_gemma_prompt(&request.messages);
|
||||
tracing::debug!("Formatted prompt: {}", prompt);
|
||||
|
||||
// Channel for streaming SSE events
|
||||
let (tx, rx) = mpsc::unbounded_channel::<Result<Event, Infallible>>();
|
||||
|
||||
// Send initial role event
|
||||
let initial_chunk = ChatCompletionChunk {
|
||||
id: response_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: Some("assistant".to_string()), content: None },
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&initial_chunk) {
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
// Generate text using existing buffer-based approach
|
||||
let mut buffer = Vec::new();
|
||||
{
|
||||
// Recreate TextGeneration instance to ensure completely fresh state
|
||||
// This prevents KV cache persistence that causes tensor shape mismatches
|
||||
let fresh_text_gen = build_pipeline(state.build_args.clone());
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
*text_gen = fresh_text_gen;
|
||||
|
||||
|
||||
// Spawn generation task that streams tokens as they are generated
|
||||
let state_clone = state.clone();
|
||||
let response_id_clone = response_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let mut text_gen = state_clone.text_generation.lock().await;
|
||||
text_gen.reset_state();
|
||||
|
||||
// Stream tokens via callback with repetition detection
|
||||
let mut recent_tokens = Vec::new();
|
||||
let mut repetition_count = 0;
|
||||
const MAX_REPETITION_COUNT: usize = 5; // Stop after 5 consecutive repetitions
|
||||
const REPETITION_WINDOW: usize = 8; // Look at last 8 tokens for patterns
|
||||
|
||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Error generating text: {}", e),
|
||||
"type": "text_generation_error"
|
||||
let result = text_gen.run_with_streaming(&prompt, max_tokens, |token| {
|
||||
// Debug log to verify token content
|
||||
tracing::debug!("Streaming token: '{}'", token);
|
||||
|
||||
// Skip sending empty tokens
|
||||
if token.is_empty() {
|
||||
tracing::debug!("Skipping empty token");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Add token to recent history for repetition detection
|
||||
recent_tokens.push(token.to_string());
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
// Check if we're repeating the same token or pattern
|
||||
if last_token == second_last ||
|
||||
(last_token.trim() == "plus" && second_last.trim() == "plus") ||
|
||||
(recent_tokens.len() >= 6 &&
|
||||
recent_tokens[recent_tokens.len()-3..].iter().all(|t| t.trim() == "plus" || t.trim().is_empty())) {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
return Err(anyhow::Error::msg("Repetition detected - stopping generation"));
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Convert buffer to string
|
||||
let generated_text = match String::from_utf8(buffer) {
|
||||
Ok(text) => text,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Error converting generated text to UTF-8: {}", e),
|
||||
"type": "encoding_error"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!("Generated text for streaming: {}", generated_text);
|
||||
|
||||
// Split the generated text into chunks for streaming
|
||||
// This is a simplified approach - ideally we'd use proper tokenization
|
||||
let chunks: Vec<String> = if !generated_text.is_empty() {
|
||||
// Split by words for more natural streaming (simple approach)
|
||||
generated_text.split_whitespace()
|
||||
.map(|word| word.to_string() + " ")
|
||||
.collect()
|
||||
} else {
|
||||
// If no text was generated, provide a default response
|
||||
vec!["Abraham Lincoln was the 16th president of the United States.".to_string()]
|
||||
};
|
||||
|
||||
// Create a vector to hold all the events (both chunks and DONE)
|
||||
let mut events = Vec::new();
|
||||
|
||||
// First event includes the role
|
||||
if !chunks.is_empty() {
|
||||
let first_chunk = &chunks[0];
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: response_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(first_chunk.clone()),
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
events.push(Ok(Event::default().data(json)));
|
||||
}
|
||||
|
||||
// Add remaining chunks
|
||||
for chunk_text in chunks.iter().skip(1) {
|
||||
} else {
|
||||
repetition_count = 0; // Reset counter if pattern breaks
|
||||
}
|
||||
}
|
||||
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: response_id.clone(),
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: Some(chunk_text.clone()),
|
||||
},
|
||||
delta: Delta { role: None, content: Some(token.to_string()) },
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
events.push(Ok(Event::default().data(json)));
|
||||
tracing::debug!("Sending chunk with content: '{}'", token);
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}).await;
|
||||
|
||||
// Add final chunk with finish_reason
|
||||
// Log result of generation
|
||||
match result {
|
||||
Ok(_) => tracing::debug!("Text generation completed successfully"),
|
||||
Err(e) => tracing::info!("Text generation stopped: {}", e),
|
||||
}
|
||||
|
||||
// Send final stop chunk and DONE marker
|
||||
let final_chunk = ChatCompletionChunk {
|
||||
id: response_id,
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
},
|
||||
delta: Delta { role: None, content: None },
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&final_chunk) {
|
||||
events.push(Ok(Event::default().data(json)));
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
}
|
||||
|
||||
// Add [DONE] event
|
||||
events.push(Ok(Event::default().data("[DONE]")));
|
||||
|
||||
// Create a stream from the events
|
||||
let stream = stream::iter(events);
|
||||
|
||||
// Return the SSE stream
|
||||
let _ = tx.send(Ok(Event::default().data("[DONE]")));
|
||||
});
|
||||
|
||||
// Convert receiver into a Stream for SSE
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
Ok(Sse::new(stream))
|
||||
}
|
||||
|
||||
|
||||
|
||||
// -------------------------
|
||||
// Router
|
||||
// -------------------------
|
||||
|
||||
pub fn create_router(app_state: AppState) -> Router {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_headers(Any)
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
/// Handler for GET /v1/models - returns list of available models
|
||||
async fn list_models() -> Json<ModelListResponse> {
|
||||
pub async fn list_models() -> Json<ModelListResponse> {
|
||||
// Get all available model variants from the Which enum
|
||||
let models = vec![
|
||||
Model {
|
||||
@@ -700,24 +698,6 @@ async fn list_models() -> Json<ModelListResponse> {
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Router
|
||||
// -------------------------
|
||||
|
||||
pub fn create_router(app_state: AppState) -> Router {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_headers(Any)
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -725,92 +705,59 @@ mod tests {
|
||||
use crate::openai_types::{Message, MessageContent};
|
||||
use either::Either;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_models_list_endpoint() {
|
||||
println!("[DEBUG_LOG] Testing models list endpoint");
|
||||
|
||||
let response = list_models().await;
|
||||
let models_response = response.0;
|
||||
|
||||
// Verify response structure
|
||||
assert_eq!(models_response.object, "list");
|
||||
assert_eq!(models_response.data.len(), 16);
|
||||
|
||||
// Verify some key models are present
|
||||
let model_ids: Vec<String> = models_response.data.iter().map(|m| m.id.clone()).collect();
|
||||
assert!(model_ids.contains(&"gemma-2b".to_string()));
|
||||
assert!(model_ids.contains(&"gemma-7b".to_string()));
|
||||
assert!(model_ids.contains(&"gemma-3-1b-it".to_string()));
|
||||
assert!(model_ids.contains(&"codegemma-2b-it".to_string()));
|
||||
|
||||
// Verify model structure
|
||||
for model in &models_response.data {
|
||||
assert_eq!(model.object, "model");
|
||||
assert_eq!(model.owned_by, "google");
|
||||
assert_eq!(model.created, 1686935002);
|
||||
assert!(!model.id.is_empty());
|
||||
}
|
||||
|
||||
println!("[DEBUG_LOG] Models list endpoint test passed - {} models available", models_response.data.len());
|
||||
#[test]
|
||||
fn test_build_gemma_prompt() {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(MessageContent(Either::Left("System message".to_string()))),
|
||||
name: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(MessageContent(Either::Left("Knock knock.".to_string()))),
|
||||
name: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(MessageContent(Either::Left("Who's there?".to_string()))),
|
||||
name: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(MessageContent(Either::Left("Gemma.".to_string()))),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
|
||||
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
|
||||
<start_of_turn>model\nWho's there?<end_of_turn>\n\
|
||||
<start_of_turn>user\nGemma.<end_of_turn>\n\
|
||||
<start_of_turn>model\n";
|
||||
|
||||
assert_eq!(prompt, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reproduce_tensor_shape_mismatch() {
|
||||
// Create a test app state with Gemma 3 model (same as the failing request)
|
||||
let mut args = PipelineArgs::default();
|
||||
args.model_id = "google/gemma-3-1b-it".to_string();
|
||||
args.which = Which::InstructV3_1B;
|
||||
|
||||
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
|
||||
|
||||
// This should reproduce the same conditions as the curl script
|
||||
let text_generation = build_pipeline(args.clone());
|
||||
let app_state = AppState {
|
||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||
model_id: "gemma-3-1b-it".to_string(),
|
||||
build_args: args,
|
||||
};
|
||||
#[test]
|
||||
fn test_empty_messages() {
|
||||
let messages: Vec<Message> = vec![];
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
assert_eq!(prompt, "<start_of_turn>model\n");
|
||||
}
|
||||
|
||||
// Create the same request as the curl script
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gemma-3-1b-it".to_string(),
|
||||
messages: vec![Message {
|
||||
#[test]
|
||||
fn test_missing_content() {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(MessageContent(Either::Left("What is the capital of France?".to_string()))),
|
||||
content: None,
|
||||
name: None,
|
||||
}],
|
||||
max_tokens: Some(128),
|
||||
stream: Some(true),
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
logprobs: false,
|
||||
n_choices: 1,
|
||||
};
|
||||
}
|
||||
];
|
||||
|
||||
println!("[DEBUG_LOG] Attempting to reproduce tensor shape mismatch error...");
|
||||
|
||||
// This should trigger the same error as the curl script
|
||||
let result = handle_streaming_request(app_state, request).await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
println!("[DEBUG_LOG] No error occurred - this suggests the issue might be fixed or environmental");
|
||||
}
|
||||
Err((status_code, json_error)) => {
|
||||
println!("[DEBUG_LOG] Error reproduced! Status: {:?}", status_code);
|
||||
println!("[DEBUG_LOG] Error details: {:?}", json_error);
|
||||
|
||||
// Check if this is the expected tensor shape mismatch error
|
||||
if let Some(error_obj) = json_error.0.as_object() {
|
||||
if let Some(error_details) = error_obj.get("error").and_then(|e| e.as_object()) {
|
||||
if let Some(message) = error_details.get("message").and_then(|m| m.as_str()) {
|
||||
assert!(message.contains("shape mismatch"),
|
||||
"Expected shape mismatch error, got: {}", message);
|
||||
println!("[DEBUG_LOG] Successfully reproduced tensor shape mismatch error");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user