add test ui for the local inference server

This commit is contained in:
geoffsee
2025-06-05 21:09:49 -04:00
committed by Geoff Seemueller
parent 1704d5cd47
commit 1df24a7d3b
5 changed files with 546 additions and 45 deletions

View File

@@ -0,0 +1,295 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>OpenAI-Compatible API Tester</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
line-height: 1.6;
}
h1, h2 {
color: #333;
}
.container {
margin-bottom: 20px;
}
textarea {
width: 100%;
height: 150px;
padding: 10px;
margin-bottom: 10px;
border: 1px solid #ddd;
border-radius: 4px;
font-family: monospace;
}
button {
background-color: #4CAF50;
color: white;
padding: 10px 15px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 16px;
}
button:hover {
background-color: #45a049;
}
pre {
background-color: #f5f5f5;
padding: 15px;
border-radius: 4px;
overflow-x: auto;
white-space: pre-wrap;
}
.response {
margin-top: 20px;
}
.error {
color: red;
}
.settings {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-bottom: 15px;
}
.settings div {
display: flex;
flex-direction: column;
}
label {
margin-bottom: 5px;
font-weight: bold;
}
input {
padding: 8px;
border: 1px solid #ddd;
border-radius: 4px;
}
.examples {
margin-top: 30px;
}
.example-btn {
background-color: #2196F3;
margin-right: 10px;
margin-bottom: 10px;
}
.example-btn:hover {
background-color: #0b7dda;
}
</style>
</head>
<body>
<h1>OpenAI-Compatible API Tester</h1>
<p>Use this page to test the OpenAI-compatible chat completions endpoint of the local inference engine.</p>
<div class="container">
<h2>Request Settings</h2>
<div class="settings">
<div>
<label for="serverUrl">Server URL:</label>
<input type="text" id="serverUrl" value="http://localhost:3000" />
</div>
<div>
<label for="model">Model:</label>
<input type="text" id="model" value="gemma-3-1b-it" />
</div>
<div>
<label for="maxTokens">Max Tokens:</label>
<input type="number" id="maxTokens" value="150" />
</div>
<div>
<label for="temperature">Temperature:</label>
<input type="number" id="temperature" value="0.7" step="0.1" min="0" max="2" />
</div>
<div>
<label for="topP">Top P:</label>
<input type="number" id="topP" value="0.9" step="0.1" min="0" max="1" />
</div>
</div>
<h2>Request Body</h2>
<textarea id="requestBody">{
"model": "gemma-3-1b-it",
"messages": [
{
"role": "user",
"content": "Hello, how are you today?"
}
],
"max_tokens": 150,
"temperature": 0.7,
"top_p": 0.9
}</textarea>
<button id="sendRequest">Send Request</button>
<div class="examples">
<h3>Example Requests</h3>
<button class="example-btn" id="example1">Basic Question</button>
<button class="example-btn" id="example2">Multi-turn Conversation</button>
<button class="example-btn" id="example3">Creative Writing</button>
<button class="example-btn" id="example4">Code Generation</button>
</div>
<div class="response">
<h2>Response</h2>
<pre id="responseOutput">Response will appear here...</pre>
</div>
</div>
<script>
document.addEventListener('DOMContentLoaded', function() {
// Update request body when settings change
const serverUrlInput = document.getElementById('serverUrl');
const modelInput = document.getElementById('model');
const maxTokensInput = document.getElementById('maxTokens');
const temperatureInput = document.getElementById('temperature');
const topPInput = document.getElementById('topP');
const requestBodyTextarea = document.getElementById('requestBody');
const responseOutput = document.getElementById('responseOutput');
// Function to update request body from settings
function updateRequestBodyFromSettings() {
try {
const requestBody = JSON.parse(requestBodyTextarea.value);
requestBody.model = modelInput.value;
requestBody.max_tokens = parseInt(maxTokensInput.value);
requestBody.temperature = parseFloat(temperatureInput.value);
requestBody.top_p = parseFloat(topPInput.value);
requestBodyTextarea.value = JSON.stringify(requestBody, null, 2);
} catch (error) {
console.error("Error updating request body:", error);
}
}
// Update settings when request body changes
function updateSettingsFromRequestBody() {
try {
const requestBody = JSON.parse(requestBodyTextarea.value);
if (requestBody.model) modelInput.value = requestBody.model;
if (requestBody.max_tokens) maxTokensInput.value = requestBody.max_tokens;
if (requestBody.temperature) temperatureInput.value = requestBody.temperature;
if (requestBody.top_p) topPInput.value = requestBody.top_p;
} catch (error) {
console.error("Error updating settings:", error);
}
}
// Add event listeners for settings changes
modelInput.addEventListener('change', updateRequestBodyFromSettings);
maxTokensInput.addEventListener('change', updateRequestBodyFromSettings);
temperatureInput.addEventListener('change', updateRequestBodyFromSettings);
topPInput.addEventListener('change', updateRequestBodyFromSettings);
// Add event listener for request body changes
requestBodyTextarea.addEventListener('blur', updateSettingsFromRequestBody);
// Send request button
document.getElementById('sendRequest').addEventListener('click', async function() {
try {
responseOutput.textContent = "Sending request...";
const serverUrl = serverUrlInput.value;
const endpoint = '/v1/chat/completions';
const url = serverUrl + endpoint;
const requestBody = JSON.parse(requestBodyTextarea.value);
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestBody)
});
const data = await response.json();
responseOutput.textContent = JSON.stringify(data, null, 2);
} catch (error) {
responseOutput.textContent = "Error: " + error.message;
responseOutput.classList.add('error');
}
});
// Example requests
document.getElementById('example1').addEventListener('click', function() {
requestBodyTextarea.value = JSON.stringify({
model: modelInput.value,
messages: [
{
role: "user",
content: "Who was the 16th president of the United States?"
}
],
max_tokens: parseInt(maxTokensInput.value),
temperature: parseFloat(temperatureInput.value),
top_p: parseFloat(topPInput.value)
}, null, 2);
});
document.getElementById('example2').addEventListener('click', function() {
requestBodyTextarea.value = JSON.stringify({
model: modelInput.value,
messages: [
{
role: "system",
content: "You are a helpful assistant that provides concise answers."
},
{
role: "user",
content: "What is machine learning?"
},
{
role: "assistant",
content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
},
{
role: "user",
content: "Give me an example of a machine learning algorithm."
}
],
max_tokens: parseInt(maxTokensInput.value),
temperature: parseFloat(temperatureInput.value),
top_p: parseFloat(topPInput.value)
}, null, 2);
});
document.getElementById('example3').addEventListener('click', function() {
requestBodyTextarea.value = JSON.stringify({
model: modelInput.value,
messages: [
{
role: "user",
content: "Write a short poem about artificial intelligence."
}
],
max_tokens: parseInt(maxTokensInput.value),
temperature: 0.9, // Higher temperature for creative tasks
top_p: 0.9
}, null, 2);
temperatureInput.value = 0.9;
});
document.getElementById('example4').addEventListener('click', function() {
requestBodyTextarea.value = JSON.stringify({
model: modelInput.value,
messages: [
{
role: "user",
content: "Write a Python function to calculate the Fibonacci sequence up to n terms."
}
],
max_tokens: parseInt(maxTokensInput.value),
temperature: 0.3, // Lower temperature for code generation
top_p: 0.9
}, null, 2);
temperatureInput.value = 0.3;
});
});
</script>
</body>
</html>

