From 1df24a7d3b64322006409ea14200ed11d0e0332e Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Thu, 5 Jun 2025 21:09:49 -0400 Subject: [PATCH] add test ui for the local inference server --- local_inference_engine/api_test.html | 295 ++++++++++++++++++++++ local_inference_engine/openai_api_test.js | 176 +++++++++++++ local_inference_engine/src/main.rs | 84 +++++- src/handlers/ui.rs | 34 --- src/routes.rs | 2 +- 5 files changed, 546 insertions(+), 45 deletions(-) create mode 100644 local_inference_engine/api_test.html create mode 100644 local_inference_engine/openai_api_test.js diff --git a/local_inference_engine/api_test.html b/local_inference_engine/api_test.html new file mode 100644 index 0000000..2f50e95 --- /dev/null +++ b/local_inference_engine/api_test.html @@ -0,0 +1,295 @@ + + + + + + OpenAI-Compatible API Tester + + + +

OpenAI-Compatible API Tester

+

Use this page to test the OpenAI-compatible chat completions endpoint of the local inference engine.

+ +
+

Request Settings

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ +

Request Body

+ + + +
+

Example Requests

+ + + + +
+ +
+

Response

+
Response will appear here...
+
+
+ + + + \ No newline at end of file diff --git a/local_inference_engine/openai_api_test.js b/local_inference_engine/openai_api_test.js new file mode 100644 index 0000000..465e312 --- /dev/null +++ b/local_inference_engine/openai_api_test.js @@ -0,0 +1,176 @@ +// Test requests for the OpenAI-compatible endpoint in the inference server +// This file contains IIFE (Immediately Invoked Function Expression) JavaScript requests +// to test the /v1/chat/completions endpoint + +// Basic chat completion request +(async function testBasicChatCompletion() { + console.log("Test 1: Basic chat completion request"); + try { + const response = await fetch('http://localhost:3000/v1/chat/completions', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: "gemma-2-2b-it", + messages: [ + { + role: "user", + content: "Who was the 16th president of the United States?" + } + ], + max_tokens: 100 + }) + }); + + const data = await response.json(); + console.log("Response:", JSON.stringify(data, null, 2)); + } catch (error) { + console.error("Error:", error); + } +})(); + +// Multi-turn conversation +(async function testMultiTurnConversation() { + console.log("\nTest 2: Multi-turn conversation"); + try { + const response = await fetch('http://localhost:3000/v1/chat/completions', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: "gemma-2-2b-it", + messages: [ + { + role: "system", + content: "You are a helpful assistant that provides concise answers." + }, + { + role: "user", + content: "What is machine learning?" + }, + { + role: "assistant", + content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed." + }, + { + role: "user", + content: "Give me an example of a machine learning algorithm." + } + ], + max_tokens: 150 + }) + }); + + const data = await response.json(); + console.log("Response:", JSON.stringify(data, null, 2)); + } catch (error) { + console.error("Error:", error); + } +})(); + +// Request with temperature and top_p parameters +(async function testTemperatureAndTopP() { + console.log("\nTest 3: Request with temperature and top_p parameters"); + try { + const response = await fetch('http://localhost:3000/v1/chat/completions', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: "gemma-2-2b-it", + messages: [ + { + role: "user", + content: "Write a short poem about artificial intelligence." + } + ], + max_tokens: 200, + temperature: 0.8, + top_p: 0.9 + }) + }); + + const data = await response.json(); + console.log("Response:", JSON.stringify(data, null, 2)); + } catch (error) { + console.error("Error:", error); + } +})(); + +// Request with streaming enabled +(async function testStreaming() { + console.log("\nTest 4: Request with streaming enabled"); + try { + const response = await fetch('http://localhost:3000/v1/chat/completions', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: "gemma-2-2b-it", + messages: [ + { + role: "user", + content: "Explain quantum computing in simple terms." + } + ], + max_tokens: 150, + stream: true + }) + }); + + // Note: Streaming might not be implemented yet, this is to test the API's handling of the parameter + if (response.headers.get('content-type')?.includes('text/event-stream')) { + console.log("Streaming response detected. Reading stream..."); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + console.log("Chunk:", chunk); + } + } else { + const data = await response.json(); + console.log("Non-streaming response:", JSON.stringify(data, null, 2)); + } + } catch (error) { + console.error("Error:", error); + } +})(); + +// Request with a different model +(async function testDifferentModel() { + console.log("\nTest 5: Request with a different model"); + try { + const response = await fetch('http://localhost:3000/v1/chat/completions', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: "gemma-2-2b-it", // Using a different model if available + messages: [ + { + role: "user", + content: "What are the benefits of renewable energy?" + } + ], + max_tokens: 150 + }) + }); + + const data = await response.json(); + console.log("Response:", JSON.stringify(data, null, 2)); + } catch (error) { + console.error("Error:", error); + } +})(); + +console.log("\nAll test requests have been sent. Check the server logs for more details."); +console.log("To run the server, use: cargo run --bin local_inference_engine -- --server"); diff --git a/local_inference_engine/src/main.rs b/local_inference_engine/src/main.rs index c6ba3e5..f29ff6d 100644 --- a/local_inference_engine/src/main.rs +++ b/local_inference_engine/src/main.rs @@ -242,16 +242,16 @@ async fn chat_completions( let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer); if let Err(e) = result { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { - "message": format!("Error generating text: {}", e), - "type": "internal_server_error" - } - })), - )); - } + return Err(( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": { + "message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"", + "type": "unsupported_api" + } + })), + )); +} // Convert buffer to string if let Ok(text) = String::from_utf8(buffer) { @@ -520,6 +520,70 @@ impl TextGeneration { } }; + // Determine if we're using a Model3 (gemma-3) variant + let is_model3 = match &self.model { + Model::V3(_) => true, + _ => false, + }; + + // For Model3, we need to use a different approach + if is_model3 { + // For gemma-3 models, we'll generate one token at a time with the full context + let start_gen = std::time::Instant::now(); + + // Initial generation with the full prompt + let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; + let mut logits = self.model.forward(&input, 0)?; + logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + + for _ in 0..sample_len { + // Apply repeat penalty if needed + let current_logits = if self.repeat_penalty == 1. { + logits.clone() + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + + // Manual implementation of repeat penalty to avoid type conflicts + let mut logits_vec = logits.to_vec1::()?; + + for &token_id in &tokens[start_at..] { + let token_id = token_id as usize; + if token_id < logits_vec.len() { + let score = logits_vec[token_id]; + let sign = if score < 0.0 { -1.0 } else { 1.0 }; + logits_vec[token_id] = sign * score / self.repeat_penalty; + } + } + + // Create a new tensor with the modified logits + let device = logits.device().clone(); + let shape = logits.shape().clone(); + let new_logits = Tensor::new(&logits_vec[..], &device)?; + new_logits.reshape(shape)? + }; + + let next_token = self.logits_processor.sample(¤t_logits)?; + tokens.push(next_token); + generated_tokens += 1; + + if next_token == eos_token || next_token == eot_token { + break; + } + + if let Some(t) = self.tokenizer.next_token(next_token)? { + write!(output, "{}", t)?; + } + + // For the next iteration, just use the new token + let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?; + logits = self.model.forward(&new_input, tokens.len() - 1)?; + logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + } + + return Ok(()); + } + + // Standard approach for other models let start_gen = std::time::Instant::now(); for index in 0..sample_len { let context_size = if index > 0 { 1 } else { tokens.len() }; diff --git a/src/handlers/ui.rs b/src/handlers/ui.rs index bcf5967..e69de29 100644 --- a/src/handlers/ui.rs +++ b/src/handlers/ui.rs @@ -1,34 +0,0 @@ -use axum::{ - body::Body, - http::{StatusCode, header::CONTENT_TYPE}, - response::{IntoResponse, Response}, -}; -use rust_embed::RustEmbed; -use tracing::{debug, error}; - -#[derive(RustEmbed)] -#[folder = "assets/"] -struct Asset; - -pub async fn serve_ui() -> impl IntoResponse { - debug!("Serving UI request"); - - // Attempt to retrieve the embedded "index.html" - match Asset::get("index.html") { - Some(content) => { - debug!("Successfully retrieved index.html"); - Response::builder() - .status(StatusCode::OK) - .header(CONTENT_TYPE, "text/html") - .body(Body::from(content.data)) - .unwrap() - } - None => { - error!("index.html not found in embedded assets"); - Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::from("404 Not Found")) - .unwrap() - } - } -} \ No newline at end of file diff --git a/src/routes.rs b/src/routes.rs index 67b7cd4..a68dab4 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,5 +1,5 @@ use axum::response::Response; -use crate::handlers::{not_found::handle_not_found, ui::serve_ui}; +use crate::handlers::{not_found::handle_not_found}; use axum::routing::{get, Router}; use http::StatusCode; use tower_http::trace::{self, TraceLayer};