add openai compatible endpoint for chat completions
This commit is contained in:

committed by
Geoff Seemueller

parent
3b4c8b045a
commit
8a3c0797c3
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,4 @@
|
||||
/target
|
||||
**/target
|
||||
/node_modules/
|
||||
/.idea
|
||||
chrome
|
||||
|
21
README.md
21
README.md
@@ -1,10 +1,27 @@
|
||||
# open-web-agent-rs
|
||||
|
||||
A Rust-based web agent with local inference capabilities.
|
||||
|
||||
## Components
|
||||
|
||||
### Local Inference Engine
|
||||
|
||||
The [Local Inference Engine](./local_inference_engine/README.md) provides a way to run large language models locally. It supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
|
||||
|
||||
Features:
|
||||
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
|
||||
- CLI mode for direct text generation
|
||||
- Server mode with OpenAI-compatible API
|
||||
- Support for various model configurations (base, instruction-tuned)
|
||||
- Metal acceleration on macOS
|
||||
|
||||
See the [Local Inference Engine README](./local_inference_engine/README.md) for detailed usage instructions.
|
||||
|
||||
### Web Server
|
||||
|
||||
Server is being converted to MCP. Things are probably broken.
|
||||
|
||||
```text
|
||||
bun i
|
||||
bun dev
|
||||
```
|
||||
|
||||
|
||||
|
331
local_inference_engine/Cargo.lock
generated
331
local_inference_engine/Cargo.lock
generated
@@ -205,7 +205,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -313,6 +313,17 @@ version = "1.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10d2119f741b79fe9907f5396d19bffcb46568cfcc315e78677d731972ac7085"
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.88"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atoi"
|
||||
version = "2.0.0"
|
||||
@@ -357,6 +368,61 @@ dependencies = [
|
||||
"arrayvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tower 0.5.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"sync_wrapper",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.75"
|
||||
@@ -405,7 +471,7 @@ dependencies = [
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -523,7 +589,7 @@ checksum = "7ecc273b49b3205b83d648f0690daa588925572cc5063745bfe547fe7ec8e1a1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -770,7 +836,7 @@ dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1024,7 +1090,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1035,7 +1101,7 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1052,7 +1118,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1073,7 +1139,7 @@ dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1083,7 +1149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
|
||||
dependencies = [
|
||||
"derive_builder_core",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1134,7 +1200,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1161,6 +1227,9 @@ name = "either"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "encode_unicode"
|
||||
@@ -1198,7 +1267,7 @@ dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1332,7 +1401,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1412,7 +1481,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1726,7 +1795,7 @@ dependencies = [
|
||||
"proc-macro-error2",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1862,6 +1931,12 @@ version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||
|
||||
[[package]]
|
||||
name = "httpdate"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.6.0"
|
||||
@@ -1875,6 +1950,7 @@ dependencies = [
|
||||
"http",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"smallvec",
|
||||
@@ -2125,6 +2201,7 @@ checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.15.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2182,7 +2259,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2414,6 +2491,7 @@ dependencies = [
|
||||
"ab_glyph",
|
||||
"accelerate-src",
|
||||
"anyhow",
|
||||
"axum",
|
||||
"bindgen_cuda",
|
||||
"byteorder",
|
||||
"candle-core",
|
||||
@@ -2426,6 +2504,7 @@ dependencies = [
|
||||
"cpal",
|
||||
"csv",
|
||||
"cudarc",
|
||||
"either",
|
||||
"enterpolation",
|
||||
"half",
|
||||
"hf-hub",
|
||||
@@ -2446,9 +2525,23 @@ dependencies = [
|
||||
"symphonia",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tower 0.4.13",
|
||||
"tower-http 0.5.2",
|
||||
"tracing",
|
||||
"tracing-chrome",
|
||||
"tracing-subscriber",
|
||||
"utoipa",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2509,6 +2602,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.10"
|
||||
@@ -2635,7 +2734,7 @@ checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2779,7 +2878,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2851,7 +2950,7 @@ dependencies = [
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3012,7 +3111,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3084,7 +3183,30 @@ dependencies = [
|
||||
"by_address",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"smallvec",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3183,7 +3305,7 @@ dependencies = [
|
||||
"phf_shared",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3257,7 +3379,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dee91521343f4c5c6a63edd65e54f31f5c92fe8978c40a4282f8372194c6a7d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3278,6 +3400,30 @@ dependencies = [
|
||||
"toml_edit",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||
dependencies = [
|
||||
"proc-macro-error-attr",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error-attr"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error-attr2"
|
||||
version = "2.0.0"
|
||||
@@ -3297,7 +3443,7 @@ dependencies = [
|
||||
"proc-macro-error-attr2",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3325,7 +3471,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3355,7 +3501,7 @@ dependencies = [
|
||||
"prost",
|
||||
"prost-types",
|
||||
"regex",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
@@ -3369,7 +3515,7 @@ dependencies = [
|
||||
"itertools 0.12.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3454,7 +3600,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3467,7 +3613,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3774,8 +3920,8 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower 0.5.2",
|
||||
"tower-http 0.6.6",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
@@ -3949,6 +4095,12 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.11.1"
|
||||
@@ -4001,7 +4153,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4016,6 +4168,16 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_plain"
|
||||
version = "1.0.2"
|
||||
@@ -4072,6 +4234,15 @@ version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.8.1"
|
||||
@@ -4200,7 +4371,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4404,6 +4575,16 @@ dependencies = [
|
||||
"symphonia-metadata",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.101"
|
||||
@@ -4432,7 +4613,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4553,7 +4734,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4564,7 +4745,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4648,7 +4829,9 @@ dependencies = [
|
||||
"bytes",
|
||||
"libc",
|
||||
"mio",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"windows-sys 0.52.0",
|
||||
@@ -4662,7 +4845,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4748,6 +4931,17 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
|
||||
dependencies = [
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.2"
|
||||
@@ -4761,6 +4955,23 @@ dependencies = [
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"bytes",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"pin-project-lite",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4776,7 +4987,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"iri-string",
|
||||
"pin-project-lite",
|
||||
"tower",
|
||||
"tower 0.5.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
@@ -4799,6 +5010,7 @@ version = "0.1.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
@@ -4812,7 +5024,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5035,6 +5247,31 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "utoipa"
|
||||
version = "4.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"utoipa-gen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utoipa-gen"
|
||||
version = "4.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392"
|
||||
dependencies = [
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.17.0"
|
||||
@@ -5137,7 +5374,7 @@ dependencies = [
|
||||
"log",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
@@ -5172,7 +5409,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
@@ -5319,7 +5556,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5330,7 +5567,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5721,7 +5958,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
@@ -5733,7 +5970,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
@@ -5754,7 +5991,7 @@ checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5774,7 +6011,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
@@ -5814,7 +6051,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@@ -36,6 +36,13 @@ clap= { version = "4.2.4", features = ["derive"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
axum = { version = "0.7.4", features = ["json"] }
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tokio = { version = "1.43.0", features = ["full"] }
|
||||
either = { version = "1.9.0", features = ["serde"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
|
207
local_inference_engine/README.md
Normal file
207
local_inference_engine/README.md
Normal file
@@ -0,0 +1,207 @@
|
||||
# Local Inference Engine
|
||||
|
||||
A Rust-based inference engine for running large language models locally. This tool supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
|
||||
|
||||
## Features
|
||||
|
||||
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
|
||||
- CLI mode for direct text generation
|
||||
- Server mode with OpenAI-compatible API
|
||||
- Support for various model configurations (base, instruction-tuned)
|
||||
- Metal acceleration on macOS
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust toolchain (install via [rustup](https://rustup.rs/))
|
||||
- Cargo package manager
|
||||
- For GPU acceleration:
|
||||
- macOS: Metal support
|
||||
- Linux/Windows: CUDA support (requires appropriate drivers)
|
||||
|
||||
### Building from Source
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/seemueller-io/open-web-agent-rs.git
|
||||
cd open-web-agent-rs
|
||||
```
|
||||
|
||||
2. Build the local inference engine:
|
||||
```bash
|
||||
cd local_inference_engine
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### CLI Mode
|
||||
|
||||
Run the inference engine in CLI mode to generate text directly:
|
||||
|
||||
```bash
|
||||
cargo run --release -- --prompt "Your prompt text here" --which 3-1b-it
|
||||
```
|
||||
|
||||
#### CLI Options
|
||||
|
||||
- `--prompt <TEXT>`: The prompt text to generate from
|
||||
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
||||
- `--server`: Run OpenAI compatible server
|
||||
- Available options: "2b", "7b", "2b-it", "7b-it", "1.1-2b-it", "1.1-7b-it", "code-2b", "code-7b", "code-2b-it", "code-7b-it", "2-2b", "2-2b-it", "2-9b", "2-9b-it", "3-1b", "3-1b-it"
|
||||
- `--temperature <FLOAT>`: Temperature for sampling (higher = more random)
|
||||
- `--top-p <FLOAT>`: Nucleus sampling probability cutoff
|
||||
- `--sample-len <INT>`: Maximum number of tokens to generate (default: 10000)
|
||||
- `--repeat-penalty <FLOAT>`: Penalty for repeating tokens (default: 1.1)
|
||||
- `--repeat-last-n <INT>`: Context size for repeat penalty (default: 64)
|
||||
- `--cpu`: Run on CPU instead of GPU
|
||||
- `--tracing`: Enable tracing (generates a trace-timestamp.json file)
|
||||
|
||||
### Server Mode with OpenAI-compatible API
|
||||
|
||||
Run the inference engine in server mode to expose an OpenAI-compatible API:
|
||||
|
||||
```bash
|
||||
cargo run --release -- --server --port 3000 --which 3-1b-it
|
||||
```
|
||||
|
||||
This starts a web server on the specified port (default: 3000) with an OpenAI-compatible chat completions endpoint.
|
||||
|
||||
#### Server Options
|
||||
|
||||
- `--server`: Run in server mode
|
||||
- `--port <INT>`: Port to use for the server (default: 3000)
|
||||
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
||||
- Other model options as described in CLI mode
|
||||
|
||||
## API Usage
|
||||
|
||||
The server exposes an OpenAI-compatible chat completions endpoint:
|
||||
|
||||
### Chat Completions
|
||||
|
||||
```
|
||||
POST /v1/chat/completions
|
||||
```
|
||||
|
||||
#### Request Format
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemma-3-1b-it",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 256,
|
||||
"top_p": 0.9,
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
#### Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-123abc456def789ghi",
|
||||
"object": "chat.completion",
|
||||
"created": 1677858242,
|
||||
"model": "gemma-3-1b-it",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well, thank you for asking! How can I assist you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 25,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 40
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example: Using cURL
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:3000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gemma-3-1b-it",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the capital of France?"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Using Python with OpenAI Client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:3000/v1",
|
||||
api_key="dummy" # API key is not validated but required by the client
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gemma-3-1b-it",
|
||||
messages=[
|
||||
{"role": "user", "content": "What is the capital of France?"}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
### Example: Using JavaScript/TypeScript with OpenAI SDK
|
||||
|
||||
```javascript
|
||||
import OpenAI from 'openai';
|
||||
|
||||
const openai = new OpenAI({
|
||||
baseURL: 'http://localhost:3000/v1',
|
||||
apiKey: 'dummy', // API key is not validated but required by the client
|
||||
});
|
||||
|
||||
async function main() {
|
||||
const response = await openai.chat.completions.create({
|
||||
model: 'gemma-3-1b-it',
|
||||
messages: [
|
||||
{ role: 'user', content: 'What is the capital of France?' }
|
||||
],
|
||||
temperature: 0.7,
|
||||
max_tokens: 100,
|
||||
});
|
||||
|
||||
console.log(response.choices[0].message.content);
|
||||
}
|
||||
|
||||
main();
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Model download errors**: Make sure you have a stable internet connection. The models are downloaded from Hugging Face Hub.
|
||||
|
||||
2. **Out of memory errors**: Try using a smaller model variant or reducing the batch size.
|
||||
|
||||
3. **Slow inference on CPU**: This is expected. For better performance, use GPU acceleration if available.
|
||||
|
||||
4. **Metal/CUDA errors**: Ensure you have the latest drivers installed for your GPU.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the terms specified in the LICENSE file.
|
@@ -8,12 +8,286 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use clap::Parser;
|
||||
use either::Either;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
|
||||
// OpenAI API compatible structs
|
||||
|
||||
/// Inner content structure for messages that can be either a string or key-value pairs
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageInnerContent(
|
||||
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageInnerContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
(
|
||||
"MessageInnerContent",
|
||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Function for MessageInnerContent Schema generation to handle `Either`
|
||||
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||
|
||||
Schema::OneOf(
|
||||
OneOfBuilder::new()
|
||||
// Either::Left - simple string
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))
|
||||
// Either::Right - object with string values
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new()
|
||||
.schema_type(SchemaType::Object)
|
||||
.additional_properties(Some(RefOr::T(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))))
|
||||
.build(),
|
||||
))
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Message content that can be either simple text or complex structured content
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContent(
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Function for MessageContent Schema generation to handle `Either`
|
||||
fn message_content_schema() -> utoipa::openapi::Schema {
|
||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||
|
||||
Schema::OneOf(
|
||||
OneOfBuilder::new()
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))
|
||||
.item(Schema::Array(
|
||||
ArrayBuilder::new()
|
||||
.items(RefOr::T(Schema::Object(
|
||||
ObjectBuilder::new()
|
||||
.schema_type(SchemaType::Object)
|
||||
.additional_properties(Some(RefOr::Ref(
|
||||
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
||||
)))
|
||||
.build(),
|
||||
)))
|
||||
.build(),
|
||||
))
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Represents a single message in a conversation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub struct Message {
|
||||
/// The message content
|
||||
pub content: Option<MessageContent>,
|
||||
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
||||
pub role: String,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Stop token configuration for generation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum StopTokens {
|
||||
/// Multiple possible stop sequences
|
||||
Multi(Vec<String>),
|
||||
/// Single stop sequence
|
||||
Single(String),
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_false() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_1usize() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_model() -> String {
|
||||
"default".to_string()
|
||||
}
|
||||
|
||||
/// Chat completion request following OpenAI's specification
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionRequest {
|
||||
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
||||
pub messages: Vec<Message>,
|
||||
#[schema(example = "gemma-3-1b-it")]
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
#[serde(default = "default_false")]
|
||||
#[schema(example = false)]
|
||||
pub logprobs: bool,
|
||||
#[schema(example = 256)]
|
||||
pub max_tokens: Option<usize>,
|
||||
#[serde(rename = "n")]
|
||||
#[serde(default = "default_1usize")]
|
||||
#[schema(example = 1)]
|
||||
pub n_choices: usize,
|
||||
#[schema(example = 0.7)]
|
||||
pub temperature: Option<f64>,
|
||||
#[schema(example = 0.9)]
|
||||
pub top_p: Option<f64>,
|
||||
#[schema(example = false)]
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
/// Chat completion choice
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionChoice {
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
/// Chat completion response
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatCompletionChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
/// Token usage information
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
// Application state shared between handlers
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
text_generation: Arc<Mutex<TextGeneration>>,
|
||||
model_id: String,
|
||||
}
|
||||
|
||||
// Chat completions endpoint handler
|
||||
async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// Convert messages to a prompt string
|
||||
for message in &request.messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
// Format based on role
|
||||
match role.as_str() {
|
||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||
}
|
||||
}
|
||||
|
||||
// Add the assistant prefix for the response
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
// Capture the output
|
||||
let mut output = Vec::new();
|
||||
{
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
|
||||
// Buffer to capture the output
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
// Run text generation
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
||||
|
||||
if let Err(e) = result {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Error generating text: {}", e),
|
||||
"type": "internal_server_error"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
// Convert buffer to string
|
||||
if let Ok(text) = String::from_utf8(buffer) {
|
||||
output.push(text);
|
||||
}
|
||||
}
|
||||
|
||||
// Create response
|
||||
let response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
model: request.model,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(MessageContent(Either::Left(output.join("")))),
|
||||
name: None,
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: prompt.len() / 4, // Rough estimate
|
||||
completion_tokens: output.join("").len() / 4, // Rough estimate
|
||||
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
||||
},
|
||||
};
|
||||
|
||||
// Return the response as JSON
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
use candle_core::{DType, Device, MetalDevice, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
@@ -22,6 +296,22 @@ use tokenizers::Tokenizer;
|
||||
use crate::token_output_stream::TokenOutputStream;
|
||||
use crate::utilities_lib::device;
|
||||
|
||||
// Create the router with the chat completions endpoint
|
||||
fn create_router(app_state: AppState) -> Router {
|
||||
// CORS layer to allow requests from any origin
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
Router::new()
|
||||
// OpenAI compatible endpoints
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
// Add more endpoints as needed
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
@@ -108,6 +398,7 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
// Run text generation and print to stdout
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
@@ -195,6 +486,90 @@ impl TextGeneration {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run text generation and write to a buffer
|
||||
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Write prompt tokens to output
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Write any remaining tokens
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
write!(output, "{}", rest)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@@ -208,8 +583,17 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Run in server mode with OpenAI compatible API
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
server: bool,
|
||||
|
||||
/// Port to use for the server
|
||||
#[arg(long, default_value_t = 3000)]
|
||||
port: u16,
|
||||
|
||||
/// Prompt for text generation (not used in server mode)
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
@@ -308,7 +692,7 @@ fn main() -> Result<()> {
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
model_id.clone(),
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
@@ -371,7 +755,7 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
let pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
@@ -382,6 +766,36 @@ fn main() -> Result<()> {
|
||||
&device,
|
||||
);
|
||||
|
||||
if args.server {
|
||||
// Start the server
|
||||
println!("Starting server on port {}", args.port);
|
||||
|
||||
// Create app state
|
||||
let app_state = AppState {
|
||||
text_generation: Arc::new(Mutex::new(pipeline)),
|
||||
model_id,
|
||||
};
|
||||
|
||||
// Create router
|
||||
let app = create_router(app_state);
|
||||
|
||||
// Run the server
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
|
||||
|
||||
// Use tokio to run the server
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(async {
|
||||
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
// Run in CLI mode
|
||||
if let Some(prompt_text) = &args.prompt {
|
||||
let prompt = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
@@ -397,15 +811,20 @@ fn main() -> Result<()> {
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B => args.prompt,
|
||||
| Which::BaseV3_1B => prompt_text.clone(),
|
||||
Which::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||
args.prompt
|
||||
prompt_text
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let mut pipeline = pipeline;
|
||||
pipeline.run(&prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
} else {
|
||||
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user