update structure to improve testability
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
OPENAI_API_KEY="your-key-goes-here"
|
OPENAI_API_KEY="your-key-goes-here"
|
||||||
OPENAI_API_BASE="http://localhost:3000/v1"
|
OPENAI_API_BASE="http://localhost:3777/v1"
|
||||||
GENAISCRIPT_MODEL_LARGE="gemma-3-1b-it"
|
GENAISCRIPT_MODEL_LARGE="gemma-3-1b-it"
|
||||||
GENAISCRIPT_MODEL_SMALL="gemma-3-1b-it"
|
GENAISCRIPT_MODEL_SMALL="gemma-3-1b-it"
|
||||||
SEARXNG_API_BASE_URL="http://localhost:8080"
|
SEARXNG_API_BASE_URL="http://localhost:8080"
|
||||||
|
4
local_inference_engine/Cargo.lock
generated
4
local_inference_engine/Cargo.lock
generated
@@ -1962,7 +1962,7 @@ dependencies = [
|
|||||||
name = "hyper-rustls"
|
name = "hyper-rustls"
|
||||||
version = "0.27.6"
|
version = "0.27.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d"
|
checksum = "03a01595e11bdcec50946522c32dde3fc6914743777a68b93777965f2f02406d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"http",
|
"http",
|
||||||
"hyper",
|
"hyper",
|
||||||
@@ -3834,7 +3834,7 @@ dependencies = [
|
|||||||
name = "reborrow"
|
name = "reborrow"
|
||||||
version = "0.5.5"
|
version = "0.5.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
|
checksum = "03251193777f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
|
@@ -63,15 +63,15 @@ cargo run --release -- --prompt "Your prompt text here" --which 3-1b-it
|
|||||||
Run the inference engine in server mode to expose an OpenAI-compatible API:
|
Run the inference engine in server mode to expose an OpenAI-compatible API:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --release -- --server --port 3000 --which 3-1b-it
|
cargo run --release -- --server --port 3777 --which 3-1b-it
|
||||||
```
|
```
|
||||||
|
|
||||||
This starts a web server on the specified port (default: 3000) with an OpenAI-compatible chat completions endpoint.
|
This starts a web server on the specified port (default: 3777) with an OpenAI-compatible chat completions endpoint.
|
||||||
|
|
||||||
#### Server Options
|
#### Server Options
|
||||||
|
|
||||||
- `--server`: Run in server mode
|
- `--server`: Run in server mode
|
||||||
- `--port <INT>`: Port to use for the server (default: 3000)
|
- `--port <INT>`: Port to use for the server (default: 3777)
|
||||||
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
||||||
- Other model options as described in CLI mode
|
- Other model options as described in CLI mode
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ POST /v1/chat/completions
|
|||||||
### Example: Using cURL
|
### Example: Using cURL
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -X POST http://localhost:3000/v1/chat/completions \
|
curl -X POST http://localhost:3777/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "gemma-3-1b-it",
|
"model": "gemma-3-1b-it",
|
||||||
@@ -148,7 +148,7 @@ curl -X POST http://localhost:3000/v1/chat/completions \
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
base_url="http://localhost:3000/v1",
|
base_url="http://localhost:3777/v1",
|
||||||
api_key="dummy" # API key is not validated but required by the client
|
api_key="dummy" # API key is not validated but required by the client
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ print(response.choices[0].message.content)
|
|||||||
import OpenAI from 'openai';
|
import OpenAI from 'openai';
|
||||||
|
|
||||||
const openai = new OpenAI({
|
const openai = new OpenAI({
|
||||||
baseURL: 'http://localhost:3000/v1',
|
baseURL: 'http://localhost:3777/v1',
|
||||||
apiKey: 'dummy', // API key is not validated but required by the client
|
apiKey: 'dummy', // API key is not validated but required by the client
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@@ -93,7 +93,7 @@
|
|||||||
<div class="settings">
|
<div class="settings">
|
||||||
<div>
|
<div>
|
||||||
<label for="serverUrl">Server URL:</label>
|
<label for="serverUrl">Server URL:</label>
|
||||||
<input type="text" id="serverUrl" value="http://localhost:3000" />
|
<input type="text" id="serverUrl" value="http://localhost:3777" />
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<label for="model">Model:</label>
|
<label for="model">Model:</label>
|
||||||
|
@@ -6,7 +6,7 @@
|
|||||||
(async function testBasicChatCompletion() {
|
(async function testBasicChatCompletion() {
|
||||||
console.log("Test 1: Basic chat completion request");
|
console.log("Test 1: Basic chat completion request");
|
||||||
try {
|
try {
|
||||||
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -34,7 +34,7 @@
|
|||||||
(async function testMultiTurnConversation() {
|
(async function testMultiTurnConversation() {
|
||||||
console.log("\nTest 2: Multi-turn conversation");
|
console.log("\nTest 2: Multi-turn conversation");
|
||||||
try {
|
try {
|
||||||
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -74,7 +74,7 @@
|
|||||||
(async function testTemperatureAndTopP() {
|
(async function testTemperatureAndTopP() {
|
||||||
console.log("\nTest 3: Request with temperature and top_p parameters");
|
console.log("\nTest 3: Request with temperature and top_p parameters");
|
||||||
try {
|
try {
|
||||||
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -104,7 +104,7 @@
|
|||||||
(async function testStreaming() {
|
(async function testStreaming() {
|
||||||
console.log("\nTest 4: Request with streaming enabled");
|
console.log("\nTest 4: Request with streaming enabled");
|
||||||
try {
|
try {
|
||||||
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -148,7 +148,7 @@
|
|||||||
(async function testDifferentModel() {
|
(async function testDifferentModel() {
|
||||||
console.log("\nTest 5: Request with a different model");
|
console.log("\nTest 5: Request with a different model");
|
||||||
try {
|
try {
|
||||||
const response = await fetch('http://localhost:3000/v1/chat/completions', {
|
const response = await fetch('http://localhost:3777/v1/chat/completions', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
72
local_inference_engine/src/cli.rs
Normal file
72
local_inference_engine/src/cli.rs
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
use clap::Parser;
|
||||||
|
use crate::model::Which;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
pub struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
pub cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
pub tracing: bool,
|
||||||
|
|
||||||
|
/// Run in server mode with OpenAI compatible API
|
||||||
|
#[arg(long)]
|
||||||
|
pub server: bool,
|
||||||
|
|
||||||
|
/// Port to use for the server
|
||||||
|
#[arg(long, default_value_t = 3777)]
|
||||||
|
pub port: u16,
|
||||||
|
|
||||||
|
/// Prompt for text generation (not used in server mode)
|
||||||
|
#[arg(long)]
|
||||||
|
pub prompt: Option<String>,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
pub seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
|
pub sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
pub model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
pub revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
pub tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
pub config_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
pub weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
pub repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
pub repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model to use.
|
||||||
|
#[arg(long, default_value = "3-1b-it")]
|
||||||
|
pub which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
pub use_flash_attn: bool,
|
||||||
|
}
|
@@ -652,7 +652,7 @@ struct Args {
|
|||||||
server: bool,
|
server: bool,
|
||||||
|
|
||||||
/// Port to use for the server
|
/// Port to use for the server
|
||||||
#[arg(long, default_value_t = 3000)]
|
#[arg(long, default_value_t = 3777)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
/// Prompt for text generation (not used in server mode)
|
/// Prompt for text generation (not used in server mode)
|
||||||
|
90
local_inference_engine/src/model.rs
Normal file
90
local_inference_engine/src/model.rs
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
use candle_core::Tensor;
|
||||||
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
pub enum Which {
|
||||||
|
#[value(name = "2b")]
|
||||||
|
Base2B,
|
||||||
|
#[value(name = "7b")]
|
||||||
|
Base7B,
|
||||||
|
#[value(name = "2b-it")]
|
||||||
|
Instruct2B,
|
||||||
|
#[value(name = "7b-it")]
|
||||||
|
Instruct7B,
|
||||||
|
#[value(name = "1.1-2b-it")]
|
||||||
|
InstructV1_1_2B,
|
||||||
|
#[value(name = "1.1-7b-it")]
|
||||||
|
InstructV1_1_7B,
|
||||||
|
#[value(name = "code-2b")]
|
||||||
|
CodeBase2B,
|
||||||
|
#[value(name = "code-7b")]
|
||||||
|
CodeBase7B,
|
||||||
|
#[value(name = "code-2b-it")]
|
||||||
|
CodeInstruct2B,
|
||||||
|
#[value(name = "code-7b-it")]
|
||||||
|
CodeInstruct7B,
|
||||||
|
#[value(name = "2-2b")]
|
||||||
|
BaseV2_2B,
|
||||||
|
#[value(name = "2-2b-it")]
|
||||||
|
InstructV2_2B,
|
||||||
|
#[value(name = "2-9b")]
|
||||||
|
BaseV2_9B,
|
||||||
|
#[value(name = "2-9b-it")]
|
||||||
|
InstructV2_9B,
|
||||||
|
#[value(name = "3-1b")]
|
||||||
|
BaseV3_1B,
|
||||||
|
#[value(name = "3-1b-it")]
|
||||||
|
InstructV3_1B,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Model {
|
||||||
|
V1(Model1),
|
||||||
|
V2(Model2),
|
||||||
|
V3(Model3),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::V1(m) => m.forward(input_ids, pos),
|
||||||
|
Self::V2(m) => m.forward(input_ids, pos),
|
||||||
|
Self::V3(m) => m.forward(input_ids, pos),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
pub fn to_model_id(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||||
|
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||||
|
Self::Base2B => "google/gemma-2b".to_string(),
|
||||||
|
Self::Base7B => "google/gemma-7b".to_string(),
|
||||||
|
Self::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||||
|
Self::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||||
|
Self::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||||
|
Self::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||||
|
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||||
|
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||||
|
Self::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
||||||
|
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||||
|
Self::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||||
|
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||||
|
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||||
|
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_instruct_model(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_v3_model(&self) -> bool {
|
||||||
|
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
|
||||||
|
}
|
||||||
|
}
|
167
local_inference_engine/src/openai_types.rs
Normal file
167
local_inference_engine/src/openai_types.rs
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
use either::Either;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
/// Inner content structure for messages that can be either a string or key-value pairs
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct MessageInnerContent(
|
||||||
|
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
||||||
|
);
|
||||||
|
|
||||||
|
impl ToSchema<'_> for MessageInnerContent {
|
||||||
|
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||||
|
(
|
||||||
|
"MessageInnerContent",
|
||||||
|
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Function for MessageInnerContent Schema generation to handle `Either`
|
||||||
|
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||||
|
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||||
|
|
||||||
|
Schema::OneOf(
|
||||||
|
OneOfBuilder::new()
|
||||||
|
// Either::Left - simple string
|
||||||
|
.item(Schema::Object(
|
||||||
|
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||||
|
))
|
||||||
|
// Either::Right - object with string values
|
||||||
|
.item(Schema::Object(
|
||||||
|
ObjectBuilder::new()
|
||||||
|
.schema_type(SchemaType::Object)
|
||||||
|
.additional_properties(Some(RefOr::T(Schema::Object(
|
||||||
|
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||||
|
))))
|
||||||
|
.build(),
|
||||||
|
))
|
||||||
|
.build(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Message content that can be either simple text or complex structured content
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct MessageContent(
|
||||||
|
#[serde(with = "either::serde_untagged")]
|
||||||
|
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||||
|
);
|
||||||
|
|
||||||
|
impl ToSchema<'_> for MessageContent {
|
||||||
|
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||||
|
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Function for MessageContent Schema generation to handle `Either`
|
||||||
|
fn message_content_schema() -> utoipa::openapi::Schema {
|
||||||
|
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||||
|
|
||||||
|
Schema::OneOf(
|
||||||
|
OneOfBuilder::new()
|
||||||
|
.item(Schema::Object(
|
||||||
|
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||||
|
))
|
||||||
|
.item(Schema::Array(
|
||||||
|
ArrayBuilder::new()
|
||||||
|
.items(RefOr::T(Schema::Object(
|
||||||
|
ObjectBuilder::new()
|
||||||
|
.schema_type(SchemaType::Object)
|
||||||
|
.additional_properties(Some(RefOr::Ref(
|
||||||
|
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
||||||
|
)))
|
||||||
|
.build(),
|
||||||
|
)))
|
||||||
|
.build(),
|
||||||
|
))
|
||||||
|
.build(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a single message in a conversation
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub struct Message {
|
||||||
|
/// The message content
|
||||||
|
pub content: Option<MessageContent>,
|
||||||
|
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
||||||
|
pub role: String,
|
||||||
|
pub name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stop token configuration for generation
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum StopTokens {
|
||||||
|
/// Multiple possible stop sequences
|
||||||
|
Multi(Vec<String>),
|
||||||
|
/// Single stop sequence
|
||||||
|
Single(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default value helper
|
||||||
|
pub fn default_false() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default value helper
|
||||||
|
pub fn default_1usize() -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default value helper
|
||||||
|
pub fn default_model() -> String {
|
||||||
|
"default".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chat completion request following OpenAI's specification
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub struct ChatCompletionRequest {
|
||||||
|
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
#[schema(example = "gemma-3-1b-it")]
|
||||||
|
#[serde(default = "default_model")]
|
||||||
|
pub model: String,
|
||||||
|
#[serde(default = "default_false")]
|
||||||
|
#[schema(example = false)]
|
||||||
|
pub logprobs: bool,
|
||||||
|
#[schema(example = 256)]
|
||||||
|
pub max_tokens: Option<usize>,
|
||||||
|
#[serde(rename = "n")]
|
||||||
|
#[serde(default = "default_1usize")]
|
||||||
|
#[schema(example = 1)]
|
||||||
|
pub n_choices: usize,
|
||||||
|
#[schema(example = 0.7)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[schema(example = 0.9)]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[schema(example = false)]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chat completion choice
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct ChatCompletionChoice {
|
||||||
|
pub index: usize,
|
||||||
|
pub message: Message,
|
||||||
|
pub finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chat completion response
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Token usage information
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: usize,
|
||||||
|
pub completion_tokens: usize,
|
||||||
|
pub total_tokens: usize,
|
||||||
|
}
|
126
local_inference_engine/src/server.rs
Normal file
126
local_inference_engine/src/server.rs
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
routing::{get, post},
|
||||||
|
Json, Router,
|
||||||
|
};
|
||||||
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::openai_types::{ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, Message, MessageContent, Usage};
|
||||||
|
use crate::text_generation::TextGeneration;
|
||||||
|
use either::Either;
|
||||||
|
|
||||||
|
// Application state shared between handlers
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub text_generation: Arc<Mutex<TextGeneration>>,
|
||||||
|
pub model_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat completions endpoint handler
|
||||||
|
pub async fn chat_completions(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(request): Json<ChatCompletionRequest>,
|
||||||
|
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||||
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
// Convert messages to a prompt string
|
||||||
|
for message in &request.messages {
|
||||||
|
let role = &message.role;
|
||||||
|
let content = match &message.content {
|
||||||
|
Some(content) => match &content.0 {
|
||||||
|
Either::Left(text) => text.clone(),
|
||||||
|
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||||
|
},
|
||||||
|
None => "".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Format based on role
|
||||||
|
match role.as_str() {
|
||||||
|
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||||
|
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||||
|
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||||
|
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the assistant prefix for the response
|
||||||
|
prompt.push_str("Assistant: ");
|
||||||
|
|
||||||
|
// Capture the output
|
||||||
|
let mut output = Vec::new();
|
||||||
|
{
|
||||||
|
let mut text_gen = state.text_generation.lock().await;
|
||||||
|
|
||||||
|
// Buffer to capture the output
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
|
||||||
|
// Run text generation
|
||||||
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||||
|
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": {
|
||||||
|
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin local_inference_engine -- --prompt \"Your prompt here\"",
|
||||||
|
"type": "unsupported_api"
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert buffer to string
|
||||||
|
if let Ok(text) = String::from_utf8(buffer) {
|
||||||
|
output.push(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response
|
||||||
|
let response = ChatCompletionResponse {
|
||||||
|
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace("-", "")),
|
||||||
|
object: "chat.completion".to_string(),
|
||||||
|
created: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
model: request.model,
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
message: Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: Some(MessageContent(Either::Left(output.join("")))),
|
||||||
|
name: None,
|
||||||
|
},
|
||||||
|
finish_reason: "stop".to_string(),
|
||||||
|
}],
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: prompt.len() / 4, // Rough estimate
|
||||||
|
completion_tokens: output.join("").len() / 4, // Rough estimate
|
||||||
|
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return the response as JSON
|
||||||
|
Ok(Json(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the router with the chat completions endpoint
|
||||||
|
pub fn create_router(app_state: AppState) -> Router {
|
||||||
|
// CORS layer to allow requests from any origin
|
||||||
|
let cors = CorsLayer::new()
|
||||||
|
.allow_origin(Any)
|
||||||
|
.allow_methods(Any)
|
||||||
|
.allow_headers(Any);
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
// OpenAI compatible endpoints
|
||||||
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
// Add more endpoints as needed
|
||||||
|
.layer(cors)
|
||||||
|
.with_state(app_state)
|
||||||
|
}
|
277
local_inference_engine/src/text_generation.rs
Normal file
277
local_inference_engine/src/text_generation.rs
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
use crate::model::Model;
|
||||||
|
use crate::token_output_stream::TokenOutputStream;
|
||||||
|
|
||||||
|
pub struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run text generation and print to stdout
|
||||||
|
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <eos> token"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => {
|
||||||
|
println!(
|
||||||
|
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||||
|
);
|
||||||
|
eos_token
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
|
||||||
|
// Manual implementation of repeat penalty to avoid type conflicts
|
||||||
|
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
for &token_id in &tokens[start_at..] {
|
||||||
|
let token_id = token_id as usize;
|
||||||
|
if token_id < logits_vec.len() {
|
||||||
|
let score = logits_vec[token_id];
|
||||||
|
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||||
|
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new tensor with the modified logits
|
||||||
|
let device = logits.device().clone();
|
||||||
|
let shape = logits.shape().clone();
|
||||||
|
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||||
|
new_logits.reshape(shape)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token || next_token == eot_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run text generation and write to a buffer
|
||||||
|
pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
// Write prompt tokens to output
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
write!(output, "{}", t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <eos> token"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => {
|
||||||
|
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
||||||
|
eos_token
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Determine if we're using a Model3 (gemma-3) variant
|
||||||
|
let is_model3 = match &self.model {
|
||||||
|
Model::V3(_) => true,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
|
// For Model3, we need to use a different approach
|
||||||
|
if is_model3 {
|
||||||
|
// For gemma-3 models, we'll generate one token at a time with the full context
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
|
||||||
|
// Initial generation with the full prompt
|
||||||
|
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||||
|
let mut logits = self.model.forward(&input, 0)?;
|
||||||
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
for _ in 0..sample_len {
|
||||||
|
// Apply repeat penalty if needed
|
||||||
|
let current_logits = if self.repeat_penalty == 1. {
|
||||||
|
logits.clone()
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
|
||||||
|
// Manual implementation of repeat penalty to avoid type conflicts
|
||||||
|
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
for &token_id in &tokens[start_at..] {
|
||||||
|
let token_id = token_id as usize;
|
||||||
|
if token_id < logits_vec.len() {
|
||||||
|
let score = logits_vec[token_id];
|
||||||
|
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||||
|
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new tensor with the modified logits
|
||||||
|
let device = logits.device().clone();
|
||||||
|
let shape = logits.shape().clone();
|
||||||
|
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||||
|
new_logits.reshape(shape)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
|
||||||
|
if next_token == eos_token || next_token == eot_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
write!(output, "{}", t)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the next iteration, just use the new token
|
||||||
|
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
||||||
|
logits = self.model.forward(&new_input, tokens.len() - 1)?;
|
||||||
|
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard approach for other models
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
|
||||||
|
// Manual implementation of repeat penalty to avoid type conflicts
|
||||||
|
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
for &token_id in &tokens[start_at..] {
|
||||||
|
let token_id = token_id as usize;
|
||||||
|
if token_id < logits_vec.len() {
|
||||||
|
let score = logits_vec[token_id];
|
||||||
|
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||||
|
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new tensor with the modified logits
|
||||||
|
let device = logits.device().clone();
|
||||||
|
let shape = logits.shape().clone();
|
||||||
|
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||||
|
new_logits.reshape(shape)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token || next_token == eot_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
write!(output, "{}", t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write any remaining tokens
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
write!(output, "{}", rest)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@@ -47,7 +47,7 @@ const {text} = await host.fetchText(new URL(url).toString());
|
|||||||
// browser: getBrowser(),
|
// browser: getBrowser(),
|
||||||
// headless: true,
|
// headless: true,
|
||||||
// javaScriptEnabled: browser !== "chromium",
|
// javaScriptEnabled: browser !== "chromium",
|
||||||
// // timeout: 3000,
|
// // timeout: 3777,
|
||||||
// // bypassCSP: true,
|
// // bypassCSP: true,
|
||||||
// // baseUrl: new URL(url).origin,
|
// // baseUrl: new URL(url).origin,
|
||||||
// });
|
// });
|
||||||
|
@@ -145,7 +145,7 @@ ui:
|
|||||||
# Note: since commit af77ec3, morty accepts a base64 encoded key.
|
# Note: since commit af77ec3, morty accepts a base64 encoded key.
|
||||||
#
|
#
|
||||||
# result_proxy:
|
# result_proxy:
|
||||||
# url: http://127.0.0.1:3000/
|
# url: http://127.0.0.1:3777/
|
||||||
# # the key is a base64 encoded string, the YAML !!binary prefix is optional
|
# # the key is a base64 encoded string, the YAML !!binary prefix is optional
|
||||||
# key: !!binary "your_morty_proxy_key"
|
# key: !!binary "your_morty_proxy_key"
|
||||||
# # [true|false] enable the "proxy" button next to each result
|
# # [true|false] enable the "proxy" button next to each result
|
||||||
|
Reference in New Issue
Block a user