add test ui for the local inference server
This commit is contained in:

committed by
Geoff Seemueller

parent
1704d5cd47
commit
1df24a7d3b
295
local_inference_engine/api_test.html
Normal file
295
local_inference_engine/api_test.html
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>OpenAI-Compatible API Tester</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 20px;
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
h1, h2 {
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
textarea {
|
||||||
|
width: 100%;
|
||||||
|
height: 150px;
|
||||||
|
padding: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-family: monospace;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
background-color: #4CAF50;
|
||||||
|
color: white;
|
||||||
|
padding: 10px 15px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 16px;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
background-color: #45a049;
|
||||||
|
}
|
||||||
|
pre {
|
||||||
|
background-color: #f5f5f5;
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow-x: auto;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
.response {
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
.error {
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
.settings {
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 10px;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
.settings div {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
label {
|
||||||
|
margin-bottom: 5px;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
padding: 8px;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
.examples {
|
||||||
|
margin-top: 30px;
|
||||||
|
}
|
||||||
|
.example-btn {
|
||||||
|
background-color: #2196F3;
|
||||||
|
margin-right: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
.example-btn:hover {
|
||||||
|
background-color: #0b7dda;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>OpenAI-Compatible API Tester</h1>
|
||||||
|
<p>Use this page to test the OpenAI-compatible chat completions endpoint of the local inference engine.</p>
|
||||||
|
|
||||||
|
<div class="container">
|
||||||
|
<h2>Request Settings</h2>
|
||||||
|
<div class="settings">
|
||||||
|
<div>
|
||||||
|
<label for="serverUrl">Server URL:</label>
|
||||||
|
<input type="text" id="serverUrl" value="http://localhost:3000" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="model">Model:</label>
|
||||||
|
<input type="text" id="model" value="gemma-3-1b-it" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="maxTokens">Max Tokens:</label>
|
||||||
|
<input type="number" id="maxTokens" value="150" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="temperature">Temperature:</label>
|
||||||
|
<input type="number" id="temperature" value="0.7" step="0.1" min="0" max="2" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="topP">Top P:</label>
|
||||||
|
<input type="number" id="topP" value="0.9" step="0.1" min="0" max="1" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h2>Request Body</h2>
|
||||||
|
<textarea id="requestBody">{
|
||||||
|
"model": "gemma-3-1b-it",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, how are you today?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 150,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9
|
||||||
|
}</textarea>
|
||||||
|
<button id="sendRequest">Send Request</button>
|
||||||
|
|
||||||
|
<div class="examples">
|
||||||
|
<h3>Example Requests</h3>
|
||||||
|
<button class="example-btn" id="example1">Basic Question</button>
|
||||||
|
<button class="example-btn" id="example2">Multi-turn Conversation</button>
|
||||||
|
<button class="example-btn" id="example3">Creative Writing</button>
|
||||||
|
<button class="example-btn" id="example4">Code Generation</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="response">
|
||||||
|
<h2>Response</h2>
|
||||||
|
<pre id="responseOutput">Response will appear here...</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
|
// Update request body when settings change
|
||||||
|
const serverUrlInput = document.getElementById('serverUrl');
|
||||||
|
const modelInput = document.getElementById('model');
|
||||||
|
const maxTokensInput = document.getElementById('maxTokens');
|
||||||
|
const temperatureInput = document.getElementById('temperature');
|
||||||
|
const topPInput = document.getElementById('topP');
|
||||||
|
const requestBodyTextarea = document.getElementById('requestBody');
|
||||||
|
const responseOutput = document.getElementById('responseOutput');
|
||||||
|
|
||||||
|
// Function to update request body from settings
|
||||||
|
function updateRequestBodyFromSettings() {
|
||||||
|
try {
|
||||||
|
const requestBody = JSON.parse(requestBodyTextarea.value);
|
||||||
|
requestBody.model = modelInput.value;
|
||||||
|
requestBody.max_tokens = parseInt(maxTokensInput.value);
|
||||||
|
requestBody.temperature = parseFloat(temperatureInput.value);
|
||||||
|
requestBody.top_p = parseFloat(topPInput.value);
|
||||||
|
requestBodyTextarea.value = JSON.stringify(requestBody, null, 2);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error updating request body:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update settings when request body changes
|
||||||
|
function updateSettingsFromRequestBody() {
|
||||||
|
try {
|
||||||
|
const requestBody = JSON.parse(requestBodyTextarea.value);
|
||||||
|
if (requestBody.model) modelInput.value = requestBody.model;
|
||||||
|
if (requestBody.max_tokens) maxTokensInput.value = requestBody.max_tokens;
|
||||||
|
if (requestBody.temperature) temperatureInput.value = requestBody.temperature;
|
||||||
|
if (requestBody.top_p) topPInput.value = requestBody.top_p;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error updating settings:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add event listeners for settings changes
|
||||||
|
modelInput.addEventListener('change', updateRequestBodyFromSettings);
|
||||||
|
maxTokensInput.addEventListener('change', updateRequestBodyFromSettings);
|
||||||
|
temperatureInput.addEventListener('change', updateRequestBodyFromSettings);
|
||||||
|
topPInput.addEventListener('change', updateRequestBodyFromSettings);
|
||||||
|
|
||||||
|
// Add event listener for request body changes
|
||||||
|
requestBodyTextarea.addEventListener('blur', updateSettingsFromRequestBody);
|
||||||
|
|
||||||
|
// Send request button
|
||||||
|
document.getElementById('sendRequest').addEventListener('click', async function() {
|
||||||
|
try {
|
||||||
|
responseOutput.textContent = "Sending request...";
|
||||||
|
const serverUrl = serverUrlInput.value;
|
||||||
|
const endpoint = '/v1/chat/completions';
|
||||||
|
const url = serverUrl + endpoint;
|
||||||
|
|
||||||
|
const requestBody = JSON.parse(requestBodyTextarea.value);
|
||||||
|
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify(requestBody)
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
responseOutput.textContent = JSON.stringify(data, null, 2);
|
||||||
|
} catch (error) {
|
||||||
|
responseOutput.textContent = "Error: " + error.message;
|
||||||
|
responseOutput.classList.add('error');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Example requests
|
||||||
|
document.getElementById('example1').addEventListener('click', function() {
|
||||||
|
requestBodyTextarea.value = JSON.stringify({
|
||||||
|
model: modelInput.value,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "Who was the 16th president of the United States?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: parseInt(maxTokensInput.value),
|
||||||
|
temperature: parseFloat(temperatureInput.value),
|
||||||
|
top_p: parseFloat(topPInput.value)
|
||||||
|
}, null, 2);
|
||||||
|
});
|
||||||
|
|
||||||
|
document.getElementById('example2').addEventListener('click', function() {
|
||||||
|
requestBodyTextarea.value = JSON.stringify({
|
||||||
|
model: modelInput.value,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "You are a helpful assistant that provides concise answers."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "What is machine learning?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "Give me an example of a machine learning algorithm."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: parseInt(maxTokensInput.value),
|
||||||
|
temperature: parseFloat(temperatureInput.value),
|
||||||
|
top_p: parseFloat(topPInput.value)
|
||||||
|
}, null, 2);
|
||||||
|
});
|
||||||
|
|
||||||
|
document.getElementById('example3').addEventListener('click', function() {
|
||||||
|
requestBodyTextarea.value = JSON.stringify({
|
||||||
|
model: modelInput.value,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "Write a short poem about artificial intelligence."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: parseInt(maxTokensInput.value),
|
||||||
|
temperature: 0.9, // Higher temperature for creative tasks
|
||||||
|
top_p: 0.9
|
||||||
|
}, null, 2);
|
||||||
|
temperatureInput.value = 0.9;
|
||||||
|
});
|
||||||
|
|
||||||
|
document.getElementById('example4').addEventListener('click', function() {
|
||||||
|
requestBodyTextarea.value = JSON.stringify({
|
||||||
|
model: modelInput.value,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "Write a Python function to calculate the Fibonacci sequence up to n terms."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: parseInt(maxTokensInput.value),
|
||||||
|
temperature: 0.3, // Lower temperature for code generation
|
||||||
|
top_p: 0.9
|
||||||
|
}, null, 2);
|
||||||
|
temperatureInput.value = 0.3;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
176
local_inference_engine/openai_api_test.js
Normal file
176
local_inference_engine/openai_api_test.js
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
// Test requests for the OpenAI-compatible endpoint in the inference server
|
||||||
|
// This file contains IIFE (Immediately Invoked Function Expression) JavaScript requests
|
||||||
|
// to test the /v1/chat/completions endpoint
|
||||||
|
|
||||||
|
// Basic chat completion request
|
||||||
|
(async function testBasicChatCompletion() {
|
||||||
|
console.log("Test 1: Basic chat completion request");
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: "gemma-2-2b-it",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "Who was the 16th president of the United States?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: 100
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log("Response:", JSON.stringify(data, null, 2));
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error:", error);
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
// Multi-turn conversation
|
||||||
|
(async function testMultiTurnConversation() {
|
||||||
|
console.log("\nTest 2: Multi-turn conversation");
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: "gemma-2-2b-it",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "You are a helpful assistant that provides concise answers."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "What is machine learning?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "Give me an example of a machine learning algorithm."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: 150
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log("Response:", JSON.stringify(data, null, 2));
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error:", error);
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
// Request with temperature and top_p parameters
|
||||||
|
(async function testTemperatureAndTopP() {
|
||||||
|
console.log("\nTest 3: Request with temperature and top_p parameters");
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: "gemma-2-2b-it",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "Write a short poem about artificial intelligence."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: 200,
|
||||||
|
temperature: 0.8,
|
||||||
|
top_p: 0.9
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log("Response:", JSON.stringify(data, null, 2));
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error:", error);
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
// Request with streaming enabled
|
||||||
|
(async function testStreaming() {
|
||||||
|
console.log("\nTest 4: Request with streaming enabled");
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: "gemma-2-2b-it",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "Explain quantum computing in simple terms."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: 150,
|
||||||
|
stream: true
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
// Note: Streaming might not be implemented yet, this is to test the API's handling of the parameter
|
||||||
|
if (response.headers.get('content-type')?.includes('text/event-stream')) {
|
||||||
|
console.log("Streaming response detected. Reading stream...");
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
const chunk = decoder.decode(value);
|
||||||
|
console.log("Chunk:", chunk);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const data = await response.json();
|
||||||
|
console.log("Non-streaming response:", JSON.stringify(data, null, 2));
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error:", error);
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
// Request with a different model
|
||||||
|
(async function testDifferentModel() {
|
||||||
|
console.log("\nTest 5: Request with a different model");
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: "gemma-2-2b-it", // Using a different model if available
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: "What are the benefits of renewable energy?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens: 150
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log("Response:", JSON.stringify(data, null, 2));
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error:", error);
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
console.log("\nAll test requests have been sent. Check the server logs for more details.");
|
||||||
|
console.log("To run the server, use: cargo run --bin local_inference_engine -- --server");
|
@@ -243,11 +243,11 @@ async fn chat_completions(
|
|||||||
|
|
||||||
if let Err(e) = result {
|
if let Err(e) = result {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": {
|
"error": {
|
||||||
"message": format!("Error generating text: {}", e),
|
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"",
|
||||||
"type": "internal_server_error"
|
"type": "unsupported_api"
|
||||||
}
|
}
|
||||||
})),
|
})),
|
||||||
));
|
));
|
||||||
@@ -520,6 +520,70 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Determine if we're using a Model3 (gemma-3) variant
|
||||||
|
let is_model3 = match &self.model {
|
||||||
|
Model::V3(_) => true,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
|
// For Model3, we need to use a different approach
|
||||||
|
if is_model3 {
|
||||||
|
// For gemma-3 models, we'll generate one token at a time with the full context
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
|
||||||
|
// Initial generation with the full prompt
|
||||||
|
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||||
|
let mut logits = self.model.forward(&input, 0)?;
|
||||||
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
for _ in 0..sample_len {
|
||||||
|
// Apply repeat penalty if needed
|
||||||
|
let current_logits = if self.repeat_penalty == 1. {
|
||||||
|
logits.clone()
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
|
||||||
|
// Manual implementation of repeat penalty to avoid type conflicts
|
||||||
|
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
for &token_id in &tokens[start_at..] {
|
||||||
|
let token_id = token_id as usize;
|
||||||
|
if token_id < logits_vec.len() {
|
||||||
|
let score = logits_vec[token_id];
|
||||||
|
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||||
|
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new tensor with the modified logits
|
||||||
|
let device = logits.device().clone();
|
||||||
|
let shape = logits.shape().clone();
|
||||||
|
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||||
|
new_logits.reshape(shape)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
|
||||||
|
if next_token == eos_token || next_token == eot_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
write!(output, "{}", t)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the next iteration, just use the new token
|
||||||
|
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
||||||
|
logits = self.model.forward(&new_input, tokens.len() - 1)?;
|
||||||
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard approach for other models
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
@@ -1,34 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
body::Body,
|
|
||||||
http::{StatusCode, header::CONTENT_TYPE},
|
|
||||||
response::{IntoResponse, Response},
|
|
||||||
};
|
|
||||||
use rust_embed::RustEmbed;
|
|
||||||
use tracing::{debug, error};
|
|
||||||
|
|
||||||
#[derive(RustEmbed)]
|
|
||||||
#[folder = "assets/"]
|
|
||||||
struct Asset;
|
|
||||||
|
|
||||||
pub async fn serve_ui() -> impl IntoResponse {
|
|
||||||
debug!("Serving UI request");
|
|
||||||
|
|
||||||
// Attempt to retrieve the embedded "index.html"
|
|
||||||
match Asset::get("index.html") {
|
|
||||||
Some(content) => {
|
|
||||||
debug!("Successfully retrieved index.html");
|
|
||||||
Response::builder()
|
|
||||||
.status(StatusCode::OK)
|
|
||||||
.header(CONTENT_TYPE, "text/html")
|
|
||||||
.body(Body::from(content.data))
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
error!("index.html not found in embedded assets");
|
|
||||||
Response::builder()
|
|
||||||
.status(StatusCode::NOT_FOUND)
|
|
||||||
.body(Body::from("404 Not Found"))
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,5 +1,5 @@
|
|||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use crate::handlers::{not_found::handle_not_found, ui::serve_ui};
|
use crate::handlers::{not_found::handle_not_found};
|
||||||
use axum::routing::{get, Router};
|
use axum::routing::{get, Router};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use tower_http::trace::{self, TraceLayer};
|
use tower_http::trace::{self, TraceLayer};
|
||||||
|
Reference in New Issue
Block a user