View File

@@ -0,0 +1,176 @@
// Test requests for the OpenAI-compatible endpoint in the inference server
// This file contains IIFE (Immediately Invoked Function Expression) JavaScript requests
// to test the /v1/chat/completions endpoint
// Basic chat completion request
(async function testBasicChatCompletion() {
console.log("Test 1: Basic chat completion request");
try {
const response = await fetch('http://localhost:3000/v1/chat/completions', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: "gemma-2-2b-it",
messages: [
{
role: "user",
content: "Who was the 16th president of the United States?"
}
],
max_tokens: 100
})
});
const data = await response.json();
console.log("Response:", JSON.stringify(data, null, 2));
} catch (error) {
console.error("Error:", error);
}
})();
// Multi-turn conversation
(async function testMultiTurnConversation() {
console.log("\nTest 2: Multi-turn conversation");
try {
const response = await fetch('http://localhost:3000/v1/chat/completions', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: "gemma-2-2b-it",
messages: [
{
role: "system",
content: "You are a helpful assistant that provides concise answers."
},
{
role: "user",
content: "What is machine learning?"
},
{
role: "assistant",
content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
},
{
role: "user",
content: "Give me an example of a machine learning algorithm."
}
],
max_tokens: 150
})
});
const data = await response.json();
console.log("Response:", JSON.stringify(data, null, 2));
} catch (error) {
console.error("Error:", error);
}
})();
// Request with temperature and top_p parameters
(async function testTemperatureAndTopP() {
console.log("\nTest 3: Request with temperature and top_p parameters");
try {
const response = await fetch('http://localhost:3000/v1/chat/completions', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: "gemma-2-2b-it",
messages: [
{
role: "user",
content: "Write a short poem about artificial intelligence."
}
],
max_tokens: 200,
temperature: 0.8,
top_p: 0.9
})
});
const data = await response.json();
console.log("Response:", JSON.stringify(data, null, 2));
} catch (error) {
console.error("Error:", error);
}
})();
// Request with streaming enabled
(async function testStreaming() {
console.log("\nTest 4: Request with streaming enabled");
try {
const response = await fetch('http://localhost:3000/v1/chat/completions', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: "gemma-2-2b-it",
messages: [
{
role: "user",
content: "Explain quantum computing in simple terms."
}
],
max_tokens: 150,
stream: true
})
});
// Note: Streaming might not be implemented yet, this is to test the API's handling of the parameter
if (response.headers.get('content-type')?.includes('text/event-stream')) {
console.log("Streaming response detected. Reading stream...");
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
console.log("Chunk:", chunk);
}
} else {
const data = await response.json();
console.log("Non-streaming response:", JSON.stringify(data, null, 2));
}
} catch (error) {
console.error("Error:", error);
}
})();
// Request with a different model
(async function testDifferentModel() {
console.log("\nTest 5: Request with a different model");
try {
const response = await fetch('http://localhost:3000/v1/chat/completions', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: "gemma-2-2b-it", // Using a different model if available
messages: [
{
role: "user",
content: "What are the benefits of renewable energy?"
}
],
max_tokens: 150
})
});
const data = await response.json();
console.log("Response:", JSON.stringify(data, null, 2));
} catch (error) {
console.error("Error:", error);
}
})();
console.log("\nAll test requests have been sent. Check the server logs for more details.");
console.log("To run the server, use: cargo run --bin local_inference_engine -- --server");

