From 1df24a7d3b64322006409ea14200ed11d0e0332e Mon Sep 17 00:00:00 2001
From: geoffsee <>
Date: Thu, 5 Jun 2025 21:09:49 -0400
Subject: [PATCH] add test ui for the local inference server
---
local_inference_engine/api_test.html | 295 ++++++++++++++++++++++
local_inference_engine/openai_api_test.js | 176 +++++++++++++
local_inference_engine/src/main.rs | 84 +++++-
src/handlers/ui.rs | 34 ---
src/routes.rs | 2 +-
5 files changed, 546 insertions(+), 45 deletions(-)
create mode 100644 local_inference_engine/api_test.html
create mode 100644 local_inference_engine/openai_api_test.js
diff --git a/local_inference_engine/api_test.html b/local_inference_engine/api_test.html
new file mode 100644
index 0000000..2f50e95
--- /dev/null
+++ b/local_inference_engine/api_test.html
@@ -0,0 +1,295 @@
+
+
+
+
+
+ OpenAI-Compatible API Tester
+
+
+
+ OpenAI-Compatible API Tester
+ Use this page to test the OpenAI-compatible chat completions endpoint of the local inference engine.
+
+
+
Request Settings
+
+
+
Request Body
+
+
+
+
+
Example Requests
+
+
+
+
+
+
+
+
Response
+
Response will appear here...
+
+
+
+
+
+
\ No newline at end of file
diff --git a/local_inference_engine/openai_api_test.js b/local_inference_engine/openai_api_test.js
new file mode 100644
index 0000000..465e312
--- /dev/null
+++ b/local_inference_engine/openai_api_test.js
@@ -0,0 +1,176 @@
+// Test requests for the OpenAI-compatible endpoint in the inference server
+// This file contains IIFE (Immediately Invoked Function Expression) JavaScript requests
+// to test the /v1/chat/completions endpoint
+
+// Basic chat completion request
+(async function testBasicChatCompletion() {
+ console.log("Test 1: Basic chat completion request");
+ try {
+ const response = await fetch('http://localhost:3000/v1/chat/completions', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ model: "gemma-2-2b-it",
+ messages: [
+ {
+ role: "user",
+ content: "Who was the 16th president of the United States?"
+ }
+ ],
+ max_tokens: 100
+ })
+ });
+
+ const data = await response.json();
+ console.log("Response:", JSON.stringify(data, null, 2));
+ } catch (error) {
+ console.error("Error:", error);
+ }
+})();
+
+// Multi-turn conversation
+(async function testMultiTurnConversation() {
+ console.log("\nTest 2: Multi-turn conversation");
+ try {
+ const response = await fetch('http://localhost:3000/v1/chat/completions', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ model: "gemma-2-2b-it",
+ messages: [
+ {
+ role: "system",
+ content: "You are a helpful assistant that provides concise answers."
+ },
+ {
+ role: "user",
+ content: "What is machine learning?"
+ },
+ {
+ role: "assistant",
+ content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
+ },
+ {
+ role: "user",
+ content: "Give me an example of a machine learning algorithm."
+ }
+ ],
+ max_tokens: 150
+ })
+ });
+
+ const data = await response.json();
+ console.log("Response:", JSON.stringify(data, null, 2));
+ } catch (error) {
+ console.error("Error:", error);
+ }
+})();
+
+// Request with temperature and top_p parameters
+(async function testTemperatureAndTopP() {
+ console.log("\nTest 3: Request with temperature and top_p parameters");
+ try {
+ const response = await fetch('http://localhost:3000/v1/chat/completions', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ model: "gemma-2-2b-it",
+ messages: [
+ {
+ role: "user",
+ content: "Write a short poem about artificial intelligence."
+ }
+ ],
+ max_tokens: 200,
+ temperature: 0.8,
+ top_p: 0.9
+ })
+ });
+
+ const data = await response.json();
+ console.log("Response:", JSON.stringify(data, null, 2));
+ } catch (error) {
+ console.error("Error:", error);
+ }
+})();
+
+// Request with streaming enabled
+(async function testStreaming() {
+ console.log("\nTest 4: Request with streaming enabled");
+ try {
+ const response = await fetch('http://localhost:3000/v1/chat/completions', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ model: "gemma-2-2b-it",
+ messages: [
+ {
+ role: "user",
+ content: "Explain quantum computing in simple terms."
+ }
+ ],
+ max_tokens: 150,
+ stream: true
+ })
+ });
+
+ // Note: Streaming might not be implemented yet, this is to test the API's handling of the parameter
+ if (response.headers.get('content-type')?.includes('text/event-stream')) {
+ console.log("Streaming response detected. Reading stream...");
+ const reader = response.body.getReader();
+ const decoder = new TextDecoder();
+
+ while (true) {
+ const { done, value } = await reader.read();
+ if (done) break;
+
+ const chunk = decoder.decode(value);
+ console.log("Chunk:", chunk);
+ }
+ } else {
+ const data = await response.json();
+ console.log("Non-streaming response:", JSON.stringify(data, null, 2));
+ }
+ } catch (error) {
+ console.error("Error:", error);
+ }
+})();
+
+// Request with a different model
+(async function testDifferentModel() {
+ console.log("\nTest 5: Request with a different model");
+ try {
+ const response = await fetch('http://localhost:3000/v1/chat/completions', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ model: "gemma-2-2b-it", // Using a different model if available
+ messages: [
+ {
+ role: "user",
+ content: "What are the benefits of renewable energy?"
+ }
+ ],
+ max_tokens: 150
+ })
+ });
+
+ const data = await response.json();
+ console.log("Response:", JSON.stringify(data, null, 2));
+ } catch (error) {
+ console.error("Error:", error);
+ }
+})();
+
+console.log("\nAll test requests have been sent. Check the server logs for more details.");
+console.log("To run the server, use: cargo run --bin local_inference_engine -- --server");
diff --git a/local_inference_engine/src/main.rs b/local_inference_engine/src/main.rs
index c6ba3e5..f29ff6d 100644
--- a/local_inference_engine/src/main.rs
+++ b/local_inference_engine/src/main.rs
@@ -242,16 +242,16 @@ async fn chat_completions(
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
- return Err((
- StatusCode::INTERNAL_SERVER_ERROR,
- Json(serde_json::json!({
- "error": {
- "message": format!("Error generating text: {}", e),
- "type": "internal_server_error"
- }
- })),
- ));
- }
+ return Err((
+ StatusCode::BAD_REQUEST,
+ Json(serde_json::json!({
+ "error": {
+ "message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"",
+ "type": "unsupported_api"
+ }
+ })),
+ ));
+}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
@@ -520,6 +520,70 @@ impl TextGeneration {
}
};
+ // Determine if we're using a Model3 (gemma-3) variant
+ let is_model3 = match &self.model {
+ Model::V3(_) => true,
+ _ => false,
+ };
+
+ // For Model3, we need to use a different approach
+ if is_model3 {
+ // For gemma-3 models, we'll generate one token at a time with the full context
+ let start_gen = std::time::Instant::now();
+
+ // Initial generation with the full prompt
+ let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
+ let mut logits = self.model.forward(&input, 0)?;
+ logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
+
+ for _ in 0..sample_len {
+ // Apply repeat penalty if needed
+ let current_logits = if self.repeat_penalty == 1. {
+ logits.clone()
+ } else {
+ let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+
+ // Manual implementation of repeat penalty to avoid type conflicts
+ let mut logits_vec = logits.to_vec1::()?;
+
+ for &token_id in &tokens[start_at..] {
+ let token_id = token_id as usize;
+ if token_id < logits_vec.len() {
+ let score = logits_vec[token_id];
+ let sign = if score < 0.0 { -1.0 } else { 1.0 };
+ logits_vec[token_id] = sign * score / self.repeat_penalty;
+ }
+ }
+
+ // Create a new tensor with the modified logits
+ let device = logits.device().clone();
+ let shape = logits.shape().clone();
+ let new_logits = Tensor::new(&logits_vec[..], &device)?;
+ new_logits.reshape(shape)?
+ };
+
+ let next_token = self.logits_processor.sample(¤t_logits)?;
+ tokens.push(next_token);
+ generated_tokens += 1;
+
+ if next_token == eos_token || next_token == eot_token {
+ break;
+ }
+
+ if let Some(t) = self.tokenizer.next_token(next_token)? {
+ write!(output, "{}", t)?;
+ }
+
+ // For the next iteration, just use the new token
+ let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
+ logits = self.model.forward(&new_input, tokens.len() - 1)?;
+ logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
+ }
+
+ return Ok(());
+ }
+
+ // Standard approach for other models
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
diff --git a/src/handlers/ui.rs b/src/handlers/ui.rs
index bcf5967..e69de29 100644
--- a/src/handlers/ui.rs
+++ b/src/handlers/ui.rs
@@ -1,34 +0,0 @@
-use axum::{
- body::Body,
- http::{StatusCode, header::CONTENT_TYPE},
- response::{IntoResponse, Response},
-};
-use rust_embed::RustEmbed;
-use tracing::{debug, error};
-
-#[derive(RustEmbed)]
-#[folder = "assets/"]
-struct Asset;
-
-pub async fn serve_ui() -> impl IntoResponse {
- debug!("Serving UI request");
-
- // Attempt to retrieve the embedded "index.html"
- match Asset::get("index.html") {
- Some(content) => {
- debug!("Successfully retrieved index.html");
- Response::builder()
- .status(StatusCode::OK)
- .header(CONTENT_TYPE, "text/html")
- .body(Body::from(content.data))
- .unwrap()
- }
- None => {
- error!("index.html not found in embedded assets");
- Response::builder()
- .status(StatusCode::NOT_FOUND)
- .body(Body::from("404 Not Found"))
- .unwrap()
- }
- }
-}
\ No newline at end of file
diff --git a/src/routes.rs b/src/routes.rs
index 67b7cd4..a68dab4 100644
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -1,5 +1,5 @@
use axum::response::Response;
-use crate::handlers::{not_found::handle_not_found, ui::serve_ui};
+use crate::handlers::{not_found::handle_not_found};
use axum::routing::{get, Router};
use http::StatusCode;
use tower_http::trace::{self, TraceLayer};