add openai compatible endpoint for chat completions
This commit is contained in:

committed by
Geoff Seemueller

parent
3b4c8b045a
commit
8a3c0797c3
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,4 @@
|
|||||||
/target
|
**/target
|
||||||
/node_modules/
|
/node_modules/
|
||||||
/.idea
|
/.idea
|
||||||
chrome
|
chrome
|
||||||
|
21
README.md
21
README.md
@@ -1,10 +1,27 @@
|
|||||||
# open-web-agent-rs
|
# open-web-agent-rs
|
||||||
|
|
||||||
|
A Rust-based web agent with local inference capabilities.
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
### Local Inference Engine
|
||||||
|
|
||||||
|
The [Local Inference Engine](./local_inference_engine/README.md) provides a way to run large language models locally. It supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
|
||||||
|
- CLI mode for direct text generation
|
||||||
|
- Server mode with OpenAI-compatible API
|
||||||
|
- Support for various model configurations (base, instruction-tuned)
|
||||||
|
- Metal acceleration on macOS
|
||||||
|
|
||||||
|
See the [Local Inference Engine README](./local_inference_engine/README.md) for detailed usage instructions.
|
||||||
|
|
||||||
|
### Web Server
|
||||||
|
|
||||||
Server is being converted to MCP. Things are probably broken.
|
Server is being converted to MCP. Things are probably broken.
|
||||||
|
|
||||||
```text
|
```text
|
||||||
bun i
|
bun i
|
||||||
bun dev
|
bun dev
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
331
local_inference_engine/Cargo.lock
generated
331
local_inference_engine/Cargo.lock
generated
@@ -205,7 +205,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -313,6 +313,17 @@ version = "1.1.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "10d2119f741b79fe9907f5396d19bffcb46568cfcc315e78677d731972ac7085"
|
checksum = "10d2119f741b79fe9907f5396d19bffcb46568cfcc315e78677d731972ac7085"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-trait"
|
||||||
|
version = "0.1.88"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.101",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atoi"
|
name = "atoi"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
@@ -357,6 +368,61 @@ dependencies = [
|
|||||||
"arrayvec",
|
"arrayvec",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum"
|
||||||
|
version = "0.7.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"axum-core",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-body-util",
|
||||||
|
"hyper",
|
||||||
|
"hyper-util",
|
||||||
|
"itoa",
|
||||||
|
"matchit",
|
||||||
|
"memchr",
|
||||||
|
"mime",
|
||||||
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
|
"serde_urlencoded",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tokio",
|
||||||
|
"tower 0.5.2",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-core"
|
||||||
|
version = "0.4.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-body-util",
|
||||||
|
"mime",
|
||||||
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backtrace"
|
name = "backtrace"
|
||||||
version = "0.3.75"
|
version = "0.3.75"
|
||||||
@@ -405,7 +471,7 @@ dependencies = [
|
|||||||
"regex",
|
"regex",
|
||||||
"rustc-hash",
|
"rustc-hash",
|
||||||
"shlex",
|
"shlex",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -523,7 +589,7 @@ checksum = "7ecc273b49b3205b83d648f0690daa588925572cc5063745bfe547fe7ec8e1a1"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -770,7 +836,7 @@ dependencies = [
|
|||||||
"heck",
|
"heck",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1024,7 +1090,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"strsim",
|
"strsim",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1035,7 +1101,7 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"darling_core",
|
"darling_core",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1052,7 +1118,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1073,7 +1139,7 @@ dependencies = [
|
|||||||
"darling",
|
"darling",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1083,7 +1149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
|
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive_builder_core",
|
"derive_builder_core",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1134,7 +1200,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1161,6 +1227,9 @@ name = "either"
|
|||||||
version = "1.15.0"
|
version = "1.15.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "encode_unicode"
|
name = "encode_unicode"
|
||||||
@@ -1198,7 +1267,7 @@ dependencies = [
|
|||||||
"heck",
|
"heck",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1332,7 +1401,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1412,7 +1481,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1726,7 +1795,7 @@ dependencies = [
|
|||||||
"proc-macro-error2",
|
"proc-macro-error2",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1862,6 +1931,12 @@ version = "1.10.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "httpdate"
|
||||||
|
version = "1.0.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hyper"
|
name = "hyper"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
@@ -1875,6 +1950,7 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"httparse",
|
"httparse",
|
||||||
|
"httpdate",
|
||||||
"itoa",
|
"itoa",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
@@ -2125,6 +2201,7 @@ checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"equivalent",
|
"equivalent",
|
||||||
"hashbrown 0.15.3",
|
"hashbrown 0.15.3",
|
||||||
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2182,7 +2259,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2414,6 +2491,7 @@ dependencies = [
|
|||||||
"ab_glyph",
|
"ab_glyph",
|
||||||
"accelerate-src",
|
"accelerate-src",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"axum",
|
||||||
"bindgen_cuda",
|
"bindgen_cuda",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"candle-core",
|
"candle-core",
|
||||||
@@ -2426,6 +2504,7 @@ dependencies = [
|
|||||||
"cpal",
|
"cpal",
|
||||||
"csv",
|
"csv",
|
||||||
"cudarc",
|
"cudarc",
|
||||||
|
"either",
|
||||||
"enterpolation",
|
"enterpolation",
|
||||||
"half",
|
"half",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
@@ -2446,9 +2525,23 @@ dependencies = [
|
|||||||
"symphonia",
|
"symphonia",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tower 0.4.13",
|
||||||
|
"tower-http 0.5.2",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-chrome",
|
"tracing-chrome",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
"utoipa",
|
||||||
|
"uuid",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lock_api"
|
||||||
|
version = "0.4.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"scopeguard",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2509,6 +2602,12 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matchit"
|
||||||
|
version = "0.7.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matrixmultiply"
|
name = "matrixmultiply"
|
||||||
version = "0.3.10"
|
version = "0.3.10"
|
||||||
@@ -2635,7 +2734,7 @@ checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2779,7 +2878,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2851,7 +2950,7 @@ dependencies = [
|
|||||||
"proc-macro-crate",
|
"proc-macro-crate",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3012,7 +3111,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3084,7 +3183,30 @@ dependencies = [
|
|||||||
"by_address",
|
"by_address",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot"
|
||||||
|
version = "0.12.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13"
|
||||||
|
dependencies = [
|
||||||
|
"lock_api",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot_core"
|
||||||
|
version = "0.9.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"redox_syscall",
|
||||||
|
"smallvec",
|
||||||
|
"windows-targets 0.52.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3183,7 +3305,7 @@ dependencies = [
|
|||||||
"phf_shared",
|
"phf_shared",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3257,7 +3379,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "9dee91521343f4c5c6a63edd65e54f31f5c92fe8978c40a4282f8372194c6a7d"
|
checksum = "9dee91521343f4c5c6a63edd65e54f31f5c92fe8978c40a4282f8372194c6a7d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3278,6 +3400,30 @@ dependencies = [
|
|||||||
"toml_edit",
|
"toml_edit",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-error"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-error-attr",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-error-attr"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro-error-attr2"
|
name = "proc-macro-error-attr2"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
@@ -3297,7 +3443,7 @@ dependencies = [
|
|||||||
"proc-macro-error-attr2",
|
"proc-macro-error-attr2",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3325,7 +3471,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30"
|
checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3355,7 +3501,7 @@ dependencies = [
|
|||||||
"prost",
|
"prost",
|
||||||
"prost-types",
|
"prost-types",
|
||||||
"regex",
|
"regex",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -3369,7 +3515,7 @@ dependencies = [
|
|||||||
"itertools 0.12.1",
|
"itertools 0.12.1",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3454,7 +3600,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-macros-backend",
|
"pyo3-macros-backend",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3467,7 +3613,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-build-config",
|
"pyo3-build-config",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3774,8 +3920,8 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
"tokio-native-tls",
|
"tokio-native-tls",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
"tower",
|
"tower 0.5.2",
|
||||||
"tower-http",
|
"tower-http 0.6.6",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"url",
|
"url",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
@@ -3949,6 +4095,12 @@ dependencies = [
|
|||||||
"windows-sys 0.59.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scopeguard"
|
||||||
|
version = "1.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "security-framework"
|
name = "security-framework"
|
||||||
version = "2.11.1"
|
version = "2.11.1"
|
||||||
@@ -4001,7 +4153,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4016,6 +4168,16 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
|
||||||
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_plain"
|
name = "serde_plain"
|
||||||
version = "1.0.2"
|
version = "1.0.2"
|
||||||
@@ -4072,6 +4234,15 @@ version = "1.3.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "signal-hook-registry"
|
||||||
|
version = "1.4.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "simba"
|
name = "simba"
|
||||||
version = "0.8.1"
|
version = "0.8.1"
|
||||||
@@ -4200,7 +4371,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4404,6 +4575,16 @@ dependencies = [
|
|||||||
"symphonia-metadata",
|
"symphonia-metadata",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "1.0.109"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.101"
|
version = "2.0.101"
|
||||||
@@ -4432,7 +4613,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4553,7 +4734,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4564,7 +4745,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4648,7 +4829,9 @@ dependencies = [
|
|||||||
"bytes",
|
"bytes",
|
||||||
"libc",
|
"libc",
|
||||||
"mio",
|
"mio",
|
||||||
|
"parking_lot",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"signal-hook-registry",
|
||||||
"socket2",
|
"socket2",
|
||||||
"tokio-macros",
|
"tokio-macros",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
@@ -4662,7 +4845,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4748,6 +4931,17 @@ dependencies = [
|
|||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower"
|
||||||
|
version = "0.4.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
|
||||||
|
dependencies = [
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower"
|
name = "tower"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
@@ -4761,6 +4955,23 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-http"
|
||||||
|
version = "0.5.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.1",
|
||||||
|
"bytes",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-body-util",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4776,7 +4987,7 @@ dependencies = [
|
|||||||
"http-body",
|
"http-body",
|
||||||
"iri-string",
|
"iri-string",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tower",
|
"tower 0.5.2",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
@@ -4799,6 +5010,7 @@ version = "0.1.41"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
|
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"log",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tracing-attributes",
|
"tracing-attributes",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
@@ -4812,7 +5024,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5035,6 +5247,31 @@ version = "0.2.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa"
|
||||||
|
version = "4.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"utoipa-gen",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa-gen"
|
||||||
|
version = "4.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-error",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"syn 2.0.101",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "uuid"
|
name = "uuid"
|
||||||
version = "1.17.0"
|
version = "1.17.0"
|
||||||
@@ -5137,7 +5374,7 @@ dependencies = [
|
|||||||
"log",
|
"log",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -5172,7 +5409,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
"wasm-bindgen-backend",
|
"wasm-bindgen-backend",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
@@ -5319,7 +5556,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5330,7 +5567,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5721,7 +5958,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
"synstructure",
|
"synstructure",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -5733,7 +5970,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
"synstructure",
|
"synstructure",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -5754,7 +5991,7 @@ checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5774,7 +6011,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
"synstructure",
|
"synstructure",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -5814,7 +6051,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@@ -36,6 +36,13 @@ clap= { version = "4.2.4", features = ["derive"] }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
|
axum = { version = "0.7.4", features = ["json"] }
|
||||||
|
tower = "0.4.13"
|
||||||
|
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||||
|
tokio = { version = "1.43.0", features = ["full"] }
|
||||||
|
either = { version = "1.9.0", features = ["serde"] }
|
||||||
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
207
local_inference_engine/README.md
Normal file
207
local_inference_engine/README.md
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
# Local Inference Engine
|
||||||
|
|
||||||
|
A Rust-based inference engine for running large language models locally. This tool supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
|
||||||
|
- CLI mode for direct text generation
|
||||||
|
- Server mode with OpenAI-compatible API
|
||||||
|
- Support for various model configurations (base, instruction-tuned)
|
||||||
|
- Metal acceleration on macOS
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Rust toolchain (install via [rustup](https://rustup.rs/))
|
||||||
|
- Cargo package manager
|
||||||
|
- For GPU acceleration:
|
||||||
|
- macOS: Metal support
|
||||||
|
- Linux/Windows: CUDA support (requires appropriate drivers)
|
||||||
|
|
||||||
|
### Building from Source
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/seemueller-io/open-web-agent-rs.git
|
||||||
|
cd open-web-agent-rs
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Build the local inference engine:
|
||||||
|
```bash
|
||||||
|
cd local_inference_engine
|
||||||
|
cargo build --release
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### CLI Mode
|
||||||
|
|
||||||
|
Run the inference engine in CLI mode to generate text directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --release -- --prompt "Your prompt text here" --which 3-1b-it
|
||||||
|
```
|
||||||
|
|
||||||
|
#### CLI Options
|
||||||
|
|
||||||
|
- `--prompt <TEXT>`: The prompt text to generate from
|
||||||
|
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
||||||
|
- `--server`: Run OpenAI compatible server
|
||||||
|
- Available options: "2b", "7b", "2b-it", "7b-it", "1.1-2b-it", "1.1-7b-it", "code-2b", "code-7b", "code-2b-it", "code-7b-it", "2-2b", "2-2b-it", "2-9b", "2-9b-it", "3-1b", "3-1b-it"
|
||||||
|
- `--temperature <FLOAT>`: Temperature for sampling (higher = more random)
|
||||||
|
- `--top-p <FLOAT>`: Nucleus sampling probability cutoff
|
||||||
|
- `--sample-len <INT>`: Maximum number of tokens to generate (default: 10000)
|
||||||
|
- `--repeat-penalty <FLOAT>`: Penalty for repeating tokens (default: 1.1)
|
||||||
|
- `--repeat-last-n <INT>`: Context size for repeat penalty (default: 64)
|
||||||
|
- `--cpu`: Run on CPU instead of GPU
|
||||||
|
- `--tracing`: Enable tracing (generates a trace-timestamp.json file)
|
||||||
|
|
||||||
|
### Server Mode with OpenAI-compatible API
|
||||||
|
|
||||||
|
Run the inference engine in server mode to expose an OpenAI-compatible API:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --release -- --server --port 3000 --which 3-1b-it
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts a web server on the specified port (default: 3000) with an OpenAI-compatible chat completions endpoint.
|
||||||
|
|
||||||
|
#### Server Options
|
||||||
|
|
||||||
|
- `--server`: Run in server mode
|
||||||
|
- `--port <INT>`: Port to use for the server (default: 3000)
|
||||||
|
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
||||||
|
- Other model options as described in CLI mode
|
||||||
|
|
||||||
|
## API Usage
|
||||||
|
|
||||||
|
The server exposes an OpenAI-compatible chat completions endpoint:
|
||||||
|
|
||||||
|
### Chat Completions
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /v1/chat/completions
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Request Format
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "gemma-3-1b-it",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 256,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"stream": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response Format
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-123abc456def789ghi",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677858242,
|
||||||
|
"model": "gemma-3-1b-it",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'm doing well, thank you for asking! How can I assist you today?"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"completion_tokens": 15,
|
||||||
|
"total_tokens": 40
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Using cURL
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:3000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemma-3-1b-it",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is the capital of France?"}
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 100
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Using Python with OpenAI Client
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:3000/v1",
|
||||||
|
api_key="dummy" # API key is not validated but required by the client
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gemma-3-1b-it",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What is the capital of France?"}
|
||||||
|
],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Using JavaScript/TypeScript with OpenAI SDK
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
import OpenAI from 'openai';
|
||||||
|
|
||||||
|
const openai = new OpenAI({
|
||||||
|
baseURL: 'http://localhost:3000/v1',
|
||||||
|
apiKey: 'dummy', // API key is not validated but required by the client
|
||||||
|
});
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const response = await openai.chat.completions.create({
|
||||||
|
model: 'gemma-3-1b-it',
|
||||||
|
messages: [
|
||||||
|
{ role: 'user', content: 'What is the capital of France?' }
|
||||||
|
],
|
||||||
|
temperature: 0.7,
|
||||||
|
max_tokens: 100,
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(response.choices[0].message.content);
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Model download errors**: Make sure you have a stable internet connection. The models are downloaded from Hugging Face Hub.
|
||||||
|
|
||||||
|
2. **Out of memory errors**: Try using a smaller model variant or reducing the batch size.
|
||||||
|
|
||||||
|
3. **Slow inference on CPU**: This is expected. For better performance, use GPU acceleration if available.
|
||||||
|
|
||||||
|
4. **Metal/CUDA errors**: Ensure you have the latest drivers installed for your GPU.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the terms specified in the LICENSE file.
|
@@ -8,12 +8,286 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
response::IntoResponse,
|
||||||
|
routing::{get, post},
|
||||||
|
Json, Router,
|
||||||
|
};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use either::Either;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
|
|
||||||
|
// OpenAI API compatible structs
|
||||||
|
|
||||||
|
/// Inner content structure for messages that can be either a string or key-value pairs
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct MessageInnerContent(
|
||||||
|
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
||||||
|
);
|
||||||
|
|
||||||
|
impl ToSchema<'_> for MessageInnerContent {
|
||||||
|
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||||
|
(
|
||||||
|
"MessageInnerContent",
|
||||||
|
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Function for MessageInnerContent Schema generation to handle `Either`
|
||||||
|
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||||
|
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||||
|
|
||||||
|
Schema::OneOf(
|
||||||
|
OneOfBuilder::new()
|
||||||
|
// Either::Left - simple string
|
||||||
|
.item(Schema::Object(
|
||||||
|
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||||
|
))
|
||||||
|
// Either::Right - object with string values
|
||||||
|
.item(Schema::Object(
|
||||||
|
ObjectBuilder::new()
|
||||||
|
.schema_type(SchemaType::Object)
|
||||||
|
.additional_properties(Some(RefOr::T(Schema::Object(
|
||||||
|
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||||
|
))))
|
||||||
|
.build(),
|
||||||
|
))
|
||||||
|
.build(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Message content that can be either simple text or complex structured content
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct MessageContent(
|
||||||
|
#[serde(with = "either::serde_untagged")]
|
||||||
|
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||||
|
);
|
||||||
|
|
||||||
|
impl ToSchema<'_> for MessageContent {
|
||||||
|
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||||
|
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Function for MessageContent Schema generation to handle `Either`
|
||||||
|
fn message_content_schema() -> utoipa::openapi::Schema {
|
||||||
|
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||||
|
|
||||||
|
Schema::OneOf(
|
||||||
|
OneOfBuilder::new()
|
||||||
|
.item(Schema::Object(
|
||||||
|
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||||
|
))
|
||||||
|
.item(Schema::Array(
|
||||||
|
ArrayBuilder::new()
|
||||||
|
.items(RefOr::T(Schema::Object(
|
||||||
|
ObjectBuilder::new()
|
||||||
|
.schema_type(SchemaType::Object)
|
||||||
|
.additional_properties(Some(RefOr::Ref(
|
||||||
|
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
||||||
|
)))
|
||||||
|
.build(),
|
||||||
|
)))
|
||||||
|
.build(),
|
||||||
|
))
|
||||||
|
.build(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a single message in a conversation
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub struct Message {
|
||||||
|
/// The message content
|
||||||
|
pub content: Option<MessageContent>,
|
||||||
|
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
||||||
|
pub role: String,
|
||||||
|
pub name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stop token configuration for generation
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum StopTokens {
|
||||||
|
/// Multiple possible stop sequences
|
||||||
|
Multi(Vec<String>),
|
||||||
|
/// Single stop sequence
|
||||||
|
Single(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default value helper
|
||||||
|
fn default_false() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default value helper
|
||||||
|
fn default_1usize() -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default value helper
|
||||||
|
fn default_model() -> String {
|
||||||
|
"default".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chat completion request following OpenAI's specification
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub struct ChatCompletionRequest {
|
||||||
|
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
#[schema(example = "gemma-3-1b-it")]
|
||||||
|
#[serde(default = "default_model")]
|
||||||
|
pub model: String,
|
||||||
|
#[serde(default = "default_false")]
|
||||||
|
#[schema(example = false)]
|
||||||
|
pub logprobs: bool,
|
||||||
|
#[schema(example = 256)]
|
||||||
|
pub max_tokens: Option<usize>,
|
||||||
|
#[serde(rename = "n")]
|
||||||
|
#[serde(default = "default_1usize")]
|
||||||
|
#[schema(example = 1)]
|
||||||
|
pub n_choices: usize,
|
||||||
|
#[schema(example = 0.7)]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[schema(example = 0.9)]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[schema(example = false)]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chat completion choice
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct ChatCompletionChoice {
|
||||||
|
pub index: usize,
|
||||||
|
pub message: Message,
|
||||||
|
pub finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chat completion response
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Token usage information
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: usize,
|
||||||
|
pub completion_tokens: usize,
|
||||||
|
pub total_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Application state shared between handlers
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {
|
||||||
|
text_generation: Arc<Mutex<TextGeneration>>,
|
||||||
|
model_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat completions endpoint handler
|
||||||
|
async fn chat_completions(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(request): Json<ChatCompletionRequest>,
|
||||||
|
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||||
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
// Convert messages to a prompt string
|
||||||
|
for message in &request.messages {
|
||||||
|
let role = &message.role;
|
||||||
|
let content = match &message.content {
|
||||||
|
Some(content) => match &content.0 {
|
||||||
|
Either::Left(text) => text.clone(),
|
||||||
|
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||||
|
},
|
||||||
|
None => "".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Format based on role
|
||||||
|
match role.as_str() {
|
||||||
|
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||||
|
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||||
|
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||||
|
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the assistant prefix for the response
|
||||||
|
prompt.push_str("Assistant: ");
|
||||||
|
|
||||||
|
// Capture the output
|
||||||
|
let mut output = Vec::new();
|
||||||
|
{
|
||||||
|
let mut text_gen = state.text_generation.lock().await;
|
||||||
|
|
||||||
|
// Buffer to capture the output
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
|
||||||
|
// Run text generation
|
||||||
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||||
|
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!("Error generating text: {}", e),
|
||||||
|
"type": "internal_server_error"
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert buffer to string
|
||||||
|
if let Ok(text) = String::from_utf8(buffer) {
|
||||||
|
output.push(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response
|
||||||
|
let response = ChatCompletionResponse {
|
||||||
|
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
||||||
|
object: "chat.completion".to_string(),
|
||||||
|
created: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
model: request.model,
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
message: Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: Some(MessageContent(Either::Left(output.join("")))),
|
||||||
|
name: None,
|
||||||
|
},
|
||||||
|
finish_reason: "stop".to_string(),
|
||||||
|
}],
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: prompt.len() / 4, // Rough estimate
|
||||||
|
completion_tokens: output.join("").len() / 4, // Rough estimate
|
||||||
|
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return the response as JSON
|
||||||
|
Ok(Json(response))
|
||||||
|
}
|
||||||
|
|
||||||
use candle_core::{DType, Device, MetalDevice, Tensor};
|
use candle_core::{DType, Device, MetalDevice, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
@@ -22,6 +296,22 @@ use tokenizers::Tokenizer;
|
|||||||
use crate::token_output_stream::TokenOutputStream;
|
use crate::token_output_stream::TokenOutputStream;
|
||||||
use crate::utilities_lib::device;
|
use crate::utilities_lib::device;
|
||||||
|
|
||||||
|
// Create the router with the chat completions endpoint
|
||||||
|
fn create_router(app_state: AppState) -> Router {
|
||||||
|
// CORS layer to allow requests from any origin
|
||||||
|
let cors = CorsLayer::new()
|
||||||
|
.allow_origin(Any)
|
||||||
|
.allow_methods(Any)
|
||||||
|
.allow_headers(Any);
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
// OpenAI compatible endpoints
|
||||||
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
// Add more endpoints as needed
|
||||||
|
.layer(cors)
|
||||||
|
.with_state(app_state)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
#[value(name = "2b")]
|
#[value(name = "2b")]
|
||||||
@@ -108,6 +398,7 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run text generation and print to stdout
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
self.tokenizer.clear();
|
self.tokenizer.clear();
|
||||||
@@ -195,6 +486,90 @@ impl TextGeneration {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run text generation and write to a buffer
|
||||||
|
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
// Write prompt tokens to output
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
write!(output, "{}", t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <eos> token"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => {
|
||||||
|
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
||||||
|
eos_token
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
|
||||||
|
// Manual implementation of repeat penalty to avoid type conflicts
|
||||||
|
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
for &token_id in &tokens[start_at..] {
|
||||||
|
let token_id = token_id as usize;
|
||||||
|
if token_id < logits_vec.len() {
|
||||||
|
let score = logits_vec[token_id];
|
||||||
|
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||||
|
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new tensor with the modified logits
|
||||||
|
let device = logits.device().clone();
|
||||||
|
let shape = logits.shape().clone();
|
||||||
|
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||||
|
new_logits.reshape(shape)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token || next_token == eot_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
write!(output, "{}", t)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write any remaining tokens
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
write!(output, "{}", rest)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@@ -208,8 +583,17 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Run in server mode with OpenAI compatible API
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
server: bool,
|
||||||
|
|
||||||
|
/// Port to use for the server
|
||||||
|
#[arg(long, default_value_t = 3000)]
|
||||||
|
port: u16,
|
||||||
|
|
||||||
|
/// Prompt for text generation (not used in server mode)
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@@ -308,7 +692,7 @@ fn main() -> Result<()> {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
model_id,
|
model_id.clone(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
args.revision,
|
args.revision,
|
||||||
));
|
));
|
||||||
@@ -371,7 +755,7 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
let pipeline = TextGeneration::new(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
args.seed,
|
args.seed,
|
||||||
@@ -382,6 +766,36 @@ fn main() -> Result<()> {
|
|||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if args.server {
|
||||||
|
// Start the server
|
||||||
|
println!("Starting server on port {}", args.port);
|
||||||
|
|
||||||
|
// Create app state
|
||||||
|
let app_state = AppState {
|
||||||
|
text_generation: Arc::new(Mutex::new(pipeline)),
|
||||||
|
model_id,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create router
|
||||||
|
let app = create_router(app_state);
|
||||||
|
|
||||||
|
// Run the server
|
||||||
|
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
|
||||||
|
|
||||||
|
// Use tokio to run the server
|
||||||
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()?
|
||||||
|
.block_on(async {
|
||||||
|
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
// Run in CLI mode
|
||||||
|
if let Some(prompt_text) = &args.prompt {
|
||||||
let prompt = match args.which {
|
let prompt = match args.which {
|
||||||
Which::Base2B
|
Which::Base2B
|
||||||
| Which::Base7B
|
| Which::Base7B
|
||||||
@@ -397,15 +811,20 @@ fn main() -> Result<()> {
|
|||||||
| Which::InstructV2_2B
|
| Which::InstructV2_2B
|
||||||
| Which::BaseV2_9B
|
| Which::BaseV2_9B
|
||||||
| Which::InstructV2_9B
|
| Which::InstructV2_9B
|
||||||
| Which::BaseV3_1B => args.prompt,
|
| Which::BaseV3_1B => prompt_text.clone(),
|
||||||
Which::InstructV3_1B => {
|
Which::InstructV3_1B => {
|
||||||
format!(
|
format!(
|
||||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||||
args.prompt
|
prompt_text
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut pipeline = pipeline;
|
||||||
pipeline.run(&prompt, args.sample_len)?;
|
pipeline.run(&prompt, args.sample_len)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user