diff --git a/.env.example b/.env.example index 6542a91..44ba56c 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,5 @@ OPENAI_API_KEY="your-key-goes-here" -OPENAI_API_BASE="http://localhost:3000/v1" +OPENAI_API_BASE="http://localhost:3777/v1" GENAISCRIPT_MODEL_LARGE="gemma-3-1b-it" GENAISCRIPT_MODEL_SMALL="gemma-3-1b-it" SEARXNG_API_BASE_URL="http://localhost:8080" diff --git a/local_inference_engine/Cargo.lock b/local_inference_engine/Cargo.lock index 5da43a1..3476ea7 100644 --- a/local_inference_engine/Cargo.lock +++ b/local_inference_engine/Cargo.lock @@ -1962,7 +1962,7 @@ dependencies = [ name = "hyper-rustls" version = "0.27.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d" +checksum = "03a01595e11bdcec50946522c32dde3fc6914743777a68b93777965f2f02406d" dependencies = [ "http", "hyper", @@ -3834,7 +3834,7 @@ dependencies = [ name = "reborrow" version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" +checksum = "03251193777f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "redox_syscall" diff --git a/local_inference_engine/README.md b/local_inference_engine/README.md index edfad37..63b4cde 100644 --- a/local_inference_engine/README.md +++ b/local_inference_engine/README.md @@ -63,15 +63,15 @@ cargo run --release -- --prompt "Your prompt text here" --which 3-1b-it Run the inference engine in server mode to expose an OpenAI-compatible API: ```bash -cargo run --release -- --server --port 3000 --which 3-1b-it +cargo run --release -- --server --port 3777 --which 3-1b-it ``` -This starts a web server on the specified port (default: 3000) with an OpenAI-compatible chat completions endpoint. +This starts a web server on the specified port (default: 3777) with an OpenAI-compatible chat completions endpoint. #### Server Options - `--server`: Run in server mode -- `--port `: Port to use for the server (default: 3000) +- `--port `: Port to use for the server (default: 3777) - `--which `: Model variant to use (default: "3-1b-it") - Other model options as described in CLI mode @@ -130,7 +130,7 @@ POST /v1/chat/completions ### Example: Using cURL ```bash -curl -X POST http://localhost:3000/v1/chat/completions \ +curl -X POST http://localhost:3777/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "gemma-3-1b-it", @@ -148,7 +148,7 @@ curl -X POST http://localhost:3000/v1/chat/completions \ from openai import OpenAI client = OpenAI( - base_url="http://localhost:3000/v1", + base_url="http://localhost:3777/v1", api_key="dummy" # API key is not validated but required by the client ) @@ -170,7 +170,7 @@ print(response.choices[0].message.content) import OpenAI from 'openai'; const openai = new OpenAI({ - baseURL: 'http://localhost:3000/v1', + baseURL: 'http://localhost:3777/v1', apiKey: 'dummy', // API key is not validated but required by the client }); diff --git a/local_inference_engine/api_test.html b/local_inference_engine/api_test.html index 2f50e95..f4654a5 100644 --- a/local_inference_engine/api_test.html +++ b/local_inference_engine/api_test.html @@ -93,7 +93,7 @@
- +
diff --git a/local_inference_engine/openai_api_test.js b/local_inference_engine/openai_api_test.js index 465e312..016dabc 100644 --- a/local_inference_engine/openai_api_test.js +++ b/local_inference_engine/openai_api_test.js @@ -6,7 +6,7 @@ (async function testBasicChatCompletion() { console.log("Test 1: Basic chat completion request"); try { - const response = await fetch('http://localhost:3000/v1/chat/completions', { + const response = await fetch('http://localhost:3777/v1/chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -34,7 +34,7 @@ (async function testMultiTurnConversation() { console.log("\nTest 2: Multi-turn conversation"); try { - const response = await fetch('http://localhost:3000/v1/chat/completions', { + const response = await fetch('http://localhost:3777/v1/chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -74,7 +74,7 @@ (async function testTemperatureAndTopP() { console.log("\nTest 3: Request with temperature and top_p parameters"); try { - const response = await fetch('http://localhost:3000/v1/chat/completions', { + const response = await fetch('http://localhost:3777/v1/chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -104,7 +104,7 @@ (async function testStreaming() { console.log("\nTest 4: Request with streaming enabled"); try { - const response = await fetch('http://localhost:3000/v1/chat/completions', { + const response = await fetch('http://localhost:3777/v1/chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -148,7 +148,7 @@ (async function testDifferentModel() { console.log("\nTest 5: Request with a different model"); try { - const response = await fetch('http://localhost:3000/v1/chat/completions', { + const response = await fetch('http://localhost:3777/v1/chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/local_inference_engine/src/cli.rs b/local_inference_engine/src/cli.rs new file mode 100644 index 0000000..2758bc3 --- /dev/null +++ b/local_inference_engine/src/cli.rs @@ -0,0 +1,72 @@ +use clap::Parser; +use crate::model::Which; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + pub cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + pub tracing: bool, + + /// Run in server mode with OpenAI compatible API + #[arg(long)] + pub server: bool, + + /// Port to use for the server + #[arg(long, default_value_t = 3777)] + pub port: u16, + + /// Prompt for text generation (not used in server mode) + #[arg(long)] + pub prompt: Option, + + /// The temperature used to generate samples. + #[arg(long)] + pub temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + pub top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + pub seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 10000)] + pub sample_len: usize, + + #[arg(long)] + pub model_id: Option, + + #[arg(long, default_value = "main")] + pub revision: String, + + #[arg(long)] + pub tokenizer_file: Option, + + #[arg(long)] + pub config_file: Option, + + #[arg(long)] + pub weight_files: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + pub repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + pub repeat_last_n: usize, + + /// The model to use. + #[arg(long, default_value = "3-1b-it")] + pub which: Which, + + #[arg(long)] + pub use_flash_attn: bool, +} \ No newline at end of file diff --git a/local_inference_engine/src/main.rs b/local_inference_engine/src/main.rs index f29ff6d..bfbdb03 100644 --- a/local_inference_engine/src/main.rs +++ b/local_inference_engine/src/main.rs @@ -652,7 +652,7 @@ struct Args { server: bool, /// Port to use for the server - #[arg(long, default_value_t = 3000)] + #[arg(long, default_value_t = 3777)] port: u16, /// Prompt for text generation (not used in server mode) diff --git a/local_inference_engine/src/model.rs b/local_inference_engine/src/model.rs new file mode 100644 index 0000000..7a7944b --- /dev/null +++ b/local_inference_engine/src/model.rs @@ -0,0 +1,90 @@ +use candle_core::Tensor; +use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; +use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; +use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +pub enum Which { + #[value(name = "2b")] + Base2B, + #[value(name = "7b")] + Base7B, + #[value(name = "2b-it")] + Instruct2B, + #[value(name = "7b-it")] + Instruct7B, + #[value(name = "1.1-2b-it")] + InstructV1_1_2B, + #[value(name = "1.1-7b-it")] + InstructV1_1_7B, + #[value(name = "code-2b")] + CodeBase2B, + #[value(name = "code-7b")] + CodeBase7B, + #[value(name = "code-2b-it")] + CodeInstruct2B, + #[value(name = "code-7b-it")] + CodeInstruct7B, + #[value(name = "2-2b")] + BaseV2_2B, + #[value(name = "2-2b-it")] + InstructV2_2B, + #[value(name = "2-9b")] + BaseV2_9B, + #[value(name = "2-9b-it")] + InstructV2_9B, + #[value(name = "3-1b")] + BaseV3_1B, + #[value(name = "3-1b-it")] + InstructV3_1B, +} + +pub enum Model { + V1(Model1), + V2(Model2), + V3(Model3), +} + +impl Model { + pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result { + match self { + Self::V1(m) => m.forward(input_ids, pos), + Self::V2(m) => m.forward(input_ids, pos), + Self::V3(m) => m.forward(input_ids, pos), + } + } +} + +impl Which { + pub fn to_model_id(&self) -> String { + match self { + Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(), + Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(), + Self::Base2B => "google/gemma-2b".to_string(), + Self::Base7B => "google/gemma-7b".to_string(), + Self::Instruct2B => "google/gemma-2b-it".to_string(), + Self::Instruct7B => "google/gemma-7b-it".to_string(), + Self::CodeBase2B => "google/codegemma-2b".to_string(), + Self::CodeBase7B => "google/codegemma-7b".to_string(), + Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(), + Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(), + Self::BaseV2_2B => "google/gemma-2-2b".to_string(), + Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(), + Self::BaseV2_9B => "google/gemma-2-9b".to_string(), + Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(), + Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(), + Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(), + } + } + + pub fn is_instruct_model(&self) -> bool { + match self { + Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false, + _ => true, + } + } + + pub fn is_v3_model(&self) -> bool { + matches!(self, Self::BaseV3_1B | Self::InstructV3_1B) + } +} \ No newline at end of file diff --git a/local_inference_engine/src/openai_types.rs b/local_inference_engine/src/openai_types.rs new file mode 100644 index 0000000..d9a3689 --- /dev/null +++ b/local_inference_engine/src/openai_types.rs @@ -0,0 +1,167 @@ +use either::Either; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use utoipa::ToSchema; + +/// Inner content structure for messages that can be either a string or key-value pairs +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MessageInnerContent( + #[serde(with = "either::serde_untagged")] pub Either>, +); + +impl ToSchema<'_> for MessageInnerContent { + fn schema() -> (&'static str, utoipa::openapi::RefOr) { + ( + "MessageInnerContent", + utoipa::openapi::RefOr::T(message_inner_content_schema()), + ) + } +} + +/// Function for MessageInnerContent Schema generation to handle `Either` +fn message_inner_content_schema() -> utoipa::openapi::Schema { + use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType}; + + Schema::OneOf( + OneOfBuilder::new() + // Either::Left - simple string + .item(Schema::Object( + ObjectBuilder::new().schema_type(SchemaType::String).build(), + )) + // Either::Right - object with string values + .item(Schema::Object( + ObjectBuilder::new() + .schema_type(SchemaType::Object) + .additional_properties(Some(RefOr::T(Schema::Object( + ObjectBuilder::new().schema_type(SchemaType::String).build(), + )))) + .build(), + )) + .build(), + ) +} + +/// Message content that can be either simple text or complex structured content +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MessageContent( + #[serde(with = "either::serde_untagged")] + pub Either>>, +); + +impl ToSchema<'_> for MessageContent { + fn schema() -> (&'static str, utoipa::openapi::RefOr) { + ("MessageContent", utoipa::openapi::RefOr::T(message_content_schema())) + } +} + +/// Function for MessageContent Schema generation to handle `Either` +fn message_content_schema() -> utoipa::openapi::Schema { + use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType}; + + Schema::OneOf( + OneOfBuilder::new() + .item(Schema::Object( + ObjectBuilder::new().schema_type(SchemaType::String).build(), + )) + .item(Schema::Array( + ArrayBuilder::new() + .items(RefOr::T(Schema::Object( + ObjectBuilder::new() + .schema_type(SchemaType::Object) + .additional_properties(Some(RefOr::Ref( + utoipa::openapi::Ref::from_schema_name("MessageInnerContent"), + ))) + .build(), + ))) + .build(), + )) + .build(), + ) +} + +/// Represents a single message in a conversation +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +pub struct Message { + /// The message content + pub content: Option, + /// The role of the message sender ("user", "assistant", "system", "tool", etc.) + pub role: String, + pub name: Option, +} + +/// Stop token configuration for generation +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +#[serde(untagged)] +pub enum StopTokens { + /// Multiple possible stop sequences + Multi(Vec), + /// Single stop sequence + Single(String), +} + +/// Default value helper +pub fn default_false() -> bool { + false +} + +/// Default value helper +pub fn default_1usize() -> usize { + 1 +} + +/// Default value helper +pub fn default_model() -> String { + "default".to_string() +} + +/// Chat completion request following OpenAI's specification +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +pub struct ChatCompletionRequest { + #[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))] + pub messages: Vec, + #[schema(example = "gemma-3-1b-it")] + #[serde(default = "default_model")] + pub model: String, + #[serde(default = "default_false")] + #[schema(example = false)] + pub logprobs: bool, + #[schema(example = 256)] + pub max_tokens: Option, + #[serde(rename = "n")] + #[serde(default = "default_1usize")] + #[schema(example = 1)] + pub n_choices: usize, + #[schema(example = 0.7)] + pub temperature: Option, + #[schema(example = 0.9)] + pub top_p: Option, + #[schema(example = false)] + pub stream: Option, +} + +/// Chat completion choice +#[derive(Debug, Serialize, ToSchema)] +pub struct ChatCompletionChoice { + pub index: usize, + pub message: Message, + pub finish_reason: String, +} + +/// Chat completion response +#[derive(Debug, Serialize, ToSchema)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +/// Token usage information +#[derive(Debug, Serialize, ToSchema)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} \ No newline at end of file diff --git a/local_inference_engine/src/server.rs b/local_inference_engine/src/server.rs new file mode 100644 index 0000000..e1ee272 --- /dev/null +++ b/local_inference_engine/src/server.rs @@ -0,0 +1,126 @@ +use axum::{ + extract::State, + http::StatusCode, + routing::{get, post}, + Json, Router, +}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::sync::Mutex; +use tower_http::cors::{Any, CorsLayer}; +use uuid::Uuid; + +use crate::openai_types::{ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, Message, MessageContent, Usage}; +use crate::text_generation::TextGeneration; +use either::Either; + +// Application state shared between handlers +#[derive(Clone)] +pub struct AppState { + pub text_generation: Arc>, + pub model_id: String, +} + +// Chat completions endpoint handler +pub async fn chat_completions( + State(state): State, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + let mut prompt = String::new(); + + // Convert messages to a prompt string + for message in &request.messages { + let role = &message.role; + let content = match &message.content { + Some(content) => match &content.0 { + Either::Left(text) => text.clone(), + Either::Right(_) => "".to_string(), // Handle complex content if needed + }, + None => "".to_string(), + }; + + // Format based on role + match role.as_str() { + "system" => prompt.push_str(&format!("System: {}\n", content)), + "user" => prompt.push_str(&format!("User: {}\n", content)), + "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)), + _ => prompt.push_str(&format!("{}: {}\n", role, content)), + } + } + + // Add the assistant prefix for the response + prompt.push_str("Assistant: "); + + // Capture the output + let mut output = Vec::new(); + { + let mut text_gen = state.text_generation.lock().await; + + // Buffer to capture the output + let mut buffer = Vec::new(); + + // Run text generation + let max_tokens = request.max_tokens.unwrap_or(1000); + let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer); + + if let Err(e) = result { + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": { + "message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"", + "type": "unsupported_api" + } + })), + )); + } + + // Convert buffer to string + if let Ok(text) = String::from_utf8(buffer) { + output.push(text); + } + } + + // Create response + let response = ChatCompletionResponse { + id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace("-", "")), + object: "chat.completion".to_string(), + created: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + model: request.model, + choices: vec![ChatCompletionChoice { + index: 0, + message: Message { + role: "assistant".to_string(), + content: Some(MessageContent(Either::Left(output.join("")))), + name: None, + }, + finish_reason: "stop".to_string(), + }], + usage: Usage { + prompt_tokens: prompt.len() / 4, // Rough estimate + completion_tokens: output.join("").len() / 4, // Rough estimate + total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate + }, + }; + + // Return the response as JSON + Ok(Json(response)) +} + +// Create the router with the chat completions endpoint +pub fn create_router(app_state: AppState) -> Router { + // CORS layer to allow requests from any origin + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any); + + Router::new() + // OpenAI compatible endpoints + .route("/v1/chat/completions", post(chat_completions)) + // Add more endpoints as needed + .layer(cors) + .with_state(app_state) +} \ No newline at end of file diff --git a/local_inference_engine/src/text_generation.rs b/local_inference_engine/src/text_generation.rs new file mode 100644 index 0000000..db7787a --- /dev/null +++ b/local_inference_engine/src/text_generation.rs @@ -0,0 +1,277 @@ +use anyhow::{Error as E, Result}; +use candle_core::{DType, Device, Tensor}; +use candle_transformers::generation::LogitsProcessor; +use tokenizers::Tokenizer; +use std::io::Write; + +use crate::model::Model; +use crate::token_output_stream::TokenOutputStream; + +pub struct TextGeneration { + model: Model, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + pub fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + // Run text generation and print to stdout + pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("") { + Some(token) => token, + None => anyhow::bail!("cannot find the token"), + }; + + let eot_token = match self.tokenizer.get_token("") { + Some(token) => token, + None => { + println!( + "Warning: token not found in tokenizer, using as a backup" + ); + eos_token + } + }; + + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, start_pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + + // Manual implementation of repeat penalty to avoid type conflicts + let mut logits_vec = logits.to_vec1::()?; + + for &token_id in &tokens[start_at..] { + let token_id = token_id as usize; + if token_id < logits_vec.len() { + let score = logits_vec[token_id]; + let sign = if score < 0.0 { -1.0 } else { 1.0 }; + logits_vec[token_id] = sign * score / self.repeat_penalty; + } + } + + // Create a new tensor with the modified logits + let device = logits.device().clone(); + let shape = logits.shape().clone(); + let new_logits = Tensor::new(&logits_vec[..], &device)?; + new_logits.reshape(shape)? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token || next_token == eot_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } + + // Run text generation and write to a buffer + pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec) -> Result<()> { + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + // Write prompt tokens to output + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + write!(output, "{}", t)?; + } + } + + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("") { + Some(token) => token, + None => anyhow::bail!("cannot find the token"), + }; + + let eot_token = match self.tokenizer.get_token("") { + Some(token) => token, + None => { + write!(output, "Warning: token not found in tokenizer, using as a backup")?; + eos_token + } + }; + + // Determine if we're using a Model3 (gemma-3) variant + let is_model3 = match &self.model { + Model::V3(_) => true, + _ => false, + }; + + // For Model3, we need to use a different approach + if is_model3 { + // For gemma-3 models, we'll generate one token at a time with the full context + let start_gen = std::time::Instant::now(); + + // Initial generation with the full prompt + let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; + let mut logits = self.model.forward(&input, 0)?; + logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + + for _ in 0..sample_len { + // Apply repeat penalty if needed + let current_logits = if self.repeat_penalty == 1. { + logits.clone() + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + + // Manual implementation of repeat penalty to avoid type conflicts + let mut logits_vec = logits.to_vec1::()?; + + for &token_id in &tokens[start_at..] { + let token_id = token_id as usize; + if token_id < logits_vec.len() { + let score = logits_vec[token_id]; + let sign = if score < 0.0 { -1.0 } else { 1.0 }; + logits_vec[token_id] = sign * score / self.repeat_penalty; + } + } + + // Create a new tensor with the modified logits + let device = logits.device().clone(); + let shape = logits.shape().clone(); + let new_logits = Tensor::new(&logits_vec[..], &device)?; + new_logits.reshape(shape)? + }; + + let next_token = self.logits_processor.sample(¤t_logits)?; + tokens.push(next_token); + generated_tokens += 1; + + if next_token == eos_token || next_token == eot_token { + break; + } + + if let Some(t) = self.tokenizer.next_token(next_token)? { + write!(output, "{}", t)?; + } + + // For the next iteration, just use the new token + let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?; + logits = self.model.forward(&new_input, tokens.len() - 1)?; + logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + } + + return Ok(()); + } + + // Standard approach for other models + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, start_pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + + // Manual implementation of repeat penalty to avoid type conflicts + let mut logits_vec = logits.to_vec1::()?; + + for &token_id in &tokens[start_at..] { + let token_id = token_id as usize; + if token_id < logits_vec.len() { + let score = logits_vec[token_id]; + let sign = if score < 0.0 { -1.0 } else { 1.0 }; + logits_vec[token_id] = sign * score / self.repeat_penalty; + } + } + + // Create a new tensor with the modified logits + let device = logits.device().clone(); + let shape = logits.shape().clone(); + let new_logits = Tensor::new(&logits_vec[..], &device)?; + new_logits.reshape(shape)? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token || next_token == eot_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + write!(output, "{}", t)?; + } + } + + // Write any remaining tokens + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + write!(output, "{}", rest)?; + } + + Ok(()) + } +} \ No newline at end of file diff --git a/packages/genaiscript/genaisrc/web-scrape.genai.mts b/packages/genaiscript/genaisrc/web-scrape.genai.mts index 726c801..cb1cfdf 100644 --- a/packages/genaiscript/genaisrc/web-scrape.genai.mts +++ b/packages/genaiscript/genaisrc/web-scrape.genai.mts @@ -47,7 +47,7 @@ const {text} = await host.fetchText(new URL(url).toString()); // browser: getBrowser(), // headless: true, // javaScriptEnabled: browser !== "chromium", -// // timeout: 3000, +// // timeout: 3777, // // bypassCSP: true, // // baseUrl: new URL(url).origin, // }); diff --git a/searxng/settings.yml b/searxng/settings.yml index 1035159..1fadc6f 100644 --- a/searxng/settings.yml +++ b/searxng/settings.yml @@ -145,7 +145,7 @@ ui: # Note: since commit af77ec3, morty accepts a base64 encoded key. # # result_proxy: -# url: http://127.0.0.1:3000/ +# url: http://127.0.0.1:3777/ # # the key is a base64 encoded string, the YAML !!binary prefix is optional # key: !!binary "your_morty_proxy_key" # # [true|false] enable the "proxy" button next to each result