cleanup, add ci

This commit is contained in:
geoffsee
2025-08-31 10:31:07 -04:00
parent 419e1c2ea7
commit f5d2a85f2e
42 changed files with 1740 additions and 705 deletions

View File

@@ -3,18 +3,6 @@ name = "inference-engine"
version = "0.1.0"
edition = "2021"
[[bin]]
name="gemma_inference"
path = "src/gemma_inference.rs"
required-features = ["bin"]
[[bin]]
name="llama_inference"
path = "src/llama_inference.rs"
required-features = ["bin"]
[dependencies]
accelerate-src = { version = "0.3.2", optional = true }
candle-datasets = { version = "=0.9.1", optional = true }

View File

@@ -30,4 +30,4 @@ pub trait ModelInference {
}
/// Factory function type for creating model inference implementations
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;

View File

@@ -1,19 +1,19 @@
// Expose modules for testing and library usage
pub mod token_output_stream;
pub mod model;
pub mod text_generation;
pub mod utilities_lib;
pub mod openai_types;
pub mod text_generation;
pub mod token_output_stream;
pub mod utilities_lib;
// pub mod cli;
pub mod server;
pub mod inference;
pub mod server;
// Re-export key components for easier access
pub use inference::ModelInference;
pub use model::{Model, Which};
pub use server::{create_router, AppState};
pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream;
pub use server::{AppState, create_router};
pub use inference::ModelInference;
use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -1,8 +1,8 @@
// use candle_core::Tensor;
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which {
@@ -52,7 +52,11 @@ pub enum Model {
}
impl Model {
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
pub fn forward(
&mut self,
input_ids: &candle_core::Tensor,
pos: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
@@ -88,7 +92,13 @@ impl Which {
pub fn is_instruct_model(&self) -> bool {
match self {
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
Self::Base2B
| Self::Base7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::BaseV2_2B
| Self::BaseV2_9B
| Self::BaseV3_1B => false,
_ => true,
}
}
@@ -100,4 +110,4 @@ impl Which {
pub fn is_llama_model(&self) -> bool {
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
}
}
}

View File

@@ -10,7 +10,10 @@ pub struct MessageInnerContent(
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
fn schema() -> (
&'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
@@ -45,12 +48,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
fn schema() -> (
&'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
(
"MessageContent",
utoipa::openapi::RefOr::T(message_content_schema()),
)
}
}
@@ -213,4 +222,4 @@ pub struct ModelListResponse {
pub object: String,
/// Array of available models
pub data: Vec<Model>,
}
}

View File

@@ -6,19 +6,22 @@ use axum::{
Json, Router,
};
use futures_util::stream::{self, Stream};
use tokio_stream::wrappers::UnboundedReceiverStream;
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
use crate::openai_types::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
};
use crate::Which;
use either::Either;
use serde_json::Value;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value;
// -------------------------
// Shared app state
// -------------------------
@@ -62,12 +65,15 @@ fn normalize_model_id(model_id: &str) -> String {
fn build_gemma_prompt(messages: &[Message]) -> String {
let mut prompt = String::new();
for message in messages {
match message.role.as_str() {
"system" => {
if let Some(MessageContent(Either::Left(content))) = &message.content {
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
prompt.push_str(&format!(
"<start_of_turn>system\n{}<end_of_turn>\n",
content
));
}
}
"user" => {
@@ -83,7 +89,7 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
_ => {}
}
}
prompt.push_str("<start_of_turn>model\n");
prompt
}
@@ -97,9 +103,13 @@ pub async fn chat_completions(
Json(request): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
if !request.stream.unwrap_or(false) {
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
return Ok(chat_completions_non_streaming_proxy(state, request)
.await
.into_response());
}
Ok(chat_completions_stream(state, request).await.into_response())
Ok(chat_completions_stream(state, request)
.await
.into_response())
}
pub async fn chat_completions_non_streaming_proxy(
@@ -136,7 +146,9 @@ pub async fn chat_completions_non_streaming_proxy(
ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => {
// For Llama, just use the last user message for now
request.messages.last()
request
.messages
.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
@@ -147,46 +159,47 @@ pub async fn chat_completions_non_streaming_proxy(
};
// Get streaming receiver based on model type
let rx = match state.model_type {
ModelType::Gemma => {
if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
let rx =
match state.model_type {
ModelType::Gemma => {
if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
}))
));
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
})),
));
}
}
}
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
}))
));
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
})),
));
}
}
}
};
};
// Collect all tokens from the stream
let mut completion = String::new();
@@ -281,7 +294,9 @@ async fn handle_streaming_request(
ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => {
// For Llama, just use the last user message for now
request.messages.last()
request
.messages
.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
@@ -303,7 +318,10 @@ async fn handle_streaming_request(
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: Some("assistant".to_string()), content: None },
delta: Delta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None,
}],
};
@@ -324,7 +342,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
})),
));
}
}
@@ -333,7 +351,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
}))
})),
));
}
}
@@ -348,7 +366,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
})),
));
}
}
@@ -357,7 +375,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
}))
})),
));
}
}
@@ -386,16 +404,20 @@ async fn handle_streaming_request(
if recent_tokens.len() > REPETITION_WINDOW {
recent_tokens.remove(0);
}
// Check for repetitive patterns
if recent_tokens.len() >= 4 {
let last_token = &recent_tokens[recent_tokens.len() - 1];
let second_last = &recent_tokens[recent_tokens.len() - 2];
if last_token == second_last {
repetition_count += 1;
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
tracing::warn!(
"Detected repetition pattern: '{}' (count: {})",
last_token,
repetition_count
);
if repetition_count >= MAX_REPETITION_COUNT {
tracing::info!("Stopping generation due to excessive repetition");
break;
@@ -412,11 +434,14 @@ async fn handle_streaming_request(
model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: None, content: Some(token) },
delta: Delta {
role: None,
content: Some(token),
},
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
let _ = tx.send(Ok(Event::default().data(json)));
}
@@ -436,7 +461,10 @@ async fn handle_streaming_request(
model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: None, content: None },
delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
@@ -451,8 +479,6 @@ async fn handle_streaming_request(
Ok(Sse::new(stream))
}
// -------------------------
// Router
// -------------------------
@@ -647,7 +673,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
})
}
#[cfg(test)]
mod tests {
use super::*;
@@ -681,10 +706,7 @@ mod tests {
let prompt = build_gemma_prompt(&messages);
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
<start_of_turn>model\nWho's there?<end_of_turn>\n\
<start_of_turn>user\nGemma.<end_of_turn>\n\
<start_of_turn>model\n";
let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
assert_eq!(prompt, expected);
}
@@ -698,15 +720,13 @@ mod tests {
#[test]
fn test_missing_content() {
let messages = vec![
Message {
role: "user".to_string(),
content: None,
name: None,
}
];
let messages = vec![Message {
role: "user".to_string(),
content: None,
name: None,
}];
let prompt = build_gemma_prompt(&messages);
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
assert_eq!(prompt, "<start_of_turn>model\n");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -84,4 +84,4 @@ impl TokenOutputStream {
self.prev_index = 0;
self.current_index = 0;
}
}
}

View File

@@ -147,7 +147,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
) -> Result<Vec<std::path::PathBuf>> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
@@ -164,4 +165,4 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}
}