View File

@@ -242,16 +242,16 @@ async fn chat_completions(
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": {
"message": format!("Error generating text: {}", e),
"type": "internal_server_error"
}
})),
));
}
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"",
"type": "unsupported_api"
}
})),
));
}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
@@ -520,6 +520,70 @@ impl TextGeneration {
}
};
// Determine if we're using a Model3 (gemma-3) variant
let is_model3 = match &self.model {
Model::V3(_) => true,
_ => false,
};
// For Model3, we need to use a different approach
if is_model3 {
// For gemma-3 models, we'll generate one token at a time with the full context
let start_gen = std::time::Instant::now();
// Initial generation with the full prompt
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let mut logits = self.model.forward(&input, 0)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
for _ in 0..sample_len {
// Apply repeat penalty if needed
let current_logits = if self.repeat_penalty == 1. {
logits.clone()
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&current_logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
// For the next iteration, just use the new token
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
logits = self.model.forward(&new_input, tokens.len() - 1)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
}
return Ok(());
}
// Standard approach for other models
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };

View File

@@ -1,34 +0,0 @@
use axum::{
body::Body,
http::{StatusCode, header::CONTENT_TYPE},
response::{IntoResponse, Response},
};
use rust_embed::RustEmbed;
use tracing::{debug, error};
#[derive(RustEmbed)]
#[folder = "assets/"]
struct Asset;
pub async fn serve_ui() -> impl IntoResponse {
debug!("Serving UI request");
// Attempt to retrieve the embedded "index.html"
match Asset::get("index.html") {
Some(content) => {
debug!("Successfully retrieved index.html");
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "text/html")
.body(Body::from(content.data))
.unwrap()
}
None => {
error!("index.html not found in embedded assets");
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("404 Not Found"))
.unwrap()
}
}
}

View File

@@ -1,5 +1,5 @@
use axum::response::Response;
use crate::handlers::{not_found::handle_not_found, ui::serve_ui};
use crate::handlers::{not_found::handle_not_found};
use axum::routing::{get, Router};
use http::StatusCode;
use tower_http::trace::{self, TraceLayer};