From 8a3c0797c3376ba85669aa47c675623d2100eddf Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Thu, 5 Jun 2025 20:31:59 -0400 Subject: [PATCH] add openai compatible endpoint for chat completions --- .gitignore | 2 +- README.md | 21 +- local_inference_engine/Cargo.lock | 331 +++++++++++++++++--- local_inference_engine/Cargo.toml | 7 + local_inference_engine/README.md | 207 +++++++++++++ local_inference_engine/src/main.rs | 475 +++++++++++++++++++++++++++-- 6 files changed, 965 insertions(+), 78 deletions(-) create mode 100644 local_inference_engine/README.md diff --git a/.gitignore b/.gitignore index 34cebe8..3ac610a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -/target +**/target /node_modules/ /.idea chrome diff --git a/README.md b/README.md index f11a274..836ce8e 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,27 @@ # open-web-agent-rs +A Rust-based web agent with local inference capabilities. + +## Components + +### Local Inference Engine + +The [Local Inference Engine](./local_inference_engine/README.md) provides a way to run large language models locally. It supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API. + +Features: +- Run Gemma models locally (1B, 2B, 7B, 9B variants) +- CLI mode for direct text generation +- Server mode with OpenAI-compatible API +- Support for various model configurations (base, instruction-tuned) +- Metal acceleration on macOS + +See the [Local Inference Engine README](./local_inference_engine/README.md) for detailed usage instructions. + +### Web Server + Server is being converted to MCP. Things are probably broken. ```text bun i bun dev ``` - - diff --git a/local_inference_engine/Cargo.lock b/local_inference_engine/Cargo.lock index 8dff7b8..5da43a1 100644 --- a/local_inference_engine/Cargo.lock +++ b/local_inference_engine/Cargo.lock @@ -205,7 +205,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -313,6 +313,17 @@ version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10d2119f741b79fe9907f5396d19bffcb46568cfcc315e78677d731972ac7085" +[[package]] +name = "async-trait" +version = "0.1.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "atoi" version = "2.0.0" @@ -357,6 +368,61 @@ dependencies = [ "arrayvec", ] +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower 0.5.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -405,7 +471,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn", + "syn 2.0.101", ] [[package]] @@ -523,7 +589,7 @@ checksum = "7ecc273b49b3205b83d648f0690daa588925572cc5063745bfe547fe7ec8e1a1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -770,7 +836,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1024,7 +1090,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn", + "syn 2.0.101", ] [[package]] @@ -1035,7 +1101,7 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1052,7 +1118,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1073,7 +1139,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1083,7 +1149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn", + "syn 2.0.101", ] [[package]] @@ -1134,7 +1200,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1161,6 +1227,9 @@ name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -1198,7 +1267,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1332,7 +1401,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1412,7 +1481,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1726,7 +1795,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -1862,6 +1931,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.6.0" @@ -1875,6 +1950,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -2125,6 +2201,7 @@ checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", "hashbrown 0.15.3", + "serde", ] [[package]] @@ -2182,7 +2259,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -2414,6 +2491,7 @@ dependencies = [ "ab_glyph", "accelerate-src", "anyhow", + "axum", "bindgen_cuda", "byteorder", "candle-core", @@ -2426,6 +2504,7 @@ dependencies = [ "cpal", "csv", "cudarc", + "either", "enterpolation", "half", "hf-hub", @@ -2446,9 +2525,23 @@ dependencies = [ "symphonia", "tokenizers", "tokio", + "tower 0.4.13", + "tower-http 0.5.2", "tracing", "tracing-chrome", "tracing-subscriber", + "utoipa", + "uuid", +] + +[[package]] +name = "lock_api" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" +dependencies = [ + "autocfg", + "scopeguard", ] [[package]] @@ -2509,6 +2602,12 @@ dependencies = [ "libc", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -2635,7 +2734,7 @@ checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -2779,7 +2878,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -2851,7 +2950,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -3012,7 +3111,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -3084,7 +3183,30 @@ dependencies = [ "by_address", "proc-macro2", "quote", - "syn", + "syn 2.0.101", +] + +[[package]] +name = "parking_lot" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.6", ] [[package]] @@ -3183,7 +3305,7 @@ dependencies = [ "phf_shared", "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -3257,7 +3379,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9dee91521343f4c5c6a63edd65e54f31f5c92fe8978c40a4282f8372194c6a7d" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.101", ] [[package]] @@ -3278,6 +3400,30 @@ dependencies = [ "toml_edit", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -3297,7 +3443,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -3325,7 +3471,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" dependencies = [ "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -3355,7 +3501,7 @@ dependencies = [ "prost", "prost-types", "regex", - "syn", + "syn 2.0.101", "tempfile", ] @@ -3369,7 +3515,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -3454,7 +3600,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -3467,7 +3613,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -3774,8 +3920,8 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-util", - "tower", - "tower-http", + "tower 0.5.2", + "tower-http 0.6.6", "tower-service", "url", "wasm-bindgen", @@ -3949,6 +4095,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "security-framework" version = "2.11.1" @@ -4001,7 +4153,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -4016,6 +4168,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -4072,6 +4234,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" +dependencies = [ + "libc", +] + [[package]] name = "simba" version = "0.8.1" @@ -4200,7 +4371,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn", + "syn 2.0.101", ] [[package]] @@ -4404,6 +4575,16 @@ dependencies = [ "symphonia-metadata", ] +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.101" @@ -4432,7 +4613,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -4553,7 +4734,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -4564,7 +4745,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -4648,7 +4829,9 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -4662,7 +4845,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -4748,6 +4931,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.5.2" @@ -4761,6 +4955,23 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags 2.9.1", + "bytes", + "http", + "http-body", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", ] [[package]] @@ -4776,7 +4987,7 @@ dependencies = [ "http-body", "iri-string", "pin-project-lite", - "tower", + "tower 0.5.2", "tower-layer", "tower-service", ] @@ -4799,6 +5010,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -4812,7 +5024,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -5035,6 +5247,31 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "utoipa" +version = "4.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" +dependencies = [ + "indexmap", + "serde", + "serde_json", + "utoipa-gen", +] + +[[package]] +name = "utoipa-gen" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "regex", + "syn 2.0.101", +] + [[package]] name = "uuid" version = "1.17.0" @@ -5137,7 +5374,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn 2.0.101", "wasm-bindgen-shared", ] @@ -5172,7 +5409,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -5319,7 +5556,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -5330,7 +5567,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -5721,7 +5958,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", "synstructure", ] @@ -5733,7 +5970,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", "synstructure", ] @@ -5754,7 +5991,7 @@ checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] @@ -5774,7 +6011,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", "synstructure", ] @@ -5814,7 +6051,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.101", ] [[package]] diff --git a/local_inference_engine/Cargo.toml b/local_inference_engine/Cargo.toml index aa9dd5f..3d4f589 100644 --- a/local_inference_engine/Cargo.toml +++ b/local_inference_engine/Cargo.toml @@ -36,6 +36,13 @@ clap= { version = "4.2.4", features = ["derive"] } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" +axum = { version = "0.7.4", features = ["json"] } +tower = "0.4.13" +tower-http = { version = "0.5.1", features = ["cors"] } +tokio = { version = "1.43.0", features = ["full"] } +either = { version = "1.9.0", features = ["serde"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +uuid = { version = "1.7.0", features = ["v4"] } [dev-dependencies] diff --git a/local_inference_engine/README.md b/local_inference_engine/README.md new file mode 100644 index 0000000..edfad37 --- /dev/null +++ b/local_inference_engine/README.md @@ -0,0 +1,207 @@ +# Local Inference Engine + +A Rust-based inference engine for running large language models locally. This tool supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API. + +## Features + +- Run Gemma models locally (1B, 2B, 7B, 9B variants) +- CLI mode for direct text generation +- Server mode with OpenAI-compatible API +- Support for various model configurations (base, instruction-tuned) +- Metal acceleration on macOS + +## Installation + +### Prerequisites + +- Rust toolchain (install via [rustup](https://rustup.rs/)) +- Cargo package manager +- For GPU acceleration: + - macOS: Metal support + - Linux/Windows: CUDA support (requires appropriate drivers) + +### Building from Source + +1. Clone the repository: + ```bash + git clone https://github.com/seemueller-io/open-web-agent-rs.git + cd open-web-agent-rs + ``` + +2. Build the local inference engine: + ```bash + cd local_inference_engine + cargo build --release + ``` + +## Usage + +### CLI Mode + +Run the inference engine in CLI mode to generate text directly: + +```bash +cargo run --release -- --prompt "Your prompt text here" --which 3-1b-it +``` + +#### CLI Options + +- `--prompt `: The prompt text to generate from +- `--which `: Model variant to use (default: "3-1b-it") +- `--server`: Run OpenAI compatible server +- Available options: "2b", "7b", "2b-it", "7b-it", "1.1-2b-it", "1.1-7b-it", "code-2b", "code-7b", "code-2b-it", "code-7b-it", "2-2b", "2-2b-it", "2-9b", "2-9b-it", "3-1b", "3-1b-it" +- `--temperature `: Temperature for sampling (higher = more random) +- `--top-p `: Nucleus sampling probability cutoff +- `--sample-len `: Maximum number of tokens to generate (default: 10000) +- `--repeat-penalty `: Penalty for repeating tokens (default: 1.1) +- `--repeat-last-n `: Context size for repeat penalty (default: 64) +- `--cpu`: Run on CPU instead of GPU +- `--tracing`: Enable tracing (generates a trace-timestamp.json file) + +### Server Mode with OpenAI-compatible API + +Run the inference engine in server mode to expose an OpenAI-compatible API: + +```bash +cargo run --release -- --server --port 3000 --which 3-1b-it +``` + +This starts a web server on the specified port (default: 3000) with an OpenAI-compatible chat completions endpoint. + +#### Server Options + +- `--server`: Run in server mode +- `--port `: Port to use for the server (default: 3000) +- `--which `: Model variant to use (default: "3-1b-it") +- Other model options as described in CLI mode + +## API Usage + +The server exposes an OpenAI-compatible chat completions endpoint: + +### Chat Completions + +``` +POST /v1/chat/completions +``` + +#### Request Format + +```json +{ + "model": "gemma-3-1b-it", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"} + ], + "temperature": 0.7, + "max_tokens": 256, + "top_p": 0.9, + "stream": false +} +``` + +#### Response Format + +```json +{ + "id": "chatcmpl-123abc456def789ghi", + "object": "chat.completion", + "created": 1677858242, + "model": "gemma-3-1b-it", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I'm doing well, thank you for asking! How can I assist you today?" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 25, + "completion_tokens": 15, + "total_tokens": 40 + } +} +``` + +### Example: Using cURL + +```bash +curl -X POST http://localhost:3000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemma-3-1b-it", + "messages": [ + {"role": "user", "content": "What is the capital of France?"} + ], + "temperature": 0.7, + "max_tokens": 100 + }' +``` + +### Example: Using Python with OpenAI Client + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:3000/v1", + api_key="dummy" # API key is not validated but required by the client +) + +response = client.chat.completions.create( + model="gemma-3-1b-it", + messages=[ + {"role": "user", "content": "What is the capital of France?"} + ], + temperature=0.7, + max_tokens=100 +) + +print(response.choices[0].message.content) +``` + +### Example: Using JavaScript/TypeScript with OpenAI SDK + +```javascript +import OpenAI from 'openai'; + +const openai = new OpenAI({ + baseURL: 'http://localhost:3000/v1', + apiKey: 'dummy', // API key is not validated but required by the client +}); + +async function main() { + const response = await openai.chat.completions.create({ + model: 'gemma-3-1b-it', + messages: [ + { role: 'user', content: 'What is the capital of France?' } + ], + temperature: 0.7, + max_tokens: 100, + }); + + console.log(response.choices[0].message.content); +} + +main(); +``` + +## Troubleshooting + +### Common Issues + +1. **Model download errors**: Make sure you have a stable internet connection. The models are downloaded from Hugging Face Hub. + +2. **Out of memory errors**: Try using a smaller model variant or reducing the batch size. + +3. **Slow inference on CPU**: This is expected. For better performance, use GPU acceleration if available. + +4. **Metal/CUDA errors**: Ensure you have the latest drivers installed for your GPU. + +## License + +This project is licensed under the terms specified in the LICENSE file. \ No newline at end of file diff --git a/local_inference_engine/src/main.rs b/local_inference_engine/src/main.rs index 3ef9dcb..c6ba3e5 100644 --- a/local_inference_engine/src/main.rs +++ b/local_inference_engine/src/main.rs @@ -8,12 +8,286 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::{Error as E, Result}; +use axum::{ + extract::State, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, + Json, Router, +}; use clap::Parser; +use either::Either; +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use tokio::sync::Mutex; +use tower_http::cors::{Any, CorsLayer}; +use utoipa::ToSchema; use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; +// OpenAI API compatible structs + +/// Inner content structure for messages that can be either a string or key-value pairs +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MessageInnerContent( + #[serde(with = "either::serde_untagged")] pub Either>, +); + +impl ToSchema<'_> for MessageInnerContent { + fn schema() -> (&'static str, utoipa::openapi::RefOr) { + ( + "MessageInnerContent", + utoipa::openapi::RefOr::T(message_inner_content_schema()), + ) + } +} + +/// Function for MessageInnerContent Schema generation to handle `Either` +fn message_inner_content_schema() -> utoipa::openapi::Schema { + use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType}; + + Schema::OneOf( + OneOfBuilder::new() + // Either::Left - simple string + .item(Schema::Object( + ObjectBuilder::new().schema_type(SchemaType::String).build(), + )) + // Either::Right - object with string values + .item(Schema::Object( + ObjectBuilder::new() + .schema_type(SchemaType::Object) + .additional_properties(Some(RefOr::T(Schema::Object( + ObjectBuilder::new().schema_type(SchemaType::String).build(), + )))) + .build(), + )) + .build(), + ) +} + +/// Message content that can be either simple text or complex structured content +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MessageContent( + #[serde(with = "either::serde_untagged")] + Either>>, +); + +impl ToSchema<'_> for MessageContent { + fn schema() -> (&'static str, utoipa::openapi::RefOr) { + ("MessageContent", utoipa::openapi::RefOr::T(message_content_schema())) + } +} + +/// Function for MessageContent Schema generation to handle `Either` +fn message_content_schema() -> utoipa::openapi::Schema { + use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType}; + + Schema::OneOf( + OneOfBuilder::new() + .item(Schema::Object( + ObjectBuilder::new().schema_type(SchemaType::String).build(), + )) + .item(Schema::Array( + ArrayBuilder::new() + .items(RefOr::T(Schema::Object( + ObjectBuilder::new() + .schema_type(SchemaType::Object) + .additional_properties(Some(RefOr::Ref( + utoipa::openapi::Ref::from_schema_name("MessageInnerContent"), + ))) + .build(), + ))) + .build(), + )) + .build(), + ) +} + +/// Represents a single message in a conversation +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +pub struct Message { + /// The message content + pub content: Option, + /// The role of the message sender ("user", "assistant", "system", "tool", etc.) + pub role: String, + pub name: Option, +} + +/// Stop token configuration for generation +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +#[serde(untagged)] +pub enum StopTokens { + /// Multiple possible stop sequences + Multi(Vec), + /// Single stop sequence + Single(String), +} + +/// Default value helper +fn default_false() -> bool { + false +} + +/// Default value helper +fn default_1usize() -> usize { + 1 +} + +/// Default value helper +fn default_model() -> String { + "default".to_string() +} + +/// Chat completion request following OpenAI's specification +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +pub struct ChatCompletionRequest { + #[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))] + pub messages: Vec, + #[schema(example = "gemma-3-1b-it")] + #[serde(default = "default_model")] + pub model: String, + #[serde(default = "default_false")] + #[schema(example = false)] + pub logprobs: bool, + #[schema(example = 256)] + pub max_tokens: Option, + #[serde(rename = "n")] + #[serde(default = "default_1usize")] + #[schema(example = 1)] + pub n_choices: usize, + #[schema(example = 0.7)] + pub temperature: Option, + #[schema(example = 0.9)] + pub top_p: Option, + #[schema(example = false)] + pub stream: Option, +} + +/// Chat completion choice +#[derive(Debug, Serialize, ToSchema)] +pub struct ChatCompletionChoice { + pub index: usize, + pub message: Message, + pub finish_reason: String, +} + +/// Chat completion response +#[derive(Debug, Serialize, ToSchema)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +/// Token usage information +#[derive(Debug, Serialize, ToSchema)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} + +// Application state shared between handlers +#[derive(Clone)] +struct AppState { + text_generation: Arc>, + model_id: String, +} + +// Chat completions endpoint handler +async fn chat_completions( + State(state): State, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + let mut prompt = String::new(); + + // Convert messages to a prompt string + for message in &request.messages { + let role = &message.role; + let content = match &message.content { + Some(content) => match &content.0 { + Either::Left(text) => text.clone(), + Either::Right(_) => "".to_string(), // Handle complex content if needed + }, + None => "".to_string(), + }; + + // Format based on role + match role.as_str() { + "system" => prompt.push_str(&format!("System: {}\n", content)), + "user" => prompt.push_str(&format!("User: {}\n", content)), + "assistant" => prompt.push_str(&format!("Assistant: {}\n", content)), + _ => prompt.push_str(&format!("{}: {}\n", role, content)), + } + } + + // Add the assistant prefix for the response + prompt.push_str("Assistant: "); + + // Capture the output + let mut output = Vec::new(); + { + let mut text_gen = state.text_generation.lock().await; + + // Buffer to capture the output + let mut buffer = Vec::new(); + + // Run text generation + let max_tokens = request.max_tokens.unwrap_or(1000); + let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer); + + if let Err(e) = result { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { + "message": format!("Error generating text: {}", e), + "type": "internal_server_error" + } + })), + )); + } + + // Convert buffer to string + if let Ok(text) = String::from_utf8(buffer) { + output.push(text); + } + } + + // Create response + let response = ChatCompletionResponse { + id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")), + object: "chat.completion".to_string(), + created: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + model: request.model, + choices: vec![ChatCompletionChoice { + index: 0, + message: Message { + role: "assistant".to_string(), + content: Some(MessageContent(Either::Left(output.join("")))), + name: None, + }, + finish_reason: "stop".to_string(), + }], + usage: Usage { + prompt_tokens: prompt.len() / 4, // Rough estimate + completion_tokens: output.join("").len() / 4, // Rough estimate + total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate + }, + }; + + // Return the response as JSON + Ok(Json(response)) +} + use candle_core::{DType, Device, MetalDevice, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; @@ -22,6 +296,22 @@ use tokenizers::Tokenizer; use crate::token_output_stream::TokenOutputStream; use crate::utilities_lib::device; +// Create the router with the chat completions endpoint +fn create_router(app_state: AppState) -> Router { + // CORS layer to allow requests from any origin + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any); + + Router::new() + // OpenAI compatible endpoints + .route("/v1/chat/completions", post(chat_completions)) + // Add more endpoints as needed + .layer(cors) + .with_state(app_state) +} + #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] enum Which { #[value(name = "2b")] @@ -108,6 +398,7 @@ impl TextGeneration { } } + // Run text generation and print to stdout fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { use std::io::Write; self.tokenizer.clear(); @@ -195,6 +486,90 @@ impl TextGeneration { ); Ok(()) } + + // Run text generation and write to a buffer + fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + // Write prompt tokens to output + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + write!(output, "{}", t)?; + } + } + + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("") { + Some(token) => token, + None => anyhow::bail!("cannot find the token"), + }; + + let eot_token = match self.tokenizer.get_token("") { + Some(token) => token, + None => { + write!(output, "Warning: token not found in tokenizer, using as a backup")?; + eos_token + } + }; + + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, start_pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + + // Manual implementation of repeat penalty to avoid type conflicts + let mut logits_vec = logits.to_vec1::()?; + + for &token_id in &tokens[start_at..] { + let token_id = token_id as usize; + if token_id < logits_vec.len() { + let score = logits_vec[token_id]; + let sign = if score < 0.0 { -1.0 } else { 1.0 }; + logits_vec[token_id] = sign * score / self.repeat_penalty; + } + } + + // Create a new tensor with the modified logits + let device = logits.device().clone(); + let shape = logits.shape().clone(); + let new_logits = Tensor::new(&logits_vec[..], &device)?; + new_logits.reshape(shape)? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token || next_token == eot_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + write!(output, "{}", t)?; + } + } + + // Write any remaining tokens + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + write!(output, "{}", rest)?; + } + + Ok(()) + } } #[derive(Parser, Debug)] @@ -208,8 +583,17 @@ struct Args { #[arg(long)] tracing: bool, + /// Run in server mode with OpenAI compatible API #[arg(long)] - prompt: String, + server: bool, + + /// Port to use for the server + #[arg(long, default_value_t = 3000)] + port: u16, + + /// Prompt for text generation (not used in server mode) + #[arg(long)] + prompt: Option, /// The temperature used to generate samples. #[arg(long)] @@ -308,7 +692,7 @@ fn main() -> Result<()> { }, }; let repo = api.repo(Repo::with_revision( - model_id, + model_id.clone(), RepoType::Model, args.revision, )); @@ -371,7 +755,7 @@ fn main() -> Result<()> { println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new( + let pipeline = TextGeneration::new( model, tokenizer, args.seed, @@ -382,30 +766,65 @@ fn main() -> Result<()> { &device, ); - let prompt = match args.which { - Which::Base2B - | Which::Base7B - | Which::Instruct2B - | Which::Instruct7B - | Which::InstructV1_1_2B - | Which::InstructV1_1_7B - | Which::CodeBase2B - | Which::CodeBase7B - | Which::CodeInstruct2B - | Which::CodeInstruct7B - | Which::BaseV2_2B - | Which::InstructV2_2B - | Which::BaseV2_9B - | Which::InstructV2_9B - | Which::BaseV3_1B => args.prompt, - Which::InstructV3_1B => { - format!( - " user\n{}\n model\n", - args.prompt - ) - } - }; + if args.server { + // Start the server + println!("Starting server on port {}", args.port); - pipeline.run(&prompt, args.sample_len)?; - Ok(()) + // Create app state + let app_state = AppState { + text_generation: Arc::new(Mutex::new(pipeline)), + model_id, + }; + + // Create router + let app = create_router(app_state); + + // Run the server + let addr = SocketAddr::from(([0, 0, 0, 0], args.port)); + + // Use tokio to run the server + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()? + .block_on(async { + axum::serve(tokio::net::TcpListener::bind(&addr).await?, app) + .await + .map_err(|e| anyhow::anyhow!("Server error: {}", e)) + })?; + + Ok(()) + } else { + // Run in CLI mode + if let Some(prompt_text) = &args.prompt { + let prompt = match args.which { + Which::Base2B + | Which::Base7B + | Which::Instruct2B + | Which::Instruct7B + | Which::InstructV1_1_2B + | Which::InstructV1_1_7B + | Which::CodeBase2B + | Which::CodeBase7B + | Which::CodeInstruct2B + | Which::CodeInstruct7B + | Which::BaseV2_2B + | Which::InstructV2_2B + | Which::BaseV2_9B + | Which::InstructV2_9B + | Which::BaseV3_1B => prompt_text.clone(), + Which::InstructV3_1B => { + format!( + " user\n{}\n model\n", + prompt_text + ) + } + }; + + let mut pipeline = pipeline; + pipeline.run(&prompt, args.sample_len)?; + Ok(()) + } else { + anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.") + } + } }