Refactor apply_cached_repeat_penalty for optimized caching and reuse, add extensive unit tests, and integrate special handling for gemma-specific models.

Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions.

- Add CPU fallback support for text generation when primary device is unsupported
- Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors
- Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations
- Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing

chat completion endpoint functions with gemma3 (no streaming)

Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
This commit is contained in:
geoffsee
2025-08-26 01:30:26 -04:00
parent 7dd23213c9
commit 8338750beb
64 changed files with 14997 additions and 220 deletions

6115
crates/legacy-inference-engine/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,77 @@
[package]
name = "legacy-inference-engine"
version = "0.1.0"
edition = "2021"
[dependencies]
accelerate-src = { version = "0.3.2", optional = true }
candle-datasets = { version = "=0.9.1", optional = true }
candle-nn = { version = "=0.9.1" }
candle-transformers = { version = "=0.9.1" }
candle-flash-attn = { version = "=0.9.1", optional = true }
candle-onnx = { version = "=0.9.1", optional = true }
csv = "1.3.0"
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true }
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"], optional = true }
hf-hub = { version = "0.4.1", features = ["tokio"] }
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true }
num-traits = { version = "0.2.15" }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
rayon = "1.7.0"
rubato = { version = "0.15.0", optional = true }
safetensors = "0.4.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { version = "0.21.0", default-features = false, features = ["onig", "http"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}
anyhow = "1.0.98"
clap= { version = "4.2.4", features = ["derive"] }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
axum = { version = "0.7.4", features = ["json"] }
tower = "0.4.13"
tower-http = { version = "0.5.1", features = ["cors"] }
tokio = { version = "1.43.0", features = ["full"] }
either = { version = "1.9.0", features = ["serde"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5"
# --- Add this section for conditional compilation ---
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { version = "=0.9.1", features = ["metal"] }
metal = { version = "0.32.0", features = ["mps"] }
[target.'cfg(not(target_os = "macos"))'.dependencies]
# For Linux or other non-macOS systems, you likely want the CPU backend or CUDA
# If you're building on Linux with a CUDA-enabled GPU:
candle-core = { version = "=0.9.1", features = ["cuda"], default-features = false } # Or just "cuda" if not using default features
# If you're building on Linux with only CPU:
# candle-core = { version = "=0.9.1", default-features = false } # CPU is often the default, but good to be explicit
# --- End of conditional compilation section ---
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
byteorder = { version = "1.4.3" }
clap = { version = "4.2.4", features = ["derive"] }
imageproc = { version = "0.24.0", default-features = false }
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
rand = { version = "0.9.0" }
ab_glyph = { version = "0.2.23" }
tracing = { version = "0.1.37" }
tracing-chrome = { version = "0.7.1" }
tracing-subscriber = { version = "0.3.7" }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.43.0"
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
bindgen_cuda = { version = "0.1.1", optional = true }

View File

@@ -0,0 +1,210 @@
# @open-web-agent-rs/legacy-inference-engine
## Note
This is here as a reference implementation. This is harder than it looks.
A Rust-based inference engine for running large language models locally. This tool supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
## Features
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
- CLI mode for direct text generation
- Server mode with OpenAI-compatible API
- Support for various model configurations (base, instruction-tuned)
- Metal acceleration on macOS
## Installation
### Prerequisites
- Rust toolchain (install via [rustup](https://rustup.rs/))
- Cargo package manager
- For GPU acceleration:
- macOS: Metal support
- Linux/Windows: CUDA support (requires appropriate drivers)
### Building from Source
1. Clone the repository:
```bash
git clone https://github.com/seemueller-io/open-web-agent-rs.git
cd open-web-agent-rs
```
2. Build the local inference engine:
```bash
cargo build -p legacy-inference-engine --release
```
## Usage
### CLI Mode
Run the inference engine in CLI mode to generate text directly:
```bash
cargo run -p legacy-inference-engine --release -- --prompt 'Name the 16th President of the USA.' --which 3-1b-it
```
#### CLI Options
- `--prompt <TEXT>`: The prompt text to generate from
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
- `--server`: Run OpenAI compatible server
- Available options: "2b", "7b", "2b-it", "7b-it", "1.1-2b-it", "1.1-7b-it", "code-2b", "code-7b", "code-2b-it", "code-7b-it", "2-2b", "2-2b-it", "2-9b", "2-9b-it", "3-1b", "3-1b-it"
- `--temperature <FLOAT>`: Temperature for sampling (higher = more random)
- `--top-p <FLOAT>`: Nucleus sampling probability cutoff
- `--sample-len <INT>`: Maximum number of tokens to generate (default: 10000)
- `--repeat-penalty <FLOAT>`: Penalty for repeating tokens (default: 1.1)
- `--repeat-last-n <INT>`: Context size for repeat penalty (default: 64)
- `--cpu`: Run on CPU instead of GPU
- `--tracing`: Enable tracing (generates a trace-timestamp.json file)
### Server Mode with OpenAI-compatible API
Run the inference engine in server mode to expose an OpenAI-compatible API:
```bash
cargo run -p legacy-inference-engine --release -- --server --port 3777 --which 3-1b-it
```
This starts a web server on the specified port (default: 3777) with an OpenAI-compatible chat completions endpoint.
#### Server Options
- `--server`: Run in server mode
- `--port <INT>`: Port to use for the server (default: 3777)
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
- Other model options as described in CLI mode
## API Usage
The server exposes an OpenAI-compatible chat completions endpoint:
### Chat Completions
```
POST /v1/chat/completions
```
#### Request Format
```json
{
"model": "gemma-3-1b-it",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
],
"temperature": 0.7,
"max_tokens": 256,
"top_p": 0.9,
"stream": false
}
```
#### Response Format
```json
{
"id": "chatcmpl-123abc456def789ghi",
"object": "chat.completion",
"created": 1677858242,
"model": "gemma-3-1b-it",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'm doing well, thank you for asking! How can I assist you today?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 25,
"completion_tokens": 15,
"total_tokens": 40
}
}
```
### Example: Using cURL
```bash
curl -X POST http://localhost:3777/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gemma-3-1b-it",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
],
"temperature": 0.7,
"max_tokens": 100
}'
```
### Example: Using Python with OpenAI Client
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:3777/v1",
api_key="dummy" # API key is not validated but required by the client
)
response = client.chat.completions.create(
model="gemma-3-1b-it",
messages=[
{"role": "user", "content": "What is the capital of France?"}
],
temperature=0.7,
max_tokens=100
)
print(response.choices[0].message.content)
```
### Example: Using JavaScript/TypeScript with OpenAI SDK
```javascript
import OpenAI from 'openai';
const openai = new OpenAI({
baseURL: 'http://localhost:3777/v1',
apiKey: 'dummy', // API key is not validated but required by the client
});
async function main() {
const response = await openai.chat.completions.create({
model: 'gemma-3-1b-it',
messages: [
{ role: 'user', content: 'What is the capital of France?' }
],
temperature: 0.7,
max_tokens: 100,
});
console.log(response.choices[0].message.content);
}
main();
```
## Troubleshooting
### Common Issues
1. **Model download errors**: Make sure you have a stable internet connection. The models are downloaded from Hugging Face Hub.
2. **Out of memory errors**: Try using a smaller model variant or reducing the batch size.
3. **Slow inference on CPU**: This is expected. For better performance, use GPU acceleration if available.
4. **Metal/CUDA errors**: Ensure you have the latest drivers installed for your GPU.
## License
This project is licensed under the terms specified in the LICENSE file.

View File

@@ -0,0 +1,127 @@
# Root Cause Analysis: Metal error "no metal implementation for rotary-emb"
Date: 2025-08-27
Component: crates/legacy-inference-engine
Command to reproduce: crates/legacy-inference-engine/test_cli.sh
## Summary
Running the CLI with the default model (--which 3-1b-it, i.e., Gemma 3 1B Instruct) on an Apple Silicon Mac results in a runtime failure:
```
modelError: Metal error no metal implementation for rotary-emb
Caused by:
no metal implementation for rotary-emb
```
This occurs because the project targets the Candle Metal (MPS) backend on macOS, but the Candle version in use (0.9.1) does not provide a Metal kernel implementation for the rotary embedding operation required by Gemma 3 models. The program selects the Metal device by default on macOS and hits this missing kernel during the attention computation.
## Environment and build configuration
- Machine: 2024 MacBook Pro, Apple Silicon (M4 Max)
- Crate: legacy-inference-engine
- Candle versions: pinned to =0.9.1
- candle-core = "=0.9.1"
- candle-transformers = "=0.9.1"
- macOS-specific dependency enabling Metal (file: crates/legacy-inference-engine/Cargo.toml):
```text
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { version = "=0.9.1", features = ["metal"] }
metal = { version = "0.32.0", features = ["mps"] }
```
- Run command (attached script): crates/legacy-inference-engine/test_cli.sh
```text
cargo run -p legacy-inference-engine --release -- --prompt 'Name the 16th President of the USA.' --which 3-1b-it
```
## What the code does at runtime
1) Device selection (defaults to Metal on macOS if available):
- File: crates/legacy-inference-engine/src/utilities_lib.rs (lines 412)
```text
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
// ... falls back to CPU
Ok(Device::Cpu)
}
}
```
- The CLI does not pass --cpu, so on Apple Silicon with Metal available, Device::new_metal(0) is selected.
2) Default model selection is Gemma 3 1B Instruct:
- File: crates/legacy-inference-engine/src/main.rs
- Arg default (lines 705707):
```text
/// The model to use.
#[arg(long, default_value = "3-1b-it")]
which: Which,
```
- Model id resolution (lines 758760):
```text
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
```
- Model loading uses Model3 (Gemma 3) for Which::BaseV3_1B | Which::InstructV3_1B (lines 817821).
3) During generation, the Gemma 3 attention path requires rotary embeddings. On the Metal backend in Candle 0.9.1, the rotary embedding op is not implemented, resulting in the runtime error.
## Additional build-time signal (misleading but not causal)
- File: crates/legacy-inference-engine/src/main.rs (lines 1011)
```text
#[cfg(feature = "metal")]
extern crate metal_src;
```
- Build warning: unexpected cfg condition value: metal
Explanation: The project does not define a Cargo feature named "metal"; instead, Metal is enabled via target-specific dependency features in Cargo.toml. This cfg gate is ineffective and triggers a warning. It does not cause the runtime failure; it just indicates confusing/obsolete gating.
## Root cause
- The program runs on the Candle Metal backend (MPS) due to device auto-selection on macOS.
- The selected model (Gemma 3 1B Instruct) requires the rotary embedding operation in its attention mechanism.
- Candle 0.9.1s Metal backend lacks an implementation for the rotary-emb kernel. When the model executes on Metal, it attempts to invoke this operation and fails with: "no metal implementation for rotary-emb".
## Evidence
- Runtime log shows the failure immediately after model load when inference begins.
- Code paths confirm: device defaults to Metal on macOS; default model is Gemma 3; Gemma 3 uses rotary embeddings.
- Candle version pinned to 0.9.1 where rotary-emb on Metal is not available.
## Impact
- Any attempt to run Gemma 3 (and possibly other rotary-embedding reliant models) on the Metal backend with Candle 0.9.1 will fail at runtime on macOS.
## Workarounds and remediation options
1) Immediate workarounds:
- Run on CPU: add the --cpu flag to force CPU backend.
- Example: cargo run -p legacy-inference-engine --release -- --cpu --prompt '...' --which 3-1b-it
- Use a model variant that does not hit the unimplemented kernel on Metal (e.g., older Gemma v1/v2), though many modern LLMs rely on rotary embeddings, so this may not help.
2) Recommended remediation (code/dependency changes):
- Upgrade Candle crates (candle-core, candle-transformers, etc.) to a version where the Metal backend implements rotary embeddings. Review Candles changelog/PRs for Metal/MPS kernel support and update to the first version that includes rotary-emb on Metal.
- Alternatively, implement a CPU fallback path for rotary-emb when running on Metal (hybrid execution). This is non-trivial and may degrade performance.
- Provide a configuration/flag to disable Metal by default on macOS for models known to require missing ops until Candle is upgraded.
- Clean up the misleading #[cfg(feature = "metal")] gate in main.rs to avoid confusion; Metal enablement is already handled in Cargo.toml via target-specific features.
## Suggested next steps
- Short term: document and expose --cpu usage in README and/or make the default model a Metal-compatible one until dependency upgrade.
- Medium term: bump Candle dependencies and test Gemma 3 on Metal; remove the obsolete cfg(feature = "metal") gate.
- Long term: integrate a device capability check and automatic fallback (informative log) when encountering unsupported kernels on the selected backend.
## References (code locations)
- crates/legacy-inference-engine/src/utilities_lib.rs lines 412: device selection (Metal default on macOS if available).
- crates/legacy-inference-engine/src/main.rs lines 705707: default which = 3-1b-it.
- crates/legacy-inference-engine/src/main.rs lines 758760 and 817821: Gemma 3 model selection and instantiation.
- crates/legacy-inference-engine/Cargo.toml macOS target section: Candle with features = ["metal"].
- crates/legacy-inference-engine/src/main.rs lines 1011: obsolete #[cfg(feature = "metal")] gate that triggers a warning.

View File

@@ -0,0 +1,295 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>OpenAI-Compatible API Tester</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
line-height: 1.6;
}
h1, h2 {
color: #333;
}
.container {
margin-bottom: 20px;
}
textarea {
width: 100%;
height: 150px;
padding: 10px;
margin-bottom: 10px;
border: 1px solid #ddd;
border-radius: 4px;
font-family: monospace;
}
button {
background-color: #4CAF50;
color: white;
padding: 10px 15px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 16px;
}
button:hover {
background-color: #45a049;
}
pre {
background-color: #f5f5f5;
padding: 15px;
border-radius: 4px;
overflow-x: auto;
white-space: pre-wrap;
}
.response {
margin-top: 20px;
}
.error {
color: red;
}
.settings {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-bottom: 15px;
}
.settings div {
display: flex;
flex-direction: column;
}
label {
margin-bottom: 5px;
font-weight: bold;
}
input {
padding: 8px;
border: 1px solid #ddd;
border-radius: 4px;
}
.examples {
margin-top: 30px;
}
.example-btn {
background-color: #2196F3;
margin-right: 10px;
margin-bottom: 10px;
}
.example-btn:hover {
background-color: #0b7dda;
}
</style>
</head>
<body>
<h1>OpenAI-Compatible API Tester</h1>
<p>Use this page to test the OpenAI-compatible chat completions endpoint of the local inference engine.</p>
<div class="container">
<h2>Request Settings</h2>
<div class="settings">
<div>
<label for="serverUrl">Server URL:</label>
<input type="text" id="serverUrl" value="http://localhost:3777" />
</div>
<div>
<label for="model">Model:</label>
<input type="text" id="model" value="gemma-3-1b-it" />
</div>
<div>
<label for="maxTokens">Max Tokens:</label>
<input type="number" id="maxTokens" value="150" />
</div>
<div>
<label for="temperature">Temperature:</label>
<input type="number" id="temperature" value="0.7" step="0.1" min="0" max="2" />
</div>
<div>
<label for="topP">Top P:</label>
<input type="number" id="topP" value="0.9" step="0.1" min="0" max="1" />
</div>
</div>
<h2>Request Body</h2>
<textarea id="requestBody">{
"model": "gemma-3-1b-it",
"messages": [
{
"role": "user",
"content": "Hello, how are you today?"
}
],
"max_tokens": 150,
"temperature": 0.7,
"top_p": 0.9
}</textarea>
<button id="sendRequest">Send Request</button>
<div class="examples">
<h3>Example Requests</h3>
<button class="example-btn" id="example1">Basic Question</button>
<button class="example-btn" id="example2">Multi-turn Conversation</button>
<button class="example-btn" id="example3">Creative Writing</button>
<button class="example-btn" id="example4">Code Generation</button>
</div>
<div class="response">
<h2>Response</h2>
<pre id="responseOutput">Response will appear here...</pre>
</div>
</div>
<script>
document.addEventListener('DOMContentLoaded', function() {
// Update request body when settings change
const serverUrlInput = document.getElementById('serverUrl');
const modelInput = document.getElementById('model');
const maxTokensInput = document.getElementById('maxTokens');
const temperatureInput = document.getElementById('temperature');
const topPInput = document.getElementById('topP');
const requestBodyTextarea = document.getElementById('requestBody');
const responseOutput = document.getElementById('responseOutput');
// Function to update request body from settings
function updateRequestBodyFromSettings() {
try {
const requestBody = JSON.parse(requestBodyTextarea.value);
requestBody.model = modelInput.value;
requestBody.max_tokens = parseInt(maxTokensInput.value);
requestBody.temperature = parseFloat(temperatureInput.value);
requestBody.top_p = parseFloat(topPInput.value);
requestBodyTextarea.value = JSON.stringify(requestBody, null, 2);
} catch (error) {
console.error("Error updating request body:", error);
}
}
// Update settings when request body changes
function updateSettingsFromRequestBody() {
try {
const requestBody = JSON.parse(requestBodyTextarea.value);
if (requestBody.model) modelInput.value = requestBody.model;
if (requestBody.max_tokens) maxTokensInput.value = requestBody.max_tokens;
if (requestBody.temperature) temperatureInput.value = requestBody.temperature;
if (requestBody.top_p) topPInput.value = requestBody.top_p;
} catch (error) {
console.error("Error updating settings:", error);
}
}
// Add event listeners for settings changes
modelInput.addEventListener('change', updateRequestBodyFromSettings);
maxTokensInput.addEventListener('change', updateRequestBodyFromSettings);
temperatureInput.addEventListener('change', updateRequestBodyFromSettings);
topPInput.addEventListener('change', updateRequestBodyFromSettings);
// Add event listener for request body changes
requestBodyTextarea.addEventListener('blur', updateSettingsFromRequestBody);
// Send request button
document.getElementById('sendRequest').addEventListener('click', async function() {
try {
responseOutput.textContent = "Sending request...";
const serverUrl = serverUrlInput.value;
const endpoint = '/v1/chat/completions';
const url = serverUrl + endpoint;
const requestBody = JSON.parse(requestBodyTextarea.value);
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestBody)
});
const data = await response.json();
responseOutput.textContent = JSON.stringify(data, null, 2);
} catch (error) {
responseOutput.textContent = "Error: " + error.message;
responseOutput.classList.add('error');
}
});
// Example requests
document.getElementById('example1').addEventListener('click', function() {
requestBodyTextarea.value = JSON.stringify({
model: modelInput.value,
messages: [
{
role: "user",
content: "Who was the 16th president of the United States?"
}
],
max_tokens: parseInt(maxTokensInput.value),
temperature: parseFloat(temperatureInput.value),
top_p: parseFloat(topPInput.value)
}, null, 2);
});
document.getElementById('example2').addEventListener('click', function() {
requestBodyTextarea.value = JSON.stringify({
model: modelInput.value,
messages: [
{
role: "system",
content: "You are a helpful assistant that provides concise answers."
},
{
role: "user",
content: "What is machine learning?"
},
{
role: "assistant",
content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
},
{
role: "user",
content: "Give me an example of a machine learning algorithm."
}
],
max_tokens: parseInt(maxTokensInput.value),
temperature: parseFloat(temperatureInput.value),
top_p: parseFloat(topPInput.value)
}, null, 2);
});
document.getElementById('example3').addEventListener('click', function() {
requestBodyTextarea.value = JSON.stringify({
model: modelInput.value,
messages: [
{
role: "user",
content: "Write a short poem about artificial intelligence."
}
],
max_tokens: parseInt(maxTokensInput.value),
temperature: 0.9, // Higher temperature for creative tasks
top_p: 0.9
}, null, 2);
temperatureInput.value = 0.9;
});
document.getElementById('example4').addEventListener('click', function() {
requestBodyTextarea.value = JSON.stringify({
model: modelInput.value,
messages: [
{
role: "user",
content: "Write a Python function to calculate the Fibonacci sequence up to n terms."
}
],
max_tokens: parseInt(maxTokensInput.value),
temperature: 0.3, // Lower temperature for code generation
top_p: 0.9
}, null, 2);
temperatureInput.value = 0.3;
});
});
</script>
</body>
</html>

View File

@@ -0,0 +1,72 @@
use clap::Parser;
use crate::model::Which;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
pub cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
pub tracing: bool,
/// Run in server mode with OpenAI compatible API
#[arg(long)]
pub server: bool,
/// Port to use for the server
#[arg(long, default_value_t = 3777)]
pub port: u16,
/// Prompt for text generation (not used in server mode)
#[arg(long)]
pub prompt: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
pub temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
pub top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
pub seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
pub sample_len: usize,
#[arg(long)]
pub model_id: Option<String>,
#[arg(long, default_value = "main")]
pub revision: String,
#[arg(long)]
pub tokenizer_file: Option<String>,
#[arg(long)]
pub config_file: Option<String>,
#[arg(long)]
pub weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
pub repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
pub repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "3-1b-it")]
pub which: Which,
#[arg(long)]
pub use_flash_attn: bool,
}

View File

@@ -0,0 +1,13 @@
// Expose modules for testing and library usage
pub mod token_output_stream;
pub mod model;
pub mod text_generation;
pub mod utilities_lib;
pub mod openai_types;
pub mod cli;
pub mod server;
// Re-export key components for easier access
pub use model::{Model, Which};
pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream;

View File

@@ -0,0 +1,912 @@
mod token_output_stream;
mod utilities_lib;
#[cfg(feature = "intel-mkl-src")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate-src")]
extern crate accelerate_src;
#[cfg(feature = "metal")]
extern crate metal_src;
use anyhow::{Error as E, Result};
use axum::{
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use clap::Parser;
use either::Either;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tower_http::cors::{Any, CorsLayer};
use utoipa::ToSchema;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
// OpenAI API compatible structs
/// Inner content structure for messages that can be either a string or key-value pairs
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageInnerContent(
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
)
}
}
/// Function for MessageInnerContent Schema generation to handle `Either`
fn message_inner_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
// Either::Left - simple string
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
// Either::Right - object with string values
.item(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::T(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))))
.build(),
))
.build(),
)
}
/// Message content that can be either simple text or complex structured content
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
}
}
/// Function for MessageContent Schema generation to handle `Either`
fn message_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
.item(Schema::Array(
ArrayBuilder::new()
.items(RefOr::T(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::Ref(
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
)))
.build(),
)))
.build(),
))
.build(),
)
}
/// Represents a single message in a conversation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct Message {
/// The message content
pub content: Option<MessageContent>,
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
pub role: String,
pub name: Option<String>,
}
/// Stop token configuration for generation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(untagged)]
pub enum StopTokens {
/// Multiple possible stop sequences
Multi(Vec<String>),
/// Single stop sequence
Single(String),
}
/// Default value helper
fn default_false() -> bool {
false
}
/// Default value helper
fn default_1usize() -> usize {
1
}
/// Default value helper
fn default_model() -> String {
"default".to_string()
}
/// Chat completion request following OpenAI's specification
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ChatCompletionRequest {
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
pub messages: Vec<Message>,
#[schema(example = "gemma-3-1b-it")]
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_false")]
#[schema(example = false)]
pub logprobs: bool,
#[schema(example = 256)]
pub max_tokens: Option<usize>,
#[serde(rename = "n")]
#[serde(default = "default_1usize")]
#[schema(example = 1)]
pub n_choices: usize,
#[schema(example = 0.7)]
pub temperature: Option<f64>,
#[schema(example = 0.9)]
pub top_p: Option<f64>,
#[schema(example = false)]
pub stream: Option<bool>,
}
/// Chat completion choice
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionChoice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
/// Chat completion response
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
}
/// Token usage information
#[derive(Debug, Serialize, ToSchema)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
// Application state shared between handlers
#[derive(Clone)]
struct AppState {
text_generation: Arc<Mutex<TextGeneration>>,
model_id: String,
}
// Chat completions endpoint handler
async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut prompt = String::new();
// Convert messages to a prompt string
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
// Format based on role
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
// Add the assistant prefix for the response
prompt.push_str("Assistant: ");
// Capture the output
let mut output = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Buffer to capture the output
let mut buffer = Vec::new();
// Run text generation
let max_tokens = request.max_tokens.unwrap_or(1000);
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
"type": "unsupported_api"
}
})),
));
}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
output.push(text);
}
}
// Create response
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: request.model,
choices: vec![ChatCompletionChoice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: Some(MessageContent(Either::Left(output.join("")))),
name: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: prompt.len() / 4, // Rough estimate
completion_tokens: output.join("").len() / 4, // Rough estimate
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
},
};
// Return the response as JSON
Ok(Json(response))
}
use candle_core::{DType, Device, MetalDevice, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType, api::sync::Api};
use serde_json::json;
use tokenizers::Tokenizer;
use crate::token_output_stream::TokenOutputStream;
use crate::utilities_lib::device;
// Create the router with the chat completions endpoint
fn create_router(app_state: AppState) -> Router {
// CORS layer to allow requests from any origin
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// OpenAI compatible endpoints
.route("/v1/chat/completions", post(chat_completions))
// Add more endpoints as needed
.layer(cors)
.with_state(app_state)
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
// Run text generation and print to stdout
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
println!(
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
);
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
// Run text generation and write to a buffer
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Write prompt tokens to output
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
write!(output, "{}", t)?;
}
}
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
eos_token
}
};
// Determine if we're using a Model3 (gemma-3) variant
let is_model3 = match &self.model {
Model::V3(_) => true,
_ => false,
};
// For Model3, we need to use a different approach
if is_model3 {
// For gemma-3 models, we'll generate one token at a time with the full context
let start_gen = std::time::Instant::now();
// Initial generation with the full prompt
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let mut logits = self.model.forward(&input, 0)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
for _ in 0..sample_len {
// Apply repeat penalty if needed
let current_logits = if self.repeat_penalty == 1. {
logits.clone()
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&current_logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
// For the next iteration, just use the new token
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
logits = self.model.forward(&new_input, tokens.len() - 1)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
}
return Ok(());
}
// Standard approach for other models
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
}
// Write any remaining tokens
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
write!(output, "{}", rest)?;
}
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Run in server mode with OpenAI compatible API
#[arg(long)]
server: bool,
/// Port to use for the server
#[arg(long, default_value_t = 3777)]
port: u16,
/// Prompt for text generation (not used in server mode)
#[arg(long)]
prompt: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "3-1b-it")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Which::Base2B => "google/gemma-2b".to_string(),
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
Which::CodeBase2B => "google/codegemma-2b".to_string(),
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id.clone(),
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.which {
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let initial_device = utilities_lib::device(args.cpu)?;
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
let is_v3_model = matches!(args.which, Which::BaseV3_1B | Which::InstructV3_1B);
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.cpu;
// Use CPU for V3 models on Metal due to missing implementations
let device = if is_v3_model && is_metal {
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
Device::Cpu
} else {
initial_device
};
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
// Use the selected device and dtype
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(args.use_flash_attn, &config, vb)?;
Model::V1(model)
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(args.use_flash_attn, &config, vb)?;
Model::V3(model)
}
};
println!("loaded the model in {:?}", start.elapsed());
let pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
if args.server {
// Start the server
println!("Starting server on port {}", args.port);
// Create app state
let app_state = AppState {
text_generation: Arc::new(Mutex::new(pipeline)),
model_id,
};
// Create router
let app = create_router(app_state);
// Run the server
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
// Use tokio to run the server
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async {
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
.await
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
})?;
Ok(())
} else {
// Run in CLI mode
if let Some(prompt_text) = &args.prompt {
let prompt = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B
| Which::BaseV2_2B
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B => prompt_text.clone(),
Which::InstructV3_1B => {
format!(
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
prompt_text
)
}
};
let mut pipeline = pipeline;
pipeline.run(&prompt, args.sample_len)?;
Ok(())
} else {
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
}
}
}

View File

@@ -0,0 +1,90 @@
use candle_core::Tensor;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
pub enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
impl Which {
pub fn to_model_id(&self) -> String {
match self {
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Self::Base2B => "google/gemma-2b".to_string(),
Self::Base7B => "google/gemma-7b".to_string(),
Self::Instruct2B => "google/gemma-2b-it".to_string(),
Self::Instruct7B => "google/gemma-7b-it".to_string(),
Self::CodeBase2B => "google/codegemma-2b".to_string(),
Self::CodeBase7B => "google/codegemma-7b".to_string(),
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Self::BaseV2_2B => "google/gemma-2-2b".to_string(),
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Self::BaseV2_9B => "google/gemma-2-9b".to_string(),
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
}
}
pub fn is_instruct_model(&self) -> bool {
match self {
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
_ => true,
}
}
pub fn is_v3_model(&self) -> bool {
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
}
}

View File

@@ -0,0 +1,167 @@
use either::Either;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use utoipa::ToSchema;
/// Inner content structure for messages that can be either a string or key-value pairs
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageInnerContent(
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
)
}
}
/// Function for MessageInnerContent Schema generation to handle `Either`
fn message_inner_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
// Either::Left - simple string
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
// Either::Right - object with string values
.item(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::T(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))))
.build(),
))
.build(),
)
}
/// Message content that can be either simple text or complex structured content
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
}
}
/// Function for MessageContent Schema generation to handle `Either`
fn message_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
.item(Schema::Array(
ArrayBuilder::new()
.items(RefOr::T(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::Ref(
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
)))
.build(),
)))
.build(),
))
.build(),
)
}
/// Represents a single message in a conversation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct Message {
/// The message content
pub content: Option<MessageContent>,
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
pub role: String,
pub name: Option<String>,
}
/// Stop token configuration for generation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(untagged)]
pub enum StopTokens {
/// Multiple possible stop sequences
Multi(Vec<String>),
/// Single stop sequence
Single(String),
}
/// Default value helper
pub fn default_false() -> bool {
false
}
/// Default value helper
pub fn default_1usize() -> usize {
1
}
/// Default value helper
pub fn default_model() -> String {
"default".to_string()
}
/// Chat completion request following OpenAI's specification
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ChatCompletionRequest {
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
pub messages: Vec<Message>,
#[schema(example = "gemma-3-1b-it")]
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_false")]
#[schema(example = false)]
pub logprobs: bool,
#[schema(example = 256)]
pub max_tokens: Option<usize>,
#[serde(rename = "n")]
#[serde(default = "default_1usize")]
#[schema(example = 1)]
pub n_choices: usize,
#[schema(example = 0.7)]
pub temperature: Option<f64>,
#[schema(example = 0.9)]
pub top_p: Option<f64>,
#[schema(example = false)]
pub stream: Option<bool>,
}
/// Chat completion choice
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionChoice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
/// Chat completion response
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
}
/// Token usage information
#[derive(Debug, Serialize, ToSchema)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}

View File

@@ -0,0 +1,128 @@
use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, Message, MessageContent, Usage};
use crate::text_generation::TextGeneration;
use either::Either;
// Application state shared between handlers
#[derive(Clone)]
pub struct AppState {
pub text_generation: Arc<Mutex<TextGeneration>>,
pub model_id: String,
}
// Chat completions endpoint handler
pub async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut prompt = String::new();
// Convert messages to a prompt string
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
// Format based on role
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
// Add the assistant prefix for the response
prompt.push_str("Assistant: ");
// Capture the output
let mut output = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Buffer to capture the output
let mut buffer = Vec::new();
// Run text generation
let max_tokens = request.max_tokens.unwrap_or(1000);
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
"type": "unsupported_api"
}
})),
));
}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
output.push(text);
}
}
// Create response
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace("-", "")),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: request.model,
choices: vec![ChatCompletionChoice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: Some(MessageContent(Either::Left(output.join("")))),
name: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: prompt.len() / 4, // Rough estimate
completion_tokens: output.join("").len() / 4, // Rough estimate
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
},
};
// Return the response as JSON
Ok(Json(response))
}
// Create the router with the chat completions endpoint
pub fn create_router(app_state: AppState) -> Router {
// CORS layer to allow requests from any origin
let cors = CorsLayer::new()
.allow_headers(Any)
.allow_credentials(true)
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// OpenAI compatible endpoints
.route("/v1/chat/completions", post(chat_completions))
// Add more endpoints as needed
.layer(cors)
.with_state(app_state)
}

View File

@@ -0,0 +1,352 @@
use anyhow::{Error as E, Result};
use candle_core::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use tokenizers::Tokenizer;
use std::io::Write;
use crate::model::Model;
use crate::token_output_stream::TokenOutputStream;
pub struct TextGeneration {
model: Model,
device: Device,
// CPU device for fallback when operations are unsupported on primary device
cpu_device: Option<Device>,
// Flag to indicate if we should try to use the primary device first
try_primary_device: bool,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
pub fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
// Initialize CPU device only if the primary device is not already CPU
let (cpu_device, try_primary_device) = if device.is_cpu() {
// If already on CPU, no need for a fallback device
(None, false)
} else {
// Store CPU device for fallback and set flag to try primary device first
(Some(Device::Cpu), true)
};
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
cpu_device,
try_primary_device,
}
}
// Helper method for model execution with fallback to CPU for unsupported operations
fn execute_with_fallback(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
// If we're not trying primary device anymore, go straight to CPU if available
if !self.try_primary_device {
if let Some(cpu_device) = &self.cpu_device {
let cpu_input = input.to_device(cpu_device).map_err(E::msg)?;
let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?;
return cpu_result.to_device(&self.device).map_err(E::msg);
} else {
// No CPU fallback, use primary device
return self.model.forward(input, start_pos).map_err(E::msg);
}
}
// Try running on the primary device first
match self.model.forward(input, start_pos) {
Ok(result) => Ok(result),
Err(err) => {
// Convert to string to check for unsupported operation
let err_string = err.to_string();
// Check if the error is about unsupported operations
if (err_string.contains("no metal implementation for") ||
err_string.contains("no cuda implementation for")) &&
self.cpu_device.is_some() {
// Extract operation name for better logging
let op_name = if let Some(idx) = err_string.find("for ") {
&err_string[(idx + 4)..]
} else {
"an operation"
};
// Log the fallback
println!("Warning: The primary device does not support {}. Falling back to CPU.", op_name);
// Move input to CPU and try again
let cpu_device = self.cpu_device.as_ref().unwrap();
let cpu_input = input.to_device(cpu_device).map_err(E::msg)?;
let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?;
// Don't try primary device for future operations
self.try_primary_device = false;
println!("Successfully executed on CPU. Will use CPU for subsequent operations.");
// Move result back to original device
cpu_result.to_device(&self.device).map_err(E::msg)
} else {
// Not an unsupported operation error or no CPU fallback
Err(E::msg(err))
}
}
}
}
// Run text generation and print to stdout
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
println!(
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
);
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
// Use execute_with_fallback instead of model.forward
let logits = self.execute_with_fallback(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
// Run text generation and write to a buffer
pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Write prompt tokens to output
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
write!(output, "{}", t)?;
}
}
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
eos_token
}
};
// Determine if we're using a Model3 (gemma-3) variant
let is_model3 = match &self.model {
Model::V3(_) => true,
_ => false,
};
// For Model3, we need to use a different approach
if is_model3 {
// For gemma-3 models, we'll generate one token at a time with the full context
let start_gen = std::time::Instant::now();
// Initial generation with the full prompt
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
// Use execute_with_fallback instead of model.forward
let mut logits = self.execute_with_fallback(&input, 0)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
for _ in 0..sample_len {
// Apply repeat penalty if needed
let current_logits = if self.repeat_penalty == 1. {
logits.clone()
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&current_logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
// For the next iteration, just use the new token
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
// Use execute_with_fallback instead of model.forward
logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
}
return Ok(());
}
// Standard approach for other models
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
// Use execute_with_fallback instead of model.forward
let logits = self.execute_with_fallback(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
}
// Write any remaining tokens
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
write!(output, "{}", rest)?;
}
Ok(())
}
}

View File

@@ -0,0 +1,86 @@
use candle_core::Result;
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
/// streaming way rather than having to wait for the full decoding.
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
prev_index: usize,
current_index: usize,
}
impl TokenOutputStream {
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
Self {
tokenizer,
tokens: Vec::new(),
prev_index: 0,
current_index: 0,
}
}
pub fn into_inner(self) -> tokenizers::Tokenizer {
self.tokenizer
}
fn decode(&self, tokens: &[u32]) -> Result<String> {
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => candle_core::bail!("cannot decode: {err}"),
}
}
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_rest(&self) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
let text = text.split_at(prev_text.len());
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_all(&self) -> Result<String> {
self.decode(&self.tokens)
}
pub fn get_token(&self, token_s: &str) -> Option<u32> {
self.tokenizer.get_vocab(true).get(token_s).copied()
}
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
&self.tokenizer
}
pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
self.current_index = 0;
}
}

View File

@@ -0,0 +1,167 @@
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<(Tensor, usize, usize)> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?;
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
let img = match resize_longest {
None => img,
Some(resize_longest) => {
let (height, width) = (img.height(), img.width());
let resize_longest = resize_longest as u32;
let (height, width) = if height < width {
let h = (resize_longest * height) / width;
(h, resize_longest)
} else {
let w = (resize_longest * width) / height;
(resize_longest, w)
};
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
}
};
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
Ok((data, initial_h, initial_w))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P,
width: usize,
height: usize,
) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?
.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let data = img.into_raw();
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
}
/// Saves an image to disk using the image crate, this expects an input with shape
/// (c, height, width).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle_core::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle_core::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
pub fn save_image_resize<P: AsRef<std::path::Path>>(
img: &Tensor,
p: P,
h: usize,
w: usize,
) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle_core::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle_core::bail!("error saving image {p:?}"),
};
let image = image::DynamicImage::from(image);
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle_core::Error::wrap))
.collect::<Result<Vec<_>>>()?;
Ok(safetensors_files)
}
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
path: P,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file);
}
}
let safetensors_files: Vec<_> = safetensors_files
.into_iter()
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}

View File

@@ -0,0 +1,3 @@
#!/usr/bin/env sh
cargo run -p legacy-inference-engine --release -- --prompt 'Name the 16th President of the USA.' --which 3-1b-it

View File

@@ -0,0 +1,67 @@
use legacy_inference_engine::model::{Model, Which};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_which_to_model_id() {
// Test a few representative model variants
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
}
#[test]
fn test_which_is_instruct_model() {
// Test base models (should return false)
assert!(!Which::Base2B.is_instruct_model());
assert!(!Which::Base7B.is_instruct_model());
assert!(!Which::CodeBase2B.is_instruct_model());
assert!(!Which::CodeBase7B.is_instruct_model());
assert!(!Which::BaseV2_2B.is_instruct_model());
assert!(!Which::BaseV2_9B.is_instruct_model());
assert!(!Which::BaseV3_1B.is_instruct_model());
// Test instruct models (should return true)
assert!(Which::Instruct2B.is_instruct_model());
assert!(Which::Instruct7B.is_instruct_model());
assert!(Which::InstructV1_1_2B.is_instruct_model());
assert!(Which::InstructV1_1_7B.is_instruct_model());
assert!(Which::CodeInstruct2B.is_instruct_model());
assert!(Which::CodeInstruct7B.is_instruct_model());
assert!(Which::InstructV2_2B.is_instruct_model());
assert!(Which::InstructV2_9B.is_instruct_model());
assert!(Which::InstructV3_1B.is_instruct_model());
}
#[test]
fn test_which_is_v3_model() {
// Test non-v3 models (should return false)
assert!(!Which::Base2B.is_v3_model());
assert!(!Which::Base7B.is_v3_model());
assert!(!Which::Instruct2B.is_v3_model());
assert!(!Which::Instruct7B.is_v3_model());
assert!(!Which::InstructV1_1_2B.is_v3_model());
assert!(!Which::InstructV1_1_7B.is_v3_model());
assert!(!Which::CodeBase2B.is_v3_model());
assert!(!Which::CodeBase7B.is_v3_model());
assert!(!Which::CodeInstruct2B.is_v3_model());
assert!(!Which::CodeInstruct7B.is_v3_model());
assert!(!Which::BaseV2_2B.is_v3_model());
assert!(!Which::InstructV2_2B.is_v3_model());
assert!(!Which::BaseV2_9B.is_v3_model());
assert!(!Which::InstructV2_9B.is_v3_model());
// Test v3 models (should return true)
assert!(Which::BaseV3_1B.is_v3_model());
assert!(Which::InstructV3_1B.is_v3_model());
}
// Note: Testing the Model enum's forward method would require creating actual model instances,
// which is complex and would require loading model weights. This is better suited for
// integration tests or mocking the models.
}

View File

@@ -0,0 +1,101 @@
use anyhow::Result;
use candle_transformers::generation::LogitsProcessor;
use legacy_inference_engine::model::Which;
use legacy_inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
// Test the Which enum's to_model_id method
#[test]
fn test_which_model_id() {
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
}
// Test the Which enum's is_instruct_model method
#[test]
fn test_which_is_instruct() {
assert!(!Which::Base2B.is_instruct_model());
assert!(Which::Instruct7B.is_instruct_model());
}
// Test the Which enum's is_v3_model method
#[test]
fn test_which_is_v3() {
assert!(!Which::Base2B.is_v3_model());
assert!(Which::BaseV3_1B.is_v3_model());
}
// Test the TokenOutputStream functionality
#[test]
fn test_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Test encoding and decoding
let text = "Hello, world!";
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
let token_ids = encoded.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all and check
let decoded = token_stream.decode_all()?;
assert_eq!(decoded.trim(), text);
Ok(())
}
// Test the LogitsProcessor
#[test]
fn test_logits_processor() -> Result<()> {
// Create a LogitsProcessor with default settings
let seed = 42;
let temp = Some(0.8);
let top_p = Some(0.9);
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
// Create a simple logits tensor
// In a real test, we would create a tensor with known values and verify
// that sampling produces expected results
// For now, we'll just verify that the LogitsProcessor can be created
assert!(true);
Ok(())
}
// Test the TextGeneration constructor
#[test]
fn test_text_generation_constructor() -> Result<()> {
// We can't easily create a Model instance for testing,
// but we can test that the constructor compiles and the types are correct
// In a real test with a mock Model, we would:
// 1. Create a mock model
// 2. Create a tokenizer
// 3. Call TextGeneration::new
// 4. Verify the properties of the created instance
// For now, we'll just verify that the code compiles
assert!(true);
Ok(())
}
// Note: Testing the actual text generation functionality would require
// integration tests with real models, which is beyond the scope of these unit tests.
// The tests above focus on the components that can be tested in isolation.
}

View File

@@ -0,0 +1,129 @@
use legacy_inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
use std::path::PathBuf;
use anyhow::Result;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
#[test]
fn test_new_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Check that the token stream was created successfully
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
Ok(())
}
#[test]
fn test_clear() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Add a token
let token_id = token_stream.get_token("<eos>").unwrap();
token_stream.next_token(token_id)?;
// Clear the stream
token_stream.clear();
// Check that the stream is empty by trying to decode all
let decoded = token_stream.decode_all()?;
assert_eq!(decoded, "");
Ok(())
}
#[test]
fn test_get_token() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get a token that should exist
let eos_token = token_stream.get_token("<eos>");
assert!(eos_token.is_some());
// Get a token that shouldn't exist
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
assert!(nonexistent_token.is_none());
Ok(())
}
#[test]
fn test_next_token_and_decode() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
let mut output = String::new();
for &token_id in token_ids {
if let Some(text) = token_stream.next_token(token_id)? {
output.push_str(&text);
}
}
// Get any remaining text
if let Some(rest) = token_stream.decode_rest()? {
output.push_str(&rest);
}
// Check the output
assert!(!output.is_empty());
assert_eq!(output.trim(), "Hello world");
Ok(())
}
#[test]
fn test_decode_all() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all
let decoded = token_stream.decode_all()?;
// Check the output
assert_eq!(decoded.trim(), "Hello world");
Ok(())
}
#[test]
fn test_into_inner() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get the inner tokenizer
let inner_tokenizer = token_stream.into_inner();
// Check that the inner tokenizer works
let encoded = inner_tokenizer.encode("Test", true).unwrap();
assert!(encoded.get_ids().len() > 0);
Ok(())
}
}