add openai compatible endpoint for chat completions

This commit is contained in:
geoffsee
2025-06-05 20:31:59 -04:00
committed by Geoff Seemueller
parent 3b4c8b045a
commit 8a3c0797c3
6 changed files with 965 additions and 78 deletions

2
.gitignore vendored
View File

@@ -1,4 +1,4 @@
/target
**/target
/node_modules/
/.idea
chrome

View File

@@ -1,10 +1,27 @@
# open-web-agent-rs
A Rust-based web agent with local inference capabilities.
## Components
### Local Inference Engine
The [Local Inference Engine](./local_inference_engine/README.md) provides a way to run large language models locally. It supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
Features:
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
- CLI mode for direct text generation
- Server mode with OpenAI-compatible API
- Support for various model configurations (base, instruction-tuned)
- Metal acceleration on macOS
See the [Local Inference Engine README](./local_inference_engine/README.md) for detailed usage instructions.
### Web Server
Server is being converted to MCP. Things are probably broken.
```text
bun i
bun dev
```

View File

@@ -205,7 +205,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -313,6 +313,17 @@ version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10d2119f741b79fe9907f5396d19bffcb46568cfcc315e78677d731972ac7085"
[[package]]
name = "async-trait"
version = "0.1.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "atoi"
version = "2.0.0"
@@ -357,6 +368,61 @@ dependencies = [
"arrayvec",
]
[[package]]
name = "axum"
version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [
"async-trait",
"axum-core",
"bytes",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower 0.5.2",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "backtrace"
version = "0.3.75"
@@ -405,7 +471,7 @@ dependencies = [
"regex",
"rustc-hash",
"shlex",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -523,7 +589,7 @@ checksum = "7ecc273b49b3205b83d648f0690daa588925572cc5063745bfe547fe7ec8e1a1"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -770,7 +836,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1024,7 +1090,7 @@ dependencies = [
"proc-macro2",
"quote",
"strsim",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1035,7 +1101,7 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [
"darling_core",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1052,7 +1118,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1073,7 +1139,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1083,7 +1149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1134,7 +1200,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1161,6 +1227,9 @@ name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
dependencies = [
"serde",
]
[[package]]
name = "encode_unicode"
@@ -1198,7 +1267,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1332,7 +1401,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1412,7 +1481,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1726,7 +1795,7 @@ dependencies = [
"proc-macro-error2",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -1862,6 +1931,12 @@ version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "1.6.0"
@@ -1875,6 +1950,7 @@ dependencies = [
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
@@ -2125,6 +2201,7 @@ checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
dependencies = [
"equivalent",
"hashbrown 0.15.3",
"serde",
]
[[package]]
@@ -2182,7 +2259,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -2414,6 +2491,7 @@ dependencies = [
"ab_glyph",
"accelerate-src",
"anyhow",
"axum",
"bindgen_cuda",
"byteorder",
"candle-core",
@@ -2426,6 +2504,7 @@ dependencies = [
"cpal",
"csv",
"cudarc",
"either",
"enterpolation",
"half",
"hf-hub",
@@ -2446,9 +2525,23 @@ dependencies = [
"symphonia",
"tokenizers",
"tokio",
"tower 0.4.13",
"tower-http 0.5.2",
"tracing",
"tracing-chrome",
"tracing-subscriber",
"utoipa",
"uuid",
]
[[package]]
name = "lock_api"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
@@ -2509,6 +2602,12 @@ dependencies = [
"libc",
]
[[package]]
name = "matchit"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matrixmultiply"
version = "0.3.10"
@@ -2635,7 +2734,7 @@ checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -2779,7 +2878,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -2851,7 +2950,7 @@ dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3012,7 +3111,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3084,7 +3183,30 @@ dependencies = [
"by_address",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
name = "parking_lot"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-targets 0.52.6",
]
[[package]]
@@ -3183,7 +3305,7 @@ dependencies = [
"phf_shared",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3257,7 +3379,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dee91521343f4c5c6a63edd65e54f31f5c92fe8978c40a4282f8372194c6a7d"
dependencies = [
"proc-macro2",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3278,6 +3400,30 @@ dependencies = [
"toml_edit",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn 1.0.109",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro-error-attr2"
version = "2.0.0"
@@ -3297,7 +3443,7 @@ dependencies = [
"proc-macro-error-attr2",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3325,7 +3471,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30"
dependencies = [
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3355,7 +3501,7 @@ dependencies = [
"prost",
"prost-types",
"regex",
"syn",
"syn 2.0.101",
"tempfile",
]
@@ -3369,7 +3515,7 @@ dependencies = [
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3454,7 +3600,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3467,7 +3613,7 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -3774,8 +3920,8 @@ dependencies = [
"tokio",
"tokio-native-tls",
"tokio-util",
"tower",
"tower-http",
"tower 0.5.2",
"tower-http 0.6.6",
"tower-service",
"url",
"wasm-bindgen",
@@ -3949,6 +4095,12 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
@@ -4001,7 +4153,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -4016,6 +4168,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_plain"
version = "1.0.2"
@@ -4072,6 +4234,15 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook-registry"
version = "1.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410"
dependencies = [
"libc",
]
[[package]]
name = "simba"
version = "0.8.1"
@@ -4200,7 +4371,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -4404,6 +4575,16 @@ dependencies = [
"symphonia-metadata",
]
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.101"
@@ -4432,7 +4613,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -4553,7 +4734,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -4564,7 +4745,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -4648,7 +4829,9 @@ dependencies = [
"bytes",
"libc",
"mio",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys 0.52.0",
@@ -4662,7 +4845,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -4748,6 +4931,17 @@ dependencies = [
"num-traits",
]
[[package]]
name = "tower"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower"
version = "0.5.2"
@@ -4761,6 +4955,23 @@ dependencies = [
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
dependencies = [
"bitflags 2.9.1",
"bytes",
"http",
"http-body",
"http-body-util",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]]
@@ -4776,7 +4987,7 @@ dependencies = [
"http-body",
"iri-string",
"pin-project-lite",
"tower",
"tower 0.5.2",
"tower-layer",
"tower-service",
]
@@ -4799,6 +5010,7 @@ version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@@ -4812,7 +5024,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -5035,6 +5247,31 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "utoipa"
version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23"
dependencies = [
"indexmap",
"serde",
"serde_json",
"utoipa-gen",
]
[[package]]
name = "utoipa-gen"
version = "4.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"regex",
"syn 2.0.101",
]
[[package]]
name = "uuid"
version = "1.17.0"
@@ -5137,7 +5374,7 @@ dependencies = [
"log",
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
"wasm-bindgen-shared",
]
@@ -5172,7 +5409,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -5319,7 +5556,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -5330,7 +5567,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -5721,7 +5958,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
"synstructure",
]
@@ -5733,7 +5970,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
"synstructure",
]
@@ -5754,7 +5991,7 @@ checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]
@@ -5774,7 +6011,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
"synstructure",
]
@@ -5814,7 +6051,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.101",
]
[[package]]

View File

@@ -36,6 +36,13 @@ clap= { version = "4.2.4", features = ["derive"] }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
axum = { version = "0.7.4", features = ["json"] }
tower = "0.4.13"
tower-http = { version = "0.5.1", features = ["cors"] }
tokio = { version = "1.43.0", features = ["full"] }
either = { version = "1.9.0", features = ["serde"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
[dev-dependencies]

View File

@@ -0,0 +1,207 @@
# Local Inference Engine
A Rust-based inference engine for running large language models locally. This tool supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
## Features
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
- CLI mode for direct text generation
- Server mode with OpenAI-compatible API
- Support for various model configurations (base, instruction-tuned)
- Metal acceleration on macOS
## Installation
### Prerequisites
- Rust toolchain (install via [rustup](https://rustup.rs/))
- Cargo package manager
- For GPU acceleration:
- macOS: Metal support
- Linux/Windows: CUDA support (requires appropriate drivers)
### Building from Source
1. Clone the repository:
```bash
git clone https://github.com/seemueller-io/open-web-agent-rs.git
cd open-web-agent-rs
```
2. Build the local inference engine:
```bash
cd local_inference_engine
cargo build --release
```
## Usage
### CLI Mode
Run the inference engine in CLI mode to generate text directly:
```bash
cargo run --release -- --prompt "Your prompt text here" --which 3-1b-it
```
#### CLI Options
- `--prompt <TEXT>`: The prompt text to generate from
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
- `--server`: Run OpenAI compatible server
- Available options: "2b", "7b", "2b-it", "7b-it", "1.1-2b-it", "1.1-7b-it", "code-2b", "code-7b", "code-2b-it", "code-7b-it", "2-2b", "2-2b-it", "2-9b", "2-9b-it", "3-1b", "3-1b-it"
- `--temperature <FLOAT>`: Temperature for sampling (higher = more random)
- `--top-p <FLOAT>`: Nucleus sampling probability cutoff
- `--sample-len <INT>`: Maximum number of tokens to generate (default: 10000)
- `--repeat-penalty <FLOAT>`: Penalty for repeating tokens (default: 1.1)
- `--repeat-last-n <INT>`: Context size for repeat penalty (default: 64)
- `--cpu`: Run on CPU instead of GPU
- `--tracing`: Enable tracing (generates a trace-timestamp.json file)
### Server Mode with OpenAI-compatible API
Run the inference engine in server mode to expose an OpenAI-compatible API:
```bash
cargo run --release -- --server --port 3000 --which 3-1b-it
```
This starts a web server on the specified port (default: 3000) with an OpenAI-compatible chat completions endpoint.
#### Server Options
- `--server`: Run in server mode
- `--port <INT>`: Port to use for the server (default: 3000)
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
- Other model options as described in CLI mode
## API Usage
The server exposes an OpenAI-compatible chat completions endpoint:
### Chat Completions
```
POST /v1/chat/completions
```
#### Request Format
```json
{
"model": "gemma-3-1b-it",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
],
"temperature": 0.7,
"max_tokens": 256,
"top_p": 0.9,
"stream": false
}
```
#### Response Format
```json
{
"id": "chatcmpl-123abc456def789ghi",
"object": "chat.completion",
"created": 1677858242,
"model": "gemma-3-1b-it",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'm doing well, thank you for asking! How can I assist you today?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 25,
"completion_tokens": 15,
"total_tokens": 40
}
}
```
### Example: Using cURL
```bash
curl -X POST http://localhost:3000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gemma-3-1b-it",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
],
"temperature": 0.7,
"max_tokens": 100
}'
```
### Example: Using Python with OpenAI Client
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="dummy" # API key is not validated but required by the client
)
response = client.chat.completions.create(
model="gemma-3-1b-it",
messages=[
{"role": "user", "content": "What is the capital of France?"}
],
temperature=0.7,
max_tokens=100
)
print(response.choices[0].message.content)
```
### Example: Using JavaScript/TypeScript with OpenAI SDK
```javascript
import OpenAI from 'openai';
const openai = new OpenAI({
baseURL: 'http://localhost:3000/v1',
apiKey: 'dummy', // API key is not validated but required by the client
});
async function main() {
const response = await openai.chat.completions.create({
model: 'gemma-3-1b-it',
messages: [
{ role: 'user', content: 'What is the capital of France?' }
],
temperature: 0.7,
max_tokens: 100,
});
console.log(response.choices[0].message.content);
}
main();
```
## Troubleshooting
### Common Issues
1. **Model download errors**: Make sure you have a stable internet connection. The models are downloaded from Hugging Face Hub.
2. **Out of memory errors**: Try using a smaller model variant or reducing the batch size.
3. **Slow inference on CPU**: This is expected. For better performance, use GPU acceleration if available.
4. **Metal/CUDA errors**: Ensure you have the latest drivers installed for your GPU.
## License
This project is licensed under the terms specified in the LICENSE file.

View File

@@ -8,12 +8,286 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use axum::{
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use clap::Parser;
use either::Either;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tower_http::cors::{Any, CorsLayer};
use utoipa::ToSchema;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
// OpenAI API compatible structs
/// Inner content structure for messages that can be either a string or key-value pairs
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageInnerContent(
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
)
}
}
/// Function for MessageInnerContent Schema generation to handle `Either`
fn message_inner_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
// Either::Left - simple string
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
// Either::Right - object with string values
.item(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::T(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))))
.build(),
))
.build(),
)
}
/// Message content that can be either simple text or complex structured content
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
}
}
/// Function for MessageContent Schema generation to handle `Either`
fn message_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
.item(Schema::Array(
ArrayBuilder::new()
.items(RefOr::T(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::Ref(
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
)))
.build(),
)))
.build(),
))
.build(),
)
}
/// Represents a single message in a conversation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct Message {
/// The message content
pub content: Option<MessageContent>,
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
pub role: String,
pub name: Option<String>,
}
/// Stop token configuration for generation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(untagged)]
pub enum StopTokens {
/// Multiple possible stop sequences
Multi(Vec<String>),
/// Single stop sequence
Single(String),
}
/// Default value helper
fn default_false() -> bool {
false
}
/// Default value helper
fn default_1usize() -> usize {
1
}
/// Default value helper
fn default_model() -> String {
"default".to_string()
}
/// Chat completion request following OpenAI's specification
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ChatCompletionRequest {
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
pub messages: Vec<Message>,
#[schema(example = "gemma-3-1b-it")]
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_false")]
#[schema(example = false)]
pub logprobs: bool,
#[schema(example = 256)]
pub max_tokens: Option<usize>,
#[serde(rename = "n")]
#[serde(default = "default_1usize")]
#[schema(example = 1)]
pub n_choices: usize,
#[schema(example = 0.7)]
pub temperature: Option<f64>,
#[schema(example = 0.9)]
pub top_p: Option<f64>,
#[schema(example = false)]
pub stream: Option<bool>,
}
/// Chat completion choice
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionChoice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
/// Chat completion response
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
}
/// Token usage information
#[derive(Debug, Serialize, ToSchema)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
// Application state shared between handlers
#[derive(Clone)]
struct AppState {
text_generation: Arc<Mutex<TextGeneration>>,
model_id: String,
}
// Chat completions endpoint handler
async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut prompt = String::new();
// Convert messages to a prompt string
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
// Format based on role
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
// Add the assistant prefix for the response
prompt.push_str("Assistant: ");
// Capture the output
let mut output = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Buffer to capture the output
let mut buffer = Vec::new();
// Run text generation
let max_tokens = request.max_tokens.unwrap_or(1000);
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": {
"message": format!("Error generating text: {}", e),
"type": "internal_server_error"
}
})),
));
}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
output.push(text);
}
}
// Create response
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: request.model,
choices: vec![ChatCompletionChoice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: Some(MessageContent(Either::Left(output.join("")))),
name: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: prompt.len() / 4, // Rough estimate
completion_tokens: output.join("").len() / 4, // Rough estimate
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
},
};
// Return the response as JSON
Ok(Json(response))
}
use candle_core::{DType, Device, MetalDevice, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
@@ -22,6 +296,22 @@ use tokenizers::Tokenizer;
use crate::token_output_stream::TokenOutputStream;
use crate::utilities_lib::device;
// Create the router with the chat completions endpoint
fn create_router(app_state: AppState) -> Router {
// CORS layer to allow requests from any origin
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// OpenAI compatible endpoints
.route("/v1/chat/completions", post(chat_completions))
// Add more endpoints as needed
.layer(cors)
.with_state(app_state)
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
@@ -108,6 +398,7 @@ impl TextGeneration {
}
}
// Run text generation and print to stdout
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
@@ -195,6 +486,90 @@ impl TextGeneration {
);
Ok(())
}
// Run text generation and write to a buffer
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Write prompt tokens to output
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
write!(output, "{}", t)?;
}
}
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
}
// Write any remaining tokens
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
write!(output, "{}", rest)?;
}
Ok(())
}
}
#[derive(Parser, Debug)]
@@ -208,8 +583,17 @@ struct Args {
#[arg(long)]
tracing: bool,
/// Run in server mode with OpenAI compatible API
#[arg(long)]
prompt: String,
server: bool,
/// Port to use for the server
#[arg(long, default_value_t = 3000)]
port: u16,
/// Prompt for text generation (not used in server mode)
#[arg(long)]
prompt: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
@@ -308,7 +692,7 @@ fn main() -> Result<()> {
},
};
let repo = api.repo(Repo::with_revision(
model_id,
model_id.clone(),
RepoType::Model,
args.revision,
));
@@ -371,7 +755,7 @@ fn main() -> Result<()> {
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
let pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
@@ -382,6 +766,36 @@ fn main() -> Result<()> {
&device,
);
if args.server {
// Start the server
println!("Starting server on port {}", args.port);
// Create app state
let app_state = AppState {
text_generation: Arc::new(Mutex::new(pipeline)),
model_id,
};
// Create router
let app = create_router(app_state);
// Run the server
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
// Use tokio to run the server
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async {
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
.await
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
})?;
Ok(())
} else {
// Run in CLI mode
if let Some(prompt_text) = &args.prompt {
let prompt = match args.which {
Which::Base2B
| Which::Base7B
@@ -397,15 +811,20 @@ fn main() -> Result<()> {
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B => args.prompt,
| Which::BaseV3_1B => prompt_text.clone(),
Which::InstructV3_1B => {
format!(
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
args.prompt
prompt_text
)
}
};
let mut pipeline = pipeline;
pipeline.run(&prompt, args.sample_len)?;
Ok(())
} else {
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
}
}
}