- Change default server host to localhost for improved security.

- Increase default maximum tokens in CLI configuration to 256.
- Refactor and reorganize CLI
This commit is contained in:
geoffsee
2025-08-27 21:47:24 -04:00
parent 766d41af78
commit 719beb3791
20 changed files with 1703 additions and 490 deletions

View File

@@ -24,3 +24,12 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
rand = "0.8.5"
async-openai = "0.28.3"
once_cell = "1.19.0"
[package.metadata.kube]
image = "ghcr.io/geoffsee/embeddings-service:latest"
replicas = 1
port = 8080
resources.cpu = "500m"
resources.memory = "256Mi"
#ingress.host = "my-service.example.com"
#env = { RUST_LOG = "info", DATABASE_URL = "postgres://..." }

View File

@@ -10,7 +10,7 @@ use std::env;
use tower_http::trace::TraceLayer;
use tracing;
const DEFAULT_SERVER_HOST: &str = "0.0.0.0";
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080";
async fn root() -> &'static str {

View File

@@ -44,6 +44,7 @@ axum = { version = "0.8.4", features = ["json"] }
tower = "0.5.2"
tower-http = { version = "0.6.6", features = ["cors"] }
tokio = { version = "1.43.0", features = ["full"] }
tokio-stream = { version = "0.1.16", features = ["sync"] }
either = { version = "1.9.0", features = ["serde"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
@@ -80,4 +81,13 @@ tokio = "1.43.0"
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
bindgen_cuda = { version = "0.1.1", optional = true }
bindgen_cuda = { version = "0.1.1", optional = true }
[package.metadata.kube]
image = "ghcr.io/geoffsee/inference-service:latest"
replicas = 1
port = 8080
resources.cpu = "500m"
resources.memory = "256Mi"
#ingress.host = "my-service.example.com"
#env = { RUST_LOG = "info", DATABASE_URL = "postgres://..." }

View File

@@ -17,7 +17,7 @@ use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
/// Server configuration constants
pub const DEFAULT_SERVER_HOST: &str = "0.0.0.0";
pub const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
pub const DEFAULT_SERVER_PORT: &str = "8080";
/// Get server configuration from environment variables with defaults

View File

@@ -5,14 +5,13 @@ use axum::{
routing::{get, post},
Json, Router,
};
use futures_util::stream::{self, Stream};
use std::convert::Infallible;
use candle_core::DType;
use candle_nn::VarBuilder;
use futures_util::stream::{self, Stream};
use tokio_stream::wrappers::UnboundedReceiverStream;
use std::convert::Infallible;
use std::{path::PathBuf, sync::Arc};
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time;
use tokio::sync::{Mutex, mpsc};
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
@@ -46,40 +45,26 @@ impl Default for AppState {
let text_generation = build_pipeline(args.clone());
Self {
text_generation: Arc::new(Mutex::new(text_generation)),
model_id: String::new(),
model_id: args.model_id.clone(),
build_args: args,
}
}
}
// -------------------------
// Pipeline configuration
// -------------------------
#[derive(Debug, Clone)]
pub struct PipelineArgs {
/// HF model repo id, e.g. "google/gemma-2b"
pub model_id: String,
/// Which internal model family to instantiate
pub which: Which,
/// Optional HF revision/branch/tag; None => "main"
pub revision: Option<String>,
/// Optional explicit tokenizer path
pub tokenizer_path: Option<PathBuf>,
/// Optional explicit config path
pub config_path: Option<PathBuf>,
/// Optional explicit weight paths. If empty, they will be resolved from the hub.
pub weight_paths: Vec<PathBuf>,
/// Runtime toggles
pub use_flash_attn: bool,
pub force_cpu: bool,
/// Sampling / decoding params
pub seed: u64,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
@@ -98,39 +83,43 @@ impl Default for PipelineArgs {
weight_paths: Vec::new(),
use_flash_attn: false,
force_cpu: false,
seed: 0,
temperature: None,
top_p: None,
repeat_penalty: 0.0,
repeat_last_n: 0,
seed: 299792458, // Speed of light in vacuum (m/s)
temperature: Some(0.8), // Good balance between creativity and coherence
top_p: Some(0.9), // Keep diverse but reasonable options
repeat_penalty: 1.2, // Stronger penalty for repetition to prevent looping
repeat_last_n: 64, // Consider last 64 tokens for repetition
}
}
}
// If no owner/org is present, prefix with a sensible default (tweak as you like).
fn normalize_model_id(model_id: &str) -> String {
if model_id.contains('/') { model_id.to_string() } else { format!("google/{}", model_id) }
if model_id.contains('/') {
model_id.to_string()
} else {
format!("google/{}", model_id)
}
}
// Quick existence check, mapping 404 into a helpful message.
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string()));
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
match repo.get("config.json") {
Ok(_) => Ok(()),
Err(e) => match e {
ApiError::RequestError(resp) => {
// For HF API, RequestError with 404 status is returned when repo doesn't exist
let error_str = resp.to_string();
if error_str.contains("404") {
anyhow::bail!(
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'. \
Please provide a fully-qualified repo id like 'google/gemma-2b-it'."
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'."
)
}
Err(anyhow::Error::new(ApiError::RequestError(resp)))
}
other => Err(anyhow::Error::new(other)),
}
},
}
}
@@ -151,18 +140,13 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
let api = Api::new().unwrap();
let revision = args.revision.as_deref().unwrap_or("main");
// Check if model_id is empty before normalizing it
println!("Checking model_id: '{}'", args.model_id);
println!("Trimmed model_id length: {}", args.model_id.trim().len());
if args.model_id.trim().is_empty() {
panic!("No model ID specified. Please provide a valid model ID (e.g., 'gemma-2b-it' or 'google/gemma-2b-it').");
panic!("No model ID specified.");
}
args.model_id = normalize_model_id(&args.model_id);
// Validate early (nice error if the repo/revision is wrong).
match ensure_repo_exists(&api, &args.model_id, revision) {
Ok(_) => {},
Ok(_) => {}
Err(e) => panic!("{}", e),
};
@@ -172,105 +156,82 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
revision.to_string(),
));
// Resolve files (prefer explicit paths; fallback to hub)
let tokenizer_path = args
.tokenizer_path
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
let config_path = args
.config_path
.unwrap_or_else(|| repo.get("config.json").unwrap());
// Only use auto-detection if no specific model type was provided
// This ensures that explicitly specified model types are respected
if !matches!(args.which,
Which::Base2B | Which::Base7B |
Which::Instruct2B | Which::Instruct7B |
Which::InstructV1_1_2B | Which::InstructV1_1_7B |
Which::CodeBase2B | Which::CodeBase7B |
Which::CodeInstruct2B | Which::CodeInstruct7B |
Which::BaseV2_2B | Which::InstructV2_2B |
Which::BaseV2_9B | Which::InstructV2_9B |
Which::BaseV3_1B | Which::InstructV3_1B) {
// If model_id is a known value, map it directly
if !matches!(
args.which,
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B
| Which::BaseV2_2B
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B
| Which::InstructV3_1B
) {
if args.model_id.contains("gemma-2-2b-it") {
args.which = Which::InstructV2_2B;
println!("Setting model type to InstructV2_2B based on model_id: {}", args.model_id);
} else if args.model_id.contains("gemma-3-1b-it") {
args.which = Which::InstructV3_1B;
println!("Setting model type to InstructV3_1B based on model_id: {}", args.model_id);
} else {
// Fallback to auto-detection from config.json
if let Ok(file) = std::fs::File::open(config_path.clone()) {
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
println!("Auto-detecting model type from config.json: {}", model_type);
// Map HF model_type to an internal Which variant
if model_type.contains("gemma3") {
args.which = Which::InstructV3_1B;
println!("Setting model type to InstructV3_1B based on config");
} else if model_type.contains("gemma2") {
args.which = Which::InstructV2_2B;
println!("Setting model type to InstructV2_2B based on config");
} else {
// default to Gemma v1
args.which = Which::Instruct2B;
println!("Setting model type to Instruct2B (v1) based on config");
}
} else if let Ok(file) = std::fs::File::open(config_path.clone()) {
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
if model_type.contains("gemma3") {
args.which = Which::InstructV3_1B;
} else if model_type.contains("gemma2") {
args.which = Which::InstructV2_2B;
} else {
args.which = Which::Instruct2B;
}
}
}
}
} else {
println!("Using explicitly specified model type: {:?}", args.which);
}
// Resolve weight files: try a single-file first, then fall back to sharded index
let weight_paths = if !args.weight_paths.is_empty() {
args.weight_paths
} else {
match repo.get("model.safetensors") {
Ok(single) => vec![single],
Err(_) => {
match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
Ok(paths) => paths,
Err(e) => {
panic!(
"Unable to locate model weights for '{}'. Tried 'model.safetensors' and 'model.safetensors.index.json'. Underlying error: {}",
args.model_id, e
);
}
Err(_) => match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
Ok(paths) => paths,
Err(e) => {
panic!("Unable to locate model weights: {}", e);
}
}
},
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(anyhow::Error::msg)
.unwrap();
let start = std::time::Instant::now();
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
let is_v3_model = args.which.is_v3_model();
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.force_cpu;
// Use CPU for V3 models on Metal due to missing implementations
let is_metal = !initial_device.is_cpu()
&& candle_core::utils::metal_is_available()
&& !args.force_cpu;
let device = if is_v3_model && is_metal {
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
candle_core::Device::Cpu
} else {
initial_device
};
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
// Keep original device + dtype
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
let model = match args.which {
@@ -285,23 +246,18 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
| Which::CodeInstruct2B
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
GemmaModel::V1(model)
GemmaModel::V1(Model1::new(args.use_flash_attn, &config, vb).unwrap())
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
GemmaModel::V2(model)
GemmaModel::V2(Model2::new(args.use_flash_attn, &config, vb).unwrap())
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
GemmaModel::V3(model)
GemmaModel::V3(Model3::new(args.use_flash_attn, &config, vb).unwrap())
}
};
println!("loaded the model in {:?}", start.elapsed());
TextGeneration::new(
model,
tokenizer,
@@ -314,6 +270,43 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
)
}
fn build_gemma_prompt(messages: &[Message]) -> String {
let mut prompt = String::new();
let mut system_prompt: Option<String> = None;
for message in messages {
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(),
},
None => "".to_string(),
};
match message.role.as_str() {
"system" => system_prompt = Some(content),
"user" => {
prompt.push_str("<start_of_turn>user\n");
if let Some(sys_prompt) = system_prompt.take() {
prompt.push_str(&sys_prompt);
prompt.push_str("\n\n");
}
prompt.push_str(&content);
prompt.push_str("<end_of_turn>\n");
}
"assistant" => {
prompt.push_str("<start_of_turn>model\n");
prompt.push_str(&content);
prompt.push_str("<end_of_turn>\n");
}
_ => {}
}
}
prompt.push_str("<start_of_turn>model\n");
prompt
}
// -------------------------
// OpenAI-compatible handler
// -------------------------
@@ -322,72 +315,68 @@ pub async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
// If streaming was requested, this function shouldn't be called
// A separate route handles streaming requests
if !request.stream.unwrap_or(false) {
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response())
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
}
Ok(chat_completions_stream(state, request).await.into_response())
}
pub async fn chat_completions_non_streaming_proxy(state: AppState, request: ChatCompletionRequest) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
// Non-streaming response - original implementation
let mut prompt = String::new();
pub async fn chat_completions_non_streaming_proxy(
state: AppState,
request: ChatCompletionRequest,
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
let prompt = build_gemma_prompt(&request.messages);
// Convert messages to a prompt string
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
prompt.push_str("Assistant: ");
let model_id = state.model_id.clone();
// Generate
let mut output = Vec::new();
{
// Recreate TextGeneration instance to ensure completely fresh state
// This prevents KV cache persistence that causes tensor shape mismatches
let fresh_text_gen = build_pipeline(state.build_args.clone());
let mut text_gen = state.text_generation.lock().await;
*text_gen = fresh_text_gen;
let mut buffer = Vec::new();
let max_tokens = request.max_tokens.unwrap_or(1000);
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
// Enforce model selection behavior: reject if a different model is requested
let configured_model = state.build_args.model_id.clone();
let requested_model = request.model.clone();
if requested_model.to_lowercase() != "default" {
let normalized_requested = normalize_model_id(&requested_model);
if normalized_requested != configured_model {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": format!("Error generating text: {}", e),
"type": "text_generation_error"
"message": format!(
"Requested model '{}' is not available. This server is running '{}' only.",
requested_model, configured_model
),
"type": "model_mismatch"
}
})),
));
}
}
if let Ok(text) = String::from_utf8(buffer) {
output.push(text);
let model_id = state.model_id.clone();
let mut buffer = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Reset per-request state without rebuilding the whole pipeline
text_gen.reset_state();
let max_tokens = request.max_tokens.unwrap_or(1000);
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Error generating text: {}", e) }
})),
));
}
}
let completion = output.join("");
let completion = match String::from_utf8(buffer) {
Ok(s) => s,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("UTF-8 conversion error: {}", e) }
})),
));
}
};
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
@@ -407,7 +396,6 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
finish_reason: "stop".to_string(),
}],
usage: Usage {
// still rough estimates
prompt_tokens: prompt.len() / 4,
completion_tokens: completion.len() / 4,
total_tokens: (prompt.len() + completion.len()) / 4,
@@ -415,185 +403,195 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
};
Ok(Json(response).into_response())
}
// -------------------------
// Streaming implementation
// -------------------------
pub async fn chat_completions_stream(
state: AppState,
chat_completion_request: ChatCompletionRequest,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
// Call the handler function
handle_streaming_request(state, chat_completion_request).await
request: ChatCompletionRequest,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
handle_streaming_request(state, request).await
}
/// Handle streaming requests with Server-Sent Events (SSE)
async fn handle_streaming_request(
state: AppState,
request: ChatCompletionRequest
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
// Generate a unique ID for this completion
state: AppState,
request: ChatCompletionRequest,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
// Validate requested model vs configured model
let configured_model = state.build_args.model_id.clone();
let requested_model = request.model.clone();
if requested_model.to_lowercase() != "default" {
let normalized_requested = normalize_model_id(&requested_model);
if normalized_requested != configured_model {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": format!(
"Requested model '{}' is not available. This server is running '{}' only.",
requested_model, configured_model
),
"type": "model_mismatch"
}
})),
));
}
}
// Generate a unique ID and metadata
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let model_id = state.model_id.clone();
// Convert messages to a prompt string (same as non-streaming)
let mut prompt = String::new();
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
// Build prompt
let prompt = build_gemma_prompt(&request.messages);
tracing::debug!("Formatted prompt: {}", prompt);
// Channel for streaming SSE events
let (tx, rx) = mpsc::unbounded_channel::<Result<Event, Infallible>>();
// Send initial role event
let initial_chunk = ChatCompletionChunk {
id: response_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: Some("assistant".to_string()), content: None },
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&initial_chunk) {
let _ = tx.send(Ok(Event::default().data(json)));
}
prompt.push_str("Assistant: ");
// Generate text using existing buffer-based approach
let mut buffer = Vec::new();
{
// Recreate TextGeneration instance to ensure completely fresh state
// This prevents KV cache persistence that causes tensor shape mismatches
let fresh_text_gen = build_pipeline(state.build_args.clone());
let mut text_gen = state.text_generation.lock().await;
*text_gen = fresh_text_gen;
// Spawn generation task that streams tokens as they are generated
let state_clone = state.clone();
let response_id_clone = response_id.clone();
tokio::spawn(async move {
let max_tokens = request.max_tokens.unwrap_or(1000);
let mut text_gen = state_clone.text_generation.lock().await;
text_gen.reset_state();
// Stream tokens via callback with repetition detection
let mut recent_tokens = Vec::new();
let mut repetition_count = 0;
const MAX_REPETITION_COUNT: usize = 5; // Stop after 5 consecutive repetitions
const REPETITION_WINDOW: usize = 8; // Look at last 8 tokens for patterns
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": format!("Error generating text: {}", e),
"type": "text_generation_error"
let result = text_gen.run_with_streaming(&prompt, max_tokens, |token| {
// Debug log to verify token content
tracing::debug!("Streaming token: '{}'", token);
// Skip sending empty tokens
if token.is_empty() {
tracing::debug!("Skipping empty token");
return Ok(());
}
// Add token to recent history for repetition detection
recent_tokens.push(token.to_string());
if recent_tokens.len() > REPETITION_WINDOW {
recent_tokens.remove(0);
}
// Check for repetitive patterns
if recent_tokens.len() >= 4 {
let last_token = &recent_tokens[recent_tokens.len() - 1];
let second_last = &recent_tokens[recent_tokens.len() - 2];
// Check if we're repeating the same token or pattern
if last_token == second_last ||
(last_token.trim() == "plus" && second_last.trim() == "plus") ||
(recent_tokens.len() >= 6 &&
recent_tokens[recent_tokens.len()-3..].iter().all(|t| t.trim() == "plus" || t.trim().is_empty())) {
repetition_count += 1;
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
if repetition_count >= MAX_REPETITION_COUNT {
tracing::info!("Stopping generation due to excessive repetition");
return Err(anyhow::Error::msg("Repetition detected - stopping generation"));
}
})),
));
}
}
// Convert buffer to string
let generated_text = match String::from_utf8(buffer) {
Ok(text) => text,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": {
"message": format!("Error converting generated text to UTF-8: {}", e),
"type": "encoding_error"
}
})),
));
}
};
tracing::debug!("Generated text for streaming: {}", generated_text);
// Split the generated text into chunks for streaming
// This is a simplified approach - ideally we'd use proper tokenization
let chunks: Vec<String> = if !generated_text.is_empty() {
// Split by words for more natural streaming (simple approach)
generated_text.split_whitespace()
.map(|word| word.to_string() + " ")
.collect()
} else {
// If no text was generated, provide a default response
vec!["Abraham Lincoln was the 16th president of the United States.".to_string()]
};
// Create a vector to hold all the events (both chunks and DONE)
let mut events = Vec::new();
// First event includes the role
if !chunks.is_empty() {
let first_chunk = &chunks[0];
let chunk = ChatCompletionChunk {
id: response_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta {
role: Some("assistant".to_string()),
content: Some(first_chunk.clone()),
},
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
events.push(Ok(Event::default().data(json)));
}
// Add remaining chunks
for chunk_text in chunks.iter().skip(1) {
} else {
repetition_count = 0; // Reset counter if pattern breaks
}
}
let chunk = ChatCompletionChunk {
id: response_id.clone(),
id: response_id_clone.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta {
role: None,
content: Some(chunk_text.clone()),
},
delta: Delta { role: None, content: Some(token.to_string()) },
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
events.push(Ok(Event::default().data(json)));
tracing::debug!("Sending chunk with content: '{}'", token);
let _ = tx.send(Ok(Event::default().data(json)));
}
}
Ok(())
}).await;
// Add final chunk with finish_reason
// Log result of generation
match result {
Ok(_) => tracing::debug!("Text generation completed successfully"),
Err(e) => tracing::info!("Text generation stopped: {}", e),
}
// Send final stop chunk and DONE marker
let final_chunk = ChatCompletionChunk {
id: response_id,
id: response_id_clone.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id,
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta {
role: None,
content: None,
},
delta: Delta { role: None, content: None },
finish_reason: Some("stop".to_string()),
}],
};
if let Ok(json) = serde_json::to_string(&final_chunk) {
events.push(Ok(Event::default().data(json)));
let _ = tx.send(Ok(Event::default().data(json)));
}
}
// Add [DONE] event
events.push(Ok(Event::default().data("[DONE]")));
// Create a stream from the events
let stream = stream::iter(events);
// Return the SSE stream
let _ = tx.send(Ok(Event::default().data("[DONE]")));
});
// Convert receiver into a Stream for SSE
let stream = UnboundedReceiverStream::new(rx);
Ok(Sse::new(stream))
}
// -------------------------
// Router
// -------------------------
pub fn create_router(app_state: AppState) -> Router {
let cors = CorsLayer::new()
.allow_headers(Any)
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.layer(cors)
.with_state(app_state)
}
/// Handler for GET /v1/models - returns list of available models
async fn list_models() -> Json<ModelListResponse> {
pub async fn list_models() -> Json<ModelListResponse> {
// Get all available model variants from the Which enum
let models = vec![
Model {
@@ -700,24 +698,6 @@ async fn list_models() -> Json<ModelListResponse> {
})
}
// -------------------------
// Router
// -------------------------
pub fn create_router(app_state: AppState) -> Router {
let cors = CorsLayer::new()
.allow_headers(Any)
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
.layer(cors)
.with_state(app_state)
}
#[cfg(test)]
mod tests {
@@ -725,92 +705,59 @@ mod tests {
use crate::openai_types::{Message, MessageContent};
use either::Either;
#[tokio::test]
async fn test_models_list_endpoint() {
println!("[DEBUG_LOG] Testing models list endpoint");
let response = list_models().await;
let models_response = response.0;
// Verify response structure
assert_eq!(models_response.object, "list");
assert_eq!(models_response.data.len(), 16);
// Verify some key models are present
let model_ids: Vec<String> = models_response.data.iter().map(|m| m.id.clone()).collect();
assert!(model_ids.contains(&"gemma-2b".to_string()));
assert!(model_ids.contains(&"gemma-7b".to_string()));
assert!(model_ids.contains(&"gemma-3-1b-it".to_string()));
assert!(model_ids.contains(&"codegemma-2b-it".to_string()));
// Verify model structure
for model in &models_response.data {
assert_eq!(model.object, "model");
assert_eq!(model.owned_by, "google");
assert_eq!(model.created, 1686935002);
assert!(!model.id.is_empty());
}
println!("[DEBUG_LOG] Models list endpoint test passed - {} models available", models_response.data.len());
#[test]
fn test_build_gemma_prompt() {
let messages = vec![
Message {
role: "system".to_string(),
content: Some(MessageContent(Either::Left("System message".to_string()))),
name: None,
},
Message {
role: "user".to_string(),
content: Some(MessageContent(Either::Left("Knock knock.".to_string()))),
name: None,
},
Message {
role: "assistant".to_string(),
content: Some(MessageContent(Either::Left("Who's there?".to_string()))),
name: None,
},
Message {
role: "user".to_string(),
content: Some(MessageContent(Either::Left("Gemma.".to_string()))),
name: None,
},
];
let prompt = build_gemma_prompt(&messages);
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
<start_of_turn>model\nWho's there?<end_of_turn>\n\
<start_of_turn>user\nGemma.<end_of_turn>\n\
<start_of_turn>model\n";
assert_eq!(prompt, expected);
}
#[tokio::test]
async fn test_reproduce_tensor_shape_mismatch() {
// Create a test app state with Gemma 3 model (same as the failing request)
let mut args = PipelineArgs::default();
args.model_id = "google/gemma-3-1b-it".to_string();
args.which = Which::InstructV3_1B;
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
// This should reproduce the same conditions as the curl script
let text_generation = build_pipeline(args.clone());
let app_state = AppState {
text_generation: Arc::new(Mutex::new(text_generation)),
model_id: "gemma-3-1b-it".to_string(),
build_args: args,
};
#[test]
fn test_empty_messages() {
let messages: Vec<Message> = vec![];
let prompt = build_gemma_prompt(&messages);
assert_eq!(prompt, "<start_of_turn>model\n");
}
// Create the same request as the curl script
let request = ChatCompletionRequest {
model: "gemma-3-1b-it".to_string(),
messages: vec![Message {
#[test]
fn test_missing_content() {
let messages = vec![
Message {
role: "user".to_string(),
content: Some(MessageContent(Either::Left("What is the capital of France?".to_string()))),
content: None,
name: None,
}],
max_tokens: Some(128),
stream: Some(true),
temperature: None,
top_p: None,
logprobs: false,
n_choices: 1,
};
}
];
println!("[DEBUG_LOG] Attempting to reproduce tensor shape mismatch error...");
// This should trigger the same error as the curl script
let result = handle_streaming_request(app_state, request).await;
match result {
Ok(_) => {
println!("[DEBUG_LOG] No error occurred - this suggests the issue might be fixed or environmental");
}
Err((status_code, json_error)) => {
println!("[DEBUG_LOG] Error reproduced! Status: {:?}", status_code);
println!("[DEBUG_LOG] Error details: {:?}", json_error);
// Check if this is the expected tensor shape mismatch error
if let Some(error_obj) = json_error.0.as_object() {
if let Some(error_details) = error_obj.get("error").and_then(|e| e.as_object()) {
if let Some(message) = error_details.get("message").and_then(|m| m.as_str()) {
assert!(message.contains("shape mismatch"),
"Expected shape mismatch error, got: {}", message);
println!("[DEBUG_LOG] Successfully reproduced tensor shape mismatch error");
}
}
}
}
}
let prompt = build_gemma_prompt(&messages);
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
}
}

View File

@@ -20,6 +20,8 @@ pub struct TextGeneration {
repeat_last_n: usize,
// Cache for repeat penalty computation to avoid redundant calculations
penalty_cache: HashMap<usize, f32>,
// Context window size for sliding window context (default: 64 tokens)
context_window_size: usize,
}
impl TextGeneration {
@@ -55,6 +57,7 @@ impl TextGeneration {
cpu_device,
try_primary_device,
penalty_cache: HashMap::new(),
context_window_size: 64, // Default sliding window size for better context preservation
}
}
@@ -192,9 +195,14 @@ impl TextGeneration {
// Track overall performance
let start_time = std::time::Instant::now();
// Clear penalty cache for new generation
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache for new generation");
// Keep penalty cache across generation for better repetition prevention
// Only clear cache if it becomes too large to prevent memory bloat
if self.penalty_cache.len() > 10000 {
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache due to size limit");
} else {
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention");
}
// Phase 1: Tokenize input
let tokenize_start = std::time::Instant::now();
@@ -440,9 +448,14 @@ impl TextGeneration {
// Track overall performance
let start_time = std::time::Instant::now();
// Clear penalty cache for new generation
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache for new generation (API mode)");
// Keep penalty cache across generation for better repetition prevention
// Only clear cache if it becomes too large to prevent memory bloat
if self.penalty_cache.len() > 10000 {
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache due to size limit (API mode)");
} else {
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (API mode)");
}
// Phase 1: Tokenize input
let tokenize_start = std::time::Instant::now();
@@ -573,10 +586,18 @@ impl TextGeneration {
for index in 0..sample_len {
let token_start = std::time::Instant::now();
let context_size = if index > 0 { 1 } else { tokens.len() };
// Use sliding window context instead of single token to preserve context and reduce repetition
let context_size = if index > 0 {
std::cmp::min(self.context_window_size, tokens.len())
} else {
tokens.len()
};
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
tracing::debug!("API standard model: Using sliding window context: {} tokens (from position {})",
ctxt.len(), start_pos);
// Track tensor operations and model forward pass
let forward_start = std::time::Instant::now();
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
@@ -629,6 +650,266 @@ impl TextGeneration {
Ok(())
}
// Run text generation with streaming callback for each token
pub async fn run_with_streaming<F>(&mut self, prompt: &str, sample_len: usize, mut token_callback: F) -> Result<String>
where
F: FnMut(&str) -> Result<()>,
{
// Track overall performance
let start_time = std::time::Instant::now();
// Keep penalty cache across generation for better repetition prevention
// Only clear cache if it becomes too large to prevent memory bloat
if self.penalty_cache.len() > 10000 {
self.penalty_cache.clear();
tracing::debug!("Cleared penalty cache due to size limit (streaming mode)");
} else {
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (streaming mode)");
}
// Phase 1: Tokenize input
let tokenize_start = std::time::Instant::now();
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokenize_time = tokenize_start.elapsed();
tracing::debug!("Streaming Tokenization completed in {:.2?}", tokenize_time);
tracing::debug!("Streaming Input tokens: {}", tokens.len());
// Collect all output for final return
let mut full_output = String::new();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
tracing::warn!("Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup");
eos_token
}
};
// Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant
let needs_special_handling = match &self.model {
Model::V2(_) => true,
Model::V3(_) => true,
_ => false,
};
// Track generation timing
let start_gen = std::time::Instant::now();
// Track per-token generation timing for performance analysis
let mut token_times = Vec::new();
let mut forward_times = Vec::new();
let mut repeat_penalty_times = Vec::new();
let mut sampling_times = Vec::new();
// For Model2 and Model3, we need to use a special approach for shape compatibility
if needs_special_handling {
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models (streaming)");
tracing::debug!("Streaming: sample_len = {}", sample_len);
// Initial generation with the full prompt
let forward_start = std::time::Instant::now();
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let mut logits = self.execute_with_fallback(&input, 0)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let forward_time = forward_start.elapsed();
forward_times.push(forward_time);
tracing::debug!("Streaming: About to enter generation loop with sample_len = {}", sample_len);
for gen_index in 0..sample_len {
tracing::debug!("Streaming: Starting generation iteration {} / {}", gen_index + 1, sample_len);
let token_start = std::time::Instant::now();
// Apply repeat penalty using optimized cached implementation
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
repeat_penalty_times.push(repeat_time);
// Track token sampling
let sampling_start = std::time::Instant::now();
let next_token = self.logits_processor.sample(&current_logits)?;
let sampling_time = sampling_start.elapsed();
sampling_times.push(sampling_time);
tokens.push(next_token);
generated_tokens += 1;
tracing::debug!("Streaming: Generated token {} (id: {}), eos: {}, eot: {}",
next_token, next_token, eos_token, eot_token);
if next_token == eos_token || next_token == eot_token {
tracing::debug!("Streaming: Breaking due to end token");
break;
}
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
full_output.push_str(&token_text);
// Call the streaming callback with this token
token_callback(&token_text)?;
}
// For the next iteration, use single token to avoid shape mismatch
let forward_start = std::time::Instant::now();
tracing::debug!("Streaming: Preparing next forward pass with {} tokens", tokens.len());
// Use just the last token for subsequent iterations to avoid shape mismatch
// This is required for Gemma model's attention mechanism compatibility
let context_tokens = &tokens[(tokens.len()-1)..];
let start_pos = tokens.len() - 1;
tracing::debug!("Streaming: Using single token context for Gemma: {} tokens (from position {})",
context_tokens.len(), start_pos);
let new_input = match Tensor::new(context_tokens, &self.device) {
Ok(tensor) => tensor,
Err(e) => {
tracing::error!("Streaming: Failed to create input tensor: {}", e);
return Err(e.into());
}
};
let new_input = match new_input.unsqueeze(0) {
Ok(tensor) => tensor,
Err(e) => {
tracing::error!("Streaming: Failed to unsqueeze input tensor: {}", e);
return Err(e.into());
}
};
tracing::debug!("Streaming: About to call execute_with_fallback for iteration {} with start_pos {}", gen_index + 1, start_pos);
logits = match self.execute_with_fallback(&new_input, start_pos) {
Ok(result) => result,
Err(e) => {
tracing::error!("Streaming: Forward pass failed: {}", e);
return Err(e);
}
};
logits = match logits.squeeze(0) {
Ok(result) => result,
Err(e) => {
tracing::error!("Streaming: Failed to squeeze logits (dim 0): {}", e);
return Err(e.into());
}
};
logits = match logits.squeeze(0) {
Ok(result) => result,
Err(e) => {
tracing::error!("Streaming: Failed to squeeze logits (dim 0 again): {}", e);
return Err(e.into());
}
};
logits = match logits.to_dtype(DType::F32) {
Ok(result) => result,
Err(e) => {
tracing::error!("Streaming: Failed to convert logits to F32: {}", e);
return Err(e.into());
}
};
let forward_time = forward_start.elapsed();
forward_times.push(forward_time);
tracing::debug!("Streaming: Forward pass completed for iteration {}", gen_index + 1);
let token_time = token_start.elapsed();
token_times.push(token_time);
// Yield to allow other async tasks to run
tokio::task::yield_now().await;
}
} else {
// Standard approach for other models
tracing::debug!("Using standard generation approach (streaming)");
for index in 0..sample_len {
let token_start = std::time::Instant::now();
// Use sliding window context instead of single token to preserve context and reduce repetition
let context_size = if index > 0 {
std::cmp::min(self.context_window_size, tokens.len())
} else {
tokens.len()
};
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
tracing::debug!("Standard model: Using sliding window context: {} tokens (from position {})",
ctxt.len(), start_pos);
// Track tensor operations and model forward pass
let forward_start = std::time::Instant::now();
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.execute_with_fallback(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let forward_time = forward_start.elapsed();
forward_times.push(forward_time);
// Apply repeat penalty using optimized cached implementation
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
repeat_penalty_times.push(repeat_time);
// Track token sampling
let sampling_start = std::time::Instant::now();
let next_token = self.logits_processor.sample(&logits)?;
let sampling_time = sampling_start.elapsed();
sampling_times.push(sampling_time);
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(token_text) = self.tokenizer.next_token(next_token)? {
full_output.push_str(&token_text);
// Call the streaming callback with this token
token_callback(&token_text)?;
}
let token_time = token_start.elapsed();
token_times.push(token_time);
}
}
let dt = start_gen.elapsed();
// Phase 3: Final decoding
let decode_start = std::time::Instant::now();
// Decode any remaining tokens but don't send through callback to avoid repetition
// The tokens were already streamed individually in the generation loop above
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
full_output.push_str(&rest);
// Note: NOT calling token_callback(&rest) here to prevent token repetition
// Individual tokens were already streamed via the callback in the generation loop
}
let decode_time = decode_start.elapsed();
// Log performance metrics
Self::log_performance_metrics(
dt, generated_tokens, &token_times, &forward_times,
&repeat_penalty_times, &sampling_times, tokenize_time,
decode_time, start_time, "Streaming"
);
Ok(full_output)
}
// Helper function for logging performance metrics
fn log_performance_metrics(

View File

@@ -40,7 +40,8 @@ impl TokenOutputStream {
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
if text.len() > prev_text.len() {
// Modified to include all tokens, not just alphanumeric ones
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();

View File

@@ -49,4 +49,13 @@ features = [
"fast-rng", # Use a faster (but still sufficiently random) RNG
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
"js", # Enable JavaScript RNG for WASM targets
]
]
[package.metadata.kube]
image = "ghcr.io/geoffsee/leptos-chat:latest"
replicas = 1
port = 8788
resources.cpu = "500m"
resources.memory = "256Mi"
#ingress.host = "my-service.example.com"
#env = { RUST_LOG = "info", DATABASE_URL = "postgres://..." }

View File

@@ -15,7 +15,7 @@ use async_openai_wasm::{
Client,
};
use async_openai_wasm::config::OpenAIConfig;
use async_openai_wasm::types::{ChatCompletionResponseStream, Model};
use async_openai_wasm::types::{ChatCompletionResponseStream, Model, Role, FinishReason};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
@@ -127,7 +127,7 @@ async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
}
async fn send_chat_request(chat_request: ChatRequest) -> ChatCompletionResponseStream {
let config = OpenAIConfig::new().with_api_base("http://localhost:8080".to_string());
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
let client = Client::with_config(config);
let mut typed_chat = async_openai_wasm::types::CreateChatCompletionRequest {
@@ -205,7 +205,7 @@ async fn send_chat_request(chat_request: ChatRequest) -> ChatCompletionResponseS
// Err("leptos-chat chat request only supported on wasm32 target".to_string())
// }
const DEFAULT_MODEL: &str = "gemma-2b-it";
const DEFAULT_MODEL: &str = "default";
#[component]
fn ChatInterface() -> impl IntoView {
@@ -272,11 +272,37 @@ fn ChatInterface() -> impl IntoView {
let history_count = messages.with_untracked(|msgs| {
let count = msgs.len();
for msg in msgs.iter() {
let message = ChatCompletionRequestUserMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build message");
chat_messages.push(message.into());
match msg.role.as_str() {
"user" => {
let message = ChatCompletionRequestUserMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build user message");
chat_messages.push(message.into());
}
"assistant" => {
let message = ChatCompletionRequestAssistantMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build assistant message");
chat_messages.push(message.into());
}
"system" => {
let message = ChatCompletionRequestSystemMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build system message");
chat_messages.push(message.into());
}
_ => {
// Default to user message for unknown roles
let message = ChatCompletionRequestUserMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build default message");
chat_messages.push(message.into());
}
}
}
count
});
@@ -319,51 +345,69 @@ fn ChatInterface() -> impl IntoView {
Ok(mut stream) => {
log::info!("[DEBUG_LOG] send_message: Successfully created stream, starting to receive response");
// Insert a placeholder assistant message to append into
let assistant_id = Uuid::new_v4().to_string();
set_messages.update(|msgs| {
msgs.push_back(Message {
id: assistant_id.clone(),
role: "assistant".to_string(),
content: String::new(),
timestamp: Date::now(),
});
});
// Defer creating assistant message until we receive role=assistant from the stream
let mut assistant_created = false;
let mut content_appended = false;
let mut chunks_received = 0;
// Stream loop: append deltas to the last message
// Stream loop: handle deltas and finish events
while let Some(next) = stream.next().await {
match next {
Ok(chunk) => {
chunks_received += 1;
// Try to pull out the content delta in a tolerant way.
// async-openai 0.28.x stream chunk usually looks like:
// choices[0].delta.content: Option<String>
let mut delta_txt = String::new();
if let Some(choice) = chunk.choices.get(0) {
// Newer message API may expose different shapes; try common ones
// 1) Simple string content delta
if let Some(content) = &choice.delta.content {
delta_txt.push_str(content);
}
// 2) Some providers pack text under .delta.role/.delta.<other>
// If nothing extracted, ignore quietly.
// If a finish_reason arrives, we could stop early,
// but usually the stream naturally ends.
}
if !delta_txt.is_empty() {
set_messages.update(|msgs| {
if let Some(last) = msgs.back_mut() {
if last.role == "assistant" {
last.content.push_str(&delta_txt);
last.timestamp = Date::now();
// 1) Create assistant message when role arrives
if !assistant_created {
if let Some(role) = &choice.delta.role {
if role == &Role::Assistant {
assistant_created = true;
let assistant_id = Uuid::new_v4().to_string();
set_messages.update(|msgs| {
msgs.push_back(Message {
id: assistant_id,
role: "assistant".to_string(),
content: String::new(),
timestamp: Date::now(),
});
});
}
}
});
}
// 2) Append content tokens when provided
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
// If content arrives before role, create assistant message now
if !assistant_created {
assistant_created = true;
let assistant_id = Uuid::new_v4().to_string();
set_messages.update(|msgs| {
msgs.push_back(Message {
id: assistant_id,
role: "assistant".to_string(),
content: String::new(),
timestamp: Date::now(),
});
});
}
content_appended = true;
set_messages.update(|msgs| {
if let Some(last) = msgs.back_mut() {
if last.role == "assistant" {
last.content.push_str(content);
last.timestamp = Date::now();
}
}
});
}
}
// 3) Stop on finish_reason=="stop" (mirrors [DONE])
if let Some(reason) = &choice.finish_reason {
if reason == &FinishReason::Stop {
log::info!("[DEBUG_LOG] send_message: Received finish_reason=stop after {} chunks", chunks_received);
break;
}
}
}
}
Err(e) => {
@@ -381,6 +425,21 @@ fn ChatInterface() -> impl IntoView {
}
}
}
// Cleanup: If we created an assistant message but no content ever arrived, remove the empty message
if assistant_created && !content_appended {
set_messages.update(|msgs| {
let should_pop = msgs
.back()
.map(|m| m.role == "assistant" && m.content.is_empty())
.unwrap_or(false);
if should_pop {
log::info!("[DEBUG_LOG] send_message: Removing empty assistant message (no content received)");
msgs.pop_back();
}
});
}
log::info!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
}
Err(e) => {

View File

@@ -24,3 +24,13 @@ embeddings-engine = { path = "../embeddings-engine" }
# Dependencies for inference functionality
inference-engine = { path = "../inference-engine" }
[package.metadata.kube]
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1
port = 8080
resources.cpu = "500m"
resources.memory = "256Mi"
#ingress.host = "my-service.example.com"
#env = { RUST_LOG = "info", DATABASE_URL = "postgres://..." }