reorg + update docs with new paths

This commit is contained in:
geoffsee
2025-09-04 12:27:13 -04:00
parent 400c70f17d
commit ff55d882c7
43 changed files with 493 additions and 182 deletions

1
.gitignore vendored
View File

@@ -77,3 +77,4 @@ venv/
!/scripts/cli.ts
/**/.*.bun-build
/AGENTS.md
.claude

1
Cargo.lock generated
View File

@@ -2905,6 +2905,7 @@ dependencies = [
"clap",
"cpal",
"either",
"embeddings-engine",
"futures-util",
"gemma-runner",
"imageproc 0.24.0",

View File

@@ -3,12 +3,12 @@ members = [
"crates/predict-otron-9000",
"crates/inference-engine",
"crates/embeddings-engine",
"crates/helm-chart-tool",
"crates/llama-runner",
"crates/gemma-runner",
"crates/cli",
"integration/helm-chart-tool",
"integration/llama-runner",
"integration/gemma-runner",
"integration/cli",
"crates/chat-ui"
, "crates/utils"]
, "integration/utils"]
default-members = ["crates/predict-otron-9000"]
resolver = "2"

View File

@@ -53,14 +53,17 @@ The project uses a 9-crate Rust workspace plus TypeScript components:
crates/
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
└── chat-ui/ # WASM web frontend (Rust 2021)
integration/
├── cli/ # CLI client crate (Rust 2024)
│ └── package/
│ └── cli.ts # TypeScript/Bun CLI client
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
├── chat-ui/ # WASM web frontend (Rust 2021)
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
└── cli/ # CLI client crate (Rust 2024)
└── package/
└── cli.ts # TypeScript/Bun CLI client
└── utils/ # Shared utilities (Rust 2021)
```
### Service Architecture
@@ -160,16 +163,16 @@ cd crates/chat-ui
#### TypeScript CLI Client
```bash
# List available models
cd crates/cli/package && bun run cli.ts --list-models
cd integration/cli/package && bun run cli.ts --list-models
# Chat completion
cd crates/cli/package && bun run cli.ts "What is the capital of France?"
cd integration/cli/package && bun run cli.ts "What is the capital of France?"
# With specific model
cd crates/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
cd integration/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
# Show help
cd crates/cli/package && bun run cli.ts --help
cd integration/cli/package && bun run cli.ts --help
```
## API Usage
@@ -464,7 +467,7 @@ curl -s http://localhost:8080/v1/models | jq
**CLI client test:**
```bash
cd crates/cli/package && bun run cli.ts "What is 2+2?"
cd integration/cli/package && bun run cli.ts "What is 2+2?"
```
**Web frontend:**

View File

@@ -1,43 +1,183 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
use axum::{Json, Router, response::Json as ResponseJson, routing::{get, post}, http::StatusCode};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy;
use serde::Serialize;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tower_http::trace::TraceLayer;
use tracing;
// Persistent model instance (singleton pattern)
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
tracing::info!("Initializing persistent embedding model (singleton)");
// Cache for multiple embedding models
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> = Lazy::new(|| {
RwLock::new(HashMap::new())
});
#[derive(Serialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub owned_by: String,
pub description: String,
pub dimensions: usize,
}
#[derive(Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
// Function to convert model name strings to EmbeddingModel enum variants
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
match model_name {
// Sentence Transformers models
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLML6V2),
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => Ok(EmbeddingModel::AllMiniLML6V2Q),
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => Ok(EmbeddingModel::AllMiniLML12V2),
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => Ok(EmbeddingModel::AllMiniLML12V2Q),
// BGE models
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
"BAAI/bge-large-en-v1.5" | "bge-large-en-v1.5" => Ok(EmbeddingModel::BGELargeENV15),
"BAAI/bge-large-en-v1.5-q" | "bge-large-en-v1.5-q" => Ok(EmbeddingModel::BGELargeENV15Q),
"BAAI/bge-small-en-v1.5" | "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
// Nomic models
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => Ok(EmbeddingModel::NomicEmbedTextV1),
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => Ok(EmbeddingModel::NomicEmbedTextV15),
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => Ok(EmbeddingModel::NomicEmbedTextV15Q),
// Paraphrase models
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q" | "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
// ModernBert
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => Ok(EmbeddingModel::ModernBertEmbedLarge),
// Multilingual E5 models
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => Ok(EmbeddingModel::MultilingualE5Small),
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => Ok(EmbeddingModel::MultilingualE5Base),
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => Ok(EmbeddingModel::MultilingualE5Large),
// Mixedbread models
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => Ok(EmbeddingModel::MxbaiEmbedLargeV1),
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => Ok(EmbeddingModel::MxbaiEmbedLargeV1Q),
// GTE models
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => Ok(EmbeddingModel::GTEBaseENV15Q),
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => Ok(EmbeddingModel::GTELargeENV15Q),
// CLIP model
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
// Jina model
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode),
_ => Err(format!("Unsupported embedding model: {}", model_name)),
}
}
// Function to get model dimensions
fn get_model_dimensions(model: &EmbeddingModel) -> usize {
match model {
EmbeddingModel::AllMiniLML6V2 | EmbeddingModel::AllMiniLML6V2Q => 384,
EmbeddingModel::AllMiniLML12V2 | EmbeddingModel::AllMiniLML12V2Q => 384,
EmbeddingModel::BGEBaseENV15 | EmbeddingModel::BGEBaseENV15Q => 768,
EmbeddingModel::BGELargeENV15 | EmbeddingModel::BGELargeENV15Q => 1024,
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
EmbeddingModel::BGESmallZHV15 => 512,
EmbeddingModel::BGELargeZHV15 => 1024,
EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 | EmbeddingModel::NomicEmbedTextV15Q => 768,
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
EmbeddingModel::ModernBertEmbedLarge => 1024,
EmbeddingModel::MultilingualE5Small => 384,
EmbeddingModel::MultilingualE5Base => 768,
EmbeddingModel::MultilingualE5Large => 1024,
EmbeddingModel::MxbaiEmbedLargeV1 | EmbeddingModel::MxbaiEmbedLargeV1Q => 1024,
EmbeddingModel::GTEBaseENV15 | EmbeddingModel::GTEBaseENV15Q => 768,
EmbeddingModel::GTELargeENV15 | EmbeddingModel::GTELargeENV15Q => 1024,
EmbeddingModel::ClipVitB32 => 512,
EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
}
}
// Function to get or create a model from cache
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
// First try to get from cache (read lock)
{
let cache = MODEL_CACHE.read().map_err(|e| format!("Failed to acquire read lock: {}", e))?;
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model: {:?}", embedding_model);
return Ok(Arc::clone(model));
}
}
// Model not in cache, create it (write lock)
let mut cache = MODEL_CACHE.write().map_err(|e| format!("Failed to acquire write lock: {}", e))?;
// Double-check after acquiring write lock
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
return Ok(Arc::clone(model));
}
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
)
.expect("Failed to initialize persistent embedding model");
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
let model_init_time = model_start_time.elapsed();
tracing::info!(
"Persistent embedding model initialized in {:.2?}",
"Embedding model {:?} initialized in {:.2?}",
embedding_model,
model_init_time
);
model
});
let model_arc = Arc::new(model);
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
Ok(model_arc)
}
pub async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
) -> Result<ResponseJson<serde_json::Value>, (StatusCode, String)> {
// Start timing the entire process
let start_time = std::time::Instant::now();
// Phase 1: Access persistent model instance
// Phase 1: Parse and get the embedding model
let model_start_time = std::time::Instant::now();
// Access the lazy-initialized persistent model instance
// This will only initialize the model on the first request
let embedding_model = match parse_embedding_model(&payload.model) {
Ok(model) => model,
Err(e) => {
tracing::error!("Invalid model requested: {}", e);
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
}
};
let model = match get_or_create_model(embedding_model.clone()) {
Ok(model) => model,
Err(e) => {
tracing::error!("Failed to get/create model: {}", e);
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("Model initialization failed: {}", e)));
}
};
let model_access_time = model_start_time.elapsed();
tracing::debug!(
"Persistent model access completed in {:.2?}",
"Model access/creation completed in {:.2?}",
model_access_time
);
@@ -65,9 +205,12 @@ pub async fn embeddings_create(
// Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now();
let embeddings = EMBEDDING_MODEL
let embeddings = model
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
.map_err(|e| {
tracing::error!("Failed to generate embeddings: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, format!("Embedding generation failed: {}", e))
})?;
let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!(
@@ -117,8 +260,9 @@ pub async fn embeddings_create(
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
let expected_dimensions = get_model_dimensions(&embedding_model);
let mut random_embedding = Vec::with_capacity(expected_dimensions);
for _ in 0..expected_dimensions {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
@@ -138,18 +282,19 @@ pub async fn embeddings_create(
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
let padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
// Use the actual model dimensions instead of hardcoded 768
let actual_dimensions = padded_embedding.len();
let expected_dimensions = get_model_dimensions(&embedding_model);
if actual_dimensions != expected_dimensions {
tracing::warn!(
"Model {:?} produced {} dimensions but expected {}",
embedding_model,
actual_dimensions,
expected_dimensions
);
padded_embedding.extend(vec![0.0; padding_needed]);
}
padded_embedding
@@ -203,11 +348,232 @@ pub async fn embeddings_create(
postprocessing_time
);
ResponseJson(response)
Ok(ResponseJson(response))
}
pub async fn models_list() -> ResponseJson<ModelsResponse> {
let models = vec![
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the small Chinese model".to_string(),
dimensions: 512,
},
ModelInfo {
id: "BAAI/bge-large-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large Chinese model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "v1.5 release of the 8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "Quantized v1.5 release of the 8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence-transformers model for tasks like clustering or semantic search".to_string(),
dimensions: 768,
},
ModelInfo {
id: "lightonai/modernbert-embed-large".to_string(),
object: "model".to_string(),
owned_by: "lightonai".to_string(),
description: "Large model of ModernBert Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "intfloat/multilingual-e5-small".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Small model of multilingual E5 Text Embeddings".to_string(),
dimensions: 384,
},
ModelInfo {
id: "intfloat/multilingual-e5-base".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Base model of multilingual E5 Text Embeddings".to_string(),
dimensions: 768,
},
ModelInfo {
id: "intfloat/multilingual-e5-large".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Large model of multilingual E5 Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1-q".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Quantized Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Qdrant/clip-ViT-B-32-text".to_string(),
object: "model".to_string(),
owned_by: "Qdrant".to_string(),
description: "CLIP text encoder based on ViT-B/32".to_string(),
dimensions: 512,
},
ModelInfo {
id: "jinaai/jina-embeddings-v2-base-code".to_string(),
object: "model".to_string(),
owned_by: "jinaai".to_string(),
description: "Jina embeddings v2 base code".to_string(),
dimensions: 768,
},
];
ResponseJson(ModelsResponse {
object: "list".to_string(),
data: models,
})
}
pub fn create_embeddings_router() -> Router {
Router::new()
.route("/v1/embeddings", post(embeddings_create))
// .route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http())
}

View File

@@ -4,8 +4,6 @@ use axum::{
response::Json as ResponseJson,
routing::{get, post},
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use serde::{Deserialize, Serialize};
use std::env;
use tower_http::trace::TraceLayer;
use tracing;
@@ -13,127 +11,30 @@ use tracing;
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080";
use embeddings_engine;
async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
)
.expect("Failed to initialize model");
let embedding_input = payload.input;
let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text],
EmbeddingInput::StringArray(texts) => texts,
EmbeddingInput::IntegerArray(_) => {
panic!("Integer array input not supported for text embeddings");
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
match embeddings_engine::embeddings_create(Json(payload)).await {
Ok(response) => Ok(response),
Err((status_code, message)) => {
Err(axum::response::Response::builder()
.status(status_code)
.body(axum::body::Body::from(message))
.unwrap())
}
EmbeddingInput::ArrayOfIntegerArray(_) => {
panic!("Array of integer arrays not supported for text embeddings");
}
};
}
}
let embeddings = model
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
// Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len());
tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level
tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Create the final embedding
let final_embedding = {
// Check if the embedding is all zeros
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
if all_zeros {
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
val = rng.gen_range(-1.0..1.0);
}
random_embedding.push(val);
}
// Normalize the random embedding
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
for i in 0..random_embedding.len() {
random_embedding[i] /= norm;
}
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
);
padded_embedding.extend(vec![0.0; padding_needed]);
}
padded_embedding
}
};
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level
tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Return a response that matches the OpenAI API format
let response = serde_json::json!({
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": final_embedding
}
],
"model": payload.model,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
});
ResponseJson(response)
async fn models_list() -> ResponseJson<embeddings_engine::ModelsResponse> {
embeddings_engine::models_list().await
}
fn create_app() -> Router {
Router::new()
.route("/v1/embeddings", post(embeddings_create))
.route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http())
}
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -31,8 +31,9 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5"
futures-util = "0.3.31"
gemma-runner = { path = "../gemma-runner", features = ["metal"] }
llama-runner = { path = "../llama-runner", features = ["metal"]}
gemma-runner = { path = "../../integration/gemma-runner", features = ["metal"] }
llama-runner = { path = "../../integration/llama-runner", features = ["metal"]}
embeddings-engine = { path = "../embeddings-engine" }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }

View File

@@ -19,6 +19,7 @@ use crate::openai_types::{
};
use crate::Which;
use either::Either;
use embeddings_engine::models_list;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value;
@@ -530,7 +531,9 @@ pub async fn list_models() -> Json<ModelListResponse> {
Which::Llama32_3BInstruct,
];
let models: Vec<Model> = which_variants.into_iter().map(|which| {
let mut models: Vec<Model> = which_variants.into_iter().map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
@@ -566,11 +569,25 @@ pub async fn list_models() -> Json<ModelListResponse> {
Model {
id: model_id.to_string(),
object: "model".to_string(),
created: 1686935002, // Using same timestamp as OpenAI example
created: 1686935002,
owned_by: owned_by.to_string(),
}
}).collect();
// Get embeddings models and convert them to inference Model format
let embeddings_response = models_list().await;
let embeddings_models: Vec<Model> = embeddings_response.0.data.into_iter().map(|embedding_model| {
Model {
id: embedding_model.id,
object: embedding_model.object,
created: 1686935002,
owned_by: format!("{} - {}", embedding_model.owned_by, embedding_model.description),
}
}).collect();
// Add embeddings models to the main models list
models.extend(embeddings_models);
Json(ModelListResponse {
object: "list".to_string(),
data: models,

View File

@@ -144,6 +144,7 @@ async fn main() {
tracing::info!("Available endpoints:");
tracing::info!(" GET / - Leptos chat web application");
tracing::info!(" GET /health - Health check");
tracing::info!(" POST /v1/models - List Models");
tracing::info!(" POST /v1/embeddings - Text embeddings API");
tracing::info!(" POST /v1/chat/completions - Chat completions API");

View File

@@ -61,20 +61,22 @@ graph TD
A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080]
end
subgraph "AI Services"
subgraph "AI Services (crates/)"
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
end
subgraph "Frontend"
subgraph "Frontend (crates/)"
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
end
subgraph "Tooling"
subgraph "Integration Tools (integration/)"
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
M[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
N[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
O[utils<br/>Edition: 2021<br/>Shared utilities]
end
end
@@ -82,10 +84,10 @@ graph TD
A --> B
A --> C
A --> D
B --> J
B --> K
J -.-> F[Candle 0.9.1]
K -.-> F
B --> M
B --> N
M -.-> F[Candle 0.9.1]
N -.-> F
C -.-> G[FastEmbed 4.x]
D -.-> H[Leptos 0.8.0]
E -.-> I[OpenAI SDK 5.16+]
@@ -93,12 +95,13 @@ graph TD
style A fill:#e1f5fe
style B fill:#f3e5f5
style J fill:#f3e5f5
style K fill:#f3e5f5
style C fill:#e8f5e8
style D fill:#fff3e0
style E fill:#fce4ec
style L fill:#fff9c4
style M fill:#f3e5f5
style N fill:#f3e5f5
style O fill:#fff9c4
```
## Deployment Configurations

