mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use inference_engine::token_output_stream::TokenOutputStream;
|
||||
use tokenizers::Tokenizer;
|
||||
use std::path::PathBuf;
|
||||
use anyhow::Result;
|
||||
use inference_engine::token_output_stream::TokenOutputStream;
|
||||
use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -19,7 +19,7 @@ mod tests {
|
||||
fn test_new_token_output_stream() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Check that the token stream was created successfully
|
||||
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
|
||||
Ok(())
|
||||
@@ -29,18 +29,18 @@ mod tests {
|
||||
fn test_clear() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Add a token
|
||||
let token_id = token_stream.get_token("<eos>").unwrap();
|
||||
token_stream.next_token(token_id)?;
|
||||
|
||||
|
||||
// Clear the stream
|
||||
token_stream.clear();
|
||||
|
||||
|
||||
// Check that the stream is empty by trying to decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
assert_eq!(decoded, "");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -48,15 +48,15 @@ mod tests {
|
||||
fn test_get_token() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get a token that should exist
|
||||
let eos_token = token_stream.get_token("<eos>");
|
||||
assert!(eos_token.is_some());
|
||||
|
||||
|
||||
// Get a token that shouldn't exist
|
||||
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
|
||||
assert!(nonexistent_token.is_none());
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -64,11 +64,14 @@ mod tests {
|
||||
fn test_next_token_and_decode() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
||||
let hello_tokens = token_stream
|
||||
.tokenizer()
|
||||
.encode("Hello world", true)
|
||||
.unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
|
||||
// Add tokens one by one
|
||||
let mut output = String::new();
|
||||
for &token_id in token_ids {
|
||||
@@ -76,16 +79,16 @@ mod tests {
|
||||
output.push_str(&text);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Get any remaining text
|
||||
if let Some(rest) = token_stream.decode_rest()? {
|
||||
output.push_str(&rest);
|
||||
}
|
||||
|
||||
|
||||
// Check the output
|
||||
assert!(!output.is_empty());
|
||||
assert_eq!(output.trim(), "Hello world");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -93,22 +96,25 @@ mod tests {
|
||||
fn test_decode_all() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
||||
let hello_tokens = token_stream
|
||||
.tokenizer()
|
||||
.encode("Hello world", true)
|
||||
.unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
|
||||
// Add tokens one by one
|
||||
for &token_id in token_ids {
|
||||
token_stream.next_token(token_id)?;
|
||||
}
|
||||
|
||||
|
||||
// Decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
|
||||
|
||||
// Check the output
|
||||
assert_eq!(decoded.trim(), "Hello world");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -116,14 +122,14 @@ mod tests {
|
||||
fn test_into_inner() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get the inner tokenizer
|
||||
let inner_tokenizer = token_stream.into_inner();
|
||||
|
||||
|
||||
// Check that the inner tokenizer works
|
||||
let encoded = inner_tokenizer.encode("Test", true).unwrap();
|
||||
assert!(encoded.get_ids().len() > 0);
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user