update structure to improve testability

This commit is contained in:
geoffsee
2025-06-05 22:03:04 -04:00
parent 7f641559ae
commit d483a53376
13 changed files with 750 additions and 18 deletions

View File

@@ -1,5 +1,5 @@
OPENAI_API_KEY="your-key-goes-here" OPENAI_API_KEY="your-key-goes-here"
OPENAI_API_BASE="http://localhost:3000/v1" OPENAI_API_BASE="http://localhost:3777/v1"
GENAISCRIPT_MODEL_LARGE="gemma-3-1b-it" GENAISCRIPT_MODEL_LARGE="gemma-3-1b-it"
GENAISCRIPT_MODEL_SMALL="gemma-3-1b-it" GENAISCRIPT_MODEL_SMALL="gemma-3-1b-it"
SEARXNG_API_BASE_URL="http://localhost:8080" SEARXNG_API_BASE_URL="http://localhost:8080"

View File

@@ -1962,7 +1962,7 @@ dependencies = [
name = "hyper-rustls" name = "hyper-rustls"
version = "0.27.6" version = "0.27.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d" checksum = "03a01595e11bdcec50946522c32dde3fc6914743777a68b93777965f2f02406d"
dependencies = [ dependencies = [
"http", "http",
"hyper", "hyper",
@@ -3834,7 +3834,7 @@ dependencies = [
name = "reborrow" name = "reborrow"
version = "0.5.5" version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" checksum = "03251193777f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"

View File

@@ -63,15 +63,15 @@ cargo run --release -- --prompt "Your prompt text here" --which 3-1b-it
Run the inference engine in server mode to expose an OpenAI-compatible API: Run the inference engine in server mode to expose an OpenAI-compatible API:
```bash ```bash
cargo run --release -- --server --port 3000 --which 3-1b-it cargo run --release -- --server --port 3777 --which 3-1b-it
``` ```
This starts a web server on the specified port (default: 3000) with an OpenAI-compatible chat completions endpoint. This starts a web server on the specified port (default: 3777) with an OpenAI-compatible chat completions endpoint.
#### Server Options #### Server Options
- `--server`: Run in server mode - `--server`: Run in server mode
- `--port <INT>`: Port to use for the server (default: 3000) - `--port <INT>`: Port to use for the server (default: 3777)
- `--which <MODEL>`: Model variant to use (default: "3-1b-it") - `--which <MODEL>`: Model variant to use (default: "3-1b-it")
- Other model options as described in CLI mode - Other model options as described in CLI mode
@@ -130,7 +130,7 @@ POST /v1/chat/completions
### Example: Using cURL ### Example: Using cURL
```bash ```bash
curl -X POST http://localhost:3000/v1/chat/completions \ curl -X POST http://localhost:3777/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "gemma-3-1b-it", "model": "gemma-3-1b-it",
@@ -148,7 +148,7 @@ curl -X POST http://localhost:3000/v1/chat/completions \
from openai import OpenAI from openai import OpenAI
client = OpenAI( client = OpenAI(
base_url="http://localhost:3000/v1", base_url="http://localhost:3777/v1",
api_key="dummy" # API key is not validated but required by the client api_key="dummy" # API key is not validated but required by the client
) )
@@ -170,7 +170,7 @@ print(response.choices[0].message.content)
import OpenAI from 'openai'; import OpenAI from 'openai';
const openai = new OpenAI({ const openai = new OpenAI({
baseURL: 'http://localhost:3000/v1', baseURL: 'http://localhost:3777/v1',
apiKey: 'dummy', // API key is not validated but required by the client apiKey: 'dummy', // API key is not validated but required by the client
}); });

View File

@@ -93,7 +93,7 @@
<div class="settings"> <div class="settings">
<div> <div>
<label for="serverUrl">Server URL:</label> <label for="serverUrl">Server URL:</label>
<input type="text" id="serverUrl" value="http://localhost:3000" /> <input type="text" id="serverUrl" value="http://localhost:3777" />
</div> </div>
<div> <div>
<label for="model">Model:</label> <label for="model">Model:</label>

View File

@@ -6,7 +6,7 @@
(async function testBasicChatCompletion() { (async function testBasicChatCompletion() {
console.log("Test 1: Basic chat completion request"); console.log("Test 1: Basic chat completion request");
try { try {
const response = await fetch('http://localhost:3000/v1/chat/completions', { const response = await fetch('http://localhost:3777/v1/chat/completions', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -34,7 +34,7 @@
(async function testMultiTurnConversation() { (async function testMultiTurnConversation() {
console.log("\nTest 2: Multi-turn conversation"); console.log("\nTest 2: Multi-turn conversation");
try { try {
const response = await fetch('http://localhost:3000/v1/chat/completions', { const response = await fetch('http://localhost:3777/v1/chat/completions', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -74,7 +74,7 @@
(async function testTemperatureAndTopP() { (async function testTemperatureAndTopP() {
console.log("\nTest 3: Request with temperature and top_p parameters"); console.log("\nTest 3: Request with temperature and top_p parameters");
try { try {
const response = await fetch('http://localhost:3000/v1/chat/completions', { const response = await fetch('http://localhost:3777/v1/chat/completions', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -104,7 +104,7 @@
(async function testStreaming() { (async function testStreaming() {
console.log("\nTest 4: Request with streaming enabled"); console.log("\nTest 4: Request with streaming enabled");
try { try {
const response = await fetch('http://localhost:3000/v1/chat/completions', { const response = await fetch('http://localhost:3777/v1/chat/completions', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -148,7 +148,7 @@
(async function testDifferentModel() { (async function testDifferentModel() {
console.log("\nTest 5: Request with a different model"); console.log("\nTest 5: Request with a different model");
try { try {
const response = await fetch('http://localhost:3000/v1/chat/completions', { const response = await fetch('http://localhost:3777/v1/chat/completions', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@@ -0,0 +1,72 @@
use clap::Parser;
use crate::model::Which;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
pub cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
pub tracing: bool,
/// Run in server mode with OpenAI compatible API
#[arg(long)]
pub server: bool,
/// Port to use for the server
#[arg(long, default_value_t = 3777)]
pub port: u16,
/// Prompt for text generation (not used in server mode)
#[arg(long)]
pub prompt: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
pub temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
pub top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
pub seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
pub sample_len: usize,
#[arg(long)]
pub model_id: Option<String>,
#[arg(long, default_value = "main")]
pub revision: String,
#[arg(long)]
pub tokenizer_file: Option<String>,
#[arg(long)]
pub config_file: Option<String>,
#[arg(long)]
pub weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
pub repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
pub repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "3-1b-it")]
pub which: Which,
#[arg(long)]
pub use_flash_attn: bool,
}

View File

@@ -652,7 +652,7 @@ struct Args {
server: bool, server: bool,
/// Port to use for the server /// Port to use for the server
#[arg(long, default_value_t = 3000)] #[arg(long, default_value_t = 3777)]
port: u16, port: u16,
/// Prompt for text generation (not used in server mode) /// Prompt for text generation (not used in server mode)

View File

@@ -0,0 +1,90 @@
use candle_core::Tensor;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
pub enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
impl Which {
pub fn to_model_id(&self) -> String {
match self {
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Self::Base2B => "google/gemma-2b".to_string(),
Self::Base7B => "google/gemma-7b".to_string(),
Self::Instruct2B => "google/gemma-2b-it".to_string(),
Self::Instruct7B => "google/gemma-7b-it".to_string(),
Self::CodeBase2B => "google/codegemma-2b".to_string(),
Self::CodeBase7B => "google/codegemma-7b".to_string(),
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Self::BaseV2_2B => "google/gemma-2-2b".to_string(),
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Self::BaseV2_9B => "google/gemma-2-9b".to_string(),
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
}
}
pub fn is_instruct_model(&self) -> bool {
match self {
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
_ => true,
}
}
pub fn is_v3_model(&self) -> bool {
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
}
}

View File

@@ -0,0 +1,167 @@
use either::Either;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use utoipa::ToSchema;
/// Inner content structure for messages that can be either a string or key-value pairs
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageInnerContent(
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
)
}
}
/// Function for MessageInnerContent Schema generation to handle `Either`
fn message_inner_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
// Either::Left - simple string
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
// Either::Right - object with string values
.item(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::T(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))))
.build(),
))
.build(),
)
}
/// Message content that can be either simple text or complex structured content
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
}
}
/// Function for MessageContent Schema generation to handle `Either`
fn message_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
.item(Schema::Array(
ArrayBuilder::new()
.items(RefOr::T(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::Ref(
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
)))
.build(),
)))
.build(),
))
.build(),
)
}
/// Represents a single message in a conversation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct Message {
/// The message content
pub content: Option<MessageContent>,
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
pub role: String,
pub name: Option<String>,
}
/// Stop token configuration for generation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(untagged)]
pub enum StopTokens {
/// Multiple possible stop sequences
Multi(Vec<String>),
/// Single stop sequence
Single(String),
}
/// Default value helper
pub fn default_false() -> bool {
false
}
/// Default value helper
pub fn default_1usize() -> usize {
1
}
/// Default value helper
pub fn default_model() -> String {
"default".to_string()
}
/// Chat completion request following OpenAI's specification
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ChatCompletionRequest {
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
pub messages: Vec<Message>,
#[schema(example = "gemma-3-1b-it")]
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_false")]
#[schema(example = false)]
pub logprobs: bool,
#[schema(example = 256)]
pub max_tokens: Option<usize>,
#[serde(rename = "n")]
#[serde(default = "default_1usize")]
#[schema(example = 1)]
pub n_choices: usize,
#[schema(example = 0.7)]
pub temperature: Option<f64>,
#[schema(example = 0.9)]
pub top_p: Option<f64>,
#[schema(example = false)]
pub stream: Option<bool>,
}
/// Chat completion choice
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionChoice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
/// Chat completion response
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
}
/// Token usage information
#[derive(Debug, Serialize, ToSchema)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}

View File

@@ -0,0 +1,126 @@
use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, Message, MessageContent, Usage};
use crate::text_generation::TextGeneration;
use either::Either;
// Application state shared between handlers
#[derive(Clone)]
pub struct AppState {
pub text_generation: Arc<Mutex<TextGeneration>>,
pub model_id: String,
}
// Chat completions endpoint handler
pub async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut prompt = String::new();
// Convert messages to a prompt string
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
// Format based on role
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
// Add the assistant prefix for the response
prompt.push_str("Assistant: ");
// Capture the output
let mut output = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Buffer to capture the output
let mut buffer = Vec::new();
// Run text generation
let max_tokens = request.max_tokens.unwrap_or(1000);
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"",
"type": "unsupported_api"
}
})),
));
}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
output.push(text);
}
}
// Create response
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace("-", "")),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: request.model,
choices: vec![ChatCompletionChoice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: Some(MessageContent(Either::Left(output.join("")))),
name: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: prompt.len() / 4, // Rough estimate
completion_tokens: output.join("").len() / 4, // Rough estimate
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
},
};
// Return the response as JSON
Ok(Json(response))
}
// Create the router with the chat completions endpoint
pub fn create_router(app_state: AppState) -> Router {
// CORS layer to allow requests from any origin
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// OpenAI compatible endpoints
.route("/v1/chat/completions", post(chat_completions))
// Add more endpoints as needed
.layer(cors)
.with_state(app_state)
}

View File

@@ -0,0 +1,277 @@
use anyhow::{Error as E, Result};
use candle_core::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use tokenizers::Tokenizer;
use std::io::Write;
use crate::model::Model;
use crate::token_output_stream::TokenOutputStream;
pub struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
pub fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
// Run text generation and print to stdout
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
println!(
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
);
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
// Run text generation and write to a buffer
pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Write prompt tokens to output
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
write!(output, "{}", t)?;
}
}
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
eos_token
}
};
// Determine if we're using a Model3 (gemma-3) variant
let is_model3 = match &self.model {
Model::V3(_) => true,
_ => false,
};
// For Model3, we need to use a different approach
if is_model3 {
// For gemma-3 models, we'll generate one token at a time with the full context
let start_gen = std::time::Instant::now();
// Initial generation with the full prompt
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let mut logits = self.model.forward(&input, 0)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
for _ in 0..sample_len {
// Apply repeat penalty if needed
let current_logits = if self.repeat_penalty == 1. {
logits.clone()
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&current_logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
// For the next iteration, just use the new token
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
logits = self.model.forward(&new_input, tokens.len() - 1)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
}
return Ok(());
}
// Standard approach for other models
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
}
// Write any remaining tokens
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
write!(output, "{}", rest)?;
}
Ok(())
}
}

View File

@@ -47,7 +47,7 @@ const {text} = await host.fetchText(new URL(url).toString());
// browser: getBrowser(), // browser: getBrowser(),
// headless: true, // headless: true,
// javaScriptEnabled: browser !== "chromium", // javaScriptEnabled: browser !== "chromium",
// // timeout: 3000, // // timeout: 3777,
// // bypassCSP: true, // // bypassCSP: true,
// // baseUrl: new URL(url).origin, // // baseUrl: new URL(url).origin,
// }); // });

View File

@@ -145,7 +145,7 @@ ui:
# Note: since commit af77ec3, morty accepts a base64 encoded key. # Note: since commit af77ec3, morty accepts a base64 encoded key.
# #
# result_proxy: # result_proxy:
# url: http://127.0.0.1:3000/ # url: http://127.0.0.1:3777/
# # the key is a base64 encoded string, the YAML !!binary prefix is optional # # the key is a base64 encoded string, the YAML !!binary prefix is optional
# key: !!binary "your_morty_proxy_key" # key: !!binary "your_morty_proxy_key"
# # [true|false] enable the "proxy" button next to each result # # [true|false] enable the "proxy" button next to each result