View File

@@ -14,7 +14,7 @@ Options:
--help Show this help message
Examples:
cd crates/cli/package
cd integration/cli/package
bun run cli.ts "What is the capital of France?"
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
bun run cli.ts --prompt "Who was the 16th president of the United States?"

View File

@@ -0,0 +1,17 @@
{
"lockfileVersion": 1,
"workspaces": {
"": {
"name": "cli",
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0",
},
},
},
"packages": {
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
"openai": ["openai@5.19.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-zSqnUF7oR9ksmpusKkpUgkNrj8Sl57U+OyzO8jzc7LUjTMg4DRfR3uCm+EIMA6iw06sRPNp4t7ojp3sCpEUZRQ=="],
}
}

View File

@@ -18,7 +18,7 @@ serde_json = "1.0"
tracing = "0.1"
tracing-chrome = "0.7"
tracing-subscriber = "0.3"
utils = {path = "../utils"}
utils = {path = "../utils" }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }

View File

@@ -105,7 +105,7 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
.into_iter()
.filter_map(|e| e.ok())
{
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("Cargo.toml") {
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("../../../Cargo.toml") {
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
services.push(service_info);
}

View File

@@ -1,8 +1,8 @@
{
"name": "predict-otron-9000",
"workspaces": ["crates/cli/package"],
"workspaces": ["integration/cli/package"],
"scripts": {
"# WORKSPACE ALIASES": "#",
"cli": "bun --filter crates/cli/package"
"cli": "bun --filter integration/cli/package"
}
}