update project structure
This commit is contained in:
36
crates/agent-server/Cargo.toml
Normal file
36
crates/agent-server/Cargo.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "agent-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
|
||||
[[bin]]
|
||||
edition = "2021"
|
||||
name = "agent-server"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.8", features = ["multipart"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
http = "1.1.0"
|
||||
tokio-stream = "0.1.16"
|
||||
uuid = { version = "1.11.0", features = ["v4"] }
|
||||
tokio-util = { version = "0.7", features = ["io"] }
|
||||
serde_json = "1.0.133"
|
||||
futures = "0.3.31"
|
||||
dotenv = "0.15.0"
|
||||
shell-escape = "0.1.5"
|
||||
rust-embed = "8.5.0"
|
||||
bytes = "1.8.0"
|
||||
lazy_static = "1.5.0"
|
||||
sled = "0.34.7"
|
||||
tower-http = { version = "0.6.2", features = ["trace", "cors"] }
|
||||
tower = "0.5.2"
|
||||
anyhow = "1.0.97"
|
||||
base64 = "0.22.1"
|
||||
fips204 = "0.4.6"
|
||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = ["server", "transport-streamable-http-server", "transport-sse-server", "transport-io",] }
|
||||
mime_guess = "2.0.5"
|
32
crates/agent-server/src/agents/deep_research.rs
Normal file
32
crates/agent-server/src/agents/deep_research.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
use tracing;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/deep-research.genai.mts", 60).await
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::agents::deep_research::agent;
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deepresearch() {
|
||||
// a really provocative question for research that generally yields infinite complexity with each run
|
||||
let input = "What is a life of meaning?";
|
||||
|
||||
let mut command = agent("test-deepresearch-agent", input).await.unwrap();
|
||||
|
||||
// let mut stdout = String::new();
|
||||
|
||||
// command.stdout.take().unwrap().read_to_string(&mut stdout).await.unwrap();
|
||||
|
||||
// println!("stdout: {}", stdout);
|
||||
// // Optionally, you can capture and inspect stdout if needed:
|
||||
let _output = command.wait_with_output().await.expect("Failed to wait for output");
|
||||
// println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
|
||||
// println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
|
||||
}
|
||||
}
|
10
crates/agent-server/src/agents/image_generator.rs
Normal file
10
crates/agent-server/src/agents/image_generator.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
tracing::debug!(
|
||||
"Running image generator, \ninput: {}",
|
||||
input
|
||||
);
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/image-generator.genai.mts", 10).await
|
||||
}
|
136
crates/agent-server/src/agents/mod.rs
Normal file
136
crates/agent-server/src/agents/mod.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
pub(crate) mod news;
|
||||
pub(crate) mod scrape;
|
||||
pub(crate) mod search;
|
||||
pub(crate) mod image_generator;
|
||||
pub(crate) mod deep_research;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::{
|
||||
Error as McpError, RoleServer, ServerHandler, const_string, model::*,
|
||||
service::RequestContext, tool,
|
||||
};
|
||||
use tokio::process::Child;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Agents;
|
||||
|
||||
#[tool(tool_box)]
|
||||
impl Agents {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
#[tool(description = "Search the web for information")]
|
||||
async fn search(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The search query")]
|
||||
query: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match search::agent("tool-search", &query).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Search for news articles")]
|
||||
async fn news(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The news search query")]
|
||||
query: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match news::agent("tool-news", &query).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Scrape content from a webpage")]
|
||||
async fn scrape(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The URL to scrape")]
|
||||
url: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match scrape::agent("tool-scrape", &url).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Generate an image based on a description")]
|
||||
async fn generate_image(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The image description")]
|
||||
description: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match image_generator::agent("tool-image", &description).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Perform deep research on a topic")]
|
||||
async fn deep_research(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The research topic")]
|
||||
topic: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match deep_research::agent("tool-research", &topic).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(tool_box)]
|
||||
impl ServerHandler for Agents {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_tools()
|
||||
.build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some("This server provides various agent tools for web search, news search, web scraping, image generation, and deep research.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
|
||||
let initialize_headers = &http_request_part.headers;
|
||||
let initialize_uri = &http_request_part.uri;
|
||||
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
|
||||
}
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_agent_result(mut child: Child) -> Result<CallToolResult, McpError> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
let output = match child.wait_with_output().await {
|
||||
Ok(output) => output,
|
||||
Err(e) => return Err(McpError::internal_error(format!("Failed to get agent output: {}", e), None)),
|
||||
};
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
return Err(McpError::internal_error(
|
||||
format!("Agent failed with status {}: {}", output.status, stderr),
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(stdout)]))
|
||||
}
|
6
crates/agent-server/src/agents/news.rs
Normal file
6
crates/agent-server/src/agents/news.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts", 10).await
|
||||
}
|
6
crates/agent-server/src/agents/scrape.rs
Normal file
6
crates/agent-server/src/agents/scrape.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts", 10).await
|
||||
}
|
27
crates/agent-server/src/agents/search.rs
Normal file
27
crates/agent-server/src/agents/search.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use tokio::process::Child;
|
||||
use tracing;
|
||||
use crate::utils::utils::run_agent;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts", 10).await
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fmt::Debug;
|
||||
use crate::agents::search::agent;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_execution() {
|
||||
let input = "Who won the 2024 presidential election?";
|
||||
|
||||
let mut command = agent("test-stream", input).await.unwrap();
|
||||
|
||||
// command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
|
||||
// Optionally, you can capture and inspect stdout if needed:
|
||||
let output = command.wait_with_output().await.expect("Failed to wait for output");
|
||||
println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
|
||||
println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
|
||||
}
|
||||
}
|
27
crates/agent-server/src/config.rs
Normal file
27
crates/agent-server/src/config.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
pub struct Runtime {
|
||||
pub env_vars: Vec<String>,
|
||||
}
|
||||
|
||||
|
||||
impl Runtime {
|
||||
pub fn configure() -> Self {
|
||||
// automatic configuration between local/docker environments
|
||||
match dotenv::dotenv() {
|
||||
Ok(_) => tracing::debug!("Loaded .env file successfully"),
|
||||
Err(e) => tracing::debug!("No .env file found or error loading it: {}", e),
|
||||
}
|
||||
|
||||
Self {
|
||||
env_vars: vec![
|
||||
"OPENAI_API_KEY".to_string(),
|
||||
"GENAISCRIPT_MODEL_LARGE".to_string(),
|
||||
"GENAISCRIPT_MODEL_SMALL".to_string(),
|
||||
"SEARXNG_API_BASE_URL".to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_env_var(&self, key: &str) -> String {
|
||||
std::env::var(key).unwrap_or_default()
|
||||
}
|
||||
}
|
211
crates/agent-server/src/counter.rs
Normal file
211
crates/agent-server/src/counter.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::{
|
||||
Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars,
|
||||
service::RequestContext, tool,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct StructRequest {
|
||||
pub a: i32,
|
||||
pub b: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Counter {
|
||||
counter: Arc<Mutex<i32>>,
|
||||
}
|
||||
|
||||
#[tool(tool_box)]
|
||||
impl Counter {
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: Arc::new(Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
|
||||
RawResource::new(uri, name.to_string()).no_annotation()
|
||||
}
|
||||
|
||||
#[tool(description = "Increment the counter by 1")]
|
||||
async fn increment(&self) -> Result<CallToolResult, McpError> {
|
||||
let mut counter = self.counter.lock().await;
|
||||
*counter += 1;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Decrement the counter by 1")]
|
||||
async fn decrement(&self) -> Result<CallToolResult, McpError> {
|
||||
let mut counter = self.counter.lock().await;
|
||||
*counter -= 1;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Get the current counter value")]
|
||||
async fn get_value(&self) -> Result<CallToolResult, McpError> {
|
||||
let counter = self.counter.lock().await;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Say hello to the client")]
|
||||
fn say_hello(&self) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text("hello")]))
|
||||
}
|
||||
|
||||
#[tool(description = "Repeat what you say")]
|
||||
fn echo(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "Repeat what you say")]
|
||||
saying: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(saying)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Calculate the sum of two numbers")]
|
||||
fn sum(
|
||||
&self,
|
||||
#[tool(aggr)] StructRequest { a, b }: StructRequest,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
(a + b).to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
const_string!(Echo = "echo");
|
||||
#[tool(tool_box)]
|
||||
impl ServerHandler for Counter {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_prompts()
|
||||
.enable_resources()
|
||||
.enable_tools()
|
||||
.build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resources(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListResourcesResult, McpError> {
|
||||
Ok(ListResourcesResult {
|
||||
resources: vec![
|
||||
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
|
||||
self._create_resource_text("memo://insights", "memo-name"),
|
||||
],
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ReadResourceResult, McpError> {
|
||||
match uri.as_str() {
|
||||
"str:////Users/to/some/path/" => {
|
||||
let cwd = "/Users/to/some/path/";
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::text(cwd, uri)],
|
||||
})
|
||||
}
|
||||
"memo://insights" => {
|
||||
let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::text(memo, uri)],
|
||||
})
|
||||
}
|
||||
_ => Err(McpError::resource_not_found(
|
||||
"resource_not_found",
|
||||
Some(json!({
|
||||
"uri": uri
|
||||
})),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_prompts(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListPromptsResult, McpError> {
|
||||
Ok(ListPromptsResult {
|
||||
next_cursor: None,
|
||||
prompts: vec![Prompt::new(
|
||||
"example_prompt",
|
||||
Some("This is an example prompt that takes one required argument, message"),
|
||||
Some(vec![PromptArgument {
|
||||
name: "message".to_string(),
|
||||
description: Some("A message to put in the prompt".to_string()),
|
||||
required: Some(true),
|
||||
}]),
|
||||
)],
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_prompt(
|
||||
&self,
|
||||
GetPromptRequestParam { name, arguments }: GetPromptRequestParam,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<GetPromptResult, McpError> {
|
||||
match name.as_str() {
|
||||
"example_prompt" => {
|
||||
let message = arguments
|
||||
.and_then(|json| json.get("message")?.as_str().map(|s| s.to_string()))
|
||||
.ok_or_else(|| {
|
||||
McpError::invalid_params("No message provided to example_prompt", None)
|
||||
})?;
|
||||
|
||||
let prompt =
|
||||
format!("This is an example prompt with your message here: '{message}'");
|
||||
Ok(GetPromptResult {
|
||||
description: None,
|
||||
messages: vec![PromptMessage {
|
||||
role: PromptMessageRole::User,
|
||||
content: PromptMessageContent::text(prompt),
|
||||
}],
|
||||
})
|
||||
}
|
||||
_ => Err(McpError::invalid_params("prompt not found", None)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListResourceTemplatesResult, McpError> {
|
||||
Ok(ListResourceTemplatesResult {
|
||||
next_cursor: None,
|
||||
resource_templates: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
|
||||
let initialize_headers = &http_request_part.headers;
|
||||
let initialize_uri = &http_request_part.uri;
|
||||
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
|
||||
}
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
262
crates/agent-server/src/handlers/agents.rs
Normal file
262
crates/agent-server/src/handlers/agents.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sled;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// init sled
|
||||
lazy_static! {
|
||||
static ref DB: Arc<Mutex<sled::Db>> = Arc::new(Mutex::new(
|
||||
sled::open("./open-web-agent-rs/db/stream_store").expect("Failed to open sled database")
|
||||
));
|
||||
}
|
||||
|
||||
pub async fn use_agent(Path(agent_id): Path<String>) -> impl IntoResponse {
|
||||
let db = DB.lock().await;
|
||||
match db.get(&agent_id) {
|
||||
Ok(Some(data)) => {
|
||||
let mut info: StreamInfo = match serde_json::from_slice(&data) {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Increment the call_count in the database
|
||||
info.call_count += 1;
|
||||
let updated_info_bytes = match serde_json::to_vec(&info) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize updated StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match db.insert(&agent_id, updated_info_bytes) {
|
||||
Ok(_) => {
|
||||
if let Err(e) = db.flush_async().await {
|
||||
tracing::error!(
|
||||
"Failed to persist updated call_count to the database: {}",
|
||||
e
|
||||
);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to update call_count in the database: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let info: StreamInfo = match db.get(&agent_id) {
|
||||
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize updated StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::error!("Stream ID not found after update: {}", agent_id);
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch updated record from DB: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if (info.call_count > 1) {
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
|
||||
let resource = info.resource;
|
||||
let input = serde_json::to_string(&info.payload.input).unwrap_or_default();
|
||||
|
||||
tracing::debug!(
|
||||
"Executing agent - Type: {}, Id: {}",
|
||||
resource,
|
||||
agent_id
|
||||
);
|
||||
|
||||
let cmd = match resource.as_str() {
|
||||
"web-search" => crate::agents::search::agent(agent_id.as_str(), &*input).await,
|
||||
"news-search" => crate::agents::news::agent(agent_id.as_str(), &*input).await,
|
||||
"image-generator" => {
|
||||
crate::agents::image_generator::agent(agent_id.as_str(), &*input).await
|
||||
}
|
||||
"deep-research" => {
|
||||
crate::agents::deep_research::agent(agent_id.as_str(), &*input).await
|
||||
}
|
||||
"web-scrape" => crate::agents::scrape::agent(agent_id.as_str(), &*input).await,
|
||||
_ => {
|
||||
tracing::error!("Unsupported resource type: {}", resource);
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut cmd = match cmd {
|
||||
Ok(cmd) => cmd,
|
||||
Err(e) => {
|
||||
tracing::error!("Agent execution failed: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => stdout,
|
||||
None => {
|
||||
tracing::error!("No stdout available for the command.");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = BufReader::new(stdout);
|
||||
let sse_stream = reader_to_stream(reader, agent_id.clone());
|
||||
|
||||
return Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache, no-transform")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("X-Accel-Buffering", "yes")
|
||||
.body(Body::from_stream(sse_stream))
|
||||
.unwrap();
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::error!("Stream ID not found: {}", agent_id);
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch from DB: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reader_to_stream<R>(
|
||||
reader: BufReader<R>,
|
||||
stream_id: String,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let stream = futures::stream::unfold(reader, move |mut reader| async move {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => None,
|
||||
Ok(_) => Some((
|
||||
Ok(Bytes::from(format!("data: {}\n\n", line.trim()))),
|
||||
reader,
|
||||
)),
|
||||
Err(e) => Some((Err(e), reader)),
|
||||
}
|
||||
});
|
||||
|
||||
let stream_with_done = stream.chain(futures::stream::once(async {
|
||||
Ok(Bytes::from("data: [DONE]\n\n"))
|
||||
}));
|
||||
|
||||
Box::pin(stream_with_done)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct Payload {
|
||||
input: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct StreamInfo {
|
||||
resource: String,
|
||||
payload: Payload,
|
||||
parent: String,
|
||||
call_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct WebhookPostRequest {
|
||||
id: String,
|
||||
resource: String,
|
||||
payload: Payload,
|
||||
parent: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct WebhookPostResponse {
|
||||
stream_url: String,
|
||||
}
|
||||
|
||||
pub async fn create_agent(Json(payload): Json<WebhookPostRequest>) -> impl IntoResponse {
|
||||
let db = DB.lock().await;
|
||||
|
||||
tracing::info!("Received webhook post request with ID: {}", payload.id);
|
||||
|
||||
let stream_id = payload.id.clone();
|
||||
let info = StreamInfo {
|
||||
resource: payload.resource.clone(),
|
||||
payload: payload.payload,
|
||||
parent: payload.parent.clone(),
|
||||
call_count: 0,
|
||||
};
|
||||
|
||||
let info_bytes = match serde_json::to_vec(&info) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Use atomic compare_and_swap operation
|
||||
match db.compare_and_swap(
|
||||
&stream_id,
|
||||
None as Option<&[u8]>,
|
||||
Some(info_bytes.as_slice()),
|
||||
) {
|
||||
Ok(_) => {
|
||||
// Force an immediate sync to disk
|
||||
match db.flush_async().await {
|
||||
Ok(_) => {
|
||||
// Verify the write by attempting to read it back
|
||||
match db.get(&stream_id) {
|
||||
Ok(Some(_)) => {
|
||||
let stream_url = format!("/agents/{}", stream_id);
|
||||
tracing::info!(
|
||||
"Successfully created and verified stream URL: {}",
|
||||
stream_url
|
||||
);
|
||||
Json(WebhookPostResponse { stream_url }).into_response()
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::error!("Failed to verify stream creation: {}", stream_id);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error verifying stream creation: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to flush DB: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to insert stream info: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
3
crates/agent-server/src/handlers/mod.rs
Normal file
3
crates/agent-server/src/handlers/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod not_found;
|
||||
pub mod ui;
|
||||
pub mod agents;
|
185
crates/agent-server/src/handlers/model_context.rs
Normal file
185
crates/agent-server/src/handlers/model_context.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
body::Body, extract::Json, http::StatusCode, response::IntoResponse,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::utils::utils::run_agent;
|
||||
|
||||
// Custom function to format streaming responses according to OpenAI API format
|
||||
pub fn openai_stream_format<R>(
|
||||
reader: BufReader<R>,
|
||||
request_id: String,
|
||||
model: String,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let stream = futures::stream::unfold((reader, 0), move |(mut reader, index)| {
|
||||
let request_id = request_id.clone();
|
||||
let model = model.clone();
|
||||
async move {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => None,
|
||||
Ok(_) => {
|
||||
let content = line.trim();
|
||||
// Skip empty lines
|
||||
if content.is_empty() {
|
||||
return Some((Ok(Bytes::from("")), (reader, index)));
|
||||
}
|
||||
|
||||
// Format as OpenAI API streaming response
|
||||
let chunk = serde_json::json!({
|
||||
"id": format!("chatcmpl-{}", request_id),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": index,
|
||||
"delta": {
|
||||
"content": content
|
||||
},
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
|
||||
Some((
|
||||
Ok(Bytes::from(format!("data: {}\n\n", chunk.to_string()))),
|
||||
(reader, index),
|
||||
))
|
||||
}
|
||||
Err(e) => Some((Err(e), (reader, index))),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add the [DONE] message at the end
|
||||
let stream_with_done = stream.filter(|result| {
|
||||
futures::future::ready(match result {
|
||||
Ok(bytes) => !bytes.is_empty(),
|
||||
Err(_) => true,
|
||||
})
|
||||
}).chain(futures::stream::once(async {
|
||||
Ok(Bytes::from("data: [DONE]\n\n"))
|
||||
}));
|
||||
|
||||
Box::pin(stream_with_done)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ModelContextRequest {
|
||||
messages: Vec<Message>,
|
||||
model: Option<String>,
|
||||
stream: Option<bool>,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ModelContextResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Choice {
|
||||
index: u32,
|
||||
message: Message,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
pub async fn model_context(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<ModelContextRequest>
|
||||
) -> impl IntoResponse {
|
||||
// Generate a unique ID for this request
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Convert messages to a format that can be passed to the agent
|
||||
let input = serde_json::to_string(&payload.messages).unwrap_or_default();
|
||||
|
||||
// Use the web-search agent for now, but this could be customized based on the model parameter
|
||||
let agent_file = "./packages/genaiscript/genaisrc/web-search.genai.mts";
|
||||
|
||||
tracing::debug!(
|
||||
"Executing model context request - Id: {}",
|
||||
request_id
|
||||
);
|
||||
|
||||
// Default timeout of 60 seconds
|
||||
let mut cmd = match run_agent(&request_id, &input, agent_file, 60).await {
|
||||
Ok(cmd) => cmd,
|
||||
Err(e) => {
|
||||
tracing::error!("Model context execution failed: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Check if streaming is requested either via the stream parameter or Accept header
|
||||
let accept_header = headers.get("accept").and_then(|h| h.to_str().ok()).unwrap_or("");
|
||||
let is_streaming = payload.stream.unwrap_or(false) || accept_header.contains("text/event-stream");
|
||||
|
||||
// If streaming is requested, return a streaming response
|
||||
if is_streaming {
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => stdout,
|
||||
None => {
|
||||
tracing::error!("No stdout available for the command.");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = BufReader::new(stdout);
|
||||
let model = payload.model.clone().unwrap_or_else(|| "default-model".to_string());
|
||||
let sse_stream = openai_stream_format(reader, request_id.clone(), model);
|
||||
|
||||
return Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache, no-transform")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("X-Accel-Buffering", "yes")
|
||||
.body(Body::from_stream(sse_stream))
|
||||
.unwrap();
|
||||
} else {
|
||||
// For non-streaming responses, we need to collect all output and return it as a single response
|
||||
// This is a simplified implementation and might need to be adjusted based on actual requirements
|
||||
let response = ModelContextResponse {
|
||||
id: format!("chatcmpl-{}", request_id),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: payload.model.unwrap_or_else(|| "default-model".to_string()),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "This is a placeholder response. The actual implementation would process the agent's output.".to_string(),
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
return Json(response).into_response();
|
||||
}
|
||||
}
|
48
crates/agent-server/src/handlers/models.rs
Normal file
48
crates/agent-server/src/handlers/models.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use axum::{
|
||||
extract::Json,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ModelsResponse {
|
||||
object: String,
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Model {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
owned_by: String,
|
||||
}
|
||||
|
||||
pub async fn list_models() -> impl IntoResponse {
|
||||
let current_time = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
// Create a response with a default model
|
||||
let response = ModelsResponse {
|
||||
object: "list".to_string(),
|
||||
data: vec![
|
||||
Model {
|
||||
id: "gpt-3.5-turbo".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: current_time,
|
||||
owned_by: "open-web-agent-rs".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gpt-4".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: current_time,
|
||||
owned_by: "open-web-agent-rs".to_string(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
Json(response)
|
||||
}
|
16
crates/agent-server/src/handlers/not_found.rs
Normal file
16
crates/agent-server/src/handlers/not_found.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
Json,
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
pub async fn handle_not_found() -> impl IntoResponse {
|
||||
tracing::warn!("404 Not Found error occurred");
|
||||
|
||||
let error_response = serde_json::json!({
|
||||
"error": "Route Not Found",
|
||||
"status": 404
|
||||
});
|
||||
|
||||
(StatusCode::NOT_FOUND, Json(error_response))
|
||||
}
|
0
crates/agent-server/src/handlers/ui.rs
Normal file
0
crates/agent-server/src/handlers/ui.rs
Normal file
37
crates/agent-server/src/main.rs
Normal file
37
crates/agent-server/src/main.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use crate::config::{Runtime};
|
||||
use crate::routes::create_router;
|
||||
use crate::setup::init_logging;
|
||||
|
||||
mod config;
|
||||
mod routes;
|
||||
mod setup;
|
||||
mod handlers;
|
||||
mod agents;
|
||||
mod utils;
|
||||
mod counter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
init_logging();
|
||||
|
||||
Runtime::configure();
|
||||
|
||||
let router = create_router();
|
||||
|
||||
let addr = "0.0.0.0:3006";
|
||||
tracing::info!("Attempting to bind server to {}", addr);
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => {
|
||||
tracing::info!("Successfully bound to {}", l.local_addr().unwrap());
|
||||
l
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to bind to {}: {}", addr, e);
|
||||
panic!("Server failed to start");
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!("Server starting on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, router.into_make_service()).await.unwrap();
|
||||
}
|
140
crates/agent-server/src/routes.rs
Normal file
140
crates/agent-server/src/routes.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use axum::response::Response;
|
||||
use crate::handlers::{not_found::handle_not_found};
|
||||
use axum::routing::{get, Router};
|
||||
use http::StatusCode;
|
||||
use tower_http::trace::{self, TraceLayer};
|
||||
use tracing::Level;
|
||||
|
||||
use rmcp::transport::streamable_http_server::{
|
||||
StreamableHttpService, session::local::LocalSessionManager,
|
||||
};
|
||||
use rust_embed::Embed;
|
||||
use crate::agents::Agents;
|
||||
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "../../node_modules/@modelcontextprotocol/inspector-client/dist"]
|
||||
struct Asset;
|
||||
|
||||
pub struct StaticFile<T>(pub T);
|
||||
|
||||
impl<T> axum::response::IntoResponse for StaticFile<T>
|
||||
where
|
||||
T: Into<String>,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
let path = self.0.into();
|
||||
|
||||
match Asset::get(path.as_str()) {
|
||||
Some(content) => {
|
||||
let mime = mime_guess::from_path(path).first_or_octet_stream();
|
||||
([(http::header::CONTENT_TYPE, mime.as_ref())], content.data).into_response()
|
||||
}
|
||||
None => (StatusCode::NOT_FOUND, "404 Not Found").into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn ui_index_handler() -> impl axum::response::IntoResponse {
|
||||
StaticFile("index.html")
|
||||
}
|
||||
|
||||
async fn static_handler(uri: http::Uri) -> impl axum::response::IntoResponse {
|
||||
let path = uri.path().trim_start_matches("/").to_string();
|
||||
StaticFile(path)
|
||||
}
|
||||
|
||||
pub fn create_router() -> Router {
|
||||
|
||||
let mcp_service = StreamableHttpService::new(
|
||||
Agents::new,
|
||||
LocalSessionManager::default().into(),
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
Router::new()
|
||||
.nest_service("/mcp", mcp_service)
|
||||
.route("/health", get(health))
|
||||
.route("/", get(ui_index_handler))
|
||||
.route("/index.html", get(ui_index_handler))
|
||||
.route("/{*path}", get(static_handler))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
|
||||
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
|
||||
)
|
||||
.layer(tower_http::cors::CorsLayer::very_permissive())
|
||||
.fallback(handle_not_found)
|
||||
}
|
||||
|
||||
async fn health() -> String {
|
||||
return "ok".to_string();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::{Body, Bytes};
|
||||
use axum::http::{Request, StatusCode};
|
||||
use axum::response::Response;
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_endpoint() {
|
||||
// Call the health function directly
|
||||
let response = health().await;
|
||||
assert_eq!(response, "ok".to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_route() {
|
||||
// Create the router
|
||||
let app = create_router();
|
||||
|
||||
// Create a request to the health endpoint
|
||||
let request = Request::builder()
|
||||
.uri("/health")
|
||||
.method("GET")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
// Process the request
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
// Check the response status
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Check the response body
|
||||
let body = response_body_bytes(response).await;
|
||||
assert_eq!(&body[..], b"ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_not_found_route() {
|
||||
// Create the router
|
||||
let app = create_router();
|
||||
|
||||
// Create a request to a non-existent endpoint
|
||||
let request = Request::builder()
|
||||
.uri("/non-existent")
|
||||
.method("GET")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
// Process the request
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
// Check the response status
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
// Helper function to extract bytes from a response body
|
||||
async fn response_body_bytes(response: Response) -> Bytes {
|
||||
let body = response.into_body();
|
||||
// Use a reasonable size limit for the body (16MB)
|
||||
let bytes = axum::body::to_bytes(body, 16 * 1024 * 1024)
|
||||
.await
|
||||
.expect("Failed to read response body");
|
||||
bytes
|
||||
}
|
||||
}
|
10
crates/agent-server/src/setup.rs
Normal file
10
crates/agent-server/src/setup.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
// src/setup.rs
|
||||
pub fn init_logging() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_target(true)
|
||||
.with_thread_ids(true)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.init();
|
||||
}
|
65
crates/agent-server/src/utils/base64.rs
Normal file
65
crates/agent-server/src/utils/base64.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::GeneralPurpose;
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::engine::general_purpose::STANDARD_NO_PAD;
|
||||
|
||||
pub struct Base64Encoder {
|
||||
payload_engine: GeneralPurpose,
|
||||
signature_engine: GeneralPurpose,
|
||||
public_key_engine: GeneralPurpose,
|
||||
secret_key_engine: GeneralPurpose,
|
||||
}
|
||||
|
||||
impl Base64Encoder {
|
||||
pub(crate) fn b64_encode(&self, p0: &[u8]) -> String {
|
||||
self.payload_engine.encode(p0)
|
||||
}
|
||||
pub(crate) fn b64_decode(&self, p0: String) -> Result<Vec<u8>, base64::DecodeError> {
|
||||
self.payload_engine.decode(p0)
|
||||
}
|
||||
}
|
||||
|
||||
pub const B64_ENCODER: &Base64Encoder = &Base64Encoder::new();
|
||||
|
||||
impl Base64Encoder {
|
||||
pub const fn new() -> Self { // Made new() a const fn
|
||||
Base64Encoder {
|
||||
payload_engine: STANDARD,
|
||||
signature_engine: STANDARD,
|
||||
public_key_engine: STANDARD,
|
||||
secret_key_engine: STANDARD,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b64_encode_payload<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.payload_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_payload<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.payload_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_signature<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.signature_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_signature<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.signature_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_public_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.public_key_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_public_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.public_key_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.secret_key_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.secret_key_engine.decode(input)
|
||||
}
|
||||
}
|
2
crates/agent-server/src/utils/mod.rs
Normal file
2
crates/agent-server/src/utils/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod utils;
|
||||
pub mod base64;
|
71
crates/agent-server/src/utils/utils.rs
Normal file
71
crates/agent-server/src/utils/utils.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
// utils.rs
|
||||
use tokio::process::{Child, Command}; // Use tokio::process::Child and Command
|
||||
use std::env;
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tracing;
|
||||
|
||||
|
||||
pub struct ShimBinding {
|
||||
user_input: String,
|
||||
file_path: String,
|
||||
openai_api_key: String,
|
||||
openai_api_base: String,
|
||||
genaiscript_model_large: String,
|
||||
genaiscript_model_small: String,
|
||||
searxng_api_base_url: String,
|
||||
searxng_password: String,
|
||||
}
|
||||
|
||||
impl ShimBinding {
|
||||
pub fn new(user_input: String, file_path: String) -> Self {
|
||||
Self {
|
||||
user_input,
|
||||
file_path, // Initialize the new field
|
||||
openai_api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
|
||||
openai_api_base: env::var("OPENAI_API_BASE").unwrap_or_default(),
|
||||
genaiscript_model_large: env::var("GENAISCRIPT_MODEL_LARGE").unwrap_or_default(),
|
||||
genaiscript_model_small: env::var("GENAISCRIPT_MODEL_SMALL").unwrap_or_default(),
|
||||
searxng_api_base_url: env::var("SEARXNG_API_BASE_URL").unwrap_or_default(),
|
||||
searxng_password: env::var("SEARXNG_PASSWORD").unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute(&self) -> std::io::Result<Child> {
|
||||
let mut command = Command::new("./dist/genaiscript-rust-shim.js");
|
||||
command
|
||||
.arg("--file")
|
||||
.arg(&self.file_path) // Use the file_path field instead of hardcoded value
|
||||
.arg(format!("USER_INPUT={}", self.user_input))
|
||||
.env("OPENAI_API_KEY", &self.openai_api_key)
|
||||
.env("OPENAI_API_BASE", &self.openai_api_base)
|
||||
.env("GENAISCRIPT_MODEL_LARGE", &self.genaiscript_model_large)
|
||||
.env("GENAISCRIPT_MODEL_SMALL", &self.genaiscript_model_small)
|
||||
.env("SEARXNG_API_BASE_URL", &self.searxng_api_base_url)
|
||||
.env("SEARXNG_PASSWORD", &self.searxng_password)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit());
|
||||
|
||||
command.spawn()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// wrapper executes an agent with a timeout
|
||||
pub async fn run_agent(stream_id: &str, input: &str, file_path: &str, timeout_seconds: u64 ) -> Result<Child, String> {
|
||||
tracing::debug!("Initiating agent for stream {} with file path {}", stream_id, file_path);
|
||||
|
||||
let shim_binding = ShimBinding::new(input.to_string(), file_path.to_string());
|
||||
let spawn_future = async move {
|
||||
match shim_binding.execute() {
|
||||
Ok(child) => Ok(child),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to spawn shim process: {}", e);
|
||||
Err(e.to_string())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
timeout(Duration::from_secs(timeout_seconds), spawn_future)
|
||||
.await
|
||||
.unwrap_or_else(|_| Err("Command timed out after 10 seconds".to_string()))
|
||||
}
|
6115
crates/local_inference_engine/Cargo.lock
generated
Normal file
6115
crates/local_inference_engine/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
65
crates/local_inference_engine/Cargo.toml
Normal file
65
crates/local_inference_engine/Cargo.toml
Normal file
@@ -0,0 +1,65 @@
|
||||
[package]
|
||||
name = "local_inference_engine"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { version = "0.3.2", optional = true }
|
||||
candle-datasets = { version = "=0.9.1", optional = true }
|
||||
candle-nn = { version = "=0.9.1", features = ["metal"] }
|
||||
candle-transformers = { version = "=0.9.1", features = ["metal"] }
|
||||
candle-flash-attn = { version = "=0.9.1", optional = true }
|
||||
candle-onnx = { version = "=0.9.1", optional = true }
|
||||
|
||||
csv = "1.3.0"
|
||||
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true }
|
||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"], optional = true }
|
||||
hf-hub = { version = "0.4.1", features = ["tokio"] }
|
||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true }
|
||||
num-traits = { version = "0.2.15" }
|
||||
palette = { version = "0.7.6", optional = true }
|
||||
enterpolation = { version = "0.2.1", optional = true}
|
||||
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
|
||||
rayon = { version = "1.7.0" }
|
||||
rubato = { version = "0.15.0", optional = true }
|
||||
safetensors = { version = "0.4.1" }
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_json = { version = "1.0.99" }
|
||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||
tokenizers = { version = "0.21.0", default-features = false, features = ["onig", "http"] }
|
||||
cpal = { version = "0.15.2", optional = true }
|
||||
pdf2image = { version = "0.1.2" , optional = true}
|
||||
candle-core = { version = "=0.9.1", features = ["metal"] }
|
||||
anyhow = "1.0.98"
|
||||
clap= { version = "4.2.4", features = ["derive"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
axum = { version = "0.7.4", features = ["json"] }
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tokio = { version = "1.43.0", features = ["full"] }
|
||||
either = { version = "1.9.0", features = ["serde"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
reborrow = "0.5.5"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = { version = "1.4.3" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
imageproc = { version = "0.24.0", default-features = false }
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
rand = { version = "0.9.0" }
|
||||
ab_glyph = { version = "0.2.23" }
|
||||
tracing = { version = "0.1.37" }
|
||||
tracing-chrome = { version = "0.7.1" }
|
||||
tracing-subscriber = { version = "0.3.7" }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.43.0"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
206
crates/local_inference_engine/README.md
Normal file
206
crates/local_inference_engine/README.md
Normal file
@@ -0,0 +1,206 @@
|
||||
# Local Inference Engine
|
||||
|
||||
A Rust-based inference engine for running large language models locally. This tool supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
|
||||
|
||||
## Features
|
||||
|
||||
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
|
||||
- CLI mode for direct text generation
|
||||
- Server mode with OpenAI-compatible API
|
||||
- Support for various model configurations (base, instruction-tuned)
|
||||
- Metal acceleration on macOS
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust toolchain (install via [rustup](https://rustup.rs/))
|
||||
- Cargo package manager
|
||||
- For GPU acceleration:
|
||||
- macOS: Metal support
|
||||
- Linux/Windows: CUDA support (requires appropriate drivers)
|
||||
|
||||
### Building from Source
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/seemueller-io/open-web-agent-rs.git
|
||||
cd open-web-agent-rs
|
||||
```
|
||||
|
||||
2. Build the local inference engine:
|
||||
```bash
|
||||
cargo build -p local_inference_engine --release
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### CLI Mode
|
||||
|
||||
Run the inference engine in CLI mode to generate text directly:
|
||||
|
||||
```bash
|
||||
cargo run -p local_inference_engine --release -- --prompt "Your prompt text here" --which 3-1b-it
|
||||
```
|
||||
|
||||
#### CLI Options
|
||||
|
||||
- `--prompt <TEXT>`: The prompt text to generate from
|
||||
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
||||
- `--server`: Run OpenAI compatible server
|
||||
- Available options: "2b", "7b", "2b-it", "7b-it", "1.1-2b-it", "1.1-7b-it", "code-2b", "code-7b", "code-2b-it", "code-7b-it", "2-2b", "2-2b-it", "2-9b", "2-9b-it", "3-1b", "3-1b-it"
|
||||
- `--temperature <FLOAT>`: Temperature for sampling (higher = more random)
|
||||
- `--top-p <FLOAT>`: Nucleus sampling probability cutoff
|
||||
- `--sample-len <INT>`: Maximum number of tokens to generate (default: 10000)
|
||||
- `--repeat-penalty <FLOAT>`: Penalty for repeating tokens (default: 1.1)
|
||||
- `--repeat-last-n <INT>`: Context size for repeat penalty (default: 64)
|
||||
- `--cpu`: Run on CPU instead of GPU
|
||||
- `--tracing`: Enable tracing (generates a trace-timestamp.json file)
|
||||
|
||||
### Server Mode with OpenAI-compatible API
|
||||
|
||||
Run the inference engine in server mode to expose an OpenAI-compatible API:
|
||||
|
||||
```bash
|
||||
cargo run -p local_inference_engine --release -- --server --port 3777 --which 3-1b-it
|
||||
```
|
||||
|
||||
This starts a web server on the specified port (default: 3777) with an OpenAI-compatible chat completions endpoint.
|
||||
|
||||
#### Server Options
|
||||
|
||||
- `--server`: Run in server mode
|
||||
- `--port <INT>`: Port to use for the server (default: 3777)
|
||||
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
||||
- Other model options as described in CLI mode
|
||||
|
||||
## API Usage
|
||||
|
||||
The server exposes an OpenAI-compatible chat completions endpoint:
|
||||
|
||||
### Chat Completions
|
||||
|
||||
```
|
||||
POST /v1/chat/completions
|
||||
```
|
||||
|
||||
#### Request Format
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemma-3-1b-it",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 256,
|
||||
"top_p": 0.9,
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
#### Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-123abc456def789ghi",
|
||||
"object": "chat.completion",
|
||||
"created": 1677858242,
|
||||
"model": "gemma-3-1b-it",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well, thank you for asking! How can I assist you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 25,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 40
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Using cURL
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:3777/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gemma-3-1b-it",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the capital of France?"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Using Python with OpenAI Client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:3777/v1",
|
||||
api_key="dummy" # API key is not validated but required by the client
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gemma-3-1b-it",
|
||||
messages=[
|
||||
{"role": "user", "content": "What is the capital of France?"}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
### Example: Using JavaScript/TypeScript with OpenAI SDK
|
||||
|
||||
```javascript
|
||||
import OpenAI from 'openai';
|
||||
|
||||
const openai = new OpenAI({
|
||||
baseURL: 'http://localhost:3777/v1',
|
||||
apiKey: 'dummy', // API key is not validated but required by the client
|
||||
});
|
||||
|
||||
async function main() {
|
||||
const response = await openai.chat.completions.create({
|
||||
model: 'gemma-3-1b-it',
|
||||
messages: [
|
||||
{ role: 'user', content: 'What is the capital of France?' }
|
||||
],
|
||||
temperature: 0.7,
|
||||
max_tokens: 100,
|
||||
});
|
||||
|
||||
console.log(response.choices[0].message.content);
|
||||
}
|
||||
|
||||
main();
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Model download errors**: Make sure you have a stable internet connection. The models are downloaded from Hugging Face Hub.
|
||||
|
||||
2. **Out of memory errors**: Try using a smaller model variant or reducing the batch size.
|
||||
|
||||
3. **Slow inference on CPU**: This is expected. For better performance, use GPU acceleration if available.
|
||||
|
||||
4. **Metal/CUDA errors**: Ensure you have the latest drivers installed for your GPU.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the terms specified in the LICENSE file.
|
295
crates/local_inference_engine/api_test.html
Normal file
295
crates/local_inference_engine/api_test.html
Normal file
@@ -0,0 +1,295 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>OpenAI-Compatible API Tester</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
h1, h2 {
|
||||
color: #333;
|
||||
}
|
||||
.container {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
textarea {
|
||||
width: 100%;
|
||||
height: 150px;
|
||||
padding: 10px;
|
||||
margin-bottom: 10px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
}
|
||||
button {
|
||||
background-color: #4CAF50;
|
||||
color: white;
|
||||
padding: 10px 15px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
}
|
||||
button:hover {
|
||||
background-color: #45a049;
|
||||
}
|
||||
pre {
|
||||
background-color: #f5f5f5;
|
||||
padding: 15px;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
.response {
|
||||
margin-top: 20px;
|
||||
}
|
||||
.error {
|
||||
color: red;
|
||||
}
|
||||
.settings {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.settings div {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
label {
|
||||
margin-bottom: 5px;
|
||||
font-weight: bold;
|
||||
}
|
||||
input {
|
||||
padding: 8px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.examples {
|
||||
margin-top: 30px;
|
||||
}
|
||||
.example-btn {
|
||||
background-color: #2196F3;
|
||||
margin-right: 10px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.example-btn:hover {
|
||||
background-color: #0b7dda;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>OpenAI-Compatible API Tester</h1>
|
||||
<p>Use this page to test the OpenAI-compatible chat completions endpoint of the local inference engine.</p>
|
||||
|
||||
<div class="container">
|
||||
<h2>Request Settings</h2>
|
||||
<div class="settings">
|
||||
<div>
|
||||
<label for="serverUrl">Server URL:</label>
|
||||
<input type="text" id="serverUrl" value="http://localhost:3777" />
|
||||
</div>
|
||||
<div>
|
||||
<label for="model">Model:</label>
|
||||
<input type="text" id="model" value="gemma-3-1b-it" />
|
||||
</div>
|
||||
<div>
|
||||
<label for="maxTokens">Max Tokens:</label>
|
||||
<input type="number" id="maxTokens" value="150" />
|
||||
</div>
|
||||
<div>
|
||||
<label for="temperature">Temperature:</label>
|
||||
<input type="number" id="temperature" value="0.7" step="0.1" min="0" max="2" />
|
||||
</div>
|
||||
<div>
|
||||
<label for="topP">Top P:</label>
|
||||
<input type="number" id="topP" value="0.9" step="0.1" min="0" max="1" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h2>Request Body</h2>
|
||||
<textarea id="requestBody">{
|
||||
"model": "gemma-3-1b-it",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you today?"
|
||||
}
|
||||
],
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9
|
||||
}</textarea>
|
||||
<button id="sendRequest">Send Request</button>
|
||||
|
||||
<div class="examples">
|
||||
<h3>Example Requests</h3>
|
||||
<button class="example-btn" id="example1">Basic Question</button>
|
||||
<button class="example-btn" id="example2">Multi-turn Conversation</button>
|
||||
<button class="example-btn" id="example3">Creative Writing</button>
|
||||
<button class="example-btn" id="example4">Code Generation</button>
|
||||
</div>
|
||||
|
||||
<div class="response">
|
||||
<h2>Response</h2>
|
||||
<pre id="responseOutput">Response will appear here...</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// Update request body when settings change
|
||||
const serverUrlInput = document.getElementById('serverUrl');
|
||||
const modelInput = document.getElementById('model');
|
||||
const maxTokensInput = document.getElementById('maxTokens');
|
||||
const temperatureInput = document.getElementById('temperature');
|
||||
const topPInput = document.getElementById('topP');
|
||||
const requestBodyTextarea = document.getElementById('requestBody');
|
||||
const responseOutput = document.getElementById('responseOutput');
|
||||
|
||||
// Function to update request body from settings
|
||||
function updateRequestBodyFromSettings() {
|
||||
try {
|
||||
const requestBody = JSON.parse(requestBodyTextarea.value);
|
||||
requestBody.model = modelInput.value;
|
||||
requestBody.max_tokens = parseInt(maxTokensInput.value);
|
||||
requestBody.temperature = parseFloat(temperatureInput.value);
|
||||
requestBody.top_p = parseFloat(topPInput.value);
|
||||
requestBodyTextarea.value = JSON.stringify(requestBody, null, 2);
|
||||
} catch (error) {
|
||||
console.error("Error updating request body:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Update settings when request body changes
|
||||
function updateSettingsFromRequestBody() {
|
||||
try {
|
||||
const requestBody = JSON.parse(requestBodyTextarea.value);
|
||||
if (requestBody.model) modelInput.value = requestBody.model;
|
||||
if (requestBody.max_tokens) maxTokensInput.value = requestBody.max_tokens;
|
||||
if (requestBody.temperature) temperatureInput.value = requestBody.temperature;
|
||||
if (requestBody.top_p) topPInput.value = requestBody.top_p;
|
||||
} catch (error) {
|
||||
console.error("Error updating settings:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Add event listeners for settings changes
|
||||
modelInput.addEventListener('change', updateRequestBodyFromSettings);
|
||||
maxTokensInput.addEventListener('change', updateRequestBodyFromSettings);
|
||||
temperatureInput.addEventListener('change', updateRequestBodyFromSettings);
|
||||
topPInput.addEventListener('change', updateRequestBodyFromSettings);
|
||||
|
||||
// Add event listener for request body changes
|
||||
requestBodyTextarea.addEventListener('blur', updateSettingsFromRequestBody);
|
||||
|
||||
// Send request button
|
||||
document.getElementById('sendRequest').addEventListener('click', async function() {
|
||||
try {
|
||||
responseOutput.textContent = "Sending request...";
|
||||
const serverUrl = serverUrlInput.value;
|
||||
const endpoint = '/v1/chat/completions';
|
||||
const url = serverUrl + endpoint;
|
||||
|
||||
const requestBody = JSON.parse(requestBodyTextarea.value);
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(requestBody)
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
responseOutput.textContent = JSON.stringify(data, null, 2);
|
||||
} catch (error) {
|
||||
responseOutput.textContent = "Error: " + error.message;
|
||||
responseOutput.classList.add('error');
|
||||
}
|
||||
});
|
||||
|
||||
// Example requests
|
||||
document.getElementById('example1').addEventListener('click', function() {
|
||||
requestBodyTextarea.value = JSON.stringify({
|
||||
model: modelInput.value,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Who was the 16th president of the United States?"
|
||||
}
|
||||
],
|
||||
max_tokens: parseInt(maxTokensInput.value),
|
||||
temperature: parseFloat(temperatureInput.value),
|
||||
top_p: parseFloat(topPInput.value)
|
||||
}, null, 2);
|
||||
});
|
||||
|
||||
document.getElementById('example2').addEventListener('click', function() {
|
||||
requestBodyTextarea.value = JSON.stringify({
|
||||
model: modelInput.value,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "You are a helpful assistant that provides concise answers."
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: "What is machine learning?"
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: "Give me an example of a machine learning algorithm."
|
||||
}
|
||||
],
|
||||
max_tokens: parseInt(maxTokensInput.value),
|
||||
temperature: parseFloat(temperatureInput.value),
|
||||
top_p: parseFloat(topPInput.value)
|
||||
}, null, 2);
|
||||
});
|
||||
|
||||
document.getElementById('example3').addEventListener('click', function() {
|
||||
requestBodyTextarea.value = JSON.stringify({
|
||||
model: modelInput.value,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Write a short poem about artificial intelligence."
|
||||
}
|
||||
],
|
||||
max_tokens: parseInt(maxTokensInput.value),
|
||||
temperature: 0.9, // Higher temperature for creative tasks
|
||||
top_p: 0.9
|
||||
}, null, 2);
|
||||
temperatureInput.value = 0.9;
|
||||
});
|
||||
|
||||
document.getElementById('example4').addEventListener('click', function() {
|
||||
requestBodyTextarea.value = JSON.stringify({
|
||||
model: modelInput.value,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Write a Python function to calculate the Fibonacci sequence up to n terms."
|
||||
}
|
||||
],
|
||||
max_tokens: parseInt(maxTokensInput.value),
|
||||
temperature: 0.3, // Lower temperature for code generation
|
||||
top_p: 0.9
|
||||
}, null, 2);
|
||||
temperatureInput.value = 0.3;
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
176
crates/local_inference_engine/openai_api_test.js
Normal file
176
crates/local_inference_engine/openai_api_test.js
Normal file
@@ -0,0 +1,176 @@
|
||||
// Test requests for the OpenAI-compatible endpoint in the inference server
|
||||
// This file contains IIFE (Immediately Invoked Function Expression) JavaScript requests
|
||||
// to test the /v1/chat/completions endpoint
|
||||
|
||||
// Basic chat completion request
|
||||
(async function testBasicChatCompletion() {
|
||||
console.log("Test 1: Basic chat completion request");
|
||||
try {
|
||||
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: "gemma-2-2b-it",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Who was the 16th president of the United States?"
|
||||
}
|
||||
],
|
||||
max_tokens: 100
|
||||
})
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Response:", JSON.stringify(data, null, 2));
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
}
|
||||
})();
|
||||
|
||||
// Multi-turn conversation
|
||||
(async function testMultiTurnConversation() {
|
||||
console.log("\nTest 2: Multi-turn conversation");
|
||||
try {
|
||||
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: "gemma-2-2b-it",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "You are a helpful assistant that provides concise answers."
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: "What is machine learning?"
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: "Give me an example of a machine learning algorithm."
|
||||
}
|
||||
],
|
||||
max_tokens: 150
|
||||
})
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Response:", JSON.stringify(data, null, 2));
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
}
|
||||
})();
|
||||
|
||||
// Request with temperature and top_p parameters
|
||||
(async function testTemperatureAndTopP() {
|
||||
console.log("\nTest 3: Request with temperature and top_p parameters");
|
||||
try {
|
||||
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: "gemma-2-2b-it",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Write a short poem about artificial intelligence."
|
||||
}
|
||||
],
|
||||
max_tokens: 200,
|
||||
temperature: 0.8,
|
||||
top_p: 0.9
|
||||
})
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Response:", JSON.stringify(data, null, 2));
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
}
|
||||
})();
|
||||
|
||||
// Request with streaming enabled
|
||||
(async function testStreaming() {
|
||||
console.log("\nTest 4: Request with streaming enabled");
|
||||
try {
|
||||
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: "gemma-2-2b-it",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Explain quantum computing in simple terms."
|
||||
}
|
||||
],
|
||||
max_tokens: 150,
|
||||
stream: true
|
||||
})
|
||||
});
|
||||
|
||||
// Note: Streaming might not be implemented yet, this is to test the API's handling of the parameter
|
||||
if (response.headers.get('content-type')?.includes('text/event-stream')) {
|
||||
console.log("Streaming response detected. Reading stream...");
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
console.log("Chunk:", chunk);
|
||||
}
|
||||
} else {
|
||||
const data = await response.json();
|
||||
console.log("Non-streaming response:", JSON.stringify(data, null, 2));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
}
|
||||
})();
|
||||
|
||||
// Request with a different model
|
||||
(async function testDifferentModel() {
|
||||
console.log("\nTest 5: Request with a different model");
|
||||
try {
|
||||
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: "gemma-2-2b-it", // Using a different model if available
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "What are the benefits of renewable energy?"
|
||||
}
|
||||
],
|
||||
max_tokens: 150
|
||||
})
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Response:", JSON.stringify(data, null, 2));
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
}
|
||||
})();
|
||||
|
||||
console.log("\nAll test requests have been sent. Check the server logs for more details.");
|
||||
console.log("To run the server, use: cargo run --bin local_inference_engine -- --server");
|
72
crates/local_inference_engine/src/cli.rs
Normal file
72
crates/local_inference_engine/src/cli.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use clap::Parser;
|
||||
use crate::model::Which;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
pub cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
pub tracing: bool,
|
||||
|
||||
/// Run in server mode with OpenAI compatible API
|
||||
#[arg(long)]
|
||||
pub server: bool,
|
||||
|
||||
/// Port to use for the server
|
||||
#[arg(long, default_value_t = 3777)]
|
||||
pub port: u16,
|
||||
|
||||
/// Prompt for text generation (not used in server mode)
|
||||
#[arg(long)]
|
||||
pub prompt: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
pub temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
pub top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
pub seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
pub sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
pub model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
pub revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
pub tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
pub config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
pub weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
pub repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
pub repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "3-1b-it")]
|
||||
pub which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
pub use_flash_attn: bool,
|
||||
}
|
13
crates/local_inference_engine/src/lib.rs
Normal file
13
crates/local_inference_engine/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
// Expose modules for testing and library usage
|
||||
pub mod token_output_stream;
|
||||
pub mod model;
|
||||
pub mod text_generation;
|
||||
pub mod utilities_lib;
|
||||
pub mod openai_types;
|
||||
pub mod cli;
|
||||
pub mod server;
|
||||
|
||||
// Re-export key components for easier access
|
||||
pub use model::{Model, Which};
|
||||
pub use text_generation::TextGeneration;
|
||||
pub use token_output_stream::TokenOutputStream;
|
894
crates/local_inference_engine/src/main.rs
Normal file
894
crates/local_inference_engine/src/main.rs
Normal file
@@ -0,0 +1,894 @@
|
||||
mod token_output_stream;
|
||||
mod utilities_lib;
|
||||
|
||||
#[cfg(feature = "intel-mkl-src")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate-src")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use clap::Parser;
|
||||
use either::Either;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
|
||||
// OpenAI API compatible structs
|
||||
|
||||
/// Inner content structure for messages that can be either a string or key-value pairs
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageInnerContent(
|
||||
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageInnerContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
(
|
||||
"MessageInnerContent",
|
||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Function for MessageInnerContent Schema generation to handle `Either`
|
||||
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||
|
||||
Schema::OneOf(
|
||||
OneOfBuilder::new()
|
||||
// Either::Left - simple string
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))
|
||||
// Either::Right - object with string values
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new()
|
||||
.schema_type(SchemaType::Object)
|
||||
.additional_properties(Some(RefOr::T(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))))
|
||||
.build(),
|
||||
))
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Message content that can be either simple text or complex structured content
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContent(
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Function for MessageContent Schema generation to handle `Either`
|
||||
fn message_content_schema() -> utoipa::openapi::Schema {
|
||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||
|
||||
Schema::OneOf(
|
||||
OneOfBuilder::new()
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))
|
||||
.item(Schema::Array(
|
||||
ArrayBuilder::new()
|
||||
.items(RefOr::T(Schema::Object(
|
||||
ObjectBuilder::new()
|
||||
.schema_type(SchemaType::Object)
|
||||
.additional_properties(Some(RefOr::Ref(
|
||||
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
||||
)))
|
||||
.build(),
|
||||
)))
|
||||
.build(),
|
||||
))
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Represents a single message in a conversation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub struct Message {
|
||||
/// The message content
|
||||
pub content: Option<MessageContent>,
|
||||
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
||||
pub role: String,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Stop token configuration for generation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum StopTokens {
|
||||
/// Multiple possible stop sequences
|
||||
Multi(Vec<String>),
|
||||
/// Single stop sequence
|
||||
Single(String),
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_false() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_1usize() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_model() -> String {
|
||||
"default".to_string()
|
||||
}
|
||||
|
||||
/// Chat completion request following OpenAI's specification
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionRequest {
|
||||
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
||||
pub messages: Vec<Message>,
|
||||
#[schema(example = "gemma-3-1b-it")]
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
#[serde(default = "default_false")]
|
||||
#[schema(example = false)]
|
||||
pub logprobs: bool,
|
||||
#[schema(example = 256)]
|
||||
pub max_tokens: Option<usize>,
|
||||
#[serde(rename = "n")]
|
||||
#[serde(default = "default_1usize")]
|
||||
#[schema(example = 1)]
|
||||
pub n_choices: usize,
|
||||
#[schema(example = 0.7)]
|
||||
pub temperature: Option<f64>,
|
||||
#[schema(example = 0.9)]
|
||||
pub top_p: Option<f64>,
|
||||
#[schema(example = false)]
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
/// Chat completion choice
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionChoice {
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
/// Chat completion response
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatCompletionChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
/// Token usage information
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
// Application state shared between handlers
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
text_generation: Arc<Mutex<TextGeneration>>,
|
||||
model_id: String,
|
||||
}
|
||||
|
||||
// Chat completions endpoint handler
|
||||
async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// Convert messages to a prompt string
|
||||
for message in &request.messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
// Format based on role
|
||||
match role.as_str() {
|
||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||
}
|
||||
}
|
||||
|
||||
// Add the assistant prefix for the response
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
// Capture the output
|
||||
let mut output = Vec::new();
|
||||
{
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
|
||||
// Buffer to capture the output
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
// Run text generation
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
||||
|
||||
if let Err(e) = result {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"",
|
||||
"type": "unsupported_api"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
// Convert buffer to string
|
||||
if let Ok(text) = String::from_utf8(buffer) {
|
||||
output.push(text);
|
||||
}
|
||||
}
|
||||
|
||||
// Create response
|
||||
let response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
model: request.model,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(MessageContent(Either::Left(output.join("")))),
|
||||
name: None,
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: prompt.len() / 4, // Rough estimate
|
||||
completion_tokens: output.join("").len() / 4, // Rough estimate
|
||||
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
||||
},
|
||||
};
|
||||
|
||||
// Return the response as JSON
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
use candle_core::{DType, Device, MetalDevice, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{Repo, RepoType, api::sync::Api};
|
||||
use tokenizers::Tokenizer;
|
||||
use crate::token_output_stream::TokenOutputStream;
|
||||
use crate::utilities_lib::device;
|
||||
|
||||
// Create the router with the chat completions endpoint
|
||||
fn create_router(app_state: AppState) -> Router {
|
||||
// CORS layer to allow requests from any origin
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
Router::new()
|
||||
// OpenAI compatible endpoints
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
// Add more endpoints as needed
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "7b")]
|
||||
Base7B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "code-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "code-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "code-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
#[value(name = "2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "2-9b-it")]
|
||||
InstructV2_9B,
|
||||
#[value(name = "3-1b")]
|
||||
BaseV3_1B,
|
||||
#[value(name = "3-1b-it")]
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
V3(Model3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
Self::V3(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run text generation and print to stdout
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
println!(
|
||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||
);
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run text generation and write to a buffer
|
||||
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Write prompt tokens to output
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we're using a Model3 (gemma-3) variant
|
||||
let is_model3 = match &self.model {
|
||||
Model::V3(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// For Model3, we need to use a different approach
|
||||
if is_model3 {
|
||||
// For gemma-3 models, we'll generate one token at a time with the full context
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
// Initial generation with the full prompt
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
let mut logits = self.model.forward(&input, 0)?;
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
for _ in 0..sample_len {
|
||||
// Apply repeat penalty if needed
|
||||
let current_logits = if self.repeat_penalty == 1. {
|
||||
logits.clone()
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
|
||||
// For the next iteration, just use the new token
|
||||
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
||||
logits = self.model.forward(&new_input, tokens.len() - 1)?;
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Standard approach for other models
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Write any remaining tokens
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
write!(output, "{}", rest)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Run in server mode with OpenAI compatible API
|
||||
#[arg(long)]
|
||||
server: bool,
|
||||
|
||||
/// Port to use for the server
|
||||
#[arg(long, default_value_t = 3777)]
|
||||
port: u16,
|
||||
|
||||
/// Prompt for text generation (not used in server mode)
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "3-1b-it")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle_core::utils::with_avx(),
|
||||
candle_core::utils::with_neon(),
|
||||
candle_core::utils::with_simd128(),
|
||||
candle_core::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||
Which::Base2B => "google/gemma-2b".to_string(),
|
||||
Which::Base7B => "google/gemma-7b".to_string(),
|
||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
||||
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id.clone(),
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.which {
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
_ => utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
},
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = utilities_lib::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
// Use the original device and dtype
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
|
||||
if args.server {
|
||||
// Start the server
|
||||
println!("Starting server on port {}", args.port);
|
||||
|
||||
// Create app state
|
||||
let app_state = AppState {
|
||||
text_generation: Arc::new(Mutex::new(pipeline)),
|
||||
model_id,
|
||||
};
|
||||
|
||||
// Create router
|
||||
let app = create_router(app_state);
|
||||
|
||||
// Run the server
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
|
||||
|
||||
// Use tokio to run the server
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(async {
|
||||
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
// Run in CLI mode
|
||||
if let Some(prompt_text) = &args.prompt {
|
||||
let prompt = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B => prompt_text.clone(),
|
||||
Which::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||
prompt_text
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let mut pipeline = pipeline;
|
||||
pipeline.run(&prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
} else {
|
||||
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
|
||||
}
|
||||
}
|
||||
}
|
90
crates/local_inference_engine/src/model.rs
Normal file
90
crates/local_inference_engine/src/model.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use candle_core::Tensor;
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "7b")]
|
||||
Base7B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "code-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "code-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "code-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
#[value(name = "2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "2-9b-it")]
|
||||
InstructV2_9B,
|
||||
#[value(name = "3-1b")]
|
||||
BaseV3_1B,
|
||||
#[value(name = "3-1b-it")]
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
pub enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
V3(Model3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
Self::V3(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
pub fn to_model_id(&self) -> String {
|
||||
match self {
|
||||
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||
Self::Base2B => "google/gemma-2b".to_string(),
|
||||
Self::Base7B => "google/gemma-7b".to_string(),
|
||||
Self::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||
Self::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||
Self::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||
Self::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
Self::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
||||
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||
Self::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_instruct_model(&self) -> bool {
|
||||
match self {
|
||||
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_v3_model(&self) -> bool {
|
||||
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
|
||||
}
|
||||
}
|
167
crates/local_inference_engine/src/openai_types.rs
Normal file
167
crates/local_inference_engine/src/openai_types.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use either::Either;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
/// Inner content structure for messages that can be either a string or key-value pairs
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageInnerContent(
|
||||
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageInnerContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
(
|
||||
"MessageInnerContent",
|
||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Function for MessageInnerContent Schema generation to handle `Either`
|
||||
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||
|
||||
Schema::OneOf(
|
||||
OneOfBuilder::new()
|
||||
// Either::Left - simple string
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))
|
||||
// Either::Right - object with string values
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new()
|
||||
.schema_type(SchemaType::Object)
|
||||
.additional_properties(Some(RefOr::T(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))))
|
||||
.build(),
|
||||
))
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Message content that can be either simple text or complex structured content
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContent(
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Function for MessageContent Schema generation to handle `Either`
|
||||
fn message_content_schema() -> utoipa::openapi::Schema {
|
||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||
|
||||
Schema::OneOf(
|
||||
OneOfBuilder::new()
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))
|
||||
.item(Schema::Array(
|
||||
ArrayBuilder::new()
|
||||
.items(RefOr::T(Schema::Object(
|
||||
ObjectBuilder::new()
|
||||
.schema_type(SchemaType::Object)
|
||||
.additional_properties(Some(RefOr::Ref(
|
||||
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
||||
)))
|
||||
.build(),
|
||||
)))
|
||||
.build(),
|
||||
))
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Represents a single message in a conversation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub struct Message {
|
||||
/// The message content
|
||||
pub content: Option<MessageContent>,
|
||||
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
||||
pub role: String,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Stop token configuration for generation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum StopTokens {
|
||||
/// Multiple possible stop sequences
|
||||
Multi(Vec<String>),
|
||||
/// Single stop sequence
|
||||
Single(String),
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
pub fn default_false() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
pub fn default_1usize() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
pub fn default_model() -> String {
|
||||
"default".to_string()
|
||||
}
|
||||
|
||||
/// Chat completion request following OpenAI's specification
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionRequest {
|
||||
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
||||
pub messages: Vec<Message>,
|
||||
#[schema(example = "gemma-3-1b-it")]
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
#[serde(default = "default_false")]
|
||||
#[schema(example = false)]
|
||||
pub logprobs: bool,
|
||||
#[schema(example = 256)]
|
||||
pub max_tokens: Option<usize>,
|
||||
#[serde(rename = "n")]
|
||||
#[serde(default = "default_1usize")]
|
||||
#[schema(example = 1)]
|
||||
pub n_choices: usize,
|
||||
#[schema(example = 0.7)]
|
||||
pub temperature: Option<f64>,
|
||||
#[schema(example = 0.9)]
|
||||
pub top_p: Option<f64>,
|
||||
#[schema(example = false)]
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
/// Chat completion choice
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionChoice {
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
/// Chat completion response
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatCompletionChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
/// Token usage information
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
126
crates/local_inference_engine/src/server.rs
Normal file
126
crates/local_inference_engine/src/server.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, Message, MessageContent, Usage};
|
||||
use crate::text_generation::TextGeneration;
|
||||
use either::Either;
|
||||
|
||||
// Application state shared between handlers
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub text_generation: Arc<Mutex<TextGeneration>>,
|
||||
pub model_id: String,
|
||||
}
|
||||
|
||||
// Chat completions endpoint handler
|
||||
pub async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// Convert messages to a prompt string
|
||||
for message in &request.messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
// Format based on role
|
||||
match role.as_str() {
|
||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||
}
|
||||
}
|
||||
|
||||
// Add the assistant prefix for the response
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
// Capture the output
|
||||
let mut output = Vec::new();
|
||||
{
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
|
||||
// Buffer to capture the output
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
// Run text generation
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
||||
|
||||
if let Err(e) = result {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"",
|
||||
"type": "unsupported_api"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
// Convert buffer to string
|
||||
if let Ok(text) = String::from_utf8(buffer) {
|
||||
output.push(text);
|
||||
}
|
||||
}
|
||||
|
||||
// Create response
|
||||
let response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace("-", "")),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
model: request.model,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(MessageContent(Either::Left(output.join("")))),
|
||||
name: None,
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: prompt.len() / 4, // Rough estimate
|
||||
completion_tokens: output.join("").len() / 4, // Rough estimate
|
||||
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
||||
},
|
||||
};
|
||||
|
||||
// Return the response as JSON
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
// Create the router with the chat completions endpoint
|
||||
pub fn create_router(app_state: AppState) -> Router {
|
||||
// CORS layer to allow requests from any origin
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
Router::new()
|
||||
// OpenAI compatible endpoints
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
// Add more endpoints as needed
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
}
|
277
crates/local_inference_engine/src/text_generation.rs
Normal file
277
crates/local_inference_engine/src/text_generation.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use tokenizers::Tokenizer;
|
||||
use std::io::Write;
|
||||
|
||||
use crate::model::Model;
|
||||
use crate::token_output_stream::TokenOutputStream;
|
||||
|
||||
pub struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run text generation and print to stdout
|
||||
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
println!(
|
||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||
);
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run text generation and write to a buffer
|
||||
pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Write prompt tokens to output
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we're using a Model3 (gemma-3) variant
|
||||
let is_model3 = match &self.model {
|
||||
Model::V3(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// For Model3, we need to use a different approach
|
||||
if is_model3 {
|
||||
// For gemma-3 models, we'll generate one token at a time with the full context
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
// Initial generation with the full prompt
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
let mut logits = self.model.forward(&input, 0)?;
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
for _ in 0..sample_len {
|
||||
// Apply repeat penalty if needed
|
||||
let current_logits = if self.repeat_penalty == 1. {
|
||||
logits.clone()
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
|
||||
// For the next iteration, just use the new token
|
||||
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
||||
logits = self.model.forward(&new_input, tokens.len() - 1)?;
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Standard approach for other models
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Write any remaining tokens
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
write!(output, "{}", rest)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
86
crates/local_inference_engine/src/token_output_stream.rs
Normal file
86
crates/local_inference_engine/src/token_output_stream.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use candle_core::Result;
|
||||
|
||||
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
|
||||
/// streaming way rather than having to wait for the full decoding.
|
||||
pub struct TokenOutputStream {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
tokens: Vec<u32>,
|
||||
prev_index: usize,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
impl TokenOutputStream {
|
||||
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
tokens: Vec::new(),
|
||||
prev_index: 0,
|
||||
current_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> tokenizers::Tokenizer {
|
||||
self.tokenizer
|
||||
}
|
||||
|
||||
fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
match self.tokenizer.decode(tokens, true) {
|
||||
Ok(str) => Ok(str),
|
||||
Err(err) => candle_core::bail!("cannot decode: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
|
||||
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_rest(&self) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_all(&self) -> Result<String> {
|
||||
self.decode(&self.tokens)
|
||||
}
|
||||
|
||||
pub fn get_token(&self, token_s: &str) -> Option<u32> {
|
||||
self.tokenizer.get_vocab(true).get(token_s).copied()
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.tokens.clear();
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
167
crates/local_inference_engine/src/utilities_lib.rs
Normal file
167
crates/local_inference_engine/src/utilities_lib.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use candle_core::utils::{cuda_is_available, metal_is_available};
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
pub fn device(cpu: bool) -> Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
{
|
||||
println!(
|
||||
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
||||
);
|
||||
}
|
||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||
{
|
||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||
}
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
resize_longest: Option<usize>,
|
||||
) -> Result<(Tensor, usize, usize)> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle_core::Error::wrap)?;
|
||||
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
||||
let img = match resize_longest {
|
||||
None => img,
|
||||
Some(resize_longest) => {
|
||||
let (height, width) = (img.height(), img.width());
|
||||
let resize_longest = resize_longest as u32;
|
||||
let (height, width) = if height < width {
|
||||
let h = (resize_longest * height) / width;
|
||||
(h, resize_longest)
|
||||
} else {
|
||||
let w = (resize_longest * width) / height;
|
||||
(resize_longest, w)
|
||||
};
|
||||
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
|
||||
}
|
||||
};
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
Ok((data, initial_h, initial_w))
|
||||
}
|
||||
|
||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
width: usize,
|
||||
height: usize,
|
||||
) -> Result<Tensor> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle_core::Error::wrap)?
|
||||
.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
||||
/// (c, height, width).
|
||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
let p = p.as_ref();
|
||||
let (channel, height, width) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
candle_core::bail!("save_image expects an input of shape (3, height, width)")
|
||||
}
|
||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
Some(image) => image,
|
||||
None => candle_core::bail!("error saving image {p:?}"),
|
||||
};
|
||||
image.save(p).map_err(candle_core::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save_image_resize<P: AsRef<std::path::Path>>(
|
||||
img: &Tensor,
|
||||
p: P,
|
||||
h: usize,
|
||||
w: usize,
|
||||
) -> Result<()> {
|
||||
let p = p.as_ref();
|
||||
let (channel, height, width) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
candle_core::bail!("save_image expects an input of shape (3, height, width)")
|
||||
}
|
||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
Some(image) => image,
|
||||
None => candle_core::bail!("error saving image {p:?}"),
|
||||
};
|
||||
let image = image::DynamicImage::from(image);
|
||||
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
|
||||
image.save(p).map_err(candle_core::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Loads the safetensors files for a model from the hub based on a json index file.
|
||||
pub fn hub_load_safetensors(
|
||||
repo: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
) -> Result<Vec<std::path::PathBuf>> {
|
||||
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
|
||||
let json_file = std::fs::File::open(json_file)?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file.to_string());
|
||||
}
|
||||
}
|
||||
let safetensors_files = safetensors_files
|
||||
.iter()
|
||||
.map(|v| repo.get(v).map_err(candle_core::Error::wrap))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
json_file: &str,
|
||||
) -> Result<Vec<std::path::PathBuf>> {
|
||||
let path = path.as_ref();
|
||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file);
|
||||
}
|
||||
}
|
||||
let safetensors_files: Vec<_> = safetensors_files
|
||||
.into_iter()
|
||||
.map(|v| path.join(v))
|
||||
.collect();
|
||||
Ok(safetensors_files)
|
||||
}
|
17
crates/local_inference_engine/test.sh
Normal file
17
crates/local_inference_engine/test.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
PROMPT='Who was the 16th president'
|
||||
|
||||
|
||||
# will pull gemma-3-1b-it and run the prompt
|
||||
cargo run -- --prompt "${PROMPT}"
|
||||
|
||||
#avx: false, neon: true, simd128: false, f16c: false
|
||||
#temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
#retrieved the files in 1.388209ms
|
||||
#loaded the model in 321.509333ms
|
||||
# user
|
||||
#Who was the 16th president
|
||||
# model
|
||||
#The 16th President of the United States was **Abraham Lincoln**. He served from March 4, 1861, to March 4, 1865.
|
||||
#40 tokens generated (31.85 token/s)
|
67
crates/local_inference_engine/tests/model_tests.rs
Normal file
67
crates/local_inference_engine/tests/model_tests.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use local_inference_engine::model::{Model, Which};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_which_to_model_id() {
|
||||
// Test a few representative model variants
|
||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
||||
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
|
||||
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
|
||||
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
|
||||
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_which_is_instruct_model() {
|
||||
// Test base models (should return false)
|
||||
assert!(!Which::Base2B.is_instruct_model());
|
||||
assert!(!Which::Base7B.is_instruct_model());
|
||||
assert!(!Which::CodeBase2B.is_instruct_model());
|
||||
assert!(!Which::CodeBase7B.is_instruct_model());
|
||||
assert!(!Which::BaseV2_2B.is_instruct_model());
|
||||
assert!(!Which::BaseV2_9B.is_instruct_model());
|
||||
assert!(!Which::BaseV3_1B.is_instruct_model());
|
||||
|
||||
// Test instruct models (should return true)
|
||||
assert!(Which::Instruct2B.is_instruct_model());
|
||||
assert!(Which::Instruct7B.is_instruct_model());
|
||||
assert!(Which::InstructV1_1_2B.is_instruct_model());
|
||||
assert!(Which::InstructV1_1_7B.is_instruct_model());
|
||||
assert!(Which::CodeInstruct2B.is_instruct_model());
|
||||
assert!(Which::CodeInstruct7B.is_instruct_model());
|
||||
assert!(Which::InstructV2_2B.is_instruct_model());
|
||||
assert!(Which::InstructV2_9B.is_instruct_model());
|
||||
assert!(Which::InstructV3_1B.is_instruct_model());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_which_is_v3_model() {
|
||||
// Test non-v3 models (should return false)
|
||||
assert!(!Which::Base2B.is_v3_model());
|
||||
assert!(!Which::Base7B.is_v3_model());
|
||||
assert!(!Which::Instruct2B.is_v3_model());
|
||||
assert!(!Which::Instruct7B.is_v3_model());
|
||||
assert!(!Which::InstructV1_1_2B.is_v3_model());
|
||||
assert!(!Which::InstructV1_1_7B.is_v3_model());
|
||||
assert!(!Which::CodeBase2B.is_v3_model());
|
||||
assert!(!Which::CodeBase7B.is_v3_model());
|
||||
assert!(!Which::CodeInstruct2B.is_v3_model());
|
||||
assert!(!Which::CodeInstruct7B.is_v3_model());
|
||||
assert!(!Which::BaseV2_2B.is_v3_model());
|
||||
assert!(!Which::InstructV2_2B.is_v3_model());
|
||||
assert!(!Which::BaseV2_9B.is_v3_model());
|
||||
assert!(!Which::InstructV2_9B.is_v3_model());
|
||||
|
||||
// Test v3 models (should return true)
|
||||
assert!(Which::BaseV3_1B.is_v3_model());
|
||||
assert!(Which::InstructV3_1B.is_v3_model());
|
||||
}
|
||||
|
||||
// Note: Testing the Model enum's forward method would require creating actual model instances,
|
||||
// which is complex and would require loading model weights. This is better suited for
|
||||
// integration tests or mocking the models.
|
||||
}
|
104
crates/local_inference_engine/tests/text_generation_tests.rs
Normal file
104
crates/local_inference_engine/tests/text_generation_tests.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use local_inference_engine::text_generation::TextGeneration;
|
||||
use local_inference_engine::model::{Model, Which};
|
||||
use local_inference_engine::token_output_stream::TokenOutputStream;
|
||||
use tokenizers::Tokenizer;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper function to create a simple tokenizer for testing
|
||||
fn create_test_tokenizer() -> Result<Tokenizer> {
|
||||
// Create a simple tokenizer from the pretrained model
|
||||
// This uses the tokenizer from the Hugging Face hub
|
||||
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
// Test the Which enum's to_model_id method
|
||||
#[test]
|
||||
fn test_which_model_id() {
|
||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
||||
}
|
||||
|
||||
// Test the Which enum's is_instruct_model method
|
||||
#[test]
|
||||
fn test_which_is_instruct() {
|
||||
assert!(!Which::Base2B.is_instruct_model());
|
||||
assert!(Which::Instruct7B.is_instruct_model());
|
||||
}
|
||||
|
||||
// Test the Which enum's is_v3_model method
|
||||
#[test]
|
||||
fn test_which_is_v3() {
|
||||
assert!(!Which::Base2B.is_v3_model());
|
||||
assert!(Which::BaseV3_1B.is_v3_model());
|
||||
}
|
||||
|
||||
// Test the TokenOutputStream functionality
|
||||
#[test]
|
||||
fn test_token_output_stream() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Test encoding and decoding
|
||||
let text = "Hello, world!";
|
||||
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
|
||||
let token_ids = encoded.get_ids();
|
||||
|
||||
// Add tokens one by one
|
||||
for &token_id in token_ids {
|
||||
token_stream.next_token(token_id)?;
|
||||
}
|
||||
|
||||
// Decode all and check
|
||||
let decoded = token_stream.decode_all()?;
|
||||
assert_eq!(decoded.trim(), text);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test the LogitsProcessor
|
||||
#[test]
|
||||
fn test_logits_processor() -> Result<()> {
|
||||
// Create a LogitsProcessor with default settings
|
||||
let seed = 42;
|
||||
let temp = Some(0.8);
|
||||
let top_p = Some(0.9);
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
|
||||
// Create a simple logits tensor
|
||||
// In a real test, we would create a tensor with known values and verify
|
||||
// that sampling produces expected results
|
||||
|
||||
// For now, we'll just verify that the LogitsProcessor can be created
|
||||
assert!(true);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test the TextGeneration constructor
|
||||
#[test]
|
||||
fn test_text_generation_constructor() -> Result<()> {
|
||||
// We can't easily create a Model instance for testing,
|
||||
// but we can test that the constructor compiles and the types are correct
|
||||
|
||||
// In a real test with a mock Model, we would:
|
||||
// 1. Create a mock model
|
||||
// 2. Create a tokenizer
|
||||
// 3. Call TextGeneration::new
|
||||
// 4. Verify the properties of the created instance
|
||||
|
||||
// For now, we'll just verify that the code compiles
|
||||
assert!(true);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Note: Testing the actual text generation functionality would require
|
||||
// integration tests with real models, which is beyond the scope of these unit tests.
|
||||
// The tests above focus on the components that can be tested in isolation.
|
||||
}
|
129
crates/local_inference_engine/tests/token_output_stream_tests.rs
Normal file
129
crates/local_inference_engine/tests/token_output_stream_tests.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use local_inference_engine::token_output_stream::TokenOutputStream;
|
||||
use tokenizers::Tokenizer;
|
||||
use std::path::PathBuf;
|
||||
use anyhow::Result;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper function to create a simple tokenizer for testing
|
||||
fn create_test_tokenizer() -> Result<Tokenizer> {
|
||||
// Create a simple tokenizer from the pretrained model
|
||||
// This uses the tokenizer from the Hugging Face hub
|
||||
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_token_output_stream() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Check that the token stream was created successfully
|
||||
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Add a token
|
||||
let token_id = token_stream.get_token("<eos>").unwrap();
|
||||
token_stream.next_token(token_id)?;
|
||||
|
||||
// Clear the stream
|
||||
token_stream.clear();
|
||||
|
||||
// Check that the stream is empty by trying to decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
assert_eq!(decoded, "");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_token() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Get a token that should exist
|
||||
let eos_token = token_stream.get_token("<eos>");
|
||||
assert!(eos_token.is_some());
|
||||
|
||||
// Get a token that shouldn't exist
|
||||
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
|
||||
assert!(nonexistent_token.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_token_and_decode() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
// Add tokens one by one
|
||||
let mut output = String::new();
|
||||
for &token_id in token_ids {
|
||||
if let Some(text) = token_stream.next_token(token_id)? {
|
||||
output.push_str(&text);
|
||||
}
|
||||
}
|
||||
|
||||
// Get any remaining text
|
||||
if let Some(rest) = token_stream.decode_rest()? {
|
||||
output.push_str(&rest);
|
||||
}
|
||||
|
||||
// Check the output
|
||||
assert!(!output.is_empty());
|
||||
assert_eq!(output.trim(), "Hello world");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_all() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
// Add tokens one by one
|
||||
for &token_id in token_ids {
|
||||
token_stream.next_token(token_id)?;
|
||||
}
|
||||
|
||||
// Decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
|
||||
// Check the output
|
||||
assert_eq!(decoded.trim(), "Hello world");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_inner() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Get the inner tokenizer
|
||||
let inner_tokenizer = token_stream.into_inner();
|
||||
|
||||
// Check that the inner tokenizer works
|
||||
let encoded = inner_tokenizer.encode("Test", true).unwrap();
|
||||
assert!(encoded.get_ids().len() > 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user