mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -3,18 +3,6 @@ name = "inference-engine"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
[[bin]]
|
||||
name="gemma_inference"
|
||||
path = "src/gemma_inference.rs"
|
||||
required-features = ["bin"]
|
||||
|
||||
[[bin]]
|
||||
name="llama_inference"
|
||||
path = "src/llama_inference.rs"
|
||||
required-features = ["bin"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { version = "0.3.2", optional = true }
|
||||
candle-datasets = { version = "=0.9.1", optional = true }
|
||||
|
@@ -30,4 +30,4 @@ pub trait ModelInference {
|
||||
}
|
||||
|
||||
/// Factory function type for creating model inference implementations
|
||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
||||
|
@@ -1,19 +1,19 @@
|
||||
// Expose modules for testing and library usage
|
||||
pub mod token_output_stream;
|
||||
pub mod model;
|
||||
pub mod text_generation;
|
||||
pub mod utilities_lib;
|
||||
pub mod openai_types;
|
||||
pub mod text_generation;
|
||||
pub mod token_output_stream;
|
||||
pub mod utilities_lib;
|
||||
// pub mod cli;
|
||||
pub mod server;
|
||||
pub mod inference;
|
||||
pub mod server;
|
||||
|
||||
// Re-export key components for easier access
|
||||
pub use inference::ModelInference;
|
||||
pub use model::{Model, Which};
|
||||
pub use server::{create_router, AppState};
|
||||
pub use text_generation::TextGeneration;
|
||||
pub use token_output_stream::TokenOutputStream;
|
||||
pub use server::{AppState, create_router};
|
||||
pub use inference::ModelInference;
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
@@ -1,8 +1,8 @@
|
||||
// use candle_core::Tensor;
|
||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum Which {
|
||||
@@ -52,7 +52,11 @@ pub enum Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &candle_core::Tensor,
|
||||
pos: usize,
|
||||
) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
@@ -88,7 +92,13 @@ impl Which {
|
||||
|
||||
pub fn is_instruct_model(&self) -> bool {
|
||||
match self {
|
||||
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
|
||||
Self::Base2B
|
||||
| Self::Base7B
|
||||
| Self::CodeBase2B
|
||||
| Self::CodeBase7B
|
||||
| Self::BaseV2_2B
|
||||
| Self::BaseV2_9B
|
||||
| Self::BaseV3_1B => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
@@ -100,4 +110,4 @@ impl Which {
|
||||
pub fn is_llama_model(&self) -> bool {
|
||||
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -10,7 +10,10 @@ pub struct MessageInnerContent(
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageInnerContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
fn schema() -> (
|
||||
&'static str,
|
||||
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||
) {
|
||||
(
|
||||
"MessageInnerContent",
|
||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||
@@ -45,12 +48,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContent(
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||
fn schema() -> (
|
||||
&'static str,
|
||||
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||
) {
|
||||
(
|
||||
"MessageContent",
|
||||
utoipa::openapi::RefOr::T(message_content_schema()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,4 +222,4 @@ pub struct ModelListResponse {
|
||||
pub object: String,
|
||||
/// Array of available models
|
||||
pub data: Vec<Model>,
|
||||
}
|
||||
}
|
||||
|
@@ -6,19 +6,22 @@ use axum::{
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
||||
use crate::openai_types::{
|
||||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
||||
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
||||
};
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use serde_json::Value;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
use serde_json::Value;
|
||||
// -------------------------
|
||||
// Shared app state
|
||||
// -------------------------
|
||||
@@ -62,12 +65,15 @@ fn normalize_model_id(model_id: &str) -> String {
|
||||
|
||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
|
||||
for message in messages {
|
||||
match message.role.as_str() {
|
||||
"system" => {
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
|
||||
prompt.push_str(&format!(
|
||||
"<start_of_turn>system\n{}<end_of_turn>\n",
|
||||
content
|
||||
));
|
||||
}
|
||||
}
|
||||
"user" => {
|
||||
@@ -83,7 +89,7 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt
|
||||
}
|
||||
@@ -97,9 +103,13 @@ pub async fn chat_completions(
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
if !request.stream.unwrap_or(false) {
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request)
|
||||
.await
|
||||
.into_response());
|
||||
}
|
||||
Ok(chat_completions_stream(state, request).await.into_response())
|
||||
Ok(chat_completions_stream(state, request)
|
||||
.await
|
||||
.into_response())
|
||||
}
|
||||
|
||||
pub async fn chat_completions_non_streaming_proxy(
|
||||
@@ -136,7 +146,9 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
@@ -147,46 +159,47 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
};
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
let rx =
|
||||
match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
));
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
));
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
let mut completion = String::new();
|
||||
@@ -281,7 +294,9 @@ async fn handle_streaming_request(
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
@@ -303,7 +318,10 @@ async fn handle_streaming_request(
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: Some("assistant".to_string()), content: None },
|
||||
delta: Delta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
@@ -324,7 +342,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -333,7 +351,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -348,7 +366,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -357,7 +375,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -386,16 +404,20 @@ async fn handle_streaming_request(
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
|
||||
if last_token == second_last {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
tracing::warn!(
|
||||
"Detected repetition pattern: '{}' (count: {})",
|
||||
last_token,
|
||||
repetition_count
|
||||
);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
break;
|
||||
@@ -412,11 +434,14 @@ async fn handle_streaming_request(
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: Some(token) },
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: Some(token),
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
@@ -436,7 +461,10 @@ async fn handle_streaming_request(
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: None },
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
},
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
};
|
||||
@@ -451,8 +479,6 @@ async fn handle_streaming_request(
|
||||
Ok(Sse::new(stream))
|
||||
}
|
||||
|
||||
|
||||
|
||||
// -------------------------
|
||||
// Router
|
||||
// -------------------------
|
||||
@@ -647,7 +673,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -681,10 +706,7 @@ mod tests {
|
||||
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
|
||||
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
|
||||
<start_of_turn>model\nWho's there?<end_of_turn>\n\
|
||||
<start_of_turn>user\nGemma.<end_of_turn>\n\
|
||||
<start_of_turn>model\n";
|
||||
let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
|
||||
|
||||
assert_eq!(prompt, expected);
|
||||
}
|
||||
@@ -698,15 +720,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_missing_content() {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: None,
|
||||
name: None,
|
||||
}
|
||||
];
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: None,
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
|
||||
assert_eq!(prompt, "<start_of_turn>model\n");
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -84,4 +84,4 @@ impl TokenOutputStream {
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -147,7 +147,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
) -> Result<Vec<std::path::PathBuf>> {
|
||||
let path = path.as_ref();
|
||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
@@ -164,4 +165,4 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
.map(|v| path.join(v))
|
||||
.collect();
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
}
|
||||
|
@@ -9,7 +9,10 @@ mod tests {
|
||||
// Test a few representative model variants
|
||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
||||
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
|
||||
assert_eq!(
|
||||
Which::InstructV1_1_2B.to_model_id(),
|
||||
"google/gemma-1.1-2b-it"
|
||||
);
|
||||
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
|
||||
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
|
||||
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
|
||||
@@ -64,4 +67,4 @@ mod tests {
|
||||
// Note: Testing the Model enum's forward method would require creating actual model instances,
|
||||
// which is complex and would require loading model weights. This is better suited for
|
||||
// integration tests or mocking the models.
|
||||
}
|
||||
}
|
||||
|
@@ -106,7 +106,7 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
// Create a mock TextGeneration instance
|
||||
// Since we can't easily create a full TextGeneration instance without a model,
|
||||
// we'll test the logic by creating a simple struct with the necessary fields
|
||||
@@ -115,7 +115,7 @@ mod tests {
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -167,16 +167,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 1.0, // No penalty
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// With no penalty, logits should be unchanged
|
||||
assert_eq!(result_data, logits_data);
|
||||
Ok(())
|
||||
@@ -189,13 +190,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -238,16 +239,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0, // Apply penalty
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
|
||||
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
|
||||
assert_eq!(result_data, expected);
|
||||
@@ -261,13 +263,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -308,20 +310,21 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
|
||||
// First call should cache the penalty for token 1
|
||||
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (_result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
// Cache should contain the penalized value for token 1
|
||||
assert!(mock_gen.penalty_cache.contains_key(&1));
|
||||
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -332,13 +335,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens: Vec<u32> = vec![]; // Empty tokens
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -379,16 +382,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// With empty tokens, logits should be unchanged
|
||||
assert_eq!(result_data, logits_data);
|
||||
Ok(())
|
||||
@@ -401,13 +405,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -448,16 +452,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
|
||||
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
|
||||
assert_eq!(result_data, expected);
|
||||
@@ -471,52 +476,52 @@ mod tests {
|
||||
// Since creating a real TextGeneration instance requires a Model which needs model weights,
|
||||
// we'll create a test that demonstrates the method is now public and can be accessed.
|
||||
// The comprehensive functionality testing is already covered by the mock tests above.
|
||||
|
||||
|
||||
// Test data setup
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
// Test that we can create the necessary components
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
|
||||
|
||||
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
|
||||
// This test verifies the method signature and that it's accessible from external code
|
||||
|
||||
|
||||
// We could create a TextGeneration instance if we had a way to mock the Model,
|
||||
// but for now we confirm that the existing mock tests cover the functionality
|
||||
// and the method is properly exposed as public
|
||||
|
||||
|
||||
println!("apply_cached_repeat_penalty method is now public and accessible for testing");
|
||||
assert!(true);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
// Integration test that demonstrates the method usage pattern
|
||||
#[test]
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
|
||||
// This test demonstrates how the apply_cached_repeat_penalty method would be used
|
||||
// in practice, even though we can't create a full TextGeneration instance in unit tests
|
||||
|
||||
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
|
||||
|
||||
|
||||
// Test parameters that would be used with TextGeneration
|
||||
let repeat_penalty = 1.2f32;
|
||||
let repeat_last_n = 3usize;
|
||||
let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
|
||||
|
||||
|
||||
// Simulate the method's logic to verify it works as expected
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
if repeat_penalty != 1.0 {
|
||||
let start_at = tokens.len().saturating_sub(repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
@@ -531,14 +536,14 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let _duration = start_time.elapsed();
|
||||
|
||||
|
||||
// Verify that tokens were processed correctly
|
||||
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
|
||||
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
|
||||
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
|
||||
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
|
||||
|
||||
|
||||
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
|
||||
Ok(())
|
||||
}
|
||||
|
@@ -1,7 +1,7 @@
|
||||
use inference_engine::token_output_stream::TokenOutputStream;
|
||||
use tokenizers::Tokenizer;
|
||||
use std::path::PathBuf;
|
||||
use anyhow::Result;
|
||||
use inference_engine::token_output_stream::TokenOutputStream;
|
||||
use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -19,7 +19,7 @@ mod tests {
|
||||
fn test_new_token_output_stream() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Check that the token stream was created successfully
|
||||
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
|
||||
Ok(())
|
||||
@@ -29,18 +29,18 @@ mod tests {
|
||||
fn test_clear() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Add a token
|
||||
let token_id = token_stream.get_token("<eos>").unwrap();
|
||||
token_stream.next_token(token_id)?;
|
||||
|
||||
|
||||
// Clear the stream
|
||||
token_stream.clear();
|
||||
|
||||
|
||||
// Check that the stream is empty by trying to decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
assert_eq!(decoded, "");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -48,15 +48,15 @@ mod tests {
|
||||
fn test_get_token() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get a token that should exist
|
||||
let eos_token = token_stream.get_token("<eos>");
|
||||
assert!(eos_token.is_some());
|
||||
|
||||
|
||||
// Get a token that shouldn't exist
|
||||
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
|
||||
assert!(nonexistent_token.is_none());
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -64,11 +64,14 @@ mod tests {
|
||||
fn test_next_token_and_decode() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
||||
let hello_tokens = token_stream
|
||||
.tokenizer()
|
||||
.encode("Hello world", true)
|
||||
.unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
|
||||
// Add tokens one by one
|
||||
let mut output = String::new();
|
||||
for &token_id in token_ids {
|
||||
@@ -76,16 +79,16 @@ mod tests {
|
||||
output.push_str(&text);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Get any remaining text
|
||||
if let Some(rest) = token_stream.decode_rest()? {
|
||||
output.push_str(&rest);
|
||||
}
|
||||
|
||||
|
||||
// Check the output
|
||||
assert!(!output.is_empty());
|
||||
assert_eq!(output.trim(), "Hello world");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -93,22 +96,25 @@ mod tests {
|
||||
fn test_decode_all() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
||||
let hello_tokens = token_stream
|
||||
.tokenizer()
|
||||
.encode("Hello world", true)
|
||||
.unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
|
||||
// Add tokens one by one
|
||||
for &token_id in token_ids {
|
||||
token_stream.next_token(token_id)?;
|
||||
}
|
||||
|
||||
|
||||
// Decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
|
||||
|
||||
// Check the output
|
||||
assert_eq!(decoded.trim(), "Hello world");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -116,14 +122,14 @@ mod tests {
|
||||
fn test_into_inner() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get the inner tokenizer
|
||||
let inner_tokenizer = token_stream.into_inner();
|
||||
|
||||
|
||||
// Check that the inner tokenizer works
|
||||
let encoded = inner_tokenizer.encode("Test", true).unwrap();
|
||||
assert!(encoded.get_ids().len() > 0);
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user