View File

@@ -9,7 +9,10 @@ mod tests {
// Test a few representative model variants
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
assert_eq!(
Which::InstructV1_1_2B.to_model_id(),
"google/gemma-1.1-2b-it"
);
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
@@ -64,4 +67,4 @@ mod tests {
// Note: Testing the Model enum's forward method would require creating actual model instances,
// which is complex and would require loading model weights. This is better suited for
// integration tests or mocking the models.
}
}

View File

@@ -106,7 +106,7 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
// Create a mock TextGeneration instance
// Since we can't easily create a full TextGeneration instance without a model,
// we'll test the logic by creating a simple struct with the necessary fields
@@ -115,7 +115,7 @@ mod tests {
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
@@ -167,16 +167,17 @@ mod tests {
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 1.0, // No penalty
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let (result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// With no penalty, logits should be unchanged
assert_eq!(result_data, logits_data);
Ok(())
@@ -189,13 +190,13 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
@@ -238,16 +239,17 @@ mod tests {
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0, // Apply penalty
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let (result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
assert_eq!(result_data, expected);
@@ -261,13 +263,13 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
@@ -308,20 +310,21 @@ mod tests {
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
// First call should cache the penalty for token 1
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let (_result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
// Cache should contain the penalized value for token 1
assert!(mock_gen.penalty_cache.contains_key(&1));
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
Ok(())
}
@@ -332,13 +335,13 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens: Vec<u32> = vec![]; // Empty tokens
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
@@ -379,16 +382,17 @@ mod tests {
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let (result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// With empty tokens, logits should be unchanged
assert_eq!(result_data, logits_data);
Ok(())
@@ -401,13 +405,13 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
@@ -448,16 +452,17 @@ mod tests {
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let (result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
assert_eq!(result_data, expected);
@@ -471,52 +476,52 @@ mod tests {
// Since creating a real TextGeneration instance requires a Model which needs model weights,
// we'll create a test that demonstrates the method is now public and can be accessed.
// The comprehensive functionality testing is already covered by the mock tests above.
// Test data setup
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
// Test that we can create the necessary components
let tokenizer = create_test_tokenizer()?;
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
// This test verifies the method signature and that it's accessible from external code
// We could create a TextGeneration instance if we had a way to mock the Model,
// but for now we confirm that the existing mock tests cover the functionality
// and the method is properly exposed as public
println!("apply_cached_repeat_penalty method is now public and accessible for testing");
assert!(true);
Ok(())
}
// Integration test that demonstrates the method usage pattern
#[test]
#[test]
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
// This test demonstrates how the apply_cached_repeat_penalty method would be used
// in practice, even though we can't create a full TextGeneration instance in unit tests
let device = Device::Cpu;
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
// Test parameters that would be used with TextGeneration
let repeat_penalty = 1.2f32;
let repeat_last_n = 3usize;
let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
// Simulate the method's logic to verify it works as expected
let start_time = std::time::Instant::now();
if repeat_penalty != 1.0 {
let start_at = tokens.len().saturating_sub(repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
@@ -531,14 +536,14 @@ mod tests {
}
}
}
let _duration = start_time.elapsed();
// Verify that tokens were processed correctly
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
Ok(())
}

View File

@@ -1,7 +1,7 @@
use inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
use std::path::PathBuf;
use anyhow::Result;
use inference_engine::token_output_stream::TokenOutputStream;
use std::path::PathBuf;
use tokenizers::Tokenizer;
#[cfg(test)]
mod tests {
@@ -19,7 +19,7 @@ mod tests {
fn test_new_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Check that the token stream was created successfully
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
Ok(())
@@ -29,18 +29,18 @@ mod tests {
fn test_clear() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Add a token
let token_id = token_stream.get_token("<eos>").unwrap();
token_stream.next_token(token_id)?;
// Clear the stream
token_stream.clear();
// Check that the stream is empty by trying to decode all
let decoded = token_stream.decode_all()?;
assert_eq!(decoded, "");
Ok(())
}
@@ -48,15 +48,15 @@ mod tests {
fn test_get_token() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get a token that should exist
let eos_token = token_stream.get_token("<eos>");
assert!(eos_token.is_some());
// Get a token that shouldn't exist
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
assert!(nonexistent_token.is_none());
Ok(())
}
@@ -64,11 +64,14 @@ mod tests {
fn test_next_token_and_decode() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let hello_tokens = token_stream
.tokenizer()
.encode("Hello world", true)
.unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
let mut output = String::new();
for &token_id in token_ids {
@@ -76,16 +79,16 @@ mod tests {
output.push_str(&text);
}
}
// Get any remaining text
if let Some(rest) = token_stream.decode_rest()? {
output.push_str(&rest);
}
// Check the output
assert!(!output.is_empty());
assert_eq!(output.trim(), "Hello world");
Ok(())
}
@@ -93,22 +96,25 @@ mod tests {
fn test_decode_all() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let hello_tokens = token_stream
.tokenizer()
.encode("Hello world", true)
.unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all
let decoded = token_stream.decode_all()?;
// Check the output
assert_eq!(decoded.trim(), "Hello world");
Ok(())
}
@@ -116,14 +122,14 @@ mod tests {
fn test_into_inner() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get the inner tokenizer
let inner_tokenizer = token_stream.into_inner();
// Check that the inner tokenizer works
let encoded = inner_tokenizer.encode("Test", true).unwrap();
assert!(encoded.get_ids().len() > 0);
Ok(())
}
}
}