Files
open-web-agent-rs/crates/local_inference_engine/tests/text_generation_tests.rs
2025-06-05 22:42:17 -04:00

105 lines
3.4 KiB
Rust

use local_inference_engine::text_generation::TextGeneration;
use local_inference_engine::model::{Model, Which};
use local_inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
use candle_core::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use anyhow::Result;
use std::sync::Arc;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
// Test the Which enum's to_model_id method
#[test]
fn test_which_model_id() {
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
}
// Test the Which enum's is_instruct_model method
#[test]
fn test_which_is_instruct() {
assert!(!Which::Base2B.is_instruct_model());
assert!(Which::Instruct7B.is_instruct_model());
}
// Test the Which enum's is_v3_model method
#[test]
fn test_which_is_v3() {
assert!(!Which::Base2B.is_v3_model());
assert!(Which::BaseV3_1B.is_v3_model());
}
// Test the TokenOutputStream functionality
#[test]
fn test_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Test encoding and decoding
let text = "Hello, world!";
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
let token_ids = encoded.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all and check
let decoded = token_stream.decode_all()?;
assert_eq!(decoded.trim(), text);
Ok(())
}
// Test the LogitsProcessor
#[test]
fn test_logits_processor() -> Result<()> {
// Create a LogitsProcessor with default settings
let seed = 42;
let temp = Some(0.8);
let top_p = Some(0.9);
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
// Create a simple logits tensor
// In a real test, we would create a tensor with known values and verify
// that sampling produces expected results
// For now, we'll just verify that the LogitsProcessor can be created
assert!(true);
Ok(())
}
// Test the TextGeneration constructor
#[test]
fn test_text_generation_constructor() -> Result<()> {
// We can't easily create a Model instance for testing,
// but we can test that the constructor compiles and the types are correct
// In a real test with a mock Model, we would:
// 1. Create a mock model
// 2. Create a tokenizer
// 3. Call TextGeneration::new
// 4. Verify the properties of the created instance
// For now, we'll just verify that the code compiles
assert!(true);
Ok(())
}
// Note: Testing the actual text generation functionality would require
// integration tests with real models, which is beyond the scope of these unit tests.
// The tests above focus on the components that can be tested in isolation.
}