mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
- Introduced ServerConfig
for handling deployment modes and services.
- Added HighAvailability mode for proxying requests to external services. - Maintained Local mode for embedded services. - Updated `README.md` and included `SERVER_CONFIG.md` for detailed documentation.
This commit is contained in:
36
Cargo.lock
generated
36
Cargo.lock
generated
@@ -4255,6 +4255,8 @@ dependencies = [
|
||||
"axum",
|
||||
"embeddings-engine",
|
||||
"inference-engine",
|
||||
"reqwest",
|
||||
"rust-embed",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
@@ -5035,6 +5037,40 @@ dependencies = [
|
||||
"realfft",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed"
|
||||
version = "8.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "025908b8682a26ba8d12f6f2d66b987584a4a87bc024abc5bbc12553a8cd178a"
|
||||
dependencies = [
|
||||
"rust-embed-impl",
|
||||
"rust-embed-utils",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed-impl"
|
||||
version = "8.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6065f1a4392b71819ec1ea1df1120673418bf386f50de1d6f54204d836d4349c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rust-embed-utils",
|
||||
"syn 2.0.106",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed-utils"
|
||||
version = "8.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6cc0c81648b20b70c491ff8cce00c1c3b223bb8ed2b5d41f0e54c6c4c0a3594"
|
||||
dependencies = [
|
||||
"sha2",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.26"
|
||||
|
20
README.md
20
README.md
@@ -74,9 +74,27 @@ Environment variables for server configuration:
|
||||
|
||||
- `SERVER_HOST`: Server bind address (default: `0.0.0.0`)
|
||||
- `SERVER_PORT`: Server port (default: `8080`)
|
||||
- `SERVER_CONFIG`: JSON configuration for deployment mode (default: Local mode)
|
||||
- `RUST_LOG`: Logging level configuration
|
||||
|
||||
Example:
|
||||
#### Deployment Modes
|
||||
|
||||
The server supports two deployment modes controlled by `SERVER_CONFIG`:
|
||||
|
||||
**Local Mode (default)**: Runs inference and embeddings services locally
|
||||
```shell
|
||||
./run_server.sh
|
||||
```
|
||||
|
||||
**HighAvailability Mode**: Proxies requests to external services
|
||||
```shell
|
||||
export SERVER_CONFIG='{"serverMode": "HighAvailability"}'
|
||||
./run_server.sh
|
||||
```
|
||||
|
||||
See [docs/SERVER_CONFIG.md](docs/SERVER_CONFIG.md) for complete configuration options, Docker Compose, and Kubernetes examples.
|
||||
|
||||
#### Basic Configuration Example:
|
||||
```shell
|
||||
export SERVER_PORT=3000
|
||||
export RUST_LOG=debug
|
||||
|
@@ -18,6 +18,8 @@ serde_json = "1.0.140"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
rust-embed = "8.7.2"
|
||||
|
||||
# Dependencies for embeddings functionality
|
||||
embeddings-engine = { path = "../embeddings-engine" }
|
||||
@@ -37,3 +39,4 @@ port = 8080
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
replicas = 1
|
||||
port = 8080
|
||||
env = { SERVER_CONFIG = "" }
|
180
crates/predict-otron-9000/src/config.rs
Normal file
180
crates/predict-otron-9000/src/config.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_server_host")]
|
||||
pub server_host: String,
|
||||
#[serde(default = "default_server_port")]
|
||||
pub server_port: u16,
|
||||
pub server_mode: ServerMode,
|
||||
#[serde(default)]
|
||||
pub services: Services,
|
||||
}
|
||||
|
||||
fn default_server_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_server_port() -> u16 { 8080 }
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum ServerMode {
|
||||
Standalone,
|
||||
HighAvailability,
|
||||
}
|
||||
|
||||
impl Default for ServerMode {
|
||||
fn default() -> Self {
|
||||
Self::Standalone
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Services {
|
||||
#[serde(default = "inference_service_url")]
|
||||
pub inference_url: String,
|
||||
#[serde(default = "embeddings_service_url")]
|
||||
pub embeddings_url: String,
|
||||
}
|
||||
|
||||
impl Default for Services {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inference_url: inference_service_url(),
|
||||
embeddings_url: embeddings_service_url(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn inference_service_url() -> String {
|
||||
"http://inference-service:8080".to_string()
|
||||
}
|
||||
|
||||
fn embeddings_service_url() -> String {
|
||||
"http://embeddings-service:8080".to_string()
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::Standalone,
|
||||
services: Services::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerConfig {
|
||||
/// Load configuration from SERVER_CONFIG environment variable
|
||||
/// Falls back to default (Local mode) if not set or invalid
|
||||
pub fn from_env() -> Self {
|
||||
match env::var("SERVER_CONFIG") {
|
||||
Ok(config_str) => {
|
||||
match serde_json::from_str::<ServerConfig>(&config_str) {
|
||||
Ok(config) => {
|
||||
tracing::info!("Loaded server configuration: {:?}", config);
|
||||
config
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
||||
e
|
||||
);
|
||||
ServerConfig::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::info!("SERVER_CONFIG not set, using default Local mode");
|
||||
ServerConfig::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the server should run in high availability mode
|
||||
pub fn is_high_availability(&self) -> bool {
|
||||
self.server_mode == ServerMode::HighAvailability
|
||||
}
|
||||
|
||||
/// Get the inference service URL for proxying
|
||||
pub fn inference_url(&self) -> &str {
|
||||
&self.services.inference_url
|
||||
}
|
||||
|
||||
/// Get the embeddings service URL for proxying
|
||||
pub fn embeddings_url(&self) -> &str {
|
||||
&self.services.embeddings_url
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = ServerConfig::default();
|
||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||
assert!(!config.is_high_availability());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_availability_config() {
|
||||
let config_json = r#"{
|
||||
"serverMode": "HighAvailability",
|
||||
"services": {
|
||||
"inference_url": "http://inference-service:8080",
|
||||
"embeddings_url": "http://embeddings-service:8080"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.server_mode, ServerMode::HighAvailability);
|
||||
assert!(config.is_high_availability());
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_local_mode_config() {
|
||||
let config_json = r#"{
|
||||
"serverMode": "Local"
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||
assert!(!config.is_high_availability());
|
||||
// Should use default URLs
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_urls() {
|
||||
let config_json = r#"{
|
||||
"serverMode": "HighAvailability",
|
||||
"services": {
|
||||
"inference_url": "http://custom-inference:9000",
|
||||
"embeddings_url": "http://custom-embeddings:9001"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.inference_url(), "http://custom-inference:9000");
|
||||
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimal_high_availability_config() {
|
||||
let config_json = r#"{"serverMode": "HighAvailability"}"#;
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert!(config.is_high_availability());
|
||||
// Should use default URLs
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
}
|
||||
}
|
@@ -1,4 +1,6 @@
|
||||
mod middleware;
|
||||
mod config;
|
||||
mod proxy;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
@@ -12,9 +14,9 @@ use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use inference_engine::AppState;
|
||||
use middleware::{MetricsStore, MetricsLoggerFuture, MetricsLayer};
|
||||
use config::ServerConfig;
|
||||
use proxy::create_proxy_router;
|
||||
|
||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
@@ -42,27 +44,50 @@ async fn main() {
|
||||
// Spawn the metrics logger in a background task
|
||||
tokio::spawn(metrics_logger);
|
||||
|
||||
// Create unified router by merging embeddings and inference routers
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
// Load server configuration from environment variable
|
||||
let server_config = ServerConfig::from_env();
|
||||
|
||||
// Extract the server_host and server_port before potentially moving server_config
|
||||
let default_host = server_config.server_host.clone();
|
||||
let default_port = server_config.server_port;
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
use inference_engine::server::{PipelineArgs, build_pipeline};
|
||||
use inference_engine::Which;
|
||||
let mut pipeline_args = PipelineArgs::default();
|
||||
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
||||
pipeline_args.which = Which::InstructV3_1B;
|
||||
// Create router based on server mode
|
||||
let service_router = if server_config.clone().is_high_availability() {
|
||||
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
||||
tracing::info!(" Inference service URL: {}", server_config.inference_url());
|
||||
tracing::info!(" Embeddings service URL: {}", server_config.embeddings_url());
|
||||
|
||||
let text_generation = build_pipeline(pipeline_args.clone());
|
||||
let app_state = AppState {
|
||||
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
||||
model_id: "google/gemma-3-1b-it".to_string(),
|
||||
build_args: pipeline_args,
|
||||
// Use proxy router that forwards requests to external services
|
||||
create_proxy_router(server_config.clone())
|
||||
} else {
|
||||
tracing::info!("Running in Local mode - using embedded services");
|
||||
|
||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
use inference_engine::server::{PipelineArgs, build_pipeline};
|
||||
use inference_engine::Which;
|
||||
let mut pipeline_args = PipelineArgs::default();
|
||||
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
||||
pipeline_args.which = Which::InstructV3_1B;
|
||||
|
||||
let text_generation = build_pipeline(pipeline_args.clone());
|
||||
let app_state = AppState {
|
||||
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
||||
model_id: "google/gemma-3-1b-it".to_string(),
|
||||
build_args: pipeline_args,
|
||||
};
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Merge the local routers
|
||||
Router::new()
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
};
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Create CORS layer
|
||||
let cors = CorsLayer::new()
|
||||
.allow_headers(Any)
|
||||
@@ -73,21 +98,27 @@ async fn main() {
|
||||
// Create metrics layer
|
||||
let metrics_layer = MetricsLayer::new(metrics_store);
|
||||
|
||||
// Merge the routers and add middleware layers
|
||||
// Merge the service router with base routes and add middleware layers
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World!" }))
|
||||
.route("/", get(|| async { "API ready. This can serve the Leptos web app, but it doesn't." }))
|
||||
.route("/health", get(|| async { "ok" }))
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
.merge(service_router)
|
||||
.layer(metrics_layer) // Add metrics tracking
|
||||
.layer(cors)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
// Server configuration
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
||||
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| {
|
||||
String::from(default_host)
|
||||
});
|
||||
|
||||
let server_port = env::var("SERVER_PORT").map(|v| v.parse::<u16>().unwrap_or(default_port)).unwrap_or_else(|_| {
|
||||
default_port
|
||||
});
|
||||
|
||||
let server_address = format!("{}:{}", server_host, server_port);
|
||||
|
||||
|
||||
let listener = TcpListener::bind(&server_address).await.unwrap();
|
||||
tracing::info!("Unified predict-otron-9000 server listening on {}", listener.local_addr().unwrap());
|
||||
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
|
||||
|
303
crates/predict-otron-9000/src/proxy.rs
Normal file
303
crates/predict-otron-9000/src/proxy.rs
Normal file
@@ -0,0 +1,303 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, Method, StatusCode, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config::ServerConfig;
|
||||
|
||||
/// HTTP client configured for proxying requests
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyClient {
|
||||
client: Client,
|
||||
config: ServerConfig,
|
||||
}
|
||||
|
||||
impl ProxyClient {
|
||||
pub fn new(config: ServerConfig) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(300)) // 5 minute timeout for long-running inference
|
||||
.build()
|
||||
.expect("Failed to create HTTP client for proxy");
|
||||
|
||||
Self { client, config }
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a router that proxies requests to external services in HighAvailability mode
|
||||
pub fn create_proxy_router(config: ServerConfig) -> Router {
|
||||
let proxy_client = ProxyClient::new(config.clone());
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(proxy_chat_completions))
|
||||
.route("/v1/models", get(proxy_models))
|
||||
.route("/v1/embeddings", post(proxy_embeddings))
|
||||
.with_state(proxy_client)
|
||||
}
|
||||
|
||||
/// Proxy handler for POST /v1/chat/completions
|
||||
async fn proxy_chat_completions(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url());
|
||||
|
||||
tracing::info!("Proxying chat completions request to: {}", target_url);
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read request body: {}", e);
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
|
||||
// Check if this is a streaming request
|
||||
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
||||
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
||||
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming vs non-streaming responses
|
||||
if is_streaming {
|
||||
// For streaming, we need to forward the response as-is
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.header("content-type", "text/plain; charset=utf-8")
|
||||
.header("cache-control", "no-cache")
|
||||
.header("connection", "keep-alive")
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read streaming response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For non-streaming, forward the JSON response
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy chat completions request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy handler for GET /v1/models
|
||||
async fn proxy_models(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/models", proxy_client.config.inference_url());
|
||||
|
||||
tracing::info!("Proxying models request to: {}", target_url);
|
||||
|
||||
let mut req_builder = proxy_client.client.get(&target_url);
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read models response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy models request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy handler for POST /v1/embeddings
|
||||
async fn proxy_embeddings(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url());
|
||||
|
||||
tracing::info!("Proxying embeddings request to: {}", target_url);
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read request body: {}", e);
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read embeddings response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy embeddings request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine if a request header should be forwarded to the target service
|
||||
fn should_forward_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
||||
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine if a response header should be forwarded back to the client
|
||||
fn should_forward_response_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
||||
"server" | "date" => false, // Don't forward server-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{ServerMode, Services};
|
||||
|
||||
#[test]
|
||||
fn test_should_forward_header() {
|
||||
assert!(should_forward_header("content-type"));
|
||||
assert!(should_forward_header("authorization"));
|
||||
assert!(!should_forward_header("host"));
|
||||
assert!(!should_forward_header("connection"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_forward_response_header() {
|
||||
assert!(should_forward_response_header("content-type"));
|
||||
assert!(should_forward_response_header("cache-control"));
|
||||
assert!(!should_forward_response_header("server"));
|
||||
assert!(!should_forward_response_header("date"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_client_creation() {
|
||||
let config = ServerConfig {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::HighAvailability,
|
||||
services: Services {
|
||||
inference_url: "http://test-inference:8080".to_string(),
|
||||
embeddings_url: "http://test-embeddings:8080".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let proxy_client = ProxyClient::new(config);
|
||||
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080");
|
||||
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080");
|
||||
}
|
||||
}
|
223
docs/SERVER_CONFIG.md
Normal file
223
docs/SERVER_CONFIG.md
Normal file
@@ -0,0 +1,223 @@
|
||||
# Server Configuration Guide
|
||||
|
||||
The predict-otron-9000 server supports two deployment modes controlled by the `SERVER_CONFIG` environment variable:
|
||||
|
||||
1. **Local Mode** (default): Runs inference and embeddings services locally within the main server process
|
||||
2. **HighAvailability Mode**: Proxies requests to external inference and embeddings services
|
||||
|
||||
## Configuration Format
|
||||
|
||||
The `SERVER_CONFIG` environment variable accepts a JSON configuration with the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"serverMode": "Local",
|
||||
"services": {
|
||||
"inference_url": "http://inference-service:8080",
|
||||
"embeddings_url": "http://embeddings-service:8080"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```json
|
||||
{
|
||||
"serverMode": "HighAvailability",
|
||||
"services": {
|
||||
"inference_url": "http://inference-service:8080",
|
||||
"embeddings_url": "http://embeddings-service:8080"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Fields:**
|
||||
- `serverMode`: Either `"Local"` or `"HighAvailability"`
|
||||
- `services`: Optional object containing service URLs (uses defaults if not provided)
|
||||
|
||||
## Local Mode (Default)
|
||||
|
||||
If `SERVER_CONFIG` is not set or contains invalid JSON, the server defaults to Local mode.
|
||||
|
||||
### Example: Explicit Local Mode
|
||||
```bash
|
||||
export SERVER_CONFIG='{"serverMode": "Local"}'
|
||||
./run_server.sh
|
||||
```
|
||||
|
||||
In Local mode:
|
||||
- Inference requests are handled by the embedded inference engine
|
||||
- Embeddings requests are handled by the embedded embeddings engine
|
||||
- No external services are required
|
||||
- Supports all existing functionality without changes
|
||||
|
||||
## HighAvailability Mode
|
||||
|
||||
In HighAvailability mode, the server acts as a proxy, forwarding requests to external services.
|
||||
|
||||
### Example: Basic HighAvailability Mode
|
||||
```bash
|
||||
export SERVER_CONFIG='{"serverMode": "HighAvailability"}'
|
||||
./run_server.sh
|
||||
```
|
||||
|
||||
This uses the default service URLs:
|
||||
- Inference service: `http://inference-service:8080`
|
||||
- Embeddings service: `http://embeddings-service:8080`
|
||||
|
||||
### Example: Custom Service URLs
|
||||
```bash
|
||||
export SERVER_CONFIG='{
|
||||
"serverMode": "HighAvailability",
|
||||
"services": {
|
||||
"inference_url": "http://custom-inference:9000",
|
||||
"embeddings_url": "http://custom-embeddings:9001"
|
||||
}
|
||||
}'
|
||||
./run_server.sh
|
||||
```
|
||||
|
||||
## Docker Compose Example
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
# Inference service
|
||||
inference-service:
|
||||
image: ghcr.io/geoffsee/inference-service:latest
|
||||
ports:
|
||||
- "8081:8080"
|
||||
environment:
|
||||
- RUST_LOG=info
|
||||
|
||||
# Embeddings service
|
||||
embeddings-service:
|
||||
image: ghcr.io/geoffsee/embeddings-service:latest
|
||||
ports:
|
||||
- "8082:8080"
|
||||
environment:
|
||||
- RUST_LOG=info
|
||||
|
||||
# Main proxy server
|
||||
predict-otron-9000:
|
||||
image: ghcr.io/geoffsee/predict-otron-9000:latest
|
||||
ports:
|
||||
- "8080:8080"
|
||||
environment:
|
||||
- RUST_LOG=info
|
||||
- SERVER_CONFIG={"serverMode":"HighAvailability","services":{"inference_url":"http://inference-service:8080","embeddings_url":"http://embeddings-service:8080"}}
|
||||
depends_on:
|
||||
- inference-service
|
||||
- embeddings-service
|
||||
```
|
||||
|
||||
## Kubernetes Example
|
||||
|
||||
```yaml
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: server-config
|
||||
data:
|
||||
SERVER_CONFIG: |
|
||||
{
|
||||
"serverMode": "HighAvailability",
|
||||
"services": {
|
||||
"inference_url": "http://inference-service:8080",
|
||||
"embeddings_url": "http://embeddings-service:8080"
|
||||
}
|
||||
}
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: predict-otron-9000
|
||||
spec:
|
||||
replicas: 3
|
||||
selector:
|
||||
matchLabels:
|
||||
app: predict-otron-9000
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: predict-otron-9000
|
||||
spec:
|
||||
containers:
|
||||
- name: predict-otron-9000
|
||||
image: ghcr.io/geoffsee/predict-otron-9000:latest
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
env:
|
||||
- name: RUST_LOG
|
||||
value: "info"
|
||||
- name: SERVER_CONFIG
|
||||
valueFrom:
|
||||
configMapKeyRef:
|
||||
name: server-config
|
||||
key: SERVER_CONFIG
|
||||
```
|
||||
|
||||
## API Compatibility
|
||||
|
||||
Both modes expose the same OpenAI-compatible API endpoints:
|
||||
|
||||
- `POST /v1/chat/completions` - Chat completions (streaming and non-streaming)
|
||||
- `GET /v1/models` - List available models
|
||||
- `POST /v1/embeddings` - Generate text embeddings
|
||||
- `GET /health` - Health check
|
||||
- `GET /` - Root endpoint
|
||||
|
||||
## Logging
|
||||
|
||||
The server logs the selected mode on startup:
|
||||
|
||||
**Local Mode:**
|
||||
```
|
||||
INFO predict_otron_9000: Running in Local mode - using embedded services
|
||||
```
|
||||
|
||||
**HighAvailability Mode:**
|
||||
```
|
||||
INFO predict_otron_9000: Running in HighAvailability mode - proxying to external services
|
||||
INFO predict_otron_9000: Inference service URL: http://inference-service:8080
|
||||
INFO predict_otron_9000: Embeddings service URL: http://embeddings-service:8080
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Invalid JSON in `SERVER_CONFIG` falls back to Local mode with a warning
|
||||
- Missing `SERVER_CONFIG` defaults to Local mode
|
||||
- Network errors to external services return HTTP 502 (Bad Gateway)
|
||||
- Request/response proxying preserves original HTTP status codes and headers
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
**Local Mode:**
|
||||
- Lower latency (no network overhead)
|
||||
- Higher memory usage (models loaded locally)
|
||||
- Single point of failure
|
||||
|
||||
**HighAvailability Mode:**
|
||||
- Higher latency (network requests)
|
||||
- Lower memory usage (no local models)
|
||||
- Horizontal scaling possible
|
||||
- Network reliability dependent
|
||||
- 5-minute timeout for long-running inference requests
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
1. **Configuration not applied**: Check JSON syntax and restart the server
|
||||
2. **External services unreachable**: Verify service URLs and network connectivity
|
||||
3. **Timeouts**: Check if inference requests exceed the 5-minute timeout limit
|
||||
4. **502 errors**: External services may be down or misconfigured
|
||||
|
||||
## Migration
|
||||
|
||||
To migrate from Local to HighAvailability mode:
|
||||
|
||||
1. Deploy separate inference and embeddings services
|
||||
2. Update `SERVER_CONFIG` to point to the new services
|
||||
3. Restart the predict-otron-9000 server
|
||||
4. Verify endpoints are working with test requests
|
||||
|
||||
The API contract remains identical, ensuring zero-downtime migration possibilities.
|
Reference in New Issue
Block a user