mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
supports small llama and gemma models
Refactor inference dedicated crates for llama and gemma inferencing, not integrated
This commit is contained in:
369
Cargo.lock
generated
369
Cargo.lock
generated
@@ -686,6 +686,15 @@ dependencies = [
|
|||||||
"generic-array",
|
"generic-array",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "block2"
|
||||||
|
version = "0.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "340d2f0bdb2a43c1d3cd40513185b2bd7def0aa1052f956455114bc98f82dcf2"
|
||||||
|
dependencies = [
|
||||||
|
"objc2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "brotli"
|
name = "brotli"
|
||||||
version = "3.5.0"
|
version = "3.5.0"
|
||||||
@@ -786,8 +795,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "a9f51e2ecf6efe9737af8f993433c839f956d2b6ed4fd2dd4a7c6d8b0fa667ff"
|
checksum = "a9f51e2ecf6efe9737af8f993433c839f956d2b6ed4fd2dd4a7c6d8b0fa667ff"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"candle-kernels",
|
"candle-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"candle-metal-kernels",
|
"candle-metal-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"cudarc",
|
"cudarc",
|
||||||
"gemm 0.17.1",
|
"gemm 0.17.1",
|
||||||
"half",
|
"half",
|
||||||
@@ -807,6 +816,35 @@ dependencies = [
|
|||||||
"zip",
|
"zip",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "candle-core"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"candle-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"cudarc",
|
||||||
|
"float8",
|
||||||
|
"gemm 0.17.1",
|
||||||
|
"half",
|
||||||
|
"memmap2",
|
||||||
|
"num-traits",
|
||||||
|
"num_cpus",
|
||||||
|
"objc2-foundation",
|
||||||
|
"objc2-metal",
|
||||||
|
"rand 0.9.2",
|
||||||
|
"rand_distr 0.5.1",
|
||||||
|
"rayon",
|
||||||
|
"safetensors",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
"ug",
|
||||||
|
"ug-cuda",
|
||||||
|
"ug-metal",
|
||||||
|
"yoke 0.7.5",
|
||||||
|
"zip",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-datasets"
|
name = "candle-datasets"
|
||||||
version = "0.9.1"
|
version = "0.9.1"
|
||||||
@@ -814,15 +852,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "a0a7c351dd50cda83f00f17c4412e35c69d840e453edf06064974de1cc59343d"
|
checksum = "a0a7c351dd50cda83f00f17c4412e35c69d840e453edf06064974de1cc59343d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"candle-core",
|
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"candle-nn",
|
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"hf-hub",
|
"hf-hub 0.4.3",
|
||||||
"image",
|
"image",
|
||||||
"memmap2",
|
"memmap2",
|
||||||
"parquet",
|
"parquet",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tokenizers",
|
"tokenizers 0.21.4",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "candle-examples"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"csv",
|
||||||
|
"hf-hub 0.4.3",
|
||||||
|
"image",
|
||||||
|
"num-traits",
|
||||||
|
"rayon",
|
||||||
|
"safetensors",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tokenizers 0.21.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -833,7 +891,7 @@ checksum = "fb38a5bfae09c4ae73fd00039e5eaf97a7d6d9400cc35ee8e603fc4a5f9cb0a3"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"bindgen_cuda",
|
"bindgen_cuda",
|
||||||
"candle-core",
|
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"half",
|
"half",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -846,6 +904,14 @@ dependencies = [
|
|||||||
"bindgen_cuda",
|
"bindgen_cuda",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "candle-kernels"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||||
|
dependencies = [
|
||||||
|
"bindgen_cuda",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.9.1"
|
version = "0.9.1"
|
||||||
@@ -859,13 +925,27 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "candle-metal-kernels"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||||
|
dependencies = [
|
||||||
|
"half",
|
||||||
|
"objc2",
|
||||||
|
"objc2-foundation",
|
||||||
|
"objc2-metal",
|
||||||
|
"once_cell",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-nn"
|
name = "candle-nn"
|
||||||
version = "0.9.1"
|
version = "0.9.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c1980d53280c8f9e2c6cbe1785855d7ff8010208b46e21252b978badf13ad69d"
|
checksum = "c1980d53280c8f9e2c6cbe1785855d7ff8010208b46e21252b978badf13ad69d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"candle-core",
|
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"half",
|
"half",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"rayon",
|
"rayon",
|
||||||
@@ -874,14 +954,30 @@ dependencies = [
|
|||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "candle-nn"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||||
|
dependencies = [
|
||||||
|
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"half",
|
||||||
|
"num-traits",
|
||||||
|
"objc2-metal",
|
||||||
|
"rayon",
|
||||||
|
"safetensors",
|
||||||
|
"serde",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.9.1"
|
version = "0.9.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8a8fa227a8176fd9b8fb58d63c908c08ad3af1503ee6fcd058be072a598044d2"
|
checksum = "8a8fa227a8176fd9b8fb58d63c908c08ad3af1503ee6fcd058be072a598044d2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"candle-core",
|
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"candle-nn",
|
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"prost",
|
"prost",
|
||||||
"prost-build",
|
"prost-build",
|
||||||
]
|
]
|
||||||
@@ -893,8 +989,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "186cb80045dbe47e0b387ea6d3e906f02fb3056297080d9922984c90e90a72b0"
|
checksum = "186cb80045dbe47e0b387ea6d3e906f02fb3056297080d9922984c90e90a72b0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"candle-core",
|
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"candle-nn",
|
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"fancy-regex",
|
||||||
|
"num-traits",
|
||||||
|
"rand 0.9.2",
|
||||||
|
"rayon",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"serde_plain",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "candle-transformers"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
"fancy-regex",
|
"fancy-regex",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
@@ -1523,6 +1637,15 @@ dependencies = [
|
|||||||
"dirs-sys 0.4.1",
|
"dirs-sys 0.4.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dirs"
|
||||||
|
version = "5.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
|
||||||
|
dependencies = [
|
||||||
|
"dirs-sys 0.4.1",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs"
|
name = "dirs"
|
||||||
version = "6.0.0"
|
version = "6.0.0"
|
||||||
@@ -1556,6 +1679,16 @@ dependencies = [
|
|||||||
"windows-sys 0.60.2",
|
"windows-sys 0.60.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dispatch2"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.2",
|
||||||
|
"objc2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "displaydoc"
|
name = "displaydoc"
|
||||||
version = "0.2.5"
|
version = "0.2.5"
|
||||||
@@ -1715,6 +1848,9 @@ name = "esaxx-rs"
|
|||||||
version = "0.1.10"
|
version = "0.1.10"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
|
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "event-listener"
|
name = "event-listener"
|
||||||
@@ -1793,14 +1929,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "04c269a76bfc6cea69553b7d040acb16c793119cebd97c756d21e08d0f075ff8"
|
checksum = "04c269a76bfc6cea69553b7d040acb16c793119cebd97c756d21e08d0f075ff8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"hf-hub",
|
"hf-hub 0.4.3",
|
||||||
"image",
|
"image",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"ort",
|
"ort",
|
||||||
"ort-sys",
|
"ort-sys",
|
||||||
"rayon",
|
"rayon",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokenizers",
|
"tokenizers 0.21.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1856,6 +1992,18 @@ dependencies = [
|
|||||||
"miniz_oxide",
|
"miniz_oxide",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "float8"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "git+https://github.com/zackangelo/float8?branch=cudarc_0_16#03c1f5fe7cdb2f9cb690823fdd40593be57c408f"
|
||||||
|
dependencies = [
|
||||||
|
"cudarc",
|
||||||
|
"half",
|
||||||
|
"num-traits",
|
||||||
|
"rand 0.9.2",
|
||||||
|
"rand_distr 0.5.1",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fnv"
|
name = "fnv"
|
||||||
version = "1.0.7"
|
version = "1.0.7"
|
||||||
@@ -2246,6 +2394,24 @@ dependencies = [
|
|||||||
"seq-macro",
|
"seq-macro",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gemma-runner"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-examples",
|
||||||
|
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"clap",
|
||||||
|
"hf-hub 0.4.3",
|
||||||
|
"serde_json",
|
||||||
|
"tokenizers 0.21.4",
|
||||||
|
"tracing",
|
||||||
|
"tracing-chrome",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "generic-array"
|
name = "generic-array"
|
||||||
version = "0.14.7"
|
version = "0.14.7"
|
||||||
@@ -2421,19 +2587,48 @@ version = "0.5.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "helm-chart-tool"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"clap",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"toml 0.8.23",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hermit-abi"
|
name = "hermit-abi"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hf-hub"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
|
||||||
|
dependencies = [
|
||||||
|
"dirs 5.0.1",
|
||||||
|
"indicatif",
|
||||||
|
"log",
|
||||||
|
"native-tls",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
"ureq",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hf-hub"
|
name = "hf-hub"
|
||||||
version = "0.4.3"
|
version = "0.4.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
|
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"dirs",
|
"dirs 6.0.0",
|
||||||
"futures",
|
"futures",
|
||||||
"http",
|
"http",
|
||||||
"indicatif",
|
"indicatif",
|
||||||
@@ -2842,12 +3037,12 @@ dependencies = [
|
|||||||
"axum",
|
"axum",
|
||||||
"bindgen_cuda",
|
"bindgen_cuda",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"candle-core",
|
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"candle-datasets",
|
"candle-datasets",
|
||||||
"candle-flash-attn",
|
"candle-flash-attn",
|
||||||
"candle-nn",
|
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"candle-onnx",
|
"candle-onnx",
|
||||||
"candle-transformers",
|
"candle-transformers 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"clap",
|
"clap",
|
||||||
"cpal",
|
"cpal",
|
||||||
"csv",
|
"csv",
|
||||||
@@ -2855,11 +3050,13 @@ dependencies = [
|
|||||||
"either",
|
"either",
|
||||||
"enterpolation",
|
"enterpolation",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
"gemma-runner",
|
||||||
"half",
|
"half",
|
||||||
"hf-hub",
|
"hf-hub 0.4.3",
|
||||||
"image",
|
"image",
|
||||||
"imageproc",
|
"imageproc",
|
||||||
"intel-mkl-src",
|
"intel-mkl-src",
|
||||||
|
"llama-runner",
|
||||||
"memmap2",
|
"memmap2",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"palette",
|
"palette",
|
||||||
@@ -2873,7 +3070,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"symphonia",
|
"symphonia",
|
||||||
"tokenizers",
|
"tokenizers 0.21.4",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower",
|
"tower",
|
||||||
@@ -2981,6 +3178,15 @@ version = "1.70.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itertools"
|
name = "itertools"
|
||||||
version = "0.12.1"
|
version = "0.12.1"
|
||||||
@@ -3405,7 +3611,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
|
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"windows-targets 0.53.3",
|
"windows-targets 0.48.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3443,6 +3649,20 @@ version = "0.8.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
|
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "llama-runner"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||||
|
"clap",
|
||||||
|
"hf-hub 0.3.2",
|
||||||
|
"serde_json",
|
||||||
|
"tokenizers 0.20.4",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lock_api"
|
name = "lock_api"
|
||||||
version = "0.4.13"
|
version = "0.4.13"
|
||||||
@@ -3965,6 +4185,59 @@ dependencies = [
|
|||||||
"objc_exception",
|
"objc_exception",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "objc2"
|
||||||
|
version = "0.6.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "561f357ba7f3a2a61563a186a163d0a3a5247e1089524a3981d49adb775078bc"
|
||||||
|
dependencies = [
|
||||||
|
"objc2-encode",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "objc2-core-foundation"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.2",
|
||||||
|
"dispatch2",
|
||||||
|
"objc2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "objc2-encode"
|
||||||
|
version = "4.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "objc2-foundation"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.2",
|
||||||
|
"block2",
|
||||||
|
"libc",
|
||||||
|
"objc2",
|
||||||
|
"objc2-core-foundation",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "objc2-metal"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.9.2",
|
||||||
|
"block2",
|
||||||
|
"dispatch2",
|
||||||
|
"objc2",
|
||||||
|
"objc2-core-foundation",
|
||||||
|
"objc2-foundation",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "objc_exception"
|
name = "objc_exception"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
@@ -4803,7 +5076,7 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"socket2 0.5.10",
|
"socket2 0.5.10",
|
||||||
"tracing",
|
"tracing",
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5006,6 +5279,17 @@ dependencies = [
|
|||||||
"rayon-core",
|
"rayon-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon-cond"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
"itertools 0.11.0",
|
||||||
|
"rayon",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rayon-cond"
|
name = "rayon-cond"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
@@ -6267,7 +6551,7 @@ dependencies = [
|
|||||||
"getrandom 0.3.3",
|
"getrandom 0.3.3",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"rustix",
|
"rustix",
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -6384,6 +6668,38 @@ version = "0.1.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokenizers"
|
||||||
|
version = "0.20.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"derive_builder",
|
||||||
|
"esaxx-rs",
|
||||||
|
"getrandom 0.2.16",
|
||||||
|
"indicatif",
|
||||||
|
"itertools 0.12.1",
|
||||||
|
"lazy_static",
|
||||||
|
"log",
|
||||||
|
"macro_rules_attribute",
|
||||||
|
"monostate",
|
||||||
|
"onig",
|
||||||
|
"paste",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"rayon",
|
||||||
|
"rayon-cond 0.3.0",
|
||||||
|
"regex",
|
||||||
|
"regex-syntax 0.8.5",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"spm_precompiled",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
"unicode-normalization-alignments",
|
||||||
|
"unicode-segmentation",
|
||||||
|
"unicode_categories",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.21.4"
|
version = "0.21.4"
|
||||||
@@ -6397,7 +6713,8 @@ dependencies = [
|
|||||||
"derive_builder",
|
"derive_builder",
|
||||||
"esaxx-rs",
|
"esaxx-rs",
|
||||||
"getrandom 0.3.3",
|
"getrandom 0.3.3",
|
||||||
"hf-hub",
|
"hf-hub 0.4.3",
|
||||||
|
"indicatif",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
"log",
|
"log",
|
||||||
"macro_rules_attribute",
|
"macro_rules_attribute",
|
||||||
@@ -6406,7 +6723,7 @@ dependencies = [
|
|||||||
"paste",
|
"paste",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"rayon",
|
"rayon",
|
||||||
"rayon-cond",
|
"rayon-cond 0.4.0",
|
||||||
"regex",
|
"regex",
|
||||||
"regex-syntax 0.8.5",
|
"regex-syntax 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
@@ -7260,7 +7577,7 @@ version = "0.1.9"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.48.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@@ -4,7 +4,9 @@ members = [
|
|||||||
"crates/inference-engine",
|
"crates/inference-engine",
|
||||||
"crates/embeddings-engine",
|
"crates/embeddings-engine",
|
||||||
"crates/leptos-app",
|
"crates/leptos-app",
|
||||||
"crates/helm-chart-tool"
|
"crates/helm-chart-tool",
|
||||||
|
"crates/llama-runner",
|
||||||
|
"crates/gemma-runner"
|
||||||
]
|
]
|
||||||
default-members = ["crates/predict-otron-9000"]
|
default-members = ["crates/predict-otron-9000"]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
26
README.md
26
README.md
@@ -10,7 +10,7 @@ Powerful local AI inference with OpenAI-compatible APIs
|
|||||||
|
|
||||||
The predict-otron-9000 is a flexible AI platform that provides:
|
The predict-otron-9000 is a flexible AI platform that provides:
|
||||||
|
|
||||||
- **Local LLM Inference**: Run Gemma models locally with CPU or GPU acceleration
|
- **Local LLM Inference**: Run Gemma and Llama models locally with CPU or GPU acceleration
|
||||||
- **Embeddings Generation**: Create text embeddings with FastEmbed
|
- **Embeddings Generation**: Create text embeddings with FastEmbed
|
||||||
- **Web Interface**: Interact with models through a Leptos WASM chat interface
|
- **Web Interface**: Interact with models through a Leptos WASM chat interface
|
||||||
- **TypeScript CLI**: Command-line client for testing and automation
|
- **TypeScript CLI**: Command-line client for testing and automation
|
||||||
@@ -22,7 +22,7 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
|
|||||||
|
|
||||||
- **OpenAI Compatible**: API endpoints match OpenAI's format for easy integration
|
- **OpenAI Compatible**: API endpoints match OpenAI's format for easy integration
|
||||||
- **Text Embeddings**: Generate high-quality text embeddings using FastEmbed
|
- **Text Embeddings**: Generate high-quality text embeddings using FastEmbed
|
||||||
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma models (1B, 2B, 7B variants including instruction-tuned models)
|
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma and Llama models (various sizes including instruction-tuned variants)
|
||||||
- **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput
|
- **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput
|
||||||
- **Web Chat Interface**: Leptos-based WebAssembly (WASM) chat interface for browser-based interaction
|
- **Web Chat Interface**: Leptos-based WebAssembly (WASM) chat interface for browser-based interaction
|
||||||
- **Flexible Deployment**: Run as monolithic service or microservices architecture
|
- **Flexible Deployment**: Run as monolithic service or microservices architecture
|
||||||
@@ -31,15 +31,19 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
|
|||||||
|
|
||||||
### Workspace Structure
|
### Workspace Structure
|
||||||
|
|
||||||
The project uses a 4-crate Rust workspace plus TypeScript components:
|
The project uses a 7-crate Rust workspace plus TypeScript components:
|
||||||
|
|
||||||
```
|
```
|
||||||
crates/
|
crates/
|
||||||
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
|
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
|
||||||
├── inference-engine/ # Gemma inference via Candle (Rust 2021)
|
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
|
||||||
|
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
|
||||||
|
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
|
||||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||||
└── leptos-app/ # WASM web frontend (Rust 2021)
|
├── leptos-app/ # WASM web frontend (Rust 2021)
|
||||||
cli.ts # TypeScript/Bun CLI client
|
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
|
||||||
|
└── scripts/
|
||||||
|
└── cli.ts # TypeScript/Bun CLI client
|
||||||
```
|
```
|
||||||
|
|
||||||
### Service Architecture
|
### Service Architecture
|
||||||
@@ -149,16 +153,16 @@ cd crates/leptos-app
|
|||||||
#### TypeScript CLI Client
|
#### TypeScript CLI Client
|
||||||
```bash
|
```bash
|
||||||
# List available models
|
# List available models
|
||||||
bun run cli.ts --list-models
|
bun run scripts/cli.ts --list-models
|
||||||
|
|
||||||
# Chat completion
|
# Chat completion
|
||||||
bun run cli.ts "What is the capital of France?"
|
bun run scripts/cli.ts "What is the capital of France?"
|
||||||
|
|
||||||
# With specific model
|
# With specific model
|
||||||
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
bun run scripts/cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||||
|
|
||||||
# Show help
|
# Show help
|
||||||
bun run cli.ts --help
|
bun run scripts/cli.ts --help
|
||||||
```
|
```
|
||||||
|
|
||||||
## API Usage
|
## API Usage
|
||||||
@@ -454,7 +458,7 @@ curl -s http://localhost:8080/v1/models | jq
|
|||||||
|
|
||||||
**CLI client test:**
|
**CLI client test:**
|
||||||
```bash
|
```bash
|
||||||
bun run cli.ts "What is 2+2?"
|
bun run scripts/cli.ts "What is 2+2?"
|
||||||
```
|
```
|
||||||
|
|
||||||
**Web frontend:**
|
**Web frontend:**
|
||||||
|
28
crates/gemma-runner/Cargo.toml
Normal file
28
crates/gemma-runner/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
[package]
|
||||||
|
name = "gemma-runner"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||||
|
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||||
|
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
||||||
|
|
||||||
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
|
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
hf-hub = "0.4"
|
||||||
|
tokenizers = "0.21"
|
||||||
|
anyhow = "1.0"
|
||||||
|
clap = { version = "4.0", features = ["derive", "string"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-chrome = "0.7"
|
||||||
|
tracing-subscriber = "0.3"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
|
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
137
crates/gemma-runner/README.md
Normal file
137
crates/gemma-runner/README.md
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# Gemma Runner
|
||||||
|
|
||||||
|
Fast Gemma inference with Candle framework in Rust.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Support for multiple Gemma model versions (v1, v2, v3)
|
||||||
|
- GPU acceleration with CUDA and Metal
|
||||||
|
- Configurable sampling parameters
|
||||||
|
- Multiple model variants including instruct and code models
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
### Gemma v1
|
||||||
|
- `gemma-2b` - Base 2B model
|
||||||
|
- `gemma-7b` - Base 7B model
|
||||||
|
- `gemma-2b-it` - Instruct 2B model
|
||||||
|
- `gemma-7b-it` - Instruct 7B model
|
||||||
|
- `gemma-1.1-2b-it` - Instruct 2B v1.1 model
|
||||||
|
- `gemma-1.1-7b-it` - Instruct 7B v1.1 model
|
||||||
|
|
||||||
|
### CodeGemma
|
||||||
|
- `codegemma-2b` - Code base 2B model
|
||||||
|
- `codegemma-7b` - Code base 7B model
|
||||||
|
- `codegemma-2b-it` - Code instruct 2B model
|
||||||
|
- `codegemma-7b-it` - Code instruct 7B model
|
||||||
|
|
||||||
|
### Gemma v2
|
||||||
|
- `gemma-2-2b` - Base 2B v2 model (default)
|
||||||
|
- `gemma-2-2b-it` - Instruct 2B v2 model
|
||||||
|
- `gemma-2-9b` - Base 9B v2 model
|
||||||
|
- `gemma-2-9b-it` - Instruct 9B v2 model
|
||||||
|
|
||||||
|
### Gemma v3
|
||||||
|
- `gemma-3-1b` - Base 1B v3 model
|
||||||
|
- `gemma-3-1b-it` - Instruct 1B v3 model
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd gemma-runner
|
||||||
|
cargo build --release
|
||||||
|
```
|
||||||
|
|
||||||
|
For GPU support:
|
||||||
|
```bash
|
||||||
|
# CUDA
|
||||||
|
cargo build --release --features cuda
|
||||||
|
|
||||||
|
# Metal (macOS)
|
||||||
|
cargo build --release --features metal
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run with default model (gemma-2-2b)
|
||||||
|
cargo run -- --prompt "The capital of France is"
|
||||||
|
|
||||||
|
# Specify a different model
|
||||||
|
cargo run -- --model gemma-2b-it --prompt "Explain quantum computing"
|
||||||
|
|
||||||
|
# Generate more tokens
|
||||||
|
cargo run -- --model codegemma-2b-it --prompt "Write a Python function to sort a list" --max-tokens 200
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use CPU instead of GPU
|
||||||
|
cargo run -- --cpu --prompt "Hello world"
|
||||||
|
|
||||||
|
# Adjust sampling parameters
|
||||||
|
cargo run -- --temperature 0.8 --top-p 0.9 --prompt "Write a story about"
|
||||||
|
|
||||||
|
# Use custom model from HuggingFace Hub
|
||||||
|
cargo run -- --model-id "google/gemma-2-2b-it" --prompt "What is AI?"
|
||||||
|
|
||||||
|
# Enable tracing for performance analysis
|
||||||
|
cargo run -- --tracing --prompt "Explain machine learning"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Line Arguments
|
||||||
|
|
||||||
|
- `--prompt, -p` - The prompt to generate text from (default: "The capital of France is")
|
||||||
|
- `--model, -m` - The model to use (default: "gemma-2-2b")
|
||||||
|
- `--cpu` - Run on CPU rather than GPU
|
||||||
|
- `--temperature, -t` - Sampling temperature (optional)
|
||||||
|
- `--top-p` - Nucleus sampling probability cutoff (optional)
|
||||||
|
- `--seed` - Random seed (default: 299792458)
|
||||||
|
- `--max-tokens, -n` - Maximum tokens to generate (default: 100)
|
||||||
|
- `--model-id` - Custom model ID from HuggingFace Hub
|
||||||
|
- `--revision` - Model revision (default: "main")
|
||||||
|
- `--use-flash-attn` - Use flash attention
|
||||||
|
- `--repeat-penalty` - Repetition penalty (default: 1.1)
|
||||||
|
- `--repeat-last-n` - Context size for repeat penalty (default: 64)
|
||||||
|
- `--dtype` - Data type (f16, bf16, f32)
|
||||||
|
- `--tracing` - Enable performance tracing
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Text Generation
|
||||||
|
```bash
|
||||||
|
cargo run -- --model gemma-2b-it --prompt "Explain the theory of relativity" --max-tokens 150
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Generation
|
||||||
|
```bash
|
||||||
|
cargo run -- --model codegemma-7b-it --prompt "Write a Rust function to calculate factorial" --max-tokens 100
|
||||||
|
```
|
||||||
|
|
||||||
|
### Creative Writing
|
||||||
|
```bash
|
||||||
|
cargo run -- --model gemma-7b-it --temperature 0.9 --prompt "Once upon a time in a magical forest" --max-tokens 200
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat with Gemma 3 (Instruct format)
|
||||||
|
```bash
|
||||||
|
cargo run -- --model gemma-3-1b-it --prompt "How do I learn Rust programming?"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Notes
|
||||||
|
|
||||||
|
- GPU acceleration is automatically detected and used when available
|
||||||
|
- BF16 precision is used on CUDA for better performance
|
||||||
|
- F32 precision is used on CPU
|
||||||
|
- Flash attention can be enabled with `--use-flash-attn` for supported models
|
||||||
|
- Model files are cached locally after first download
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Rust 1.70+
|
||||||
|
- CUDA toolkit (for CUDA support)
|
||||||
|
- Metal (automatically available on macOS)
|
||||||
|
- Internet connection for first-time model download
|
389
crates/gemma-runner/src/gemma_api.rs
Normal file
389
crates/gemma-runner/src/gemma_api.rs
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::ValueEnum;
|
||||||
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
|
|
||||||
|
// Removed gemma_cli import as it's not needed for the API
|
||||||
|
use candle_core::{utils, DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
use std::sync::mpsc::{self, Receiver, Sender};
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
pub enum WhichModel {
|
||||||
|
#[value(name = "gemma-2b")]
|
||||||
|
Base2B,
|
||||||
|
#[value(name = "gemma-7b")]
|
||||||
|
Base7B,
|
||||||
|
#[value(name = "gemma-2b-it")]
|
||||||
|
Instruct2B,
|
||||||
|
#[value(name = "gemma-7b-it")]
|
||||||
|
Instruct7B,
|
||||||
|
#[value(name = "gemma-1.1-2b-it")]
|
||||||
|
InstructV1_1_2B,
|
||||||
|
#[value(name = "gemma-1.1-7b-it")]
|
||||||
|
InstructV1_1_7B,
|
||||||
|
#[value(name = "codegemma-2b")]
|
||||||
|
CodeBase2B,
|
||||||
|
#[value(name = "codegemma-7b")]
|
||||||
|
CodeBase7B,
|
||||||
|
#[value(name = "codegemma-2b-it")]
|
||||||
|
CodeInstruct2B,
|
||||||
|
#[value(name = "codegemma-7b-it")]
|
||||||
|
CodeInstruct7B,
|
||||||
|
#[value(name = "gemma-2-2b")]
|
||||||
|
BaseV2_2B,
|
||||||
|
#[value(name = "gemma-2-2b-it")]
|
||||||
|
InstructV2_2B,
|
||||||
|
#[value(name = "gemma-2-9b")]
|
||||||
|
BaseV2_9B,
|
||||||
|
#[value(name = "gemma-2-9b-it")]
|
||||||
|
InstructV2_9B,
|
||||||
|
#[value(name = "gemma-3-1b")]
|
||||||
|
BaseV3_1B,
|
||||||
|
#[value(name = "gemma-3-1b-it")]
|
||||||
|
InstructV3_1B,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
V1(Model1),
|
||||||
|
V2(Model2),
|
||||||
|
V3(Model3),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle_core::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::V1(m) => m.forward(input_ids, pos),
|
||||||
|
Self::V2(m) => m.forward(input_ids, pos),
|
||||||
|
Self::V3(m) => m.forward(input_ids, pos),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(cpu: bool) -> Result<Device> {
|
||||||
|
if cpu {
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
} else if utils::cuda_is_available() {
|
||||||
|
Ok(Device::new_cuda(0)?)
|
||||||
|
} else if utils::metal_is_available() {
|
||||||
|
Ok(Device::new_metal(0)?)
|
||||||
|
} else {
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
||||||
|
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
||||||
|
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> {
|
||||||
|
self.tokenizer.clear();
|
||||||
|
|
||||||
|
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
// Warm the tokenizer's internal state with prompt tokens (so merges are correct),
|
||||||
|
// but do not send them to the receiver.
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
let _ = self.tokenizer.next_token(t)?;
|
||||||
|
}
|
||||||
|
// Make sure stdout isn't holding anything (if caller also prints).
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
|
||||||
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <eos> token"),
|
||||||
|
};
|
||||||
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => {
|
||||||
|
eprintln!("Warning: <end_of_turn> token not found, using <eos> as backup");
|
||||||
|
eos_token
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
|
||||||
|
if next_token == eos_token || next_token == eot_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
// Best-effort send; ignore if receiver dropped.
|
||||||
|
let _ = tx.send(Ok(t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let _dt = start_gen.elapsed();
|
||||||
|
|
||||||
|
// Flush any remaining buffered bytes as one final chunk.
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
let _ = tx.send(Ok(rest));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct GemmaInferenceConfig {
|
||||||
|
pub tracing: bool,
|
||||||
|
pub prompt: String,
|
||||||
|
pub model: WhichModel,
|
||||||
|
pub cpu: bool,
|
||||||
|
pub dtype: Option<String>,
|
||||||
|
pub model_id: Option<String>,
|
||||||
|
pub revision: String,
|
||||||
|
pub use_flash_attn: bool,
|
||||||
|
pub seed: u64,
|
||||||
|
pub temperature: f64,
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
pub repeat_penalty: f32,
|
||||||
|
pub repeat_last_n: usize,
|
||||||
|
pub max_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GemmaInferenceConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
tracing: false,
|
||||||
|
prompt: "Hello".to_string(),
|
||||||
|
model: WhichModel::InstructV2_2B,
|
||||||
|
cpu: false,
|
||||||
|
dtype: None,
|
||||||
|
model_id: None,
|
||||||
|
revision: "main".to_string(),
|
||||||
|
use_flash_attn: false,
|
||||||
|
seed: 299792458,
|
||||||
|
temperature: 0.8,
|
||||||
|
top_p: None,
|
||||||
|
repeat_penalty: 1.1,
|
||||||
|
repeat_last_n: 128,
|
||||||
|
max_tokens: 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removed From<Args> implementation as Args is not available and not needed for API usage
|
||||||
|
|
||||||
|
/// Builds the model and returns a channel that streams generated token strings.
|
||||||
|
/// If model setup fails, the `Result` is returned immediately.
|
||||||
|
pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String>>> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let _guard = if cfg.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
utils::with_avx(),
|
||||||
|
utils::with_neon(),
|
||||||
|
utils::with_simd128(),
|
||||||
|
utils::with_f16c()
|
||||||
|
);
|
||||||
|
|
||||||
|
let device = device(cfg.cpu)?;
|
||||||
|
println!("Device: {:?}", device);
|
||||||
|
|
||||||
|
let dtype = match cfg.dtype.as_deref() {
|
||||||
|
Some("f16") => DType::F16,
|
||||||
|
Some("bf16") => DType::BF16,
|
||||||
|
Some("f32") => DType::F32,
|
||||||
|
Some(dtype) => anyhow::bail!("Unsupported dtype {dtype}"),
|
||||||
|
None => {
|
||||||
|
if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("Using dtype: {:?}", dtype);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
|
||||||
|
let model_id = cfg.model_id.unwrap_or_else(|| {
|
||||||
|
match cfg.model {
|
||||||
|
WhichModel::Base2B => "google/gemma-2b",
|
||||||
|
WhichModel::Base7B => "google/gemma-7b",
|
||||||
|
WhichModel::Instruct2B => "google/gemma-2b-it",
|
||||||
|
WhichModel::Instruct7B => "google/gemma-7b-it",
|
||||||
|
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
|
||||||
|
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
|
||||||
|
WhichModel::CodeBase2B => "google/codegemma-2b",
|
||||||
|
WhichModel::CodeBase7B => "google/codegemma-7b",
|
||||||
|
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
|
||||||
|
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
|
||||||
|
WhichModel::BaseV2_2B => "google/gemma-2-2b",
|
||||||
|
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
|
||||||
|
WhichModel::BaseV2_9B => "google/gemma-2-9b",
|
||||||
|
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
|
||||||
|
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||||
|
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||||
|
}
|
||||||
|
.to_string()
|
||||||
|
});
|
||||||
|
|
||||||
|
println!("Loading model: {}", &model_id);
|
||||||
|
|
||||||
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, cfg.revision));
|
||||||
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
|
let filenames = match cfg.model {
|
||||||
|
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||||
|
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
};
|
||||||
|
println!("Retrieved files in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
|
||||||
|
let model: Model = match cfg.model {
|
||||||
|
WhichModel::Base2B
|
||||||
|
| WhichModel::Base7B
|
||||||
|
| WhichModel::Instruct2B
|
||||||
|
| WhichModel::Instruct7B
|
||||||
|
| WhichModel::InstructV1_1_2B
|
||||||
|
| WhichModel::InstructV1_1_7B
|
||||||
|
| WhichModel::CodeBase2B
|
||||||
|
| WhichModel::CodeBase7B
|
||||||
|
| WhichModel::CodeInstruct2B
|
||||||
|
| WhichModel::CodeInstruct7B => {
|
||||||
|
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
|
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
|
Model::V1(model)
|
||||||
|
}
|
||||||
|
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => {
|
||||||
|
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
|
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
|
Model::V2(model)
|
||||||
|
}
|
||||||
|
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
|
||||||
|
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
|
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
|
Model::V3(model)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("Loaded model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
cfg.seed,
|
||||||
|
cfg.temperature.into(),
|
||||||
|
cfg.top_p,
|
||||||
|
cfg.repeat_penalty,
|
||||||
|
cfg.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
let prompt = match cfg.model {
|
||||||
|
WhichModel::InstructV3_1B => {
|
||||||
|
format!(
|
||||||
|
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
||||||
|
cfg.prompt
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => cfg.prompt,
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Starting inference...");
|
||||||
|
|
||||||
|
// Create the channel after successful setup.
|
||||||
|
let (tx, rx) = mpsc::channel::<Result<String>>();
|
||||||
|
|
||||||
|
// Spawn generation thread; send tokens to the channel.
|
||||||
|
thread::spawn(move || {
|
||||||
|
// If generation fails, forward the error once.
|
||||||
|
if let Err(e) = pipeline.run_stream(&prompt, cfg.max_tokens, tx.clone()) {
|
||||||
|
let _ = tx.send(Err(e));
|
||||||
|
}
|
||||||
|
// Channel closes when tx is dropped.
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
}
|
97
crates/gemma-runner/src/gemma_cli.rs
Normal file
97
crates/gemma-runner/src/gemma_cli.rs
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
use std::io::Write;
|
||||||
|
use clap::Parser;
|
||||||
|
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
||||||
|
pub struct Args {
|
||||||
|
/// The prompt to generate text from
|
||||||
|
#[arg(short, long, default_value = "The capital of France is")]
|
||||||
|
pub(crate) prompt: String,
|
||||||
|
|
||||||
|
/// The model to use
|
||||||
|
#[arg(short, long, default_value = "gemma-2-2b")]
|
||||||
|
pub(crate) model: WhichModel,
|
||||||
|
|
||||||
|
/// Run on CPU rather than GPU
|
||||||
|
#[arg(long)]
|
||||||
|
pub(crate) cpu: bool,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples
|
||||||
|
#[arg(short, long)]
|
||||||
|
pub(crate) temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff
|
||||||
|
#[arg(long)]
|
||||||
|
pub(crate) top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
pub(crate) seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens)
|
||||||
|
#[arg(short = 'n', long, default_value_t = 100)]
|
||||||
|
pub(crate) max_tokens: usize,
|
||||||
|
|
||||||
|
/// Use different dtype than default
|
||||||
|
#[arg(long)]
|
||||||
|
pub(crate) dtype: Option<String>,
|
||||||
|
|
||||||
|
/// Custom model ID from HuggingFace Hub
|
||||||
|
#[arg(long)]
|
||||||
|
pub(crate) model_id: Option<String>,
|
||||||
|
|
||||||
|
/// Model revision
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
pub(crate) revision: String,
|
||||||
|
|
||||||
|
/// Use flash attention
|
||||||
|
#[arg(long)]
|
||||||
|
pub(crate) use_flash_attn: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
pub(crate) repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
pub(crate) repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// Enable tracing
|
||||||
|
#[arg(long)]
|
||||||
|
pub(crate) tracing: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_cli() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let cfg = GemmaInferenceConfig {
|
||||||
|
tracing: args.tracing,
|
||||||
|
prompt: args.prompt,
|
||||||
|
model: args.model,
|
||||||
|
cpu: args.cpu,
|
||||||
|
dtype: args.dtype,
|
||||||
|
model_id: args.model_id,
|
||||||
|
revision: args.revision,
|
||||||
|
use_flash_attn: args.use_flash_attn,
|
||||||
|
seed: args.seed,
|
||||||
|
temperature: args.temperature.unwrap_or(0.8),
|
||||||
|
top_p: args.top_p,
|
||||||
|
repeat_penalty: args.repeat_penalty,
|
||||||
|
repeat_last_n: args.repeat_last_n,
|
||||||
|
max_tokens: args.max_tokens,
|
||||||
|
};
|
||||||
|
let rx = run_gemma_api(cfg)?;
|
||||||
|
for msg in rx {
|
||||||
|
match msg {
|
||||||
|
Ok(tok) => {
|
||||||
|
print!("{tok}");
|
||||||
|
let _ = std::io::stdout().flush(); // <- force it out now
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("generation error: {e}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
3
crates/gemma-runner/src/lib.rs
Normal file
3
crates/gemma-runner/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod gemma_api;
|
||||||
|
|
||||||
|
pub use gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
17
crates/gemma-runner/src/main.rs
Normal file
17
crates/gemma-runner/src/main.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
mod gemma_cli;
|
||||||
|
mod gemma_api;
|
||||||
|
|
||||||
|
use anyhow::Error;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use crate::gemma_cli::run_cli;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
/// just a placeholder, not used for anything
|
||||||
|
fn main() -> std::result::Result<(), Error> {
|
||||||
|
run_cli()
|
||||||
|
}
|
@@ -3,8 +3,6 @@ name = "helm-chart-tool"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[workspace]
|
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "helm-chart-tool"
|
name = "helm-chart-tool"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
@@ -3,9 +3,16 @@ name = "inference-engine"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name="cli"
|
name="gemma_inference"
|
||||||
path = "src/cli_main.rs"
|
path = "src/gemma_inference.rs"
|
||||||
|
required-features = ["bin"]
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name="llama_inference"
|
||||||
|
path = "src/llama_inference.rs"
|
||||||
|
required-features = ["bin"]
|
||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
@@ -50,6 +57,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
|||||||
uuid = { version = "1.7.0", features = ["v4"] }
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
reborrow = "0.5.5"
|
reborrow = "0.5.5"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
|
gemma-runner = { path = "../gemma-runner" }
|
||||||
|
llama-runner = { path = "../llama-runner" }
|
||||||
|
|
||||||
# --- Add this section for conditional compilation ---
|
# --- Add this section for conditional compilation ---
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
@@ -83,6 +92,9 @@ tokio = "1.43.0"
|
|||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
bin = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[package.metadata.compose]
|
[package.metadata.compose]
|
||||||
|
@@ -1,72 +0,0 @@
|
|||||||
use clap::Parser;
|
|
||||||
use crate::model::Which;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
pub struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
pub cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
pub tracing: bool,
|
|
||||||
|
|
||||||
/// Run in server mode with OpenAI compatible API
|
|
||||||
#[arg(long)]
|
|
||||||
pub server: bool,
|
|
||||||
|
|
||||||
/// Port to use for the server
|
|
||||||
#[arg(long, default_value_t = 3777)]
|
|
||||||
pub port: u16,
|
|
||||||
|
|
||||||
/// Prompt for text generation (not used in server mode)
|
|
||||||
#[arg(long)]
|
|
||||||
pub prompt: Option<String>,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
pub top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
pub seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
|
||||||
pub sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
pub revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub config_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
pub repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
pub repeat_last_n: usize,
|
|
||||||
|
|
||||||
/// The model to use.
|
|
||||||
#[arg(long, default_value = "3-1b-it")]
|
|
||||||
pub which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub use_flash_attn: bool,
|
|
||||||
}
|
|
@@ -1,912 +0,0 @@
|
|||||||
mod token_output_stream;
|
|
||||||
mod utilities_lib;
|
|
||||||
|
|
||||||
#[cfg(feature = "intel-mkl-src")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate-src")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
extern crate metal_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use axum::{
|
|
||||||
extract::State,
|
|
||||||
http::StatusCode,
|
|
||||||
response::IntoResponse,
|
|
||||||
routing::{get, post},
|
|
||||||
Json, Router,
|
|
||||||
};
|
|
||||||
use clap::Parser;
|
|
||||||
use either::Either;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
|
||||||
use utoipa::ToSchema;
|
|
||||||
|
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
|
||||||
|
|
||||||
// OpenAI API compatible structs
|
|
||||||
|
|
||||||
/// Inner content structure for messages that can be either a string or key-value pairs
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct MessageInnerContent(
|
|
||||||
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageInnerContent {
|
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
|
||||||
(
|
|
||||||
"MessageInnerContent",
|
|
||||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Function for MessageInnerContent Schema generation to handle `Either`
|
|
||||||
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
|
||||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
|
||||||
|
|
||||||
Schema::OneOf(
|
|
||||||
OneOfBuilder::new()
|
|
||||||
// Either::Left - simple string
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))
|
|
||||||
// Either::Right - object with string values
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new()
|
|
||||||
.schema_type(SchemaType::Object)
|
|
||||||
.additional_properties(Some(RefOr::T(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))))
|
|
||||||
.build(),
|
|
||||||
))
|
|
||||||
.build(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Message content that can be either simple text or complex structured content
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct MessageContent(
|
|
||||||
#[serde(with = "either::serde_untagged")]
|
|
||||||
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageContent {
|
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
|
||||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Function for MessageContent Schema generation to handle `Either`
|
|
||||||
fn message_content_schema() -> utoipa::openapi::Schema {
|
|
||||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
|
||||||
|
|
||||||
Schema::OneOf(
|
|
||||||
OneOfBuilder::new()
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))
|
|
||||||
.item(Schema::Array(
|
|
||||||
ArrayBuilder::new()
|
|
||||||
.items(RefOr::T(Schema::Object(
|
|
||||||
ObjectBuilder::new()
|
|
||||||
.schema_type(SchemaType::Object)
|
|
||||||
.additional_properties(Some(RefOr::Ref(
|
|
||||||
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
|
||||||
)))
|
|
||||||
.build(),
|
|
||||||
)))
|
|
||||||
.build(),
|
|
||||||
))
|
|
||||||
.build(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Represents a single message in a conversation
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
pub struct Message {
|
|
||||||
/// The message content
|
|
||||||
pub content: Option<MessageContent>,
|
|
||||||
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
|
||||||
pub role: String,
|
|
||||||
pub name: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stop token configuration for generation
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum StopTokens {
|
|
||||||
/// Multiple possible stop sequences
|
|
||||||
Multi(Vec<String>),
|
|
||||||
/// Single stop sequence
|
|
||||||
Single(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
fn default_false() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
fn default_1usize() -> usize {
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
fn default_model() -> String {
|
|
||||||
"default".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion request following OpenAI's specification
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
|
||||||
pub messages: Vec<Message>,
|
|
||||||
#[schema(example = "gemma-3-1b-it")]
|
|
||||||
#[serde(default = "default_model")]
|
|
||||||
pub model: String,
|
|
||||||
#[serde(default = "default_false")]
|
|
||||||
#[schema(example = false)]
|
|
||||||
pub logprobs: bool,
|
|
||||||
#[schema(example = 256)]
|
|
||||||
pub max_tokens: Option<usize>,
|
|
||||||
#[serde(rename = "n")]
|
|
||||||
#[serde(default = "default_1usize")]
|
|
||||||
#[schema(example = 1)]
|
|
||||||
pub n_choices: usize,
|
|
||||||
#[schema(example = 0.7)]
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
#[schema(example = 0.9)]
|
|
||||||
pub top_p: Option<f64>,
|
|
||||||
#[schema(example = false)]
|
|
||||||
pub stream: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion choice
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionChoice {
|
|
||||||
pub index: usize,
|
|
||||||
pub message: Message,
|
|
||||||
pub finish_reason: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion response
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatCompletionChoice>,
|
|
||||||
pub usage: Usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Token usage information
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: usize,
|
|
||||||
pub completion_tokens: usize,
|
|
||||||
pub total_tokens: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Application state shared between handlers
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct AppState {
|
|
||||||
text_generation: Arc<Mutex<TextGeneration>>,
|
|
||||||
model_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chat completions endpoint handler
|
|
||||||
async fn chat_completions(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Json(request): Json<ChatCompletionRequest>,
|
|
||||||
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
||||||
let mut prompt = String::new();
|
|
||||||
|
|
||||||
// Convert messages to a prompt string
|
|
||||||
for message in &request.messages {
|
|
||||||
let role = &message.role;
|
|
||||||
let content = match &message.content {
|
|
||||||
Some(content) => match &content.0 {
|
|
||||||
Either::Left(text) => text.clone(),
|
|
||||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
|
||||||
},
|
|
||||||
None => "".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Format based on role
|
|
||||||
match role.as_str() {
|
|
||||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
|
||||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
|
||||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
|
||||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the assistant prefix for the response
|
|
||||||
prompt.push_str("Assistant: ");
|
|
||||||
|
|
||||||
// Capture the output
|
|
||||||
let mut output = Vec::new();
|
|
||||||
{
|
|
||||||
let mut text_gen = state.text_generation.lock().await;
|
|
||||||
|
|
||||||
// Buffer to capture the output
|
|
||||||
let mut buffer = Vec::new();
|
|
||||||
|
|
||||||
// Run text generation
|
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
|
||||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
|
||||||
|
|
||||||
if let Err(e) = result {
|
|
||||||
return Err((
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"error": {
|
|
||||||
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
|
|
||||||
"type": "unsupported_api"
|
|
||||||
}
|
|
||||||
})),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert buffer to string
|
|
||||||
if let Ok(text) = String::from_utf8(buffer) {
|
|
||||||
output.push(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create response
|
|
||||||
let response = ChatCompletionResponse {
|
|
||||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
|
||||||
object: "chat.completion".to_string(),
|
|
||||||
created: std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs(),
|
|
||||||
model: request.model,
|
|
||||||
choices: vec![ChatCompletionChoice {
|
|
||||||
index: 0,
|
|
||||||
message: Message {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: Some(MessageContent(Either::Left(output.join("")))),
|
|
||||||
name: None,
|
|
||||||
},
|
|
||||||
finish_reason: "stop".to_string(),
|
|
||||||
}],
|
|
||||||
usage: Usage {
|
|
||||||
prompt_tokens: prompt.len() / 4, // Rough estimate
|
|
||||||
completion_tokens: output.join("").len() / 4, // Rough estimate
|
|
||||||
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
// Return the response as JSON
|
|
||||||
Ok(Json(response))
|
|
||||||
}
|
|
||||||
|
|
||||||
use candle_core::{DType, Device, MetalDevice, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{Repo, RepoType, api::sync::Api};
|
|
||||||
use serde_json::json;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use crate::token_output_stream::TokenOutputStream;
|
|
||||||
use crate::utilities_lib::device;
|
|
||||||
|
|
||||||
// Create the router with the chat completions endpoint
|
|
||||||
fn create_router(app_state: AppState) -> Router {
|
|
||||||
// CORS layer to allow requests from any origin
|
|
||||||
let cors = CorsLayer::new()
|
|
||||||
.allow_origin(Any)
|
|
||||||
.allow_methods(Any)
|
|
||||||
.allow_headers(Any);
|
|
||||||
|
|
||||||
Router::new()
|
|
||||||
// OpenAI compatible endpoints
|
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
|
||||||
// Add more endpoints as needed
|
|
||||||
.layer(cors)
|
|
||||||
.with_state(app_state)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
#[value(name = "2b")]
|
|
||||||
Base2B,
|
|
||||||
#[value(name = "7b")]
|
|
||||||
Base7B,
|
|
||||||
#[value(name = "2b-it")]
|
|
||||||
Instruct2B,
|
|
||||||
#[value(name = "7b-it")]
|
|
||||||
Instruct7B,
|
|
||||||
#[value(name = "1.1-2b-it")]
|
|
||||||
InstructV1_1_2B,
|
|
||||||
#[value(name = "1.1-7b-it")]
|
|
||||||
InstructV1_1_7B,
|
|
||||||
#[value(name = "code-2b")]
|
|
||||||
CodeBase2B,
|
|
||||||
#[value(name = "code-7b")]
|
|
||||||
CodeBase7B,
|
|
||||||
#[value(name = "code-2b-it")]
|
|
||||||
CodeInstruct2B,
|
|
||||||
#[value(name = "code-7b-it")]
|
|
||||||
CodeInstruct7B,
|
|
||||||
#[value(name = "2-2b")]
|
|
||||||
BaseV2_2B,
|
|
||||||
#[value(name = "2-2b-it")]
|
|
||||||
InstructV2_2B,
|
|
||||||
#[value(name = "2-9b")]
|
|
||||||
BaseV2_9B,
|
|
||||||
#[value(name = "2-9b-it")]
|
|
||||||
InstructV2_9B,
|
|
||||||
#[value(name = "3-1b")]
|
|
||||||
BaseV3_1B,
|
|
||||||
#[value(name = "3-1b-it")]
|
|
||||||
InstructV3_1B,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum Model {
|
|
||||||
V1(Model1),
|
|
||||||
V2(Model2),
|
|
||||||
V3(Model3),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::V1(m) => m.forward(input_ids, pos),
|
|
||||||
Self::V2(m) => m.forward(input_ids, pos),
|
|
||||||
Self::V3(m) => m.forward(input_ids, pos),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run text generation and print to stdout
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => {
|
|
||||||
println!(
|
|
||||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
|
||||||
);
|
|
||||||
eos_token
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run text generation and write to a buffer
|
|
||||||
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
|
|
||||||
// Write prompt tokens to output
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => {
|
|
||||||
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
|
||||||
eos_token
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Determine if we're using a Model3 (gemma-3) variant
|
|
||||||
let is_model3 = match &self.model {
|
|
||||||
Model::V3(_) => true,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
|
|
||||||
// For Model3, we need to use a different approach
|
|
||||||
if is_model3 {
|
|
||||||
// For gemma-3 models, we'll generate one token at a time with the full context
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
|
|
||||||
// Initial generation with the full prompt
|
|
||||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
|
||||||
let mut logits = self.model.forward(&input, 0)?;
|
|
||||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
|
|
||||||
for _ in 0..sample_len {
|
|
||||||
// Apply repeat penalty if needed
|
|
||||||
let current_logits = if self.repeat_penalty == 1. {
|
|
||||||
logits.clone()
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For the next iteration, just use the new token
|
|
||||||
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
|
||||||
logits = self.model.forward(&new_input, tokens.len() - 1)?;
|
|
||||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Standard approach for other models
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write any remaining tokens
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
write!(output, "{}", rest)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// Run in server mode with OpenAI compatible API
|
|
||||||
#[arg(long)]
|
|
||||||
server: bool,
|
|
||||||
|
|
||||||
/// Port to use for the server
|
|
||||||
#[arg(long, default_value_t = 3777)]
|
|
||||||
port: u16,
|
|
||||||
|
|
||||||
/// Prompt for text generation (not used in server mode)
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: Option<String>,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
config_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
|
|
||||||
/// The model to use.
|
|
||||||
#[arg(long, default_value = "3-1b-it")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
use_flash_attn: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle_core::utils::with_avx(),
|
|
||||||
candle_core::utils::with_neon(),
|
|
||||||
candle_core::utils::with_simd128(),
|
|
||||||
candle_core::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match &args.model_id {
|
|
||||||
Some(model_id) => model_id.to_string(),
|
|
||||||
None => match args.which {
|
|
||||||
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
|
||||||
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
|
||||||
Which::Base2B => "google/gemma-2b".to_string(),
|
|
||||||
Which::Base7B => "google/gemma-7b".to_string(),
|
|
||||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
|
||||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
|
||||||
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
|
||||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
|
||||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
|
||||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
|
||||||
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
|
||||||
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
|
||||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
|
||||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
|
||||||
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
|
||||||
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
model_id.clone(),
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision,
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let config_filename = match args.config_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("config.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => match args.which {
|
|
||||||
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
|
||||||
_ => utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let initial_device = utilities_lib::device(args.cpu)?;
|
|
||||||
|
|
||||||
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
|
|
||||||
let is_v3_model = matches!(args.which, Which::BaseV3_1B | Which::InstructV3_1B);
|
|
||||||
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.cpu;
|
|
||||||
|
|
||||||
// Use CPU for V3 models on Metal due to missing implementations
|
|
||||||
let device = if is_v3_model && is_metal {
|
|
||||||
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
|
|
||||||
Device::Cpu
|
|
||||||
} else {
|
|
||||||
initial_device
|
|
||||||
};
|
|
||||||
|
|
||||||
let dtype = if device.is_cuda() {
|
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use the selected device and dtype
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
let model = match args.which {
|
|
||||||
Which::Base2B
|
|
||||||
| Which::Base7B
|
|
||||||
| Which::Instruct2B
|
|
||||||
| Which::Instruct7B
|
|
||||||
| Which::InstructV1_1_2B
|
|
||||||
| Which::InstructV1_1_7B
|
|
||||||
| Which::CodeBase2B
|
|
||||||
| Which::CodeBase7B
|
|
||||||
| Which::CodeInstruct2B
|
|
||||||
| Which::CodeInstruct7B => {
|
|
||||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
|
||||||
Model::V1(model)
|
|
||||||
}
|
|
||||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
|
||||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
|
||||||
Model::V2(model)
|
|
||||||
}
|
|
||||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
|
||||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
|
||||||
Model::V3(model)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
|
|
||||||
if args.server {
|
|
||||||
// Start the server
|
|
||||||
println!("Starting server on port {}", args.port);
|
|
||||||
|
|
||||||
// Create app state
|
|
||||||
let app_state = AppState {
|
|
||||||
text_generation: Arc::new(Mutex::new(pipeline)),
|
|
||||||
model_id,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create router
|
|
||||||
let app = create_router(app_state);
|
|
||||||
|
|
||||||
// Run the server
|
|
||||||
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
|
|
||||||
|
|
||||||
// Use tokio to run the server
|
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
|
||||||
.enable_all()
|
|
||||||
.build()?
|
|
||||||
.block_on(async {
|
|
||||||
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
// Run in CLI mode
|
|
||||||
if let Some(prompt_text) = &args.prompt {
|
|
||||||
let prompt = match args.which {
|
|
||||||
Which::Base2B
|
|
||||||
| Which::Base7B
|
|
||||||
| Which::Instruct2B
|
|
||||||
| Which::Instruct7B
|
|
||||||
| Which::InstructV1_1_2B
|
|
||||||
| Which::InstructV1_1_7B
|
|
||||||
| Which::CodeBase2B
|
|
||||||
| Which::CodeBase7B
|
|
||||||
| Which::CodeInstruct2B
|
|
||||||
| Which::CodeInstruct7B
|
|
||||||
| Which::BaseV2_2B
|
|
||||||
| Which::InstructV2_2B
|
|
||||||
| Which::BaseV2_9B
|
|
||||||
| Which::InstructV2_9B
|
|
||||||
| Which::BaseV3_1B => prompt_text.clone(),
|
|
||||||
Which::InstructV3_1B => {
|
|
||||||
format!(
|
|
||||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
|
||||||
prompt_text
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut pipeline = pipeline;
|
|
||||||
pipeline.run(&prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
33
crates/inference-engine/src/inference.rs
Normal file
33
crates/inference-engine/src/inference.rs
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::Tensor;
|
||||||
|
|
||||||
|
/// ModelInference trait defines the common interface for model inference operations
|
||||||
|
///
|
||||||
|
/// This trait serves as an abstraction for different model implementations (Gemma and Llama)
|
||||||
|
/// to provide a unified interface for the inference engine.
|
||||||
|
pub trait ModelInference {
|
||||||
|
/// Perform model inference for the given input tensor starting at the specified position
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `input_ids` - The input tensor containing token IDs
|
||||||
|
/// * `pos` - The position to start generation from
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A tensor containing the logits for the next token prediction
|
||||||
|
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> Result<Tensor>;
|
||||||
|
|
||||||
|
/// Reset the model's internal state, if applicable
|
||||||
|
///
|
||||||
|
/// This method can be used to clear any cached state between inference requests
|
||||||
|
fn reset_state(&mut self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Get the model type name
|
||||||
|
///
|
||||||
|
/// Returns a string identifier for the model type (e.g., "Gemma", "Llama")
|
||||||
|
fn model_type(&self) -> &'static str;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Factory function type for creating model inference implementations
|
||||||
|
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
@@ -4,14 +4,16 @@ pub mod model;
|
|||||||
pub mod text_generation;
|
pub mod text_generation;
|
||||||
pub mod utilities_lib;
|
pub mod utilities_lib;
|
||||||
pub mod openai_types;
|
pub mod openai_types;
|
||||||
pub mod cli;
|
// pub mod cli;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
pub mod inference;
|
||||||
|
|
||||||
// Re-export key components for easier access
|
// Re-export key components for easier access
|
||||||
pub use model::{Model, Which};
|
pub use model::{Model, Which};
|
||||||
pub use text_generation::TextGeneration;
|
pub use text_generation::TextGeneration;
|
||||||
pub use token_output_stream::TokenOutputStream;
|
pub use token_output_stream::TokenOutputStream;
|
||||||
pub use server::{AppState, create_router};
|
pub use server::{AppState, create_router};
|
||||||
|
pub use inference::ModelInference;
|
||||||
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
|
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
pub enum Which {
|
pub enum Which {
|
||||||
@@ -37,12 +38,17 @@ pub enum Which {
|
|||||||
BaseV3_1B,
|
BaseV3_1B,
|
||||||
#[value(name = "3-1b-it")]
|
#[value(name = "3-1b-it")]
|
||||||
InstructV3_1B,
|
InstructV3_1B,
|
||||||
|
#[value(name = "llama-3.2-1b-it")]
|
||||||
|
LlamaInstruct3_2_1B,
|
||||||
|
#[value(name = "llama-3.2-3b-it")]
|
||||||
|
LlamaInstruct3_2_3B,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
V1(Model1),
|
V1(Model1),
|
||||||
V2(Model2),
|
V2(Model2),
|
||||||
V3(Model3),
|
V3(Model3),
|
||||||
|
Llama(LlamaModel),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@@ -51,6 +57,7 @@ impl Model {
|
|||||||
Self::V1(m) => m.forward(input_ids, pos),
|
Self::V1(m) => m.forward(input_ids, pos),
|
||||||
Self::V2(m) => m.forward(input_ids, pos),
|
Self::V2(m) => m.forward(input_ids, pos),
|
||||||
Self::V3(m) => m.forward(input_ids, pos),
|
Self::V3(m) => m.forward(input_ids, pos),
|
||||||
|
Self::Llama(m) => m.forward(input_ids, pos),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,6 +81,8 @@ impl Which {
|
|||||||
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||||
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||||
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||||
|
Self::LlamaInstruct3_2_1B => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
|
||||||
|
Self::LlamaInstruct3_2_3B => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,4 +96,8 @@ impl Which {
|
|||||||
pub fn is_v3_model(&self) -> bool {
|
pub fn is_v3_model(&self) -> bool {
|
||||||
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
|
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_llama_model(&self) -> bool {
|
||||||
|
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
|
||||||
|
}
|
||||||
}
|
}
|
@@ -5,299 +5,80 @@ use axum::{
|
|||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use candle_core::DType;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use futures_util::stream::{self, Stream};
|
use futures_util::stream::{self, Stream};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::{path::PathBuf, sync::Arc};
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
||||||
use crate::text_generation::TextGeneration;
|
use crate::Which;
|
||||||
use crate::{utilities_lib, Model as GemmaModel, Which};
|
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use hf_hub::api::sync::{Api, ApiError};
|
|
||||||
use hf_hub::{Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||||
|
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Shared app state
|
// Shared app state
|
||||||
// -------------------------
|
// -------------------------
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum ModelType {
|
||||||
|
Gemma,
|
||||||
|
Llama,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub text_generation: Arc<Mutex<TextGeneration>>,
|
pub model_type: ModelType,
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
// Store build args to recreate TextGeneration when needed
|
pub gemma_config: Option<GemmaInferenceConfig>,
|
||||||
pub build_args: PipelineArgs,
|
pub llama_config: Option<LlamaInferenceConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AppState {
|
impl Default for AppState {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let args = PipelineArgs::default();
|
let gemma_config = GemmaInferenceConfig {
|
||||||
let text_generation = build_pipeline(args.clone());
|
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
model_type: ModelType::Gemma,
|
||||||
model_id: args.model_id.clone(),
|
model_id: "gemma-3-1b-it".to_string(),
|
||||||
build_args: args,
|
gemma_config: Some(gemma_config),
|
||||||
|
llama_config: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Pipeline configuration
|
// Helper functions
|
||||||
// -------------------------
|
// -------------------------
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PipelineArgs {
|
|
||||||
pub model_id: String,
|
|
||||||
pub which: Which,
|
|
||||||
pub revision: Option<String>,
|
|
||||||
pub tokenizer_path: Option<PathBuf>,
|
|
||||||
pub config_path: Option<PathBuf>,
|
|
||||||
pub weight_paths: Vec<PathBuf>,
|
|
||||||
pub use_flash_attn: bool,
|
|
||||||
pub force_cpu: bool,
|
|
||||||
pub seed: u64,
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
pub top_p: Option<f64>,
|
|
||||||
pub repeat_penalty: f32,
|
|
||||||
pub repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for PipelineArgs {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
model_id: Which::InstructV3_1B.to_model_id().to_string(),
|
|
||||||
which: Which::InstructV3_1B,
|
|
||||||
revision: None,
|
|
||||||
tokenizer_path: None,
|
|
||||||
config_path: None,
|
|
||||||
weight_paths: Vec::new(),
|
|
||||||
use_flash_attn: false,
|
|
||||||
force_cpu: false,
|
|
||||||
seed: 299792458, // Speed of light in vacuum (m/s)
|
|
||||||
temperature: Some(0.8), // Good balance between creativity and coherence
|
|
||||||
top_p: Some(0.9), // Keep diverse but reasonable options
|
|
||||||
repeat_penalty: 1.2, // Stronger penalty for repetition to prevent looping
|
|
||||||
repeat_last_n: 64, // Consider last 64 tokens for repetition
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_model_id(model_id: &str) -> String {
|
fn normalize_model_id(model_id: &str) -> String {
|
||||||
if model_id.contains('/') {
|
model_id.to_lowercase().replace("_", "-")
|
||||||
model_id.to_string()
|
|
||||||
} else {
|
|
||||||
format!("google/{}", model_id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
model_id.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.to_string(),
|
|
||||||
));
|
|
||||||
match repo.get("config.json") {
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(e) => match e {
|
|
||||||
ApiError::RequestError(resp) => {
|
|
||||||
let error_str = resp.to_string();
|
|
||||||
if error_str.contains("404") {
|
|
||||||
anyhow::bail!(
|
|
||||||
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'."
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Err(anyhow::Error::new(ApiError::RequestError(resp)))
|
|
||||||
}
|
|
||||||
other => Err(anyhow::Error::new(other)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Pipeline builder
|
|
||||||
// -------------------------
|
|
||||||
|
|
||||||
pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle_core::utils::with_avx(),
|
|
||||||
candle_core::utils::with_neon(),
|
|
||||||
candle_core::utils::with_simd128(),
|
|
||||||
candle_core::utils::with_f16c()
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new().unwrap();
|
|
||||||
let revision = args.revision.as_deref().unwrap_or("main");
|
|
||||||
|
|
||||||
if args.model_id.trim().is_empty() {
|
|
||||||
panic!("No model ID specified.");
|
|
||||||
}
|
|
||||||
args.model_id = normalize_model_id(&args.model_id);
|
|
||||||
|
|
||||||
match ensure_repo_exists(&api, &args.model_id, revision) {
|
|
||||||
Ok(_) => {}
|
|
||||||
Err(e) => panic!("{}", e),
|
|
||||||
};
|
|
||||||
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
args.model_id.clone(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
let tokenizer_path = args
|
|
||||||
.tokenizer_path
|
|
||||||
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
|
|
||||||
let config_path = args
|
|
||||||
.config_path
|
|
||||||
.unwrap_or_else(|| repo.get("config.json").unwrap());
|
|
||||||
|
|
||||||
if !matches!(
|
|
||||||
args.which,
|
|
||||||
Which::Base2B
|
|
||||||
| Which::Base7B
|
|
||||||
| Which::Instruct2B
|
|
||||||
| Which::Instruct7B
|
|
||||||
| Which::InstructV1_1_2B
|
|
||||||
| Which::InstructV1_1_7B
|
|
||||||
| Which::CodeBase2B
|
|
||||||
| Which::CodeBase7B
|
|
||||||
| Which::CodeInstruct2B
|
|
||||||
| Which::CodeInstruct7B
|
|
||||||
| Which::BaseV2_2B
|
|
||||||
| Which::InstructV2_2B
|
|
||||||
| Which::BaseV2_9B
|
|
||||||
| Which::InstructV2_9B
|
|
||||||
| Which::BaseV3_1B
|
|
||||||
| Which::InstructV3_1B
|
|
||||||
) {
|
|
||||||
if args.model_id.contains("gemma-2-2b-it") {
|
|
||||||
args.which = Which::InstructV2_2B;
|
|
||||||
} else if args.model_id.contains("gemma-3-1b-it") {
|
|
||||||
args.which = Which::InstructV3_1B;
|
|
||||||
} else if let Ok(file) = std::fs::File::open(config_path.clone()) {
|
|
||||||
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
|
|
||||||
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
|
|
||||||
if model_type.contains("gemma3") {
|
|
||||||
args.which = Which::InstructV3_1B;
|
|
||||||
} else if model_type.contains("gemma2") {
|
|
||||||
args.which = Which::InstructV2_2B;
|
|
||||||
} else {
|
|
||||||
args.which = Which::Instruct2B;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let weight_paths = if !args.weight_paths.is_empty() {
|
|
||||||
args.weight_paths
|
|
||||||
} else {
|
|
||||||
match repo.get("model.safetensors") {
|
|
||||||
Ok(single) => vec![single],
|
|
||||||
Err(_) => match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
|
|
||||||
Ok(paths) => paths,
|
|
||||||
Err(e) => {
|
|
||||||
panic!("Unable to locate model weights: {}", e);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
|
||||||
|
|
||||||
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
|
|
||||||
let is_v3_model = args.which.is_v3_model();
|
|
||||||
let is_metal = !initial_device.is_cpu()
|
|
||||||
&& candle_core::utils::metal_is_available()
|
|
||||||
&& !args.force_cpu;
|
|
||||||
|
|
||||||
let device = if is_v3_model && is_metal {
|
|
||||||
candle_core::Device::Cpu
|
|
||||||
} else {
|
|
||||||
initial_device
|
|
||||||
};
|
|
||||||
|
|
||||||
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
|
|
||||||
|
|
||||||
let model = match args.which {
|
|
||||||
Which::Base2B
|
|
||||||
| Which::Base7B
|
|
||||||
| Which::Instruct2B
|
|
||||||
| Which::Instruct7B
|
|
||||||
| Which::InstructV1_1_2B
|
|
||||||
| Which::InstructV1_1_7B
|
|
||||||
| Which::CodeBase2B
|
|
||||||
| Which::CodeBase7B
|
|
||||||
| Which::CodeInstruct2B
|
|
||||||
| Which::CodeInstruct7B => {
|
|
||||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
|
||||||
GemmaModel::V1(Model1::new(args.use_flash_attn, &config, vb).unwrap())
|
|
||||||
}
|
|
||||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
|
||||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
|
||||||
GemmaModel::V2(Model2::new(args.use_flash_attn, &config, vb).unwrap())
|
|
||||||
}
|
|
||||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
|
||||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
|
||||||
GemmaModel::V3(Model3::new(args.use_flash_attn, &config, vb).unwrap())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
let mut system_prompt: Option<String> = None;
|
|
||||||
|
|
||||||
for message in messages {
|
for message in messages {
|
||||||
let content = match &message.content {
|
|
||||||
Some(content) => match &content.0 {
|
|
||||||
Either::Left(text) => text.clone(),
|
|
||||||
Either::Right(_) => "".to_string(),
|
|
||||||
},
|
|
||||||
None => "".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
match message.role.as_str() {
|
match message.role.as_str() {
|
||||||
"system" => system_prompt = Some(content),
|
"system" => {
|
||||||
"user" => {
|
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||||
prompt.push_str("<start_of_turn>user\n");
|
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
|
||||||
if let Some(sys_prompt) = system_prompt.take() {
|
}
|
||||||
prompt.push_str(&sys_prompt);
|
}
|
||||||
prompt.push_str("\n\n");
|
"user" => {
|
||||||
|
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||||
|
prompt.push_str(&format!("<start_of_turn>user\n{}<end_of_turn>\n", content));
|
||||||
}
|
}
|
||||||
prompt.push_str(&content);
|
|
||||||
prompt.push_str("<end_of_turn>\n");
|
|
||||||
}
|
}
|
||||||
"assistant" => {
|
"assistant" => {
|
||||||
prompt.push_str("<start_of_turn>model\n");
|
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||||
prompt.push_str(&content);
|
prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
|
||||||
prompt.push_str("<end_of_turn>\n");
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
@@ -325,14 +106,13 @@ pub async fn chat_completions_non_streaming_proxy(
|
|||||||
state: AppState,
|
state: AppState,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||||
let prompt = build_gemma_prompt(&request.messages);
|
|
||||||
|
|
||||||
// Enforce model selection behavior: reject if a different model is requested
|
// Enforce model selection behavior: reject if a different model is requested
|
||||||
let configured_model = state.build_args.model_id.clone();
|
let configured_model = state.model_id.clone();
|
||||||
let requested_model = request.model.clone();
|
let requested_model = request.model.clone();
|
||||||
if requested_model.to_lowercase() != "default" {
|
if requested_model.to_lowercase() != "default" {
|
||||||
let normalized_requested = normalize_model_id(&requested_model);
|
let normalized_requested = normalize_model_id(&requested_model);
|
||||||
if normalized_requested != configured_model {
|
let normalized_configured = normalize_model_id(&configured_model);
|
||||||
|
if normalized_requested != normalized_configured {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
@@ -349,14 +129,71 @@ pub async fn chat_completions_non_streaming_proxy(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let model_id = state.model_id.clone();
|
let model_id = state.model_id.clone();
|
||||||
|
|
||||||
let mut buffer = Vec::new();
|
|
||||||
{
|
|
||||||
let mut text_gen = state.text_generation.lock().await;
|
|
||||||
// Reset per-request state without rebuilding the whole pipeline
|
|
||||||
text_gen.reset_state();
|
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
|
||||||
|
// Build prompt based on model type
|
||||||
|
let prompt = match state.model_type {
|
||||||
|
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||||
|
ModelType::Llama => {
|
||||||
|
// For Llama, just use the last user message for now
|
||||||
|
request.messages.last()
|
||||||
|
.and_then(|m| m.content.as_ref())
|
||||||
|
.and_then(|c| match c {
|
||||||
|
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get streaming receiver based on model type
|
||||||
|
let rx = match state.model_type {
|
||||||
|
ModelType::Gemma => {
|
||||||
|
if let Some(mut config) = state.gemma_config {
|
||||||
|
config.prompt = prompt.clone();
|
||||||
|
config.max_tokens = max_tokens;
|
||||||
|
run_gemma_api(config).map_err(|e| (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||||
|
}))
|
||||||
|
))?
|
||||||
|
} else {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": "Gemma configuration not available" }
|
||||||
|
}))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ModelType::Llama => {
|
||||||
|
if let Some(mut config) = state.llama_config {
|
||||||
|
config.prompt = prompt.clone();
|
||||||
|
config.max_tokens = max_tokens;
|
||||||
|
run_llama_inference(config).map_err(|e| (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||||
|
}))
|
||||||
|
))?
|
||||||
|
} else {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": "Llama configuration not available" }
|
||||||
|
}))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Collect all tokens from the stream
|
||||||
|
let mut completion = String::new();
|
||||||
|
while let Ok(token_result) = rx.recv() {
|
||||||
|
match token_result {
|
||||||
|
Ok(token) => completion.push_str(&token),
|
||||||
|
Err(e) => {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
@@ -365,18 +202,7 @@ pub async fn chat_completions_non_streaming_proxy(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let completion = match String::from_utf8(buffer) {
|
|
||||||
Ok(s) => s,
|
|
||||||
Err(e) => {
|
|
||||||
return Err((
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"error": { "message": format!("UTF-8 conversion error: {}", e) }
|
|
||||||
})),
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
let response = ChatCompletionResponse {
|
let response = ChatCompletionResponse {
|
||||||
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
|
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
|
||||||
@@ -420,11 +246,12 @@ async fn handle_streaming_request(
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||||
// Validate requested model vs configured model
|
// Validate requested model vs configured model
|
||||||
let configured_model = state.build_args.model_id.clone();
|
let configured_model = state.model_id.clone();
|
||||||
let requested_model = request.model.clone();
|
let requested_model = request.model.clone();
|
||||||
if requested_model.to_lowercase() != "default" {
|
if requested_model.to_lowercase() != "default" {
|
||||||
let normalized_requested = normalize_model_id(&requested_model);
|
let normalized_requested = normalize_model_id(&requested_model);
|
||||||
if normalized_requested != configured_model {
|
let normalized_configured = normalize_model_id(&configured_model);
|
||||||
|
if normalized_requested != normalized_configured {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
@@ -447,9 +274,22 @@ async fn handle_streaming_request(
|
|||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.as_secs();
|
.as_secs();
|
||||||
let model_id = state.model_id.clone();
|
let model_id = state.model_id.clone();
|
||||||
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||||
|
|
||||||
// Build prompt
|
// Build prompt based on model type
|
||||||
let prompt = build_gemma_prompt(&request.messages);
|
let prompt = match state.model_type {
|
||||||
|
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||||
|
ModelType::Llama => {
|
||||||
|
// For Llama, just use the last user message for now
|
||||||
|
request.messages.last()
|
||||||
|
.and_then(|m| m.content.as_ref())
|
||||||
|
.and_then(|c| match c {
|
||||||
|
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
};
|
||||||
tracing::debug!("Formatted prompt: {}", prompt);
|
tracing::debug!("Formatted prompt: {}", prompt);
|
||||||
|
|
||||||
// Channel for streaming SSE events
|
// Channel for streaming SSE events
|
||||||
@@ -471,32 +311,78 @@ async fn handle_streaming_request(
|
|||||||
let _ = tx.send(Ok(Event::default().data(json)));
|
let _ = tx.send(Ok(Event::default().data(json)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Spawn generation task that streams tokens as they are generated
|
// Get streaming receiver based on model type
|
||||||
let state_clone = state.clone();
|
let model_rx = match state.model_type {
|
||||||
let response_id_clone = response_id.clone();
|
ModelType::Gemma => {
|
||||||
tokio::spawn(async move {
|
if let Some(mut config) = state.gemma_config {
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
config.prompt = prompt.clone();
|
||||||
let mut text_gen = state_clone.text_generation.lock().await;
|
config.max_tokens = max_tokens;
|
||||||
text_gen.reset_state();
|
match run_gemma_api(config) {
|
||||||
|
Ok(rx) => rx,
|
||||||
|
Err(e) => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||||
|
}))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": "Gemma configuration not available" }
|
||||||
|
}))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ModelType::Llama => {
|
||||||
|
if let Some(mut config) = state.llama_config {
|
||||||
|
config.prompt = prompt.clone();
|
||||||
|
config.max_tokens = max_tokens;
|
||||||
|
match run_llama_inference(config) {
|
||||||
|
Ok(rx) => rx,
|
||||||
|
Err(e) => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||||
|
}))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": "Llama configuration not available" }
|
||||||
|
}))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Stream tokens via callback with repetition detection
|
// Spawn task to receive tokens from model and forward as SSE events
|
||||||
|
let response_id_clone = response_id.clone();
|
||||||
|
let model_id_clone = model_id.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Stream tokens with repetition detection
|
||||||
let mut recent_tokens = Vec::new();
|
let mut recent_tokens = Vec::new();
|
||||||
let mut repetition_count = 0;
|
let mut repetition_count = 0;
|
||||||
const MAX_REPETITION_COUNT: usize = 5; // Stop after 5 consecutive repetitions
|
const MAX_REPETITION_COUNT: usize = 5;
|
||||||
const REPETITION_WINDOW: usize = 8; // Look at last 8 tokens for patterns
|
const REPETITION_WINDOW: usize = 8;
|
||||||
|
|
||||||
let result = text_gen.run_with_streaming(&prompt, max_tokens, |token| {
|
|
||||||
// Debug log to verify token content
|
|
||||||
tracing::debug!("Streaming token: '{}'", token);
|
|
||||||
|
|
||||||
|
while let Ok(token_result) = model_rx.recv() {
|
||||||
|
match token_result {
|
||||||
|
Ok(token) => {
|
||||||
// Skip sending empty tokens
|
// Skip sending empty tokens
|
||||||
if token.is_empty() {
|
if token.is_empty() {
|
||||||
tracing::debug!("Skipping empty token");
|
continue;
|
||||||
return Ok(());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add token to recent history for repetition detection
|
// Add token to recent history for repetition detection
|
||||||
recent_tokens.push(token.to_string());
|
recent_tokens.push(token.clone());
|
||||||
if recent_tokens.len() > REPETITION_WINDOW {
|
if recent_tokens.len() > REPETITION_WINDOW {
|
||||||
recent_tokens.remove(0);
|
recent_tokens.remove(0);
|
||||||
}
|
}
|
||||||
@@ -506,20 +392,16 @@ async fn handle_streaming_request(
|
|||||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||||
|
|
||||||
// Check if we're repeating the same token or pattern
|
if last_token == second_last {
|
||||||
if last_token == second_last ||
|
|
||||||
(last_token.trim() == "plus" && second_last.trim() == "plus") ||
|
|
||||||
(recent_tokens.len() >= 6 &&
|
|
||||||
recent_tokens[recent_tokens.len()-3..].iter().all(|t| t.trim() == "plus" || t.trim().is_empty())) {
|
|
||||||
repetition_count += 1;
|
repetition_count += 1;
|
||||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||||
|
|
||||||
if repetition_count >= MAX_REPETITION_COUNT {
|
if repetition_count >= MAX_REPETITION_COUNT {
|
||||||
tracing::info!("Stopping generation due to excessive repetition");
|
tracing::info!("Stopping generation due to excessive repetition");
|
||||||
return Err(anyhow::Error::msg("Repetition detected - stopping generation"));
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
repetition_count = 0; // Reset counter if pattern breaks
|
repetition_count = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,24 +409,23 @@ async fn handle_streaming_request(
|
|||||||
id: response_id_clone.clone(),
|
id: response_id_clone.clone(),
|
||||||
object: "chat.completion.chunk".to_string(),
|
object: "chat.completion.chunk".to_string(),
|
||||||
created,
|
created,
|
||||||
model: model_id.clone(),
|
model: model_id_clone.clone(),
|
||||||
choices: vec![ChatCompletionChunkChoice {
|
choices: vec![ChatCompletionChunkChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: Delta { role: None, content: Some(token.to_string()) },
|
delta: Delta { role: None, content: Some(token) },
|
||||||
finish_reason: None,
|
finish_reason: None,
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||||
tracing::debug!("Sending chunk with content: '{}'", token);
|
|
||||||
let _ = tx.send(Ok(Event::default().data(json)));
|
let _ = tx.send(Ok(Event::default().data(json)));
|
||||||
}
|
}
|
||||||
Ok(())
|
}
|
||||||
}).await;
|
Err(e) => {
|
||||||
|
tracing::info!("Text generation stopped: {}", e);
|
||||||
// Log result of generation
|
break;
|
||||||
match result {
|
}
|
||||||
Ok(_) => tracing::debug!("Text generation completed successfully"),
|
}
|
||||||
Err(e) => tracing::info!("Text generation stopped: {}", e),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send final stop chunk and DONE marker
|
// Send final stop chunk and DONE marker
|
||||||
@@ -552,7 +433,7 @@ async fn handle_streaming_request(
|
|||||||
id: response_id_clone.clone(),
|
id: response_id_clone.clone(),
|
||||||
object: "chat.completion.chunk".to_string(),
|
object: "chat.completion.chunk".to_string(),
|
||||||
created,
|
created,
|
||||||
model: model_id.clone(),
|
model: model_id_clone.clone(),
|
||||||
choices: vec![ChatCompletionChunkChoice {
|
choices: vec![ChatCompletionChunkChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: Delta { role: None, content: None },
|
delta: Delta { role: None, content: None },
|
||||||
@@ -594,6 +475,7 @@ pub fn create_router(app_state: AppState) -> Router {
|
|||||||
pub async fn list_models() -> Json<ModelListResponse> {
|
pub async fn list_models() -> Json<ModelListResponse> {
|
||||||
// Get all available model variants from the Which enum
|
// Get all available model variants from the Which enum
|
||||||
let models = vec![
|
let models = vec![
|
||||||
|
// Gemma models
|
||||||
Model {
|
Model {
|
||||||
id: "gemma-2b".to_string(),
|
id: "gemma-2b".to_string(),
|
||||||
object: "model".to_string(),
|
object: "model".to_string(),
|
||||||
@@ -690,6 +572,73 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
|||||||
created: 1686935002,
|
created: 1686935002,
|
||||||
owned_by: "google".to_string(),
|
owned_by: "google".to_string(),
|
||||||
},
|
},
|
||||||
|
// Llama models
|
||||||
|
Model {
|
||||||
|
id: "llama-3.2-1b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "meta".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "llama-3.2-1b-instruct".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "meta".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "llama-3.2-3b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "meta".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "llama-3.2-3b-instruct".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "meta".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "smollm2-135m".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "huggingface".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "smollm2-135m-instruct".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "huggingface".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "smollm2-360m".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "huggingface".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "smollm2-360m-instruct".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "huggingface".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "smollm2-1.7b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "huggingface".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "smollm2-1.7b-instruct".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "huggingface".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "tinyllama-1.1b-chat".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "tinyllama".to_string(),
|
||||||
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
Json(ModelListResponse {
|
Json(ModelListResponse {
|
||||||
|
24
crates/llama-runner/Cargo.toml
Normal file
24
crates/llama-runner/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[package]
|
||||||
|
name = "llama-runner"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||||
|
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||||
|
hf-hub = "0.3"
|
||||||
|
tokenizers = "0.20"
|
||||||
|
anyhow = "1.0"
|
||||||
|
clap = { version = "4.0", features = ["derive", "string"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
|
||||||
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
|
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
|
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
188
crates/llama-runner/README.md
Normal file
188
crates/llama-runner/README.md
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
# Llama Runner
|
||||||
|
|
||||||
|
A fast Rust implementation for running Llama and other language models using the Candle deep learning framework. Built on the official Candle examples with optimizations for speed and usability.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- 🚀 **High Performance**: Metal GPU acceleration on macOS, CUDA support on Linux/Windows
|
||||||
|
- 🤖 **Multiple Models**: Supports Llama 3.2, SmolLM2, TinyLlama, and more
|
||||||
|
- ⚡ **Fast Inference**: Optimized with F16 precision and KV caching
|
||||||
|
- 🎯 **Advanced Sampling**: Top-k, top-p, temperature, and repeat penalty controls
|
||||||
|
- 📊 **Performance Metrics**: Real-time tokens/second reporting
|
||||||
|
- 🔧 **Easy CLI**: Simple command-line interface with sensible defaults
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
| Model | Size | Command | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| SmolLM2-135M | 135M | `smollm2-135m` | Tiny, fast model for testing |
|
||||||
|
| SmolLM2-360M | 360M | `smollm2-360m` | Small, efficient model |
|
||||||
|
| SmolLM2-1.7B | 1.7B | `smollm2-1.7b` | Balanced performance/speed |
|
||||||
|
| Llama-3.2-1B | 1B | `llama-3.2-1b` | Meta's compact model |
|
||||||
|
| Llama-3.2-3B | 3B | `llama-3.2-3b` | Larger Llama model |
|
||||||
|
| TinyLlama-1.1B | 1.1B | `tinyllama-1.1b-chat` | Chat-optimized small model |
|
||||||
|
|
||||||
|
Add `-instruct` suffix for instruction-tuned variants (e.g., `smollm2-135m-instruct`).
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone <repository-url>
|
||||||
|
cd llama-runner
|
||||||
|
|
||||||
|
# Build with GPU acceleration (recommended)
|
||||||
|
cargo build --release --features metal # macOS
|
||||||
|
cargo build --release --features cuda # Linux/Windows with NVIDIA GPU
|
||||||
|
|
||||||
|
# CPU-only build
|
||||||
|
cargo build --release
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Fast inference with GPU acceleration
|
||||||
|
cargo run --features metal -- --prompt "What is quantum computing?"
|
||||||
|
|
||||||
|
# Specify a model and parameters
|
||||||
|
cargo run --features metal -- \
|
||||||
|
--prompt "Write a short story about space exploration" \
|
||||||
|
--model smollm2-360m \
|
||||||
|
--max-tokens 100 \
|
||||||
|
--temperature 0.8
|
||||||
|
|
||||||
|
# Use CPU (slower but works everywhere)
|
||||||
|
cargo run -- --prompt "Hello, world!" --model smollm2-135m --cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Basic Text Generation
|
||||||
|
```bash
|
||||||
|
# Simple completion
|
||||||
|
cargo run --features metal -- --prompt "The capital of France is"
|
||||||
|
|
||||||
|
# Creative writing with higher temperature
|
||||||
|
cargo run --features metal -- \
|
||||||
|
--prompt "Once upon a time" \
|
||||||
|
--temperature 1.0 \
|
||||||
|
--max-tokens 200
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Sampling
|
||||||
|
```bash
|
||||||
|
# Top-k and top-p sampling
|
||||||
|
cargo run --features metal -- \
|
||||||
|
--prompt "Explain artificial intelligence" \
|
||||||
|
--top-k 40 \
|
||||||
|
--top-p 0.9 \
|
||||||
|
--temperature 0.7
|
||||||
|
|
||||||
|
# Reduce repetition
|
||||||
|
cargo run --features metal -- \
|
||||||
|
--prompt "List the benefits of renewable energy" \
|
||||||
|
--repeat-penalty 1.2 \
|
||||||
|
--repeat-last-n 64
|
||||||
|
```
|
||||||
|
|
||||||
|
### Different Models
|
||||||
|
```bash
|
||||||
|
# Ultra-fast with tiny model
|
||||||
|
cargo run --features metal -- \
|
||||||
|
--prompt "Quick test" \
|
||||||
|
--model smollm2-135m
|
||||||
|
|
||||||
|
# Better quality with larger model
|
||||||
|
cargo run --features metal -- \
|
||||||
|
--prompt "Explain quantum physics" \
|
||||||
|
--model llama-3.2-1b \
|
||||||
|
--max-tokens 150
|
||||||
|
```
|
||||||
|
|
||||||
|
## Command-Line Options
|
||||||
|
|
||||||
|
| Option | Short | Default | Description |
|
||||||
|
|--------|-------|---------|-------------|
|
||||||
|
| `--prompt` | `-p` | "The capital of France is" | Input prompt |
|
||||||
|
| `--model` | `-m` | `smollm2-135m` | Model to use |
|
||||||
|
| `--max-tokens` | `-n` | 100 | Maximum tokens to generate |
|
||||||
|
| `--temperature` | `-t` | 0.8 | Sampling temperature (0.0 = deterministic) |
|
||||||
|
| `--top-k` | | None | Top-k sampling |
|
||||||
|
| `--top-p` | | None | Top-p (nucleus) sampling |
|
||||||
|
| `--seed` | | 299792458 | Random seed for reproducibility |
|
||||||
|
| `--repeat-penalty` | | 1.1 | Repetition penalty (1.0 = no penalty) |
|
||||||
|
| `--repeat-last-n` | | 128 | Context window for repeat penalty |
|
||||||
|
| `--cpu` | | false | Force CPU usage |
|
||||||
|
| `--dtype` | | f16 | Data type: f16, bf16, f32 |
|
||||||
|
| `--no-kv-cache` | | false | Disable key-value caching |
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
Typical performance on Apple M2 with Metal acceleration:
|
||||||
|
|
||||||
|
| Model | Size | Speed | Memory |
|
||||||
|
|-------|------|-------|--------|
|
||||||
|
| SmolLM2-135M | 135M | ~100 tok/s | ~500MB |
|
||||||
|
| SmolLM2-360M | 360M | ~80 tok/s | ~1GB |
|
||||||
|
| SmolLM2-1.7B | 1.7B | ~50 tok/s | ~3GB |
|
||||||
|
| Llama-3.2-1B | 1B | ~40 tok/s | ~2GB |
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- **Rust**: 1.70+ (latest stable recommended)
|
||||||
|
- **Memory**: 2-8GB RAM depending on model size
|
||||||
|
- **Storage**: 1-10GB for model weights
|
||||||
|
- **Network**: Internet connection for first-time model download
|
||||||
|
- **GPU** (optional): Metal on macOS, CUDA on Linux/Windows
|
||||||
|
|
||||||
|
## GPU Support
|
||||||
|
|
||||||
|
### macOS (Metal)
|
||||||
|
```bash
|
||||||
|
cargo run --features metal -- [options]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Linux/Windows (CUDA)
|
||||||
|
```bash
|
||||||
|
cargo run --features cuda -- [options]
|
||||||
|
```
|
||||||
|
|
||||||
|
### CPU Only
|
||||||
|
```bash
|
||||||
|
cargo run -- --cpu [options]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Downloads
|
||||||
|
|
||||||
|
Models are automatically downloaded from HuggingFace Hub on first use and cached locally. Download times:
|
||||||
|
|
||||||
|
- SmolLM2-135M: ~1 minute
|
||||||
|
- SmolLM2-360M: ~2 minutes
|
||||||
|
- Llama-3.2-1B: ~5 minutes
|
||||||
|
- Larger models: 10+ minutes
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Slow Performance
|
||||||
|
- Use `--features metal` on macOS or `--features cuda` on Linux/Windows
|
||||||
|
- Try smaller models like `smollm2-135m` for faster inference
|
||||||
|
- Ensure sufficient RAM for your chosen model
|
||||||
|
|
||||||
|
### Out of Memory
|
||||||
|
- Use `--cpu` to use system RAM instead of GPU memory
|
||||||
|
- Try smaller models or reduce `--max-tokens`
|
||||||
|
- Use `--dtype f32` if f16 causes issues
|
||||||
|
|
||||||
|
### Model Download Issues
|
||||||
|
- Check internet connection
|
||||||
|
- Some models may require HuggingFace Hub authentication
|
||||||
|
- Verify sufficient disk space in `~/.cache/huggingface/`
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions welcome! This project is based on the [Candle](https://github.com/huggingface/candle) framework by HuggingFace.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see LICENSE file for details.
|
8
crates/llama-runner/src/lib.rs
Normal file
8
crates/llama-runner/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
pub mod llama_api;
|
||||||
|
|
||||||
|
use clap::ValueEnum;
|
||||||
|
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||||
|
|
||||||
|
// Re-export constants and types that might be needed
|
||||||
|
pub const EOS_TOKEN: &str = "</s>";
|
||||||
|
|
337
crates/llama-runner/src/llama_api.rs
Normal file
337
crates/llama-runner/src/llama_api.rs
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
use anyhow::{bail, Error as E};
|
||||||
|
use candle_core::{utils, DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||||
|
use candle_transformers::models::llama as model;
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use hf_hub::{Repo, RepoType};
|
||||||
|
use std::sync::mpsc::{self, Receiver};
|
||||||
|
use clap::ValueEnum;
|
||||||
|
use crate::{EOS_TOKEN};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||||
|
pub enum WhichModel {
|
||||||
|
#[value(name = "llama-3.2-1b")]
|
||||||
|
#[default]
|
||||||
|
Llama32_1B,
|
||||||
|
#[value(name = "llama-3.2-1b-instruct")]
|
||||||
|
Llama32_1BInstruct,
|
||||||
|
#[value(name = "llama-3.2-3b")]
|
||||||
|
Llama32_3B,
|
||||||
|
#[value(name = "llama-3.2-3b-instruct")]
|
||||||
|
Llama32_3BInstruct,
|
||||||
|
#[value(name = "smollm2-135m")]
|
||||||
|
SmolLM2_135M,
|
||||||
|
#[value(name = "smollm2-135m-instruct")]
|
||||||
|
SmolLM2_135MInstruct,
|
||||||
|
#[value(name = "smollm2-360m")]
|
||||||
|
SmolLM2_360M,
|
||||||
|
#[value(name = "smollm2-360m-instruct")]
|
||||||
|
SmolLM2_360MInstruct,
|
||||||
|
#[value(name = "smollm2-1.7b")]
|
||||||
|
SmolLM2_1_7B,
|
||||||
|
#[value(name = "smollm2-1.7b-instruct")]
|
||||||
|
SmolLM2_1_7BInstruct,
|
||||||
|
#[value(name = "tinyllama-1.1b-chat")]
|
||||||
|
TinyLlama1_1BChat,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LlamaInferenceConfig {
|
||||||
|
pub prompt: String,
|
||||||
|
|
||||||
|
pub model: WhichModel,
|
||||||
|
pub cpu: bool,
|
||||||
|
pub temperature: f64,
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
pub top_k: Option<usize>,
|
||||||
|
pub seed: u64,
|
||||||
|
pub max_tokens: usize,
|
||||||
|
pub no_kv_cache: bool,
|
||||||
|
pub dtype: Option<String>,
|
||||||
|
pub model_id: Option<String>,
|
||||||
|
pub revision: Option<String>,
|
||||||
|
pub use_flash_attn: bool,
|
||||||
|
pub repeat_penalty: f32,
|
||||||
|
pub repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for LlamaInferenceConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
// Leave prompt empty by default; let call sites set it.
|
||||||
|
prompt: String::new(),
|
||||||
|
|
||||||
|
// Keep your existing model choice; swap at call-site if needed.
|
||||||
|
model: WhichModel::Llama32_1BInstruct,
|
||||||
|
|
||||||
|
// Prefer GPU if available.
|
||||||
|
cpu: false,
|
||||||
|
|
||||||
|
// Sampling: balanced + stable
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: Some(0.95),
|
||||||
|
top_k: Some(50),
|
||||||
|
|
||||||
|
// Reproducible by default; override for variability.
|
||||||
|
seed: 42,
|
||||||
|
|
||||||
|
// Don’t run unbounded generations.
|
||||||
|
max_tokens: 512,
|
||||||
|
|
||||||
|
// Performance flags
|
||||||
|
no_kv_cache: false, // keep cache ON for speed
|
||||||
|
use_flash_attn: true, // great speed boost if supported
|
||||||
|
|
||||||
|
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||||
|
dtype: Some("bf16".to_string()),
|
||||||
|
|
||||||
|
// Optional model source pinning (None = app defaults)
|
||||||
|
model_id: None,
|
||||||
|
revision: None,
|
||||||
|
|
||||||
|
// Anti-repeat heuristics
|
||||||
|
repeat_penalty: 1.15,
|
||||||
|
repeat_last_n: 128,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||||
|
if cpu {
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
} else if utils::cuda_is_available() {
|
||||||
|
Ok(Device::new_cuda(0)?)
|
||||||
|
} else if utils::metal_is_available() {
|
||||||
|
Ok(Device::new_metal(0)?)
|
||||||
|
} else {
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn hub_load_safetensors(
|
||||||
|
api: &hf_hub::api::sync::ApiRepo,
|
||||||
|
json_file: &str,
|
||||||
|
) -> anyhow::Result<Vec<std::path::PathBuf>> {
|
||||||
|
let json_file = api.get(json_file)?;
|
||||||
|
let json_file = std::fs::File::open(json_file)?;
|
||||||
|
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
|
||||||
|
let weight_map = match json.get("weight_map") {
|
||||||
|
None => bail!("no weight map in {json_file:?}"),
|
||||||
|
Some(serde_json::Value::Object(map)) => map,
|
||||||
|
Some(_) => bail!("weight map in {json_file:?} is not a map"),
|
||||||
|
};
|
||||||
|
let mut safetensors_files = std::collections::HashSet::new();
|
||||||
|
for value in weight_map.values() {
|
||||||
|
if let Some(file) = value.as_str() {
|
||||||
|
safetensors_files.insert(file.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let safetensors_files = safetensors_files
|
||||||
|
.iter()
|
||||||
|
.map(|v| api.get(v))
|
||||||
|
.collect::<anyhow::Result<Vec<_>, _>>()?;
|
||||||
|
Ok(safetensors_files)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_llama_inference(
|
||||||
|
cfg: LlamaInferenceConfig,
|
||||||
|
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
|
||||||
|
// ---- Device & dtype -----------------------------------------------------
|
||||||
|
let device = device(cfg.cpu)?;
|
||||||
|
println!("Device: {:?}", device);
|
||||||
|
|
||||||
|
let dtype = match cfg.dtype.as_deref() {
|
||||||
|
Some("f16") => DType::F16,
|
||||||
|
Some("bf16") => DType::BF16,
|
||||||
|
Some("f32") => DType::F32,
|
||||||
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
|
None => DType::F16,
|
||||||
|
};
|
||||||
|
println!("Using dtype: {:?}", dtype);
|
||||||
|
|
||||||
|
// ---- Load model & tokenizer --------------------------------------------
|
||||||
|
let (llama, tokenizer, mut cache) = {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
|
||||||
|
match cfg.model {
|
||||||
|
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
|
||||||
|
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
|
||||||
|
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||||
|
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
|
||||||
|
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||||
|
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
|
||||||
|
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||||
|
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
}
|
||||||
|
.to_string()
|
||||||
|
});
|
||||||
|
println!("Loading model: {}", model_id);
|
||||||
|
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||||
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
|
||||||
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
|
let config_filename = api.get("config.json")?;
|
||||||
|
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let config = config.into_config(cfg.use_flash_attn);
|
||||||
|
|
||||||
|
let filenames = match cfg.model {
|
||||||
|
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
|
||||||
|
hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
|
}
|
||||||
|
_ => vec![api.get("model.safetensors")?],
|
||||||
|
};
|
||||||
|
|
||||||
|
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let llama = Llama::load(vb, &config)?;
|
||||||
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
(llama, tokenizer, cache)
|
||||||
|
};
|
||||||
|
|
||||||
|
// ---- Prepare prompt & sampler ------------------------------------------
|
||||||
|
let eos_token_id = tokenizer
|
||||||
|
.token_to_id(EOS_TOKEN)
|
||||||
|
.map(model::LlamaEosToks::Single);
|
||||||
|
|
||||||
|
let mut tokens = tokenizer
|
||||||
|
.encode(cfg.prompt.as_str(), true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
println!("Starting inference...");
|
||||||
|
|
||||||
|
let mut logits_processor = {
|
||||||
|
let temperature = cfg.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (cfg.top_k, cfg.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(cfg.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Channel for streaming decoded fragments to the caller.
|
||||||
|
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
|
||||||
|
|
||||||
|
// ---- Spawn generation thread -------------------------------------------
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
let mut index_pos = 0usize;
|
||||||
|
let mut token_generated = 0usize;
|
||||||
|
|
||||||
|
for index in 0..cfg.max_tokens {
|
||||||
|
// Use KV-cache for single-token step after the first pass.
|
||||||
|
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||||
|
(1, index_pos)
|
||||||
|
} else {
|
||||||
|
(tokens.len(), 0)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e.into()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let logits = match llama.forward(&input, context_index, &mut cache) {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e.into()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let logits = match logits.squeeze(0) {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e.into()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let logits = if cfg.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
|
||||||
|
match candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
cfg.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
) {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e.into()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
|
let next_token = match logits_processor.sample(&logits) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e.into()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
token_generated += 1;
|
||||||
|
tokens.push(next_token);
|
||||||
|
|
||||||
|
// Early stop on EOS.
|
||||||
|
let stop = match eos_token_id {
|
||||||
|
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
|
||||||
|
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
|
||||||
|
None => false,
|
||||||
|
};
|
||||||
|
if stop {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode this token's text and stream it out.
|
||||||
|
match tokenizer.decode(&[next_token], false) {
|
||||||
|
Ok(text) => {
|
||||||
|
if !text.is_empty() {
|
||||||
|
// Best-effort send; if receiver is gone, just stop.
|
||||||
|
if tx.send(Ok(text)).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: final stats as a debug line (not sent through the stream).
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
eprintln!(
|
||||||
|
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
|
||||||
|
token_generated,
|
||||||
|
token_generated as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
// Dropping tx closes the stream.
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
}
|
||||||
|
|
109
crates/llama-runner/src/llama_cli.rs
Normal file
109
crates/llama-runner/src/llama_cli.rs
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||||
|
use clap::Parser;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Default)]
|
||||||
|
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// The prompt to generate text from
|
||||||
|
#[arg(short, long, default_value = "The capital of France is")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The model to use
|
||||||
|
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
|
||||||
|
model: WhichModel,
|
||||||
|
|
||||||
|
/// Run on CPU rather than GPU
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples
|
||||||
|
#[arg(short, long, default_value_t = 0.8)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens)
|
||||||
|
#[arg(short = 'n', long, default_value_t = 100)]
|
||||||
|
max_tokens: usize,
|
||||||
|
|
||||||
|
/// Disable the key-value cache
|
||||||
|
#[arg(long)]
|
||||||
|
no_kv_cache: bool,
|
||||||
|
|
||||||
|
/// Use different dtype than f16
|
||||||
|
#[arg(long)]
|
||||||
|
dtype: Option<String>,
|
||||||
|
|
||||||
|
/// Custom model ID from HuggingFace Hub
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
/// Model revision
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
/// Use flash attention
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty
|
||||||
|
#[arg(long, default_value_t = 128)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Into<LlamaInferenceConfig> for Args {
|
||||||
|
fn into(self) -> LlamaInferenceConfig {
|
||||||
|
LlamaInferenceConfig {
|
||||||
|
prompt: self.prompt,
|
||||||
|
model: self.model,
|
||||||
|
cpu: self.cpu,
|
||||||
|
temperature: self.temperature,
|
||||||
|
top_p: self.top_p,
|
||||||
|
top_k: self.top_k,
|
||||||
|
seed: self.seed,
|
||||||
|
max_tokens: self.max_tokens,
|
||||||
|
no_kv_cache: self.no_kv_cache,
|
||||||
|
dtype: self.dtype,
|
||||||
|
model_id: self.model_id,
|
||||||
|
revision: self.revision,
|
||||||
|
use_flash_attn: self.use_flash_attn,
|
||||||
|
repeat_penalty: self.repeat_penalty,
|
||||||
|
repeat_last_n: self.repeat_last_n,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub fn run_cli() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let cfg = args.into();
|
||||||
|
let rx = run_llama_inference(cfg)?;
|
||||||
|
for msg in rx {
|
||||||
|
match msg {
|
||||||
|
Ok(tok) => {
|
||||||
|
print!("{tok}");
|
||||||
|
let _ = std::io::stdout().flush(); // <- force it out now
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("generation error: {e}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
20
crates/llama-runner/src/main.rs
Normal file
20
crates/llama-runner/src/main.rs
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
mod llama_cli;
|
||||||
|
mod llama_api;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
use crate::llama_cli::run_cli;
|
||||||
|
|
||||||
|
const EOS_TOKEN: &str = "</s>";
|
||||||
|
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
run_cli()
|
||||||
|
}
|
@@ -67,18 +67,7 @@ async fn main() {
|
|||||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||||
|
|
||||||
// Create AppState with correct model configuration
|
// Create AppState with correct model configuration
|
||||||
use inference_engine::Which;
|
let app_state = AppState::default();
|
||||||
use inference_engine::server::{PipelineArgs, build_pipeline};
|
|
||||||
let mut pipeline_args = PipelineArgs::default();
|
|
||||||
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
|
||||||
pipeline_args.which = Which::InstructV3_1B;
|
|
||||||
|
|
||||||
let text_generation = build_pipeline(pipeline_args.clone());
|
|
||||||
let app_state = AppState {
|
|
||||||
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
|
||||||
model_id: "google/gemma-3-1b-it".to_string(),
|
|
||||||
build_args: pipeline_args,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get the inference router directly from the inference engine
|
// Get the inference router directly from the inference engine
|
||||||
let inference_router = inference_engine::create_router(app_state);
|
let inference_router = inference_engine::create_router(app_state);
|
||||||
|
@@ -22,7 +22,7 @@ The Predict-Otron-9000 is a comprehensive multi-service AI platform built around
|
|||||||
graph TB
|
graph TB
|
||||||
subgraph "Core Components"
|
subgraph "Core Components"
|
||||||
A[Main Server<br/>predict-otron-9000]
|
A[Main Server<br/>predict-otron-9000]
|
||||||
B[Inference Engine<br/>Gemma via Candle]
|
B[Inference Engine<br/>Gemma/Llama via Candle]
|
||||||
C[Embeddings Engine<br/>FastEmbed]
|
C[Embeddings Engine<br/>FastEmbed]
|
||||||
D[Web Frontend<br/>Leptos WASM]
|
D[Web Frontend<br/>Leptos WASM]
|
||||||
end
|
end
|
||||||
@@ -52,7 +52,7 @@ graph TB
|
|||||||
|
|
||||||
## Workspace Structure
|
## Workspace Structure
|
||||||
|
|
||||||
The project uses a 4-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
|
The project uses a 7-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
graph TD
|
graph TD
|
||||||
@@ -62,24 +62,33 @@ graph TD
|
|||||||
end
|
end
|
||||||
|
|
||||||
subgraph "AI Services"
|
subgraph "AI Services"
|
||||||
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Candle ML]
|
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
|
||||||
|
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
|
||||||
|
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
|
||||||
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
|
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "Frontend"
|
subgraph "Frontend"
|
||||||
D[leptos-app<br/>Edition: 2021<br/>Port: 3000/8788<br/>WASM/SSR]
|
D[leptos-app<br/>Edition: 2021<br/>Port: 3000/8788<br/>WASM/SSR]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
subgraph "Tooling"
|
||||||
|
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "External Tooling"
|
subgraph "External Tooling"
|
||||||
E[cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
|
E[scripts/cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "Dependencies"
|
subgraph "Dependencies"
|
||||||
A --> B
|
A --> B
|
||||||
A --> C
|
A --> C
|
||||||
A --> D
|
A --> D
|
||||||
B -.-> F[Candle 0.9.1]
|
B --> J
|
||||||
|
B --> K
|
||||||
|
J -.-> F[Candle 0.9.1]
|
||||||
|
K -.-> F
|
||||||
C -.-> G[FastEmbed 4.x]
|
C -.-> G[FastEmbed 4.x]
|
||||||
D -.-> H[Leptos 0.8.0]
|
D -.-> H[Leptos 0.8.0]
|
||||||
E -.-> I[OpenAI SDK 5.16+]
|
E -.-> I[OpenAI SDK 5.16+]
|
||||||
@@ -87,9 +96,12 @@ graph TD
|
|||||||
|
|
||||||
style A fill:#e1f5fe
|
style A fill:#e1f5fe
|
||||||
style B fill:#f3e5f5
|
style B fill:#f3e5f5
|
||||||
|
style J fill:#f3e5f5
|
||||||
|
style K fill:#f3e5f5
|
||||||
style C fill:#e8f5e8
|
style C fill:#e8f5e8
|
||||||
style D fill:#fff3e0
|
style D fill:#fff3e0
|
||||||
style E fill:#fce4ec
|
style E fill:#fce4ec
|
||||||
|
style L fill:#fff9c4
|
||||||
```
|
```
|
||||||
|
|
||||||
## Deployment Configurations
|
## Deployment Configurations
|
||||||
|
30
scripts/run_llama.sh
Normal file
30
scripts/run_llama.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
PROMPT=${1:-"Say hello in one short sentence."}
|
||||||
|
MODEL=${2:-"meta-llama/Llama-3.2-1B-Instruct"}
|
||||||
|
MAX_NEW=${3:-64}
|
||||||
|
FORCE_CPU=${FORCE_CPU:-0}
|
||||||
|
|
||||||
|
# Optional: keep HF cache local to repo if not already set
|
||||||
|
export HF_HOME=${HF_HOME:-"$PWD/.hf-cache"}
|
||||||
|
|
||||||
|
BIN="$(dirname "$0")/../target/release/llama_infer"
|
||||||
|
|
||||||
|
if [[ ! -x "$BIN" ]]; then
|
||||||
|
echo "Building llama-runner (release)..."
|
||||||
|
cargo build -p llama-runner --release
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running llama inference..." >&2
|
||||||
|
ARGS=(
|
||||||
|
--model-id "$MODEL"
|
||||||
|
--prompt "$PROMPT"
|
||||||
|
--max-new-tokens "$MAX_NEW"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ "$FORCE_CPU" == "1" || "$FORCE_CPU" == "true" ]]; then
|
||||||
|
ARGS+=( --force-cpu )
|
||||||
|
fi
|
||||||
|
|
||||||
|
"$BIN" "${ARGS[@]}"
|
Reference in New Issue
Block a user