add tests for the local inference server

This commit is contained in:
geoffsee
2025-06-05 22:26:06 -04:00
parent d483a53376
commit 1270a6b0ba
6 changed files with 321 additions and 5 deletions

View File

@@ -562,9 +562,9 @@ checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b"
[[package]]
name = "bumpalo"
version = "3.17.0"
version = "3.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee"
[[package]]
name = "by_address"
@@ -1962,7 +1962,7 @@ dependencies = [
name = "hyper-rustls"
version = "0.27.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03a01595e11bdcec50946522c32dde3fc6914743777a68b93777965f2f02406d"
checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d"
dependencies = [
"http",
"hyper",
@@ -2518,6 +2518,7 @@ dependencies = [
"pyo3",
"rand 0.9.1",
"rayon",
"reborrow",
"rubato",
"safetensors",
"serde",
@@ -3834,7 +3835,7 @@ dependencies = [
name = "reborrow"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193777f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "redox_syscall"
@@ -4798,6 +4799,7 @@ dependencies = [
"derive_builder",
"esaxx-rs",
"getrandom 0.2.16",
"hf-hub",
"itertools 0.13.0",
"lazy_static",
"log",

View File

@@ -27,7 +27,7 @@ safetensors = { version = "0.4.1" }
serde = { version = "1.0.171", features = ["derive"] }
serde_json = { version = "1.0.99" }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { version = "0.21.0", default-features = false, features = ["onig"] }
tokenizers = { version = "0.21.0", default-features = false, features = ["onig", "http"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}
candle-core = { version = "=0.9.1", features = ["metal"] }
@@ -43,6 +43,7 @@ tokio = { version = "1.43.0", features = ["full"] }
either = { version = "1.9.0", features = ["serde"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5"
[dev-dependencies]

View File

@@ -0,0 +1,13 @@
// Expose modules for testing and library usage
pub mod token_output_stream;
pub mod model;
pub mod text_generation;
pub mod utilities_lib;
pub mod openai_types;
pub mod cli;
pub mod server;
// Re-export key components for easier access
pub use model::{Model, Which};
pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream;

View File

@@ -0,0 +1,67 @@
use local_inference_engine::model::{Model, Which};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_which_to_model_id() {
// Test a few representative model variants
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
}
#[test]
fn test_which_is_instruct_model() {
// Test base models (should return false)
assert!(!Which::Base2B.is_instruct_model());
assert!(!Which::Base7B.is_instruct_model());
assert!(!Which::CodeBase2B.is_instruct_model());
assert!(!Which::CodeBase7B.is_instruct_model());
assert!(!Which::BaseV2_2B.is_instruct_model());
assert!(!Which::BaseV2_9B.is_instruct_model());
assert!(!Which::BaseV3_1B.is_instruct_model());
// Test instruct models (should return true)
assert!(Which::Instruct2B.is_instruct_model());
assert!(Which::Instruct7B.is_instruct_model());
assert!(Which::InstructV1_1_2B.is_instruct_model());
assert!(Which::InstructV1_1_7B.is_instruct_model());
assert!(Which::CodeInstruct2B.is_instruct_model());
assert!(Which::CodeInstruct7B.is_instruct_model());
assert!(Which::InstructV2_2B.is_instruct_model());
assert!(Which::InstructV2_9B.is_instruct_model());
assert!(Which::InstructV3_1B.is_instruct_model());
}
#[test]
fn test_which_is_v3_model() {
// Test non-v3 models (should return false)
assert!(!Which::Base2B.is_v3_model());
assert!(!Which::Base7B.is_v3_model());
assert!(!Which::Instruct2B.is_v3_model());
assert!(!Which::Instruct7B.is_v3_model());
assert!(!Which::InstructV1_1_2B.is_v3_model());
assert!(!Which::InstructV1_1_7B.is_v3_model());
assert!(!Which::CodeBase2B.is_v3_model());
assert!(!Which::CodeBase7B.is_v3_model());
assert!(!Which::CodeInstruct2B.is_v3_model());
assert!(!Which::CodeInstruct7B.is_v3_model());
assert!(!Which::BaseV2_2B.is_v3_model());
assert!(!Which::InstructV2_2B.is_v3_model());
assert!(!Which::BaseV2_9B.is_v3_model());
assert!(!Which::InstructV2_9B.is_v3_model());
// Test v3 models (should return true)
assert!(Which::BaseV3_1B.is_v3_model());
assert!(Which::InstructV3_1B.is_v3_model());
}
// Note: Testing the Model enum's forward method would require creating actual model instances,
// which is complex and would require loading model weights. This is better suited for
// integration tests or mocking the models.
}

View File

@@ -0,0 +1,104 @@
use local_inference_engine::text_generation::TextGeneration;
use local_inference_engine::model::{Model, Which};
use local_inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
use candle_core::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use anyhow::Result;
use std::sync::Arc;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
// Test the Which enum's to_model_id method
#[test]
fn test_which_model_id() {
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
}
// Test the Which enum's is_instruct_model method
#[test]
fn test_which_is_instruct() {
assert!(!Which::Base2B.is_instruct_model());
assert!(Which::Instruct7B.is_instruct_model());
}
// Test the Which enum's is_v3_model method
#[test]
fn test_which_is_v3() {
assert!(!Which::Base2B.is_v3_model());
assert!(Which::BaseV3_1B.is_v3_model());
}
// Test the TokenOutputStream functionality
#[test]
fn test_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Test encoding and decoding
let text = "Hello, world!";
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
let token_ids = encoded.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all and check
let decoded = token_stream.decode_all()?;
assert_eq!(decoded.trim(), text);
Ok(())
}
// Test the LogitsProcessor
#[test]
fn test_logits_processor() -> Result<()> {
// Create a LogitsProcessor with default settings
let seed = 42;
let temp = Some(0.8);
let top_p = Some(0.9);
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
// Create a simple logits tensor
// In a real test, we would create a tensor with known values and verify
// that sampling produces expected results
// For now, we'll just verify that the LogitsProcessor can be created
assert!(true);
Ok(())
}
// Test the TextGeneration constructor
#[test]
fn test_text_generation_constructor() -> Result<()> {
// We can't easily create a Model instance for testing,
// but we can test that the constructor compiles and the types are correct
// In a real test with a mock Model, we would:
// 1. Create a mock model
// 2. Create a tokenizer
// 3. Call TextGeneration::new
// 4. Verify the properties of the created instance
// For now, we'll just verify that the code compiles
assert!(true);
Ok(())
}
// Note: Testing the actual text generation functionality would require
// integration tests with real models, which is beyond the scope of these unit tests.
// The tests above focus on the components that can be tested in isolation.
}

View File

@@ -0,0 +1,129 @@
use local_inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
use std::path::PathBuf;
use anyhow::Result;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
#[test]
fn test_new_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Check that the token stream was created successfully
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
Ok(())
}
#[test]
fn test_clear() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Add a token
let token_id = token_stream.get_token("<eos>").unwrap();
token_stream.next_token(token_id)?;
// Clear the stream
token_stream.clear();
// Check that the stream is empty by trying to decode all
let decoded = token_stream.decode_all()?;
assert_eq!(decoded, "");
Ok(())
}
#[test]
fn test_get_token() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get a token that should exist
let eos_token = token_stream.get_token("<eos>");
assert!(eos_token.is_some());
// Get a token that shouldn't exist
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
assert!(nonexistent_token.is_none());
Ok(())
}
#[test]
fn test_next_token_and_decode() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
let mut output = String::new();
for &token_id in token_ids {
if let Some(text) = token_stream.next_token(token_id)? {
output.push_str(&text);
}
}
// Get any remaining text
if let Some(rest) = token_stream.decode_rest()? {
output.push_str(&rest);
}
// Check the output
assert!(!output.is_empty());
assert_eq!(output.trim(), "Hello world");
Ok(())
}
#[test]
fn test_decode_all() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all
let decoded = token_stream.decode_all()?;
// Check the output
assert_eq!(decoded.trim(), "Hello world");
Ok(())
}
#[test]
fn test_into_inner() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get the inner tokenizer
let inner_tokenizer = token_stream.into_inner();
// Check that the inner tokenizer works
let encoded = inner_tokenizer.encode("Test", true).unwrap();
assert!(encoded.get_ids().len() > 0);
Ok(())
}
}