diff --git a/local_inference_engine/Cargo.lock b/local_inference_engine/Cargo.lock index 3476ea7..b028a21 100644 --- a/local_inference_engine/Cargo.lock +++ b/local_inference_engine/Cargo.lock @@ -562,9 +562,9 @@ checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b" [[package]] name = "bumpalo" -version = "3.17.0" +version = "3.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" [[package]] name = "by_address" @@ -1962,7 +1962,7 @@ dependencies = [ name = "hyper-rustls" version = "0.27.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a01595e11bdcec50946522c32dde3fc6914743777a68b93777965f2f02406d" +checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d" dependencies = [ "http", "hyper", @@ -2518,6 +2518,7 @@ dependencies = [ "pyo3", "rand 0.9.1", "rayon", + "reborrow", "rubato", "safetensors", "serde", @@ -3834,7 +3835,7 @@ dependencies = [ name = "reborrow" version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03251193777f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "redox_syscall" @@ -4798,6 +4799,7 @@ dependencies = [ "derive_builder", "esaxx-rs", "getrandom 0.2.16", + "hf-hub", "itertools 0.13.0", "lazy_static", "log", diff --git a/local_inference_engine/Cargo.toml b/local_inference_engine/Cargo.toml index 3d4f589..2f083cd 100644 --- a/local_inference_engine/Cargo.toml +++ b/local_inference_engine/Cargo.toml @@ -27,7 +27,7 @@ safetensors = { version = "0.4.1" } serde = { version = "1.0.171", features = ["derive"] } serde_json = { version = "1.0.99" } symphonia = { version = "0.5.3", features = ["all"], optional = true } -tokenizers = { version = "0.21.0", default-features = false, features = ["onig"] } +tokenizers = { version = "0.21.0", default-features = false, features = ["onig", "http"] } cpal = { version = "0.15.2", optional = true } pdf2image = { version = "0.1.2" , optional = true} candle-core = { version = "=0.9.1", features = ["metal"] } @@ -43,6 +43,7 @@ tokio = { version = "1.43.0", features = ["full"] } either = { version = "1.9.0", features = ["serde"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } uuid = { version = "1.7.0", features = ["v4"] } +reborrow = "0.5.5" [dev-dependencies] diff --git a/local_inference_engine/src/lib.rs b/local_inference_engine/src/lib.rs new file mode 100644 index 0000000..5769197 --- /dev/null +++ b/local_inference_engine/src/lib.rs @@ -0,0 +1,13 @@ +// Expose modules for testing and library usage +pub mod token_output_stream; +pub mod model; +pub mod text_generation; +pub mod utilities_lib; +pub mod openai_types; +pub mod cli; +pub mod server; + +// Re-export key components for easier access +pub use model::{Model, Which}; +pub use text_generation::TextGeneration; +pub use token_output_stream::TokenOutputStream; \ No newline at end of file diff --git a/local_inference_engine/tests/model_tests.rs b/local_inference_engine/tests/model_tests.rs new file mode 100644 index 0000000..3b7f528 --- /dev/null +++ b/local_inference_engine/tests/model_tests.rs @@ -0,0 +1,67 @@ +use local_inference_engine::model::{Model, Which}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_which_to_model_id() { + // Test a few representative model variants + assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b"); + assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it"); + assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it"); + assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b"); + assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b"); + assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it"); + } + + #[test] + fn test_which_is_instruct_model() { + // Test base models (should return false) + assert!(!Which::Base2B.is_instruct_model()); + assert!(!Which::Base7B.is_instruct_model()); + assert!(!Which::CodeBase2B.is_instruct_model()); + assert!(!Which::CodeBase7B.is_instruct_model()); + assert!(!Which::BaseV2_2B.is_instruct_model()); + assert!(!Which::BaseV2_9B.is_instruct_model()); + assert!(!Which::BaseV3_1B.is_instruct_model()); + + // Test instruct models (should return true) + assert!(Which::Instruct2B.is_instruct_model()); + assert!(Which::Instruct7B.is_instruct_model()); + assert!(Which::InstructV1_1_2B.is_instruct_model()); + assert!(Which::InstructV1_1_7B.is_instruct_model()); + assert!(Which::CodeInstruct2B.is_instruct_model()); + assert!(Which::CodeInstruct7B.is_instruct_model()); + assert!(Which::InstructV2_2B.is_instruct_model()); + assert!(Which::InstructV2_9B.is_instruct_model()); + assert!(Which::InstructV3_1B.is_instruct_model()); + } + + #[test] + fn test_which_is_v3_model() { + // Test non-v3 models (should return false) + assert!(!Which::Base2B.is_v3_model()); + assert!(!Which::Base7B.is_v3_model()); + assert!(!Which::Instruct2B.is_v3_model()); + assert!(!Which::Instruct7B.is_v3_model()); + assert!(!Which::InstructV1_1_2B.is_v3_model()); + assert!(!Which::InstructV1_1_7B.is_v3_model()); + assert!(!Which::CodeBase2B.is_v3_model()); + assert!(!Which::CodeBase7B.is_v3_model()); + assert!(!Which::CodeInstruct2B.is_v3_model()); + assert!(!Which::CodeInstruct7B.is_v3_model()); + assert!(!Which::BaseV2_2B.is_v3_model()); + assert!(!Which::InstructV2_2B.is_v3_model()); + assert!(!Which::BaseV2_9B.is_v3_model()); + assert!(!Which::InstructV2_9B.is_v3_model()); + + // Test v3 models (should return true) + assert!(Which::BaseV3_1B.is_v3_model()); + assert!(Which::InstructV3_1B.is_v3_model()); + } + + // Note: Testing the Model enum's forward method would require creating actual model instances, + // which is complex and would require loading model weights. This is better suited for + // integration tests or mocking the models. +} \ No newline at end of file diff --git a/local_inference_engine/tests/text_generation_tests.rs b/local_inference_engine/tests/text_generation_tests.rs new file mode 100644 index 0000000..7fb9cf9 --- /dev/null +++ b/local_inference_engine/tests/text_generation_tests.rs @@ -0,0 +1,104 @@ +use local_inference_engine::text_generation::TextGeneration; +use local_inference_engine::model::{Model, Which}; +use local_inference_engine::token_output_stream::TokenOutputStream; +use tokenizers::Tokenizer; +use candle_core::{DType, Device, Tensor}; +use candle_transformers::generation::LogitsProcessor; +use anyhow::Result; +use std::sync::Arc; + +#[cfg(test)] +mod tests { + use super::*; + + // Helper function to create a simple tokenizer for testing + fn create_test_tokenizer() -> Result { + // Create a simple tokenizer from the pretrained model + // This uses the tokenizer from the Hugging Face hub + let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap(); + Ok(tokenizer) + } + + // Test the Which enum's to_model_id method + #[test] + fn test_which_model_id() { + assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b"); + assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it"); + } + + // Test the Which enum's is_instruct_model method + #[test] + fn test_which_is_instruct() { + assert!(!Which::Base2B.is_instruct_model()); + assert!(Which::Instruct7B.is_instruct_model()); + } + + // Test the Which enum's is_v3_model method + #[test] + fn test_which_is_v3() { + assert!(!Which::Base2B.is_v3_model()); + assert!(Which::BaseV3_1B.is_v3_model()); + } + + // Test the TokenOutputStream functionality + #[test] + fn test_token_output_stream() -> Result<()> { + let tokenizer = create_test_tokenizer()?; + let mut token_stream = TokenOutputStream::new(tokenizer); + + // Test encoding and decoding + let text = "Hello, world!"; + let encoded = token_stream.tokenizer().encode(text, true).unwrap(); + let token_ids = encoded.get_ids(); + + // Add tokens one by one + for &token_id in token_ids { + token_stream.next_token(token_id)?; + } + + // Decode all and check + let decoded = token_stream.decode_all()?; + assert_eq!(decoded.trim(), text); + + Ok(()) + } + + // Test the LogitsProcessor + #[test] + fn test_logits_processor() -> Result<()> { + // Create a LogitsProcessor with default settings + let seed = 42; + let temp = Some(0.8); + let top_p = Some(0.9); + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + + // Create a simple logits tensor + // In a real test, we would create a tensor with known values and verify + // that sampling produces expected results + + // For now, we'll just verify that the LogitsProcessor can be created + assert!(true); + Ok(()) + } + + // Test the TextGeneration constructor + #[test] + fn test_text_generation_constructor() -> Result<()> { + // We can't easily create a Model instance for testing, + // but we can test that the constructor compiles and the types are correct + + // In a real test with a mock Model, we would: + // 1. Create a mock model + // 2. Create a tokenizer + // 3. Call TextGeneration::new + // 4. Verify the properties of the created instance + + // For now, we'll just verify that the code compiles + assert!(true); + Ok(()) + } + + // Note: Testing the actual text generation functionality would require + // integration tests with real models, which is beyond the scope of these unit tests. + // The tests above focus on the components that can be tested in isolation. +} diff --git a/local_inference_engine/tests/token_output_stream_tests.rs b/local_inference_engine/tests/token_output_stream_tests.rs new file mode 100644 index 0000000..5468ad5 --- /dev/null +++ b/local_inference_engine/tests/token_output_stream_tests.rs @@ -0,0 +1,129 @@ +use local_inference_engine::token_output_stream::TokenOutputStream; +use tokenizers::Tokenizer; +use std::path::PathBuf; +use anyhow::Result; + +#[cfg(test)] +mod tests { + use super::*; + + // Helper function to create a simple tokenizer for testing + fn create_test_tokenizer() -> Result { + // Create a simple tokenizer from the pretrained model + // This uses the tokenizer from the Hugging Face hub + let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap(); + Ok(tokenizer) + } + + #[test] + fn test_new_token_output_stream() -> Result<()> { + let tokenizer = create_test_tokenizer()?; + let token_stream = TokenOutputStream::new(tokenizer); + + // Check that the token stream was created successfully + assert!(token_stream.tokenizer().get_vocab(true).len() > 0); + Ok(()) + } + + #[test] + fn test_clear() -> Result<()> { + let tokenizer = create_test_tokenizer()?; + let mut token_stream = TokenOutputStream::new(tokenizer); + + // Add a token + let token_id = token_stream.get_token("").unwrap(); + token_stream.next_token(token_id)?; + + // Clear the stream + token_stream.clear(); + + // Check that the stream is empty by trying to decode all + let decoded = token_stream.decode_all()?; + assert_eq!(decoded, ""); + + Ok(()) + } + + #[test] + fn test_get_token() -> Result<()> { + let tokenizer = create_test_tokenizer()?; + let token_stream = TokenOutputStream::new(tokenizer); + + // Get a token that should exist + let eos_token = token_stream.get_token(""); + assert!(eos_token.is_some()); + + // Get a token that shouldn't exist + let nonexistent_token = token_stream.get_token(""); + assert!(nonexistent_token.is_none()); + + Ok(()) + } + + #[test] + fn test_next_token_and_decode() -> Result<()> { + let tokenizer = create_test_tokenizer()?; + let mut token_stream = TokenOutputStream::new(tokenizer); + + // Get some tokens + let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap(); + let token_ids = hello_tokens.get_ids(); + + // Add tokens one by one + let mut output = String::new(); + for &token_id in token_ids { + if let Some(text) = token_stream.next_token(token_id)? { + output.push_str(&text); + } + } + + // Get any remaining text + if let Some(rest) = token_stream.decode_rest()? { + output.push_str(&rest); + } + + // Check the output + assert!(!output.is_empty()); + assert_eq!(output.trim(), "Hello world"); + + Ok(()) + } + + #[test] + fn test_decode_all() -> Result<()> { + let tokenizer = create_test_tokenizer()?; + let mut token_stream = TokenOutputStream::new(tokenizer); + + // Get some tokens + let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap(); + let token_ids = hello_tokens.get_ids(); + + // Add tokens one by one + for &token_id in token_ids { + token_stream.next_token(token_id)?; + } + + // Decode all + let decoded = token_stream.decode_all()?; + + // Check the output + assert_eq!(decoded.trim(), "Hello world"); + + Ok(()) + } + + #[test] + fn test_into_inner() -> Result<()> { + let tokenizer = create_test_tokenizer()?; + let token_stream = TokenOutputStream::new(tokenizer); + + // Get the inner tokenizer + let inner_tokenizer = token_stream.into_inner(); + + // Check that the inner tokenizer works + let encoded = inner_tokenizer.encode("Test", true).unwrap(); + assert!(encoded.get_ids().len() > 0); + + Ok(()) + } +} \ No newline at end of file