mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Removed legacy inference engine assets.
This commit is contained in:
211
Cargo.lock
generated
211
Cargo.lock
generated
@@ -389,17 +389,6 @@ dependencies = [
|
|||||||
"syn 2.0.106",
|
"syn 2.0.106",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "async-trait"
|
|
||||||
version = "0.1.89"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn 2.0.106",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atoi"
|
name = "atoi"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
@@ -474,47 +463,13 @@ dependencies = [
|
|||||||
"arrayvec",
|
"arrayvec",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "axum"
|
|
||||||
version = "0.7.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
|
||||||
dependencies = [
|
|
||||||
"async-trait",
|
|
||||||
"axum-core 0.4.5",
|
|
||||||
"bytes",
|
|
||||||
"futures-util",
|
|
||||||
"http",
|
|
||||||
"http-body",
|
|
||||||
"http-body-util",
|
|
||||||
"hyper",
|
|
||||||
"hyper-util",
|
|
||||||
"itoa",
|
|
||||||
"matchit 0.7.3",
|
|
||||||
"memchr",
|
|
||||||
"mime",
|
|
||||||
"percent-encoding",
|
|
||||||
"pin-project-lite",
|
|
||||||
"rustversion",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"serde_path_to_error",
|
|
||||||
"serde_urlencoded",
|
|
||||||
"sync_wrapper",
|
|
||||||
"tokio",
|
|
||||||
"tower 0.5.2",
|
|
||||||
"tower-layer",
|
|
||||||
"tower-service",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.8.4"
|
version = "0.8.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
|
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum-core 0.5.2",
|
"axum-core",
|
||||||
"bytes",
|
"bytes",
|
||||||
"form_urlencoded",
|
"form_urlencoded",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
@@ -524,7 +479,7 @@ dependencies = [
|
|||||||
"hyper",
|
"hyper",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
"itoa",
|
"itoa",
|
||||||
"matchit 0.8.4",
|
"matchit",
|
||||||
"memchr",
|
"memchr",
|
||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
@@ -536,28 +491,7 @@ dependencies = [
|
|||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower 0.5.2",
|
"tower",
|
||||||
"tower-layer",
|
|
||||||
"tower-service",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "axum-core"
|
|
||||||
version = "0.4.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
|
||||||
dependencies = [
|
|
||||||
"async-trait",
|
|
||||||
"bytes",
|
|
||||||
"futures-util",
|
|
||||||
"http",
|
|
||||||
"http-body",
|
|
||||||
"http-body-util",
|
|
||||||
"mime",
|
|
||||||
"pin-project-lite",
|
|
||||||
"rustversion",
|
|
||||||
"sync_wrapper",
|
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -1255,17 +1189,6 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "core-graphics-types"
|
|
||||||
version = "0.2.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags 2.9.2",
|
|
||||||
"core-foundation 0.10.1",
|
|
||||||
"libc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "coreaudio-rs"
|
name = "coreaudio-rs"
|
||||||
version = "0.11.3"
|
version = "0.11.3"
|
||||||
@@ -1623,15 +1546,15 @@ name = "embeddings-engine"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-openai",
|
"async-openai",
|
||||||
"axum 0.8.4",
|
"axum",
|
||||||
"fastembed",
|
"fastembed",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower 0.5.2",
|
"tower",
|
||||||
"tower-http 0.6.6",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
@@ -2777,7 +2700,7 @@ dependencies = [
|
|||||||
"ab_glyph",
|
"ab_glyph",
|
||||||
"accelerate-src",
|
"accelerate-src",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"axum 0.8.4",
|
"axum",
|
||||||
"bindgen_cuda",
|
"bindgen_cuda",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"candle-core",
|
"candle-core",
|
||||||
@@ -2813,8 +2736,8 @@ dependencies = [
|
|||||||
"symphonia",
|
"symphonia",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower 0.5.2",
|
"tower",
|
||||||
"tower-http 0.6.6",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-chrome",
|
"tracing-chrome",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
@@ -3011,58 +2934,6 @@ version = "0.5.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
|
checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "legacy-inference-engine"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"ab_glyph",
|
|
||||||
"accelerate-src",
|
|
||||||
"anyhow",
|
|
||||||
"axum 0.7.9",
|
|
||||||
"bindgen_cuda",
|
|
||||||
"byteorder",
|
|
||||||
"candle-core",
|
|
||||||
"candle-datasets",
|
|
||||||
"candle-flash-attn",
|
|
||||||
"candle-nn",
|
|
||||||
"candle-onnx",
|
|
||||||
"candle-transformers",
|
|
||||||
"clap",
|
|
||||||
"cpal",
|
|
||||||
"csv",
|
|
||||||
"cudarc",
|
|
||||||
"either",
|
|
||||||
"enterpolation",
|
|
||||||
"half",
|
|
||||||
"hf-hub",
|
|
||||||
"image",
|
|
||||||
"imageproc",
|
|
||||||
"intel-mkl-src",
|
|
||||||
"memmap2",
|
|
||||||
"metal 0.32.0",
|
|
||||||
"num-traits",
|
|
||||||
"palette",
|
|
||||||
"pdf2image",
|
|
||||||
"pyo3",
|
|
||||||
"rand 0.9.2",
|
|
||||||
"rayon",
|
|
||||||
"reborrow",
|
|
||||||
"rubato",
|
|
||||||
"safetensors",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"symphonia",
|
|
||||||
"tokenizers",
|
|
||||||
"tokio",
|
|
||||||
"tower 0.4.13",
|
|
||||||
"tower-http 0.5.2",
|
|
||||||
"tracing",
|
|
||||||
"tracing-chrome",
|
|
||||||
"tracing-subscriber",
|
|
||||||
"utoipa",
|
|
||||||
"uuid",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leptos"
|
name = "leptos"
|
||||||
version = "0.6.15"
|
version = "0.6.15"
|
||||||
@@ -3507,12 +3378,6 @@ dependencies = [
|
|||||||
"regex-automata 0.1.10",
|
"regex-automata 0.1.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "matchit"
|
|
||||||
version = "0.7.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matchit"
|
name = "matchit"
|
||||||
version = "0.8.4"
|
version = "0.8.4"
|
||||||
@@ -3572,7 +3437,7 @@ checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.9.2",
|
"bitflags 2.9.2",
|
||||||
"block",
|
"block",
|
||||||
"core-graphics-types 0.1.3",
|
"core-graphics-types",
|
||||||
"foreign-types 0.5.0",
|
"foreign-types 0.5.0",
|
||||||
"log",
|
"log",
|
||||||
"objc",
|
"objc",
|
||||||
@@ -3587,22 +3452,7 @@ checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.9.2",
|
"bitflags 2.9.2",
|
||||||
"block",
|
"block",
|
||||||
"core-graphics-types 0.1.3",
|
"core-graphics-types",
|
||||||
"foreign-types 0.5.0",
|
|
||||||
"log",
|
|
||||||
"objc",
|
|
||||||
"paste",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "metal"
|
|
||||||
version = "0.32.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "00c15a6f673ff72ddcc22394663290f870fb224c1bfce55734a75c414150e605"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags 2.9.2",
|
|
||||||
"block",
|
|
||||||
"core-graphics-types 0.2.0",
|
|
||||||
"foreign-types 0.5.0",
|
"foreign-types 0.5.0",
|
||||||
"log",
|
"log",
|
||||||
"objc",
|
"objc",
|
||||||
@@ -4401,14 +4251,14 @@ dependencies = [
|
|||||||
name = "predict-otron-9000"
|
name = "predict-otron-9000"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum 0.8.4",
|
"axum",
|
||||||
"embeddings-engine",
|
"embeddings-engine",
|
||||||
"inference-engine",
|
"inference-engine",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower 0.5.2",
|
"tower",
|
||||||
"tower-http 0.6.6",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"uuid",
|
"uuid",
|
||||||
@@ -5112,8 +4962,8 @@ dependencies = [
|
|||||||
"tokio-native-tls",
|
"tokio-native-tls",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
"tower 0.5.2",
|
"tower",
|
||||||
"tower-http 0.6.6",
|
"tower-http",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"url",
|
"url",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
@@ -6374,17 +6224,6 @@ dependencies = [
|
|||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tower"
|
|
||||||
version = "0.4.13"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
|
|
||||||
dependencies = [
|
|
||||||
"tower-layer",
|
|
||||||
"tower-service",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower"
|
name = "tower"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
@@ -6401,22 +6240,6 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tower-http"
|
|
||||||
version = "0.5.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags 2.9.2",
|
|
||||||
"bytes",
|
|
||||||
"http",
|
|
||||||
"http-body",
|
|
||||||
"http-body-util",
|
|
||||||
"pin-project-lite",
|
|
||||||
"tower-layer",
|
|
||||||
"tower-service",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-http"
|
name = "tower-http"
|
||||||
version = "0.6.6"
|
version = "0.6.6"
|
||||||
@@ -6430,7 +6253,7 @@ dependencies = [
|
|||||||
"http-body",
|
"http-body",
|
||||||
"iri-string",
|
"iri-string",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tower 0.5.2",
|
"tower",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@@ -3,8 +3,7 @@ members = [
|
|||||||
"crates/predict-otron-9000",
|
"crates/predict-otron-9000",
|
||||||
"crates/inference-engine",
|
"crates/inference-engine",
|
||||||
"crates/embeddings-engine",
|
"crates/embeddings-engine",
|
||||||
"crates/leptos-chat",
|
"crates/leptos-chat"
|
||||||
"crates/legacy-inference-engine"
|
|
||||||
]
|
]
|
||||||
default-members = ["crates/predict-otron-9000"]
|
default-members = ["crates/predict-otron-9000"]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
6115
crates/legacy-inference-engine/Cargo.lock
generated
6115
crates/legacy-inference-engine/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,77 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "legacy-inference-engine"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
accelerate-src = { version = "0.3.2", optional = true }
|
|
||||||
candle-datasets = { version = "=0.9.1", optional = true }
|
|
||||||
candle-nn = { version = "=0.9.1" }
|
|
||||||
candle-transformers = { version = "=0.9.1" }
|
|
||||||
candle-flash-attn = { version = "=0.9.1", optional = true }
|
|
||||||
candle-onnx = { version = "=0.9.1", optional = true }
|
|
||||||
|
|
||||||
csv = "1.3.0"
|
|
||||||
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true }
|
|
||||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"], optional = true }
|
|
||||||
hf-hub = { version = "0.4.1", features = ["tokio"] }
|
|
||||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true }
|
|
||||||
num-traits = { version = "0.2.15" }
|
|
||||||
palette = { version = "0.7.6", optional = true }
|
|
||||||
enterpolation = { version = "0.2.1", optional = true}
|
|
||||||
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
|
|
||||||
rayon = "1.7.0"
|
|
||||||
rubato = { version = "0.15.0", optional = true }
|
|
||||||
safetensors = "0.4.1"
|
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
|
||||||
serde_json = "1.0.99"
|
|
||||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
|
||||||
tokenizers = { version = "0.21.0", default-features = false, features = ["onig", "http"] }
|
|
||||||
cpal = { version = "0.15.2", optional = true }
|
|
||||||
pdf2image = { version = "0.1.2" , optional = true}
|
|
||||||
anyhow = "1.0.98"
|
|
||||||
clap= { version = "4.2.4", features = ["derive"] }
|
|
||||||
tracing = "0.1.37"
|
|
||||||
tracing-chrome = "0.7.1"
|
|
||||||
tracing-subscriber = "0.3.7"
|
|
||||||
axum = { version = "0.7.4", features = ["json"] }
|
|
||||||
tower = "0.4.13"
|
|
||||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
|
||||||
tokio = { version = "1.43.0", features = ["full"] }
|
|
||||||
either = { version = "1.9.0", features = ["serde"] }
|
|
||||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
|
||||||
uuid = { version = "1.7.0", features = ["v4"] }
|
|
||||||
reborrow = "0.5.5"
|
|
||||||
|
|
||||||
# --- Add this section for conditional compilation ---
|
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
|
||||||
candle-core = { version = "=0.9.1", features = ["metal"] }
|
|
||||||
metal = { version = "0.32.0", features = ["mps"] }
|
|
||||||
|
|
||||||
[target.'cfg(not(target_os = "macos"))'.dependencies]
|
|
||||||
# For Linux or other non-macOS systems, you likely want the CPU backend or CUDA
|
|
||||||
# If you're building on Linux with a CUDA-enabled GPU:
|
|
||||||
candle-core = { version = "=0.9.1", features = ["cuda"], default-features = false } # Or just "cuda" if not using default features
|
|
||||||
|
|
||||||
# If you're building on Linux with only CPU:
|
|
||||||
# candle-core = { version = "=0.9.1", default-features = false } # CPU is often the default, but good to be explicit
|
|
||||||
# --- End of conditional compilation section ---
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
|
||||||
byteorder = { version = "1.4.3" }
|
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
|
||||||
imageproc = { version = "0.24.0", default-features = false }
|
|
||||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
|
||||||
rand = { version = "0.9.0" }
|
|
||||||
ab_glyph = { version = "0.2.23" }
|
|
||||||
tracing = { version = "0.1.37" }
|
|
||||||
tracing-chrome = { version = "0.7.1" }
|
|
||||||
tracing-subscriber = { version = "0.3.7" }
|
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
|
||||||
tokio = "1.43.0"
|
|
||||||
|
|
||||||
[build-dependencies]
|
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
|
||||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
|
@@ -1,210 +0,0 @@
|
|||||||
# @open-web-agent-rs/legacy-inference-engine
|
|
||||||
|
|
||||||
## Note
|
|
||||||
This is here as a reference implementation. This is harder than it looks.
|
|
||||||
|
|
||||||
|
|
||||||
A Rust-based inference engine for running large language models locally. This tool supports both CLI mode for direct text generation and server mode with an OpenAI-compatible API.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Run Gemma models locally (1B, 2B, 7B, 9B variants)
|
|
||||||
- CLI mode for direct text generation
|
|
||||||
- Server mode with OpenAI-compatible API
|
|
||||||
- Support for various model configurations (base, instruction-tuned)
|
|
||||||
- Metal acceleration on macOS
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
- Rust toolchain (install via [rustup](https://rustup.rs/))
|
|
||||||
- Cargo package manager
|
|
||||||
- For GPU acceleration:
|
|
||||||
- macOS: Metal support
|
|
||||||
- Linux/Windows: CUDA support (requires appropriate drivers)
|
|
||||||
|
|
||||||
### Building from Source
|
|
||||||
|
|
||||||
1. Clone the repository:
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/seemueller-io/open-web-agent-rs.git
|
|
||||||
cd open-web-agent-rs
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Build the local inference engine:
|
|
||||||
```bash
|
|
||||||
cargo build -p legacy-inference-engine --release
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### CLI Mode
|
|
||||||
|
|
||||||
Run the inference engine in CLI mode to generate text directly:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run -p legacy-inference-engine --release -- --prompt 'Name the 16th President of the USA.' --which 3-1b-it
|
|
||||||
```
|
|
||||||
|
|
||||||
#### CLI Options
|
|
||||||
|
|
||||||
- `--prompt <TEXT>`: The prompt text to generate from
|
|
||||||
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
|
||||||
- `--server`: Run OpenAI compatible server
|
|
||||||
- Available options: "2b", "7b", "2b-it", "7b-it", "1.1-2b-it", "1.1-7b-it", "code-2b", "code-7b", "code-2b-it", "code-7b-it", "2-2b", "2-2b-it", "2-9b", "2-9b-it", "3-1b", "3-1b-it"
|
|
||||||
- `--temperature <FLOAT>`: Temperature for sampling (higher = more random)
|
|
||||||
- `--top-p <FLOAT>`: Nucleus sampling probability cutoff
|
|
||||||
- `--sample-len <INT>`: Maximum number of tokens to generate (default: 10000)
|
|
||||||
- `--repeat-penalty <FLOAT>`: Penalty for repeating tokens (default: 1.1)
|
|
||||||
- `--repeat-last-n <INT>`: Context size for repeat penalty (default: 64)
|
|
||||||
- `--cpu`: Run on CPU instead of GPU
|
|
||||||
- `--tracing`: Enable tracing (generates a trace-timestamp.json file)
|
|
||||||
|
|
||||||
### Server Mode with OpenAI-compatible API
|
|
||||||
|
|
||||||
Run the inference engine in server mode to expose an OpenAI-compatible API:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run -p legacy-inference-engine --release -- --server --port 3777 --which 3-1b-it
|
|
||||||
```
|
|
||||||
|
|
||||||
This starts a web server on the specified port (default: 3777) with an OpenAI-compatible chat completions endpoint.
|
|
||||||
|
|
||||||
#### Server Options
|
|
||||||
|
|
||||||
- `--server`: Run in server mode
|
|
||||||
- `--port <INT>`: Port to use for the server (default: 3777)
|
|
||||||
- `--which <MODEL>`: Model variant to use (default: "3-1b-it")
|
|
||||||
- Other model options as described in CLI mode
|
|
||||||
|
|
||||||
## API Usage
|
|
||||||
|
|
||||||
The server exposes an OpenAI-compatible chat completions endpoint:
|
|
||||||
|
|
||||||
### Chat Completions
|
|
||||||
|
|
||||||
```
|
|
||||||
POST /v1/chat/completions
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Request Format
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"model": "gemma-3-1b-it",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Hello, how are you?"}
|
|
||||||
],
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 256,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"stream": false
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Response Format
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"id": "chatcmpl-123abc456def789ghi",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1677858242,
|
|
||||||
"model": "gemma-3-1b-it",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "I'm doing well, thank you for asking! How can I assist you today?"
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 25,
|
|
||||||
"completion_tokens": 15,
|
|
||||||
"total_tokens": 40
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Example: Using cURL
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:3777/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "gemma-3-1b-it",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is the capital of France?"}
|
|
||||||
],
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 100
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
### Example: Using Python with OpenAI Client
|
|
||||||
|
|
||||||
```python
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
client = OpenAI(
|
|
||||||
base_url="http://localhost:3777/v1",
|
|
||||||
api_key="dummy" # API key is not validated but required by the client
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model="gemma-3-1b-it",
|
|
||||||
messages=[
|
|
||||||
{"role": "user", "content": "What is the capital of France?"}
|
|
||||||
],
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100
|
|
||||||
)
|
|
||||||
|
|
||||||
print(response.choices[0].message.content)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Example: Using JavaScript/TypeScript with OpenAI SDK
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
import OpenAI from 'openai';
|
|
||||||
|
|
||||||
const openai = new OpenAI({
|
|
||||||
baseURL: 'http://localhost:3777/v1',
|
|
||||||
apiKey: 'dummy', // API key is not validated but required by the client
|
|
||||||
});
|
|
||||||
|
|
||||||
async function main() {
|
|
||||||
const response = await openai.chat.completions.create({
|
|
||||||
model: 'gemma-3-1b-it',
|
|
||||||
messages: [
|
|
||||||
{ role: 'user', content: 'What is the capital of France?' }
|
|
||||||
],
|
|
||||||
temperature: 0.7,
|
|
||||||
max_tokens: 100,
|
|
||||||
});
|
|
||||||
|
|
||||||
console.log(response.choices[0].message.content);
|
|
||||||
}
|
|
||||||
|
|
||||||
main();
|
|
||||||
```
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
|
|
||||||
1. **Model download errors**: Make sure you have a stable internet connection. The models are downloaded from Hugging Face Hub.
|
|
||||||
|
|
||||||
2. **Out of memory errors**: Try using a smaller model variant or reducing the batch size.
|
|
||||||
|
|
||||||
3. **Slow inference on CPU**: This is expected. For better performance, use GPU acceleration if available.
|
|
||||||
|
|
||||||
4. **Metal/CUDA errors**: Ensure you have the latest drivers installed for your GPU.
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
This project is licensed under the terms specified in the LICENSE file.
|
|
@@ -1,127 +0,0 @@
|
|||||||
# Root Cause Analysis: Metal error "no metal implementation for rotary-emb"
|
|
||||||
|
|
||||||
Date: 2025-08-27
|
|
||||||
Component: crates/legacy-inference-engine
|
|
||||||
Command to reproduce: crates/legacy-inference-engine/test_cli.sh
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
Running the CLI with the default model (--which 3-1b-it, i.e., Gemma 3 1B Instruct) on an Apple Silicon Mac results in a runtime failure:
|
|
||||||
|
|
||||||
```
|
|
||||||
modelError: Metal error no metal implementation for rotary-emb
|
|
||||||
|
|
||||||
Caused by:
|
|
||||||
no metal implementation for rotary-emb
|
|
||||||
```
|
|
||||||
|
|
||||||
This occurs because the project targets the Candle Metal (MPS) backend on macOS, but the Candle version in use (0.9.1) does not provide a Metal kernel implementation for the rotary embedding operation required by Gemma 3 models. The program selects the Metal device by default on macOS and hits this missing kernel during the attention computation.
|
|
||||||
|
|
||||||
## Environment and build configuration
|
|
||||||
- Machine: 2024 MacBook Pro, Apple Silicon (M4 Max)
|
|
||||||
- Crate: legacy-inference-engine
|
|
||||||
- Candle versions: pinned to =0.9.1
|
|
||||||
- candle-core = "=0.9.1"
|
|
||||||
- candle-transformers = "=0.9.1"
|
|
||||||
- macOS-specific dependency enabling Metal (file: crates/legacy-inference-engine/Cargo.toml):
|
|
||||||
|
|
||||||
```text
|
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
|
||||||
candle-core = { version = "=0.9.1", features = ["metal"] }
|
|
||||||
metal = { version = "0.32.0", features = ["mps"] }
|
|
||||||
```
|
|
||||||
|
|
||||||
- Run command (attached script): crates/legacy-inference-engine/test_cli.sh
|
|
||||||
|
|
||||||
```text
|
|
||||||
cargo run -p legacy-inference-engine --release -- --prompt 'Name the 16th President of the USA.' --which 3-1b-it
|
|
||||||
```
|
|
||||||
|
|
||||||
## What the code does at runtime
|
|
||||||
1) Device selection (defaults to Metal on macOS if available):
|
|
||||||
- File: crates/legacy-inference-engine/src/utilities_lib.rs (lines 4–12)
|
|
||||||
|
|
||||||
```text
|
|
||||||
pub fn device(cpu: bool) -> Result<Device> {
|
|
||||||
if cpu {
|
|
||||||
Ok(Device::Cpu)
|
|
||||||
} else if cuda_is_available() {
|
|
||||||
Ok(Device::new_cuda(0)?)
|
|
||||||
} else if metal_is_available() {
|
|
||||||
Ok(Device::new_metal(0)?)
|
|
||||||
} else {
|
|
||||||
// ... falls back to CPU
|
|
||||||
Ok(Device::Cpu)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
- The CLI does not pass --cpu, so on Apple Silicon with Metal available, Device::new_metal(0) is selected.
|
|
||||||
|
|
||||||
2) Default model selection is Gemma 3 1B Instruct:
|
|
||||||
- File: crates/legacy-inference-engine/src/main.rs
|
|
||||||
- Arg default (lines 705–707):
|
|
||||||
|
|
||||||
```text
|
|
||||||
/// The model to use.
|
|
||||||
#[arg(long, default_value = "3-1b-it")]
|
|
||||||
which: Which,
|
|
||||||
```
|
|
||||||
|
|
||||||
- Model id resolution (lines 758–760):
|
|
||||||
|
|
||||||
```text
|
|
||||||
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
|
||||||
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
|
||||||
```
|
|
||||||
|
|
||||||
- Model loading uses Model3 (Gemma 3) for Which::BaseV3_1B | Which::InstructV3_1B (lines 817–821).
|
|
||||||
|
|
||||||
3) During generation, the Gemma 3 attention path requires rotary embeddings. On the Metal backend in Candle 0.9.1, the rotary embedding op is not implemented, resulting in the runtime error.
|
|
||||||
|
|
||||||
## Additional build-time signal (misleading but not causal)
|
|
||||||
- File: crates/legacy-inference-engine/src/main.rs (lines 10–11)
|
|
||||||
|
|
||||||
```text
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
extern crate metal_src;
|
|
||||||
```
|
|
||||||
|
|
||||||
- Build warning: unexpected cfg condition value: metal
|
|
||||||
Explanation: The project does not define a Cargo feature named "metal"; instead, Metal is enabled via target-specific dependency features in Cargo.toml. This cfg gate is ineffective and triggers a warning. It does not cause the runtime failure; it just indicates confusing/obsolete gating.
|
|
||||||
|
|
||||||
## Root cause
|
|
||||||
- The program runs on the Candle Metal backend (MPS) due to device auto-selection on macOS.
|
|
||||||
- The selected model (Gemma 3 1B Instruct) requires the rotary embedding operation in its attention mechanism.
|
|
||||||
- Candle 0.9.1’s Metal backend lacks an implementation for the rotary-emb kernel. When the model executes on Metal, it attempts to invoke this operation and fails with: "no metal implementation for rotary-emb".
|
|
||||||
|
|
||||||
## Evidence
|
|
||||||
- Runtime log shows the failure immediately after model load when inference begins.
|
|
||||||
- Code paths confirm: device defaults to Metal on macOS; default model is Gemma 3; Gemma 3 uses rotary embeddings.
|
|
||||||
- Candle version pinned to 0.9.1 where rotary-emb on Metal is not available.
|
|
||||||
|
|
||||||
## Impact
|
|
||||||
- Any attempt to run Gemma 3 (and possibly other rotary-embedding reliant models) on the Metal backend with Candle 0.9.1 will fail at runtime on macOS.
|
|
||||||
|
|
||||||
## Workarounds and remediation options
|
|
||||||
1) Immediate workarounds:
|
|
||||||
- Run on CPU: add the --cpu flag to force CPU backend.
|
|
||||||
- Example: cargo run -p legacy-inference-engine --release -- --cpu --prompt '...' --which 3-1b-it
|
|
||||||
- Use a model variant that does not hit the unimplemented kernel on Metal (e.g., older Gemma v1/v2), though many modern LLMs rely on rotary embeddings, so this may not help.
|
|
||||||
|
|
||||||
2) Recommended remediation (code/dependency changes):
|
|
||||||
- Upgrade Candle crates (candle-core, candle-transformers, etc.) to a version where the Metal backend implements rotary embeddings. Review Candle’s changelog/PRs for Metal/MPS kernel support and update to the first version that includes rotary-emb on Metal.
|
|
||||||
- Alternatively, implement a CPU fallback path for rotary-emb when running on Metal (hybrid execution). This is non-trivial and may degrade performance.
|
|
||||||
- Provide a configuration/flag to disable Metal by default on macOS for models known to require missing ops until Candle is upgraded.
|
|
||||||
- Clean up the misleading #[cfg(feature = "metal")] gate in main.rs to avoid confusion; Metal enablement is already handled in Cargo.toml via target-specific features.
|
|
||||||
|
|
||||||
## Suggested next steps
|
|
||||||
- Short term: document and expose --cpu usage in README and/or make the default model a Metal-compatible one until dependency upgrade.
|
|
||||||
- Medium term: bump Candle dependencies and test Gemma 3 on Metal; remove the obsolete cfg(feature = "metal") gate.
|
|
||||||
- Long term: integrate a device capability check and automatic fallback (informative log) when encountering unsupported kernels on the selected backend.
|
|
||||||
|
|
||||||
## References (code locations)
|
|
||||||
- crates/legacy-inference-engine/src/utilities_lib.rs lines 4–12: device selection (Metal default on macOS if available).
|
|
||||||
- crates/legacy-inference-engine/src/main.rs lines 705–707: default which = 3-1b-it.
|
|
||||||
- crates/legacy-inference-engine/src/main.rs lines 758–760 and 817–821: Gemma 3 model selection and instantiation.
|
|
||||||
- crates/legacy-inference-engine/Cargo.toml macOS target section: Candle with features = ["metal"].
|
|
||||||
- crates/legacy-inference-engine/src/main.rs lines 10–11: obsolete #[cfg(feature = "metal")] gate that triggers a warning.
|
|
@@ -1,295 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>OpenAI-Compatible API Tester</title>
|
|
||||||
<style>
|
|
||||||
body {
|
|
||||||
font-family: Arial, sans-serif;
|
|
||||||
max-width: 800px;
|
|
||||||
margin: 0 auto;
|
|
||||||
padding: 20px;
|
|
||||||
line-height: 1.6;
|
|
||||||
}
|
|
||||||
h1, h2 {
|
|
||||||
color: #333;
|
|
||||||
}
|
|
||||||
.container {
|
|
||||||
margin-bottom: 20px;
|
|
||||||
}
|
|
||||||
textarea {
|
|
||||||
width: 100%;
|
|
||||||
height: 150px;
|
|
||||||
padding: 10px;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
border: 1px solid #ddd;
|
|
||||||
border-radius: 4px;
|
|
||||||
font-family: monospace;
|
|
||||||
}
|
|
||||||
button {
|
|
||||||
background-color: #4CAF50;
|
|
||||||
color: white;
|
|
||||||
padding: 10px 15px;
|
|
||||||
border: none;
|
|
||||||
border-radius: 4px;
|
|
||||||
cursor: pointer;
|
|
||||||
font-size: 16px;
|
|
||||||
}
|
|
||||||
button:hover {
|
|
||||||
background-color: #45a049;
|
|
||||||
}
|
|
||||||
pre {
|
|
||||||
background-color: #f5f5f5;
|
|
||||||
padding: 15px;
|
|
||||||
border-radius: 4px;
|
|
||||||
overflow-x: auto;
|
|
||||||
white-space: pre-wrap;
|
|
||||||
}
|
|
||||||
.response {
|
|
||||||
margin-top: 20px;
|
|
||||||
}
|
|
||||||
.error {
|
|
||||||
color: red;
|
|
||||||
}
|
|
||||||
.settings {
|
|
||||||
display: flex;
|
|
||||||
flex-wrap: wrap;
|
|
||||||
gap: 10px;
|
|
||||||
margin-bottom: 15px;
|
|
||||||
}
|
|
||||||
.settings div {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
}
|
|
||||||
label {
|
|
||||||
margin-bottom: 5px;
|
|
||||||
font-weight: bold;
|
|
||||||
}
|
|
||||||
input {
|
|
||||||
padding: 8px;
|
|
||||||
border: 1px solid #ddd;
|
|
||||||
border-radius: 4px;
|
|
||||||
}
|
|
||||||
.examples {
|
|
||||||
margin-top: 30px;
|
|
||||||
}
|
|
||||||
.example-btn {
|
|
||||||
background-color: #2196F3;
|
|
||||||
margin-right: 10px;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
}
|
|
||||||
.example-btn:hover {
|
|
||||||
background-color: #0b7dda;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<h1>OpenAI-Compatible API Tester</h1>
|
|
||||||
<p>Use this page to test the OpenAI-compatible chat completions endpoint of the local inference engine.</p>
|
|
||||||
|
|
||||||
<div class="container">
|
|
||||||
<h2>Request Settings</h2>
|
|
||||||
<div class="settings">
|
|
||||||
<div>
|
|
||||||
<label for="serverUrl">Server URL:</label>
|
|
||||||
<input type="text" id="serverUrl" value="http://localhost:3777" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label for="model">Model:</label>
|
|
||||||
<input type="text" id="model" value="gemma-3-1b-it" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label for="maxTokens">Max Tokens:</label>
|
|
||||||
<input type="number" id="maxTokens" value="150" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label for="temperature">Temperature:</label>
|
|
||||||
<input type="number" id="temperature" value="0.7" step="0.1" min="0" max="2" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label for="topP">Top P:</label>
|
|
||||||
<input type="number" id="topP" value="0.9" step="0.1" min="0" max="1" />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<h2>Request Body</h2>
|
|
||||||
<textarea id="requestBody">{
|
|
||||||
"model": "gemma-3-1b-it",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello, how are you today?"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 150,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.9
|
|
||||||
}</textarea>
|
|
||||||
<button id="sendRequest">Send Request</button>
|
|
||||||
|
|
||||||
<div class="examples">
|
|
||||||
<h3>Example Requests</h3>
|
|
||||||
<button class="example-btn" id="example1">Basic Question</button>
|
|
||||||
<button class="example-btn" id="example2">Multi-turn Conversation</button>
|
|
||||||
<button class="example-btn" id="example3">Creative Writing</button>
|
|
||||||
<button class="example-btn" id="example4">Code Generation</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="response">
|
|
||||||
<h2>Response</h2>
|
|
||||||
<pre id="responseOutput">Response will appear here...</pre>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
document.addEventListener('DOMContentLoaded', function() {
|
|
||||||
// Update request body when settings change
|
|
||||||
const serverUrlInput = document.getElementById('serverUrl');
|
|
||||||
const modelInput = document.getElementById('model');
|
|
||||||
const maxTokensInput = document.getElementById('maxTokens');
|
|
||||||
const temperatureInput = document.getElementById('temperature');
|
|
||||||
const topPInput = document.getElementById('topP');
|
|
||||||
const requestBodyTextarea = document.getElementById('requestBody');
|
|
||||||
const responseOutput = document.getElementById('responseOutput');
|
|
||||||
|
|
||||||
// Function to update request body from settings
|
|
||||||
function updateRequestBodyFromSettings() {
|
|
||||||
try {
|
|
||||||
const requestBody = JSON.parse(requestBodyTextarea.value);
|
|
||||||
requestBody.model = modelInput.value;
|
|
||||||
requestBody.max_tokens = parseInt(maxTokensInput.value);
|
|
||||||
requestBody.temperature = parseFloat(temperatureInput.value);
|
|
||||||
requestBody.top_p = parseFloat(topPInput.value);
|
|
||||||
requestBodyTextarea.value = JSON.stringify(requestBody, null, 2);
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error updating request body:", error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update settings when request body changes
|
|
||||||
function updateSettingsFromRequestBody() {
|
|
||||||
try {
|
|
||||||
const requestBody = JSON.parse(requestBodyTextarea.value);
|
|
||||||
if (requestBody.model) modelInput.value = requestBody.model;
|
|
||||||
if (requestBody.max_tokens) maxTokensInput.value = requestBody.max_tokens;
|
|
||||||
if (requestBody.temperature) temperatureInput.value = requestBody.temperature;
|
|
||||||
if (requestBody.top_p) topPInput.value = requestBody.top_p;
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error updating settings:", error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add event listeners for settings changes
|
|
||||||
modelInput.addEventListener('change', updateRequestBodyFromSettings);
|
|
||||||
maxTokensInput.addEventListener('change', updateRequestBodyFromSettings);
|
|
||||||
temperatureInput.addEventListener('change', updateRequestBodyFromSettings);
|
|
||||||
topPInput.addEventListener('change', updateRequestBodyFromSettings);
|
|
||||||
|
|
||||||
// Add event listener for request body changes
|
|
||||||
requestBodyTextarea.addEventListener('blur', updateSettingsFromRequestBody);
|
|
||||||
|
|
||||||
// Send request button
|
|
||||||
document.getElementById('sendRequest').addEventListener('click', async function() {
|
|
||||||
try {
|
|
||||||
responseOutput.textContent = "Sending request...";
|
|
||||||
const serverUrl = serverUrlInput.value;
|
|
||||||
const endpoint = '/v1/chat/completions';
|
|
||||||
const url = serverUrl + endpoint;
|
|
||||||
|
|
||||||
const requestBody = JSON.parse(requestBodyTextarea.value);
|
|
||||||
|
|
||||||
const response = await fetch(url, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify(requestBody)
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
responseOutput.textContent = JSON.stringify(data, null, 2);
|
|
||||||
} catch (error) {
|
|
||||||
responseOutput.textContent = "Error: " + error.message;
|
|
||||||
responseOutput.classList.add('error');
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Example requests
|
|
||||||
document.getElementById('example1').addEventListener('click', function() {
|
|
||||||
requestBodyTextarea.value = JSON.stringify({
|
|
||||||
model: modelInput.value,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: "Who was the 16th president of the United States?"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_tokens: parseInt(maxTokensInput.value),
|
|
||||||
temperature: parseFloat(temperatureInput.value),
|
|
||||||
top_p: parseFloat(topPInput.value)
|
|
||||||
}, null, 2);
|
|
||||||
});
|
|
||||||
|
|
||||||
document.getElementById('example2').addEventListener('click', function() {
|
|
||||||
requestBodyTextarea.value = JSON.stringify({
|
|
||||||
model: modelInput.value,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: "You are a helpful assistant that provides concise answers."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: "What is machine learning?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "assistant",
|
|
||||||
content: "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: "Give me an example of a machine learning algorithm."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_tokens: parseInt(maxTokensInput.value),
|
|
||||||
temperature: parseFloat(temperatureInput.value),
|
|
||||||
top_p: parseFloat(topPInput.value)
|
|
||||||
}, null, 2);
|
|
||||||
});
|
|
||||||
|
|
||||||
document.getElementById('example3').addEventListener('click', function() {
|
|
||||||
requestBodyTextarea.value = JSON.stringify({
|
|
||||||
model: modelInput.value,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: "Write a short poem about artificial intelligence."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_tokens: parseInt(maxTokensInput.value),
|
|
||||||
temperature: 0.9, // Higher temperature for creative tasks
|
|
||||||
top_p: 0.9
|
|
||||||
}, null, 2);
|
|
||||||
temperatureInput.value = 0.9;
|
|
||||||
});
|
|
||||||
|
|
||||||
document.getElementById('example4').addEventListener('click', function() {
|
|
||||||
requestBodyTextarea.value = JSON.stringify({
|
|
||||||
model: modelInput.value,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: "Write a Python function to calculate the Fibonacci sequence up to n terms."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_tokens: parseInt(maxTokensInput.value),
|
|
||||||
temperature: 0.3, // Lower temperature for code generation
|
|
||||||
top_p: 0.9
|
|
||||||
}, null, 2);
|
|
||||||
temperatureInput.value = 0.3;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
@@ -1,72 +0,0 @@
|
|||||||
use clap::Parser;
|
|
||||||
use crate::model::Which;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
pub struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
pub cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
pub tracing: bool,
|
|
||||||
|
|
||||||
/// Run in server mode with OpenAI compatible API
|
|
||||||
#[arg(long)]
|
|
||||||
pub server: bool,
|
|
||||||
|
|
||||||
/// Port to use for the server
|
|
||||||
#[arg(long, default_value_t = 3777)]
|
|
||||||
pub port: u16,
|
|
||||||
|
|
||||||
/// Prompt for text generation (not used in server mode)
|
|
||||||
#[arg(long)]
|
|
||||||
pub prompt: Option<String>,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
pub top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
pub seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
|
||||||
pub sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
pub revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub config_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
pub repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
pub repeat_last_n: usize,
|
|
||||||
|
|
||||||
/// The model to use.
|
|
||||||
#[arg(long, default_value = "3-1b-it")]
|
|
||||||
pub which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
pub use_flash_attn: bool,
|
|
||||||
}
|
|
@@ -1,13 +0,0 @@
|
|||||||
// Expose modules for testing and library usage
|
|
||||||
pub mod token_output_stream;
|
|
||||||
pub mod model;
|
|
||||||
pub mod text_generation;
|
|
||||||
pub mod utilities_lib;
|
|
||||||
pub mod openai_types;
|
|
||||||
pub mod cli;
|
|
||||||
pub mod server;
|
|
||||||
|
|
||||||
// Re-export key components for easier access
|
|
||||||
pub use model::{Model, Which};
|
|
||||||
pub use text_generation::TextGeneration;
|
|
||||||
pub use token_output_stream::TokenOutputStream;
|
|
@@ -1,912 +0,0 @@
|
|||||||
mod token_output_stream;
|
|
||||||
mod utilities_lib;
|
|
||||||
|
|
||||||
#[cfg(feature = "intel-mkl-src")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate-src")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
extern crate metal_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use axum::{
|
|
||||||
extract::State,
|
|
||||||
http::StatusCode,
|
|
||||||
response::IntoResponse,
|
|
||||||
routing::{get, post},
|
|
||||||
Json, Router,
|
|
||||||
};
|
|
||||||
use clap::Parser;
|
|
||||||
use either::Either;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
|
||||||
use utoipa::ToSchema;
|
|
||||||
|
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
|
||||||
|
|
||||||
// OpenAI API compatible structs
|
|
||||||
|
|
||||||
/// Inner content structure for messages that can be either a string or key-value pairs
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct MessageInnerContent(
|
|
||||||
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageInnerContent {
|
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
|
||||||
(
|
|
||||||
"MessageInnerContent",
|
|
||||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Function for MessageInnerContent Schema generation to handle `Either`
|
|
||||||
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
|
||||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
|
||||||
|
|
||||||
Schema::OneOf(
|
|
||||||
OneOfBuilder::new()
|
|
||||||
// Either::Left - simple string
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))
|
|
||||||
// Either::Right - object with string values
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new()
|
|
||||||
.schema_type(SchemaType::Object)
|
|
||||||
.additional_properties(Some(RefOr::T(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))))
|
|
||||||
.build(),
|
|
||||||
))
|
|
||||||
.build(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Message content that can be either simple text or complex structured content
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct MessageContent(
|
|
||||||
#[serde(with = "either::serde_untagged")]
|
|
||||||
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageContent {
|
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
|
||||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Function for MessageContent Schema generation to handle `Either`
|
|
||||||
fn message_content_schema() -> utoipa::openapi::Schema {
|
|
||||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
|
||||||
|
|
||||||
Schema::OneOf(
|
|
||||||
OneOfBuilder::new()
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))
|
|
||||||
.item(Schema::Array(
|
|
||||||
ArrayBuilder::new()
|
|
||||||
.items(RefOr::T(Schema::Object(
|
|
||||||
ObjectBuilder::new()
|
|
||||||
.schema_type(SchemaType::Object)
|
|
||||||
.additional_properties(Some(RefOr::Ref(
|
|
||||||
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
|
||||||
)))
|
|
||||||
.build(),
|
|
||||||
)))
|
|
||||||
.build(),
|
|
||||||
))
|
|
||||||
.build(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Represents a single message in a conversation
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
pub struct Message {
|
|
||||||
/// The message content
|
|
||||||
pub content: Option<MessageContent>,
|
|
||||||
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
|
||||||
pub role: String,
|
|
||||||
pub name: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stop token configuration for generation
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum StopTokens {
|
|
||||||
/// Multiple possible stop sequences
|
|
||||||
Multi(Vec<String>),
|
|
||||||
/// Single stop sequence
|
|
||||||
Single(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
fn default_false() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
fn default_1usize() -> usize {
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
fn default_model() -> String {
|
|
||||||
"default".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion request following OpenAI's specification
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
|
||||||
pub messages: Vec<Message>,
|
|
||||||
#[schema(example = "gemma-3-1b-it")]
|
|
||||||
#[serde(default = "default_model")]
|
|
||||||
pub model: String,
|
|
||||||
#[serde(default = "default_false")]
|
|
||||||
#[schema(example = false)]
|
|
||||||
pub logprobs: bool,
|
|
||||||
#[schema(example = 256)]
|
|
||||||
pub max_tokens: Option<usize>,
|
|
||||||
#[serde(rename = "n")]
|
|
||||||
#[serde(default = "default_1usize")]
|
|
||||||
#[schema(example = 1)]
|
|
||||||
pub n_choices: usize,
|
|
||||||
#[schema(example = 0.7)]
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
#[schema(example = 0.9)]
|
|
||||||
pub top_p: Option<f64>,
|
|
||||||
#[schema(example = false)]
|
|
||||||
pub stream: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion choice
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionChoice {
|
|
||||||
pub index: usize,
|
|
||||||
pub message: Message,
|
|
||||||
pub finish_reason: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion response
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatCompletionChoice>,
|
|
||||||
pub usage: Usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Token usage information
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: usize,
|
|
||||||
pub completion_tokens: usize,
|
|
||||||
pub total_tokens: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Application state shared between handlers
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct AppState {
|
|
||||||
text_generation: Arc<Mutex<TextGeneration>>,
|
|
||||||
model_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chat completions endpoint handler
|
|
||||||
async fn chat_completions(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Json(request): Json<ChatCompletionRequest>,
|
|
||||||
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
||||||
let mut prompt = String::new();
|
|
||||||
|
|
||||||
// Convert messages to a prompt string
|
|
||||||
for message in &request.messages {
|
|
||||||
let role = &message.role;
|
|
||||||
let content = match &message.content {
|
|
||||||
Some(content) => match &content.0 {
|
|
||||||
Either::Left(text) => text.clone(),
|
|
||||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
|
||||||
},
|
|
||||||
None => "".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Format based on role
|
|
||||||
match role.as_str() {
|
|
||||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
|
||||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
|
||||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
|
||||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the assistant prefix for the response
|
|
||||||
prompt.push_str("Assistant: ");
|
|
||||||
|
|
||||||
// Capture the output
|
|
||||||
let mut output = Vec::new();
|
|
||||||
{
|
|
||||||
let mut text_gen = state.text_generation.lock().await;
|
|
||||||
|
|
||||||
// Buffer to capture the output
|
|
||||||
let mut buffer = Vec::new();
|
|
||||||
|
|
||||||
// Run text generation
|
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
|
||||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
|
||||||
|
|
||||||
if let Err(e) = result {
|
|
||||||
return Err((
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"error": {
|
|
||||||
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
|
|
||||||
"type": "unsupported_api"
|
|
||||||
}
|
|
||||||
})),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert buffer to string
|
|
||||||
if let Ok(text) = String::from_utf8(buffer) {
|
|
||||||
output.push(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create response
|
|
||||||
let response = ChatCompletionResponse {
|
|
||||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
|
||||||
object: "chat.completion".to_string(),
|
|
||||||
created: std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs(),
|
|
||||||
model: request.model,
|
|
||||||
choices: vec![ChatCompletionChoice {
|
|
||||||
index: 0,
|
|
||||||
message: Message {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: Some(MessageContent(Either::Left(output.join("")))),
|
|
||||||
name: None,
|
|
||||||
},
|
|
||||||
finish_reason: "stop".to_string(),
|
|
||||||
}],
|
|
||||||
usage: Usage {
|
|
||||||
prompt_tokens: prompt.len() / 4, // Rough estimate
|
|
||||||
completion_tokens: output.join("").len() / 4, // Rough estimate
|
|
||||||
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
// Return the response as JSON
|
|
||||||
Ok(Json(response))
|
|
||||||
}
|
|
||||||
|
|
||||||
use candle_core::{DType, Device, MetalDevice, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{Repo, RepoType, api::sync::Api};
|
|
||||||
use serde_json::json;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use crate::token_output_stream::TokenOutputStream;
|
|
||||||
use crate::utilities_lib::device;
|
|
||||||
|
|
||||||
// Create the router with the chat completions endpoint
|
|
||||||
fn create_router(app_state: AppState) -> Router {
|
|
||||||
// CORS layer to allow requests from any origin
|
|
||||||
let cors = CorsLayer::new()
|
|
||||||
.allow_origin(Any)
|
|
||||||
.allow_methods(Any)
|
|
||||||
.allow_headers(Any);
|
|
||||||
|
|
||||||
Router::new()
|
|
||||||
// OpenAI compatible endpoints
|
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
|
||||||
// Add more endpoints as needed
|
|
||||||
.layer(cors)
|
|
||||||
.with_state(app_state)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
#[value(name = "2b")]
|
|
||||||
Base2B,
|
|
||||||
#[value(name = "7b")]
|
|
||||||
Base7B,
|
|
||||||
#[value(name = "2b-it")]
|
|
||||||
Instruct2B,
|
|
||||||
#[value(name = "7b-it")]
|
|
||||||
Instruct7B,
|
|
||||||
#[value(name = "1.1-2b-it")]
|
|
||||||
InstructV1_1_2B,
|
|
||||||
#[value(name = "1.1-7b-it")]
|
|
||||||
InstructV1_1_7B,
|
|
||||||
#[value(name = "code-2b")]
|
|
||||||
CodeBase2B,
|
|
||||||
#[value(name = "code-7b")]
|
|
||||||
CodeBase7B,
|
|
||||||
#[value(name = "code-2b-it")]
|
|
||||||
CodeInstruct2B,
|
|
||||||
#[value(name = "code-7b-it")]
|
|
||||||
CodeInstruct7B,
|
|
||||||
#[value(name = "2-2b")]
|
|
||||||
BaseV2_2B,
|
|
||||||
#[value(name = "2-2b-it")]
|
|
||||||
InstructV2_2B,
|
|
||||||
#[value(name = "2-9b")]
|
|
||||||
BaseV2_9B,
|
|
||||||
#[value(name = "2-9b-it")]
|
|
||||||
InstructV2_9B,
|
|
||||||
#[value(name = "3-1b")]
|
|
||||||
BaseV3_1B,
|
|
||||||
#[value(name = "3-1b-it")]
|
|
||||||
InstructV3_1B,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum Model {
|
|
||||||
V1(Model1),
|
|
||||||
V2(Model2),
|
|
||||||
V3(Model3),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::V1(m) => m.forward(input_ids, pos),
|
|
||||||
Self::V2(m) => m.forward(input_ids, pos),
|
|
||||||
Self::V3(m) => m.forward(input_ids, pos),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run text generation and print to stdout
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => {
|
|
||||||
println!(
|
|
||||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
|
||||||
);
|
|
||||||
eos_token
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run text generation and write to a buffer
|
|
||||||
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
|
|
||||||
// Write prompt tokens to output
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => {
|
|
||||||
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
|
||||||
eos_token
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Determine if we're using a Model3 (gemma-3) variant
|
|
||||||
let is_model3 = match &self.model {
|
|
||||||
Model::V3(_) => true,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
|
|
||||||
// For Model3, we need to use a different approach
|
|
||||||
if is_model3 {
|
|
||||||
// For gemma-3 models, we'll generate one token at a time with the full context
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
|
|
||||||
// Initial generation with the full prompt
|
|
||||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
|
||||||
let mut logits = self.model.forward(&input, 0)?;
|
|
||||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
|
|
||||||
for _ in 0..sample_len {
|
|
||||||
// Apply repeat penalty if needed
|
|
||||||
let current_logits = if self.repeat_penalty == 1. {
|
|
||||||
logits.clone()
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For the next iteration, just use the new token
|
|
||||||
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
|
||||||
logits = self.model.forward(&new_input, tokens.len() - 1)?;
|
|
||||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Standard approach for other models
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write any remaining tokens
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
write!(output, "{}", rest)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// Run in server mode with OpenAI compatible API
|
|
||||||
#[arg(long)]
|
|
||||||
server: bool,
|
|
||||||
|
|
||||||
/// Port to use for the server
|
|
||||||
#[arg(long, default_value_t = 3777)]
|
|
||||||
port: u16,
|
|
||||||
|
|
||||||
/// Prompt for text generation (not used in server mode)
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: Option<String>,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
config_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
|
|
||||||
/// The model to use.
|
|
||||||
#[arg(long, default_value = "3-1b-it")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
use_flash_attn: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle_core::utils::with_avx(),
|
|
||||||
candle_core::utils::with_neon(),
|
|
||||||
candle_core::utils::with_simd128(),
|
|
||||||
candle_core::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match &args.model_id {
|
|
||||||
Some(model_id) => model_id.to_string(),
|
|
||||||
None => match args.which {
|
|
||||||
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
|
||||||
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
|
||||||
Which::Base2B => "google/gemma-2b".to_string(),
|
|
||||||
Which::Base7B => "google/gemma-7b".to_string(),
|
|
||||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
|
||||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
|
||||||
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
|
||||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
|
||||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
|
||||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
|
||||||
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
|
||||||
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
|
||||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
|
||||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
|
||||||
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
|
||||||
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
model_id.clone(),
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision,
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let config_filename = match args.config_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("config.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => match args.which {
|
|
||||||
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
|
||||||
_ => utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let initial_device = utilities_lib::device(args.cpu)?;
|
|
||||||
|
|
||||||
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
|
|
||||||
let is_v3_model = matches!(args.which, Which::BaseV3_1B | Which::InstructV3_1B);
|
|
||||||
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.cpu;
|
|
||||||
|
|
||||||
// Use CPU for V3 models on Metal due to missing implementations
|
|
||||||
let device = if is_v3_model && is_metal {
|
|
||||||
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
|
|
||||||
Device::Cpu
|
|
||||||
} else {
|
|
||||||
initial_device
|
|
||||||
};
|
|
||||||
|
|
||||||
let dtype = if device.is_cuda() {
|
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use the selected device and dtype
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
let model = match args.which {
|
|
||||||
Which::Base2B
|
|
||||||
| Which::Base7B
|
|
||||||
| Which::Instruct2B
|
|
||||||
| Which::Instruct7B
|
|
||||||
| Which::InstructV1_1_2B
|
|
||||||
| Which::InstructV1_1_7B
|
|
||||||
| Which::CodeBase2B
|
|
||||||
| Which::CodeBase7B
|
|
||||||
| Which::CodeInstruct2B
|
|
||||||
| Which::CodeInstruct7B => {
|
|
||||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
|
||||||
Model::V1(model)
|
|
||||||
}
|
|
||||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
|
||||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
|
||||||
Model::V2(model)
|
|
||||||
}
|
|
||||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
|
||||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
|
||||||
Model::V3(model)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
|
|
||||||
if args.server {
|
|
||||||
// Start the server
|
|
||||||
println!("Starting server on port {}", args.port);
|
|
||||||
|
|
||||||
// Create app state
|
|
||||||
let app_state = AppState {
|
|
||||||
text_generation: Arc::new(Mutex::new(pipeline)),
|
|
||||||
model_id,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create router
|
|
||||||
let app = create_router(app_state);
|
|
||||||
|
|
||||||
// Run the server
|
|
||||||
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
|
|
||||||
|
|
||||||
// Use tokio to run the server
|
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
|
||||||
.enable_all()
|
|
||||||
.build()?
|
|
||||||
.block_on(async {
|
|
||||||
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
// Run in CLI mode
|
|
||||||
if let Some(prompt_text) = &args.prompt {
|
|
||||||
let prompt = match args.which {
|
|
||||||
Which::Base2B
|
|
||||||
| Which::Base7B
|
|
||||||
| Which::Instruct2B
|
|
||||||
| Which::Instruct7B
|
|
||||||
| Which::InstructV1_1_2B
|
|
||||||
| Which::InstructV1_1_7B
|
|
||||||
| Which::CodeBase2B
|
|
||||||
| Which::CodeBase7B
|
|
||||||
| Which::CodeInstruct2B
|
|
||||||
| Which::CodeInstruct7B
|
|
||||||
| Which::BaseV2_2B
|
|
||||||
| Which::InstructV2_2B
|
|
||||||
| Which::BaseV2_9B
|
|
||||||
| Which::InstructV2_9B
|
|
||||||
| Which::BaseV3_1B => prompt_text.clone(),
|
|
||||||
Which::InstructV3_1B => {
|
|
||||||
format!(
|
|
||||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
|
||||||
prompt_text
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut pipeline = pipeline;
|
|
||||||
pipeline.run(&prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,90 +0,0 @@
|
|||||||
use candle_core::Tensor;
|
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
pub enum Which {
|
|
||||||
#[value(name = "2b")]
|
|
||||||
Base2B,
|
|
||||||
#[value(name = "7b")]
|
|
||||||
Base7B,
|
|
||||||
#[value(name = "2b-it")]
|
|
||||||
Instruct2B,
|
|
||||||
#[value(name = "7b-it")]
|
|
||||||
Instruct7B,
|
|
||||||
#[value(name = "1.1-2b-it")]
|
|
||||||
InstructV1_1_2B,
|
|
||||||
#[value(name = "1.1-7b-it")]
|
|
||||||
InstructV1_1_7B,
|
|
||||||
#[value(name = "code-2b")]
|
|
||||||
CodeBase2B,
|
|
||||||
#[value(name = "code-7b")]
|
|
||||||
CodeBase7B,
|
|
||||||
#[value(name = "code-2b-it")]
|
|
||||||
CodeInstruct2B,
|
|
||||||
#[value(name = "code-7b-it")]
|
|
||||||
CodeInstruct7B,
|
|
||||||
#[value(name = "2-2b")]
|
|
||||||
BaseV2_2B,
|
|
||||||
#[value(name = "2-2b-it")]
|
|
||||||
InstructV2_2B,
|
|
||||||
#[value(name = "2-9b")]
|
|
||||||
BaseV2_9B,
|
|
||||||
#[value(name = "2-9b-it")]
|
|
||||||
InstructV2_9B,
|
|
||||||
#[value(name = "3-1b")]
|
|
||||||
BaseV3_1B,
|
|
||||||
#[value(name = "3-1b-it")]
|
|
||||||
InstructV3_1B,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum Model {
|
|
||||||
V1(Model1),
|
|
||||||
V2(Model2),
|
|
||||||
V3(Model3),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::V1(m) => m.forward(input_ids, pos),
|
|
||||||
Self::V2(m) => m.forward(input_ids, pos),
|
|
||||||
Self::V3(m) => m.forward(input_ids, pos),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
pub fn to_model_id(&self) -> String {
|
|
||||||
match self {
|
|
||||||
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
|
||||||
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
|
||||||
Self::Base2B => "google/gemma-2b".to_string(),
|
|
||||||
Self::Base7B => "google/gemma-7b".to_string(),
|
|
||||||
Self::Instruct2B => "google/gemma-2b-it".to_string(),
|
|
||||||
Self::Instruct7B => "google/gemma-7b-it".to_string(),
|
|
||||||
Self::CodeBase2B => "google/codegemma-2b".to_string(),
|
|
||||||
Self::CodeBase7B => "google/codegemma-7b".to_string(),
|
|
||||||
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
|
||||||
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
|
||||||
Self::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
|
||||||
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
|
||||||
Self::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
|
||||||
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
|
||||||
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
|
||||||
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_instruct_model(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
|
|
||||||
_ => true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_v3_model(&self) -> bool {
|
|
||||||
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,167 +0,0 @@
|
|||||||
use either::Either;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use utoipa::ToSchema;
|
|
||||||
|
|
||||||
/// Inner content structure for messages that can be either a string or key-value pairs
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct MessageInnerContent(
|
|
||||||
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageInnerContent {
|
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
|
||||||
(
|
|
||||||
"MessageInnerContent",
|
|
||||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Function for MessageInnerContent Schema generation to handle `Either`
|
|
||||||
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
|
||||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
|
||||||
|
|
||||||
Schema::OneOf(
|
|
||||||
OneOfBuilder::new()
|
|
||||||
// Either::Left - simple string
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))
|
|
||||||
// Either::Right - object with string values
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new()
|
|
||||||
.schema_type(SchemaType::Object)
|
|
||||||
.additional_properties(Some(RefOr::T(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))))
|
|
||||||
.build(),
|
|
||||||
))
|
|
||||||
.build(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Message content that can be either simple text or complex structured content
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct MessageContent(
|
|
||||||
#[serde(with = "either::serde_untagged")]
|
|
||||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageContent {
|
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
|
||||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Function for MessageContent Schema generation to handle `Either`
|
|
||||||
fn message_content_schema() -> utoipa::openapi::Schema {
|
|
||||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
|
||||||
|
|
||||||
Schema::OneOf(
|
|
||||||
OneOfBuilder::new()
|
|
||||||
.item(Schema::Object(
|
|
||||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
|
||||||
))
|
|
||||||
.item(Schema::Array(
|
|
||||||
ArrayBuilder::new()
|
|
||||||
.items(RefOr::T(Schema::Object(
|
|
||||||
ObjectBuilder::new()
|
|
||||||
.schema_type(SchemaType::Object)
|
|
||||||
.additional_properties(Some(RefOr::Ref(
|
|
||||||
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
|
||||||
)))
|
|
||||||
.build(),
|
|
||||||
)))
|
|
||||||
.build(),
|
|
||||||
))
|
|
||||||
.build(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Represents a single message in a conversation
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
pub struct Message {
|
|
||||||
/// The message content
|
|
||||||
pub content: Option<MessageContent>,
|
|
||||||
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
|
||||||
pub role: String,
|
|
||||||
pub name: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stop token configuration for generation
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum StopTokens {
|
|
||||||
/// Multiple possible stop sequences
|
|
||||||
Multi(Vec<String>),
|
|
||||||
/// Single stop sequence
|
|
||||||
Single(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
pub fn default_false() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
pub fn default_1usize() -> usize {
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Default value helper
|
|
||||||
pub fn default_model() -> String {
|
|
||||||
"default".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion request following OpenAI's specification
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
|
||||||
pub messages: Vec<Message>,
|
|
||||||
#[schema(example = "gemma-3-1b-it")]
|
|
||||||
#[serde(default = "default_model")]
|
|
||||||
pub model: String,
|
|
||||||
#[serde(default = "default_false")]
|
|
||||||
#[schema(example = false)]
|
|
||||||
pub logprobs: bool,
|
|
||||||
#[schema(example = 256)]
|
|
||||||
pub max_tokens: Option<usize>,
|
|
||||||
#[serde(rename = "n")]
|
|
||||||
#[serde(default = "default_1usize")]
|
|
||||||
#[schema(example = 1)]
|
|
||||||
pub n_choices: usize,
|
|
||||||
#[schema(example = 0.7)]
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
#[schema(example = 0.9)]
|
|
||||||
pub top_p: Option<f64>,
|
|
||||||
#[schema(example = false)]
|
|
||||||
pub stream: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion choice
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionChoice {
|
|
||||||
pub index: usize,
|
|
||||||
pub message: Message,
|
|
||||||
pub finish_reason: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chat completion response
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatCompletionChoice>,
|
|
||||||
pub usage: Usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Token usage information
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: usize,
|
|
||||||
pub completion_tokens: usize,
|
|
||||||
pub total_tokens: usize,
|
|
||||||
}
|
|
@@ -1,128 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
extract::State,
|
|
||||||
http::StatusCode,
|
|
||||||
routing::{get, post},
|
|
||||||
Json, Router,
|
|
||||||
};
|
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, Message, MessageContent, Usage};
|
|
||||||
use crate::text_generation::TextGeneration;
|
|
||||||
use either::Either;
|
|
||||||
|
|
||||||
// Application state shared between handlers
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct AppState {
|
|
||||||
pub text_generation: Arc<Mutex<TextGeneration>>,
|
|
||||||
pub model_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chat completions endpoint handler
|
|
||||||
pub async fn chat_completions(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Json(request): Json<ChatCompletionRequest>,
|
|
||||||
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
||||||
let mut prompt = String::new();
|
|
||||||
|
|
||||||
// Convert messages to a prompt string
|
|
||||||
for message in &request.messages {
|
|
||||||
let role = &message.role;
|
|
||||||
let content = match &message.content {
|
|
||||||
Some(content) => match &content.0 {
|
|
||||||
Either::Left(text) => text.clone(),
|
|
||||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
|
||||||
},
|
|
||||||
None => "".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Format based on role
|
|
||||||
match role.as_str() {
|
|
||||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
|
||||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
|
||||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
|
||||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the assistant prefix for the response
|
|
||||||
prompt.push_str("Assistant: ");
|
|
||||||
|
|
||||||
// Capture the output
|
|
||||||
let mut output = Vec::new();
|
|
||||||
{
|
|
||||||
let mut text_gen = state.text_generation.lock().await;
|
|
||||||
|
|
||||||
// Buffer to capture the output
|
|
||||||
let mut buffer = Vec::new();
|
|
||||||
|
|
||||||
// Run text generation
|
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
|
||||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
|
||||||
|
|
||||||
if let Err(e) = result {
|
|
||||||
return Err((
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"error": {
|
|
||||||
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
|
|
||||||
"type": "unsupported_api"
|
|
||||||
}
|
|
||||||
})),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert buffer to string
|
|
||||||
if let Ok(text) = String::from_utf8(buffer) {
|
|
||||||
output.push(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create response
|
|
||||||
let response = ChatCompletionResponse {
|
|
||||||
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace("-", "")),
|
|
||||||
object: "chat.completion".to_string(),
|
|
||||||
created: std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs(),
|
|
||||||
model: request.model,
|
|
||||||
choices: vec![ChatCompletionChoice {
|
|
||||||
index: 0,
|
|
||||||
message: Message {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: Some(MessageContent(Either::Left(output.join("")))),
|
|
||||||
name: None,
|
|
||||||
},
|
|
||||||
finish_reason: "stop".to_string(),
|
|
||||||
}],
|
|
||||||
usage: Usage {
|
|
||||||
prompt_tokens: prompt.len() / 4, // Rough estimate
|
|
||||||
completion_tokens: output.join("").len() / 4, // Rough estimate
|
|
||||||
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
// Return the response as JSON
|
|
||||||
Ok(Json(response))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the router with the chat completions endpoint
|
|
||||||
pub fn create_router(app_state: AppState) -> Router {
|
|
||||||
// CORS layer to allow requests from any origin
|
|
||||||
let cors = CorsLayer::new()
|
|
||||||
.allow_headers(Any)
|
|
||||||
.allow_credentials(true)
|
|
||||||
.allow_origin(Any)
|
|
||||||
.allow_methods(Any)
|
|
||||||
.allow_headers(Any);
|
|
||||||
|
|
||||||
Router::new()
|
|
||||||
// OpenAI compatible endpoints
|
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
|
||||||
// Add more endpoints as needed
|
|
||||||
.layer(cors)
|
|
||||||
.with_state(app_state)
|
|
||||||
}
|
|
@@ -1,352 +0,0 @@
|
|||||||
use anyhow::{Error as E, Result};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use std::io::Write;
|
|
||||||
|
|
||||||
use crate::model::Model;
|
|
||||||
use crate::token_output_stream::TokenOutputStream;
|
|
||||||
|
|
||||||
pub struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
// CPU device for fallback when operations are unsupported on primary device
|
|
||||||
cpu_device: Option<Device>,
|
|
||||||
// Flag to indicate if we should try to use the primary device first
|
|
||||||
try_primary_device: bool,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
|
|
||||||
// Initialize CPU device only if the primary device is not already CPU
|
|
||||||
let (cpu_device, try_primary_device) = if device.is_cpu() {
|
|
||||||
// If already on CPU, no need for a fallback device
|
|
||||||
(None, false)
|
|
||||||
} else {
|
|
||||||
// Store CPU device for fallback and set flag to try primary device first
|
|
||||||
(Some(Device::Cpu), true)
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
cpu_device,
|
|
||||||
try_primary_device,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper method for model execution with fallback to CPU for unsupported operations
|
|
||||||
fn execute_with_fallback(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
|
|
||||||
// If we're not trying primary device anymore, go straight to CPU if available
|
|
||||||
if !self.try_primary_device {
|
|
||||||
if let Some(cpu_device) = &self.cpu_device {
|
|
||||||
let cpu_input = input.to_device(cpu_device).map_err(E::msg)?;
|
|
||||||
let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?;
|
|
||||||
return cpu_result.to_device(&self.device).map_err(E::msg);
|
|
||||||
} else {
|
|
||||||
// No CPU fallback, use primary device
|
|
||||||
return self.model.forward(input, start_pos).map_err(E::msg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try running on the primary device first
|
|
||||||
match self.model.forward(input, start_pos) {
|
|
||||||
Ok(result) => Ok(result),
|
|
||||||
Err(err) => {
|
|
||||||
// Convert to string to check for unsupported operation
|
|
||||||
let err_string = err.to_string();
|
|
||||||
|
|
||||||
// Check if the error is about unsupported operations
|
|
||||||
if (err_string.contains("no metal implementation for") ||
|
|
||||||
err_string.contains("no cuda implementation for")) &&
|
|
||||||
self.cpu_device.is_some() {
|
|
||||||
|
|
||||||
// Extract operation name for better logging
|
|
||||||
let op_name = if let Some(idx) = err_string.find("for ") {
|
|
||||||
&err_string[(idx + 4)..]
|
|
||||||
} else {
|
|
||||||
"an operation"
|
|
||||||
};
|
|
||||||
|
|
||||||
// Log the fallback
|
|
||||||
println!("Warning: The primary device does not support {}. Falling back to CPU.", op_name);
|
|
||||||
|
|
||||||
// Move input to CPU and try again
|
|
||||||
let cpu_device = self.cpu_device.as_ref().unwrap();
|
|
||||||
let cpu_input = input.to_device(cpu_device).map_err(E::msg)?;
|
|
||||||
let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?;
|
|
||||||
|
|
||||||
// Don't try primary device for future operations
|
|
||||||
self.try_primary_device = false;
|
|
||||||
println!("Successfully executed on CPU. Will use CPU for subsequent operations.");
|
|
||||||
|
|
||||||
// Move result back to original device
|
|
||||||
cpu_result.to_device(&self.device).map_err(E::msg)
|
|
||||||
} else {
|
|
||||||
// Not an unsupported operation error or no CPU fallback
|
|
||||||
Err(E::msg(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run text generation and print to stdout
|
|
||||||
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => {
|
|
||||||
println!(
|
|
||||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
|
||||||
);
|
|
||||||
eos_token
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
// Use execute_with_fallback instead of model.forward
|
|
||||||
let logits = self.execute_with_fallback(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run text generation and write to a buffer
|
|
||||||
pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
|
|
||||||
// Write prompt tokens to output
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => {
|
|
||||||
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
|
||||||
eos_token
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Determine if we're using a Model3 (gemma-3) variant
|
|
||||||
let is_model3 = match &self.model {
|
|
||||||
Model::V3(_) => true,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
|
|
||||||
// For Model3, we need to use a different approach
|
|
||||||
if is_model3 {
|
|
||||||
// For gemma-3 models, we'll generate one token at a time with the full context
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
|
|
||||||
// Initial generation with the full prompt
|
|
||||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
|
||||||
// Use execute_with_fallback instead of model.forward
|
|
||||||
let mut logits = self.execute_with_fallback(&input, 0)?;
|
|
||||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
|
|
||||||
for _ in 0..sample_len {
|
|
||||||
// Apply repeat penalty if needed
|
|
||||||
let current_logits = if self.repeat_penalty == 1. {
|
|
||||||
logits.clone()
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For the next iteration, just use the new token
|
|
||||||
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
|
||||||
// Use execute_with_fallback instead of model.forward
|
|
||||||
logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?;
|
|
||||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Standard approach for other models
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
// Use execute_with_fallback instead of model.forward
|
|
||||||
let logits = self.execute_with_fallback(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
|
|
||||||
// Manual implementation of repeat penalty to avoid type conflicts
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in &tokens[start_at..] {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
new_logits.reshape(shape)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
write!(output, "{}", t)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write any remaining tokens
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
write!(output, "{}", rest)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,86 +0,0 @@
|
|||||||
use candle_core::Result;
|
|
||||||
|
|
||||||
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
|
|
||||||
/// streaming way rather than having to wait for the full decoding.
|
|
||||||
pub struct TokenOutputStream {
|
|
||||||
tokenizer: tokenizers::Tokenizer,
|
|
||||||
tokens: Vec<u32>,
|
|
||||||
prev_index: usize,
|
|
||||||
current_index: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TokenOutputStream {
|
|
||||||
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
|
|
||||||
Self {
|
|
||||||
tokenizer,
|
|
||||||
tokens: Vec::new(),
|
|
||||||
prev_index: 0,
|
|
||||||
current_index: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_inner(self) -> tokenizers::Tokenizer {
|
|
||||||
self.tokenizer
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode(&self, tokens: &[u32]) -> Result<String> {
|
|
||||||
match self.tokenizer.decode(tokens, true) {
|
|
||||||
Ok(str) => Ok(str),
|
|
||||||
Err(err) => candle_core::bail!("cannot decode: {err}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
|
|
||||||
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
|
|
||||||
let prev_text = if self.tokens.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
|
||||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
|
||||||
self.decode(tokens)?
|
|
||||||
};
|
|
||||||
self.tokens.push(token);
|
|
||||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
|
||||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
|
||||||
let text = text.split_at(prev_text.len());
|
|
||||||
self.prev_index = self.current_index;
|
|
||||||
self.current_index = self.tokens.len();
|
|
||||||
Ok(Some(text.1.to_string()))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decode_rest(&self) -> Result<Option<String>> {
|
|
||||||
let prev_text = if self.tokens.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
|
||||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
|
||||||
self.decode(tokens)?
|
|
||||||
};
|
|
||||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
|
||||||
if text.len() > prev_text.len() {
|
|
||||||
let text = text.split_at(prev_text.len());
|
|
||||||
Ok(Some(text.1.to_string()))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decode_all(&self) -> Result<String> {
|
|
||||||
self.decode(&self.tokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_token(&self, token_s: &str) -> Option<u32> {
|
|
||||||
self.tokenizer.get_vocab(true).get(token_s).copied()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
|
|
||||||
&self.tokenizer
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn clear(&mut self) {
|
|
||||||
self.tokens.clear();
|
|
||||||
self.prev_index = 0;
|
|
||||||
self.current_index = 0;
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,167 +0,0 @@
|
|||||||
use candle_core::utils::{cuda_is_available, metal_is_available};
|
|
||||||
use candle_core::{Device, Result, Tensor};
|
|
||||||
|
|
||||||
pub fn device(cpu: bool) -> Result<Device> {
|
|
||||||
if cpu {
|
|
||||||
Ok(Device::Cpu)
|
|
||||||
} else if cuda_is_available() {
|
|
||||||
Ok(Device::new_cuda(0)?)
|
|
||||||
} else if metal_is_available() {
|
|
||||||
Ok(Device::new_metal(0)?)
|
|
||||||
} else {
|
|
||||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
|
||||||
{
|
|
||||||
println!(
|
|
||||||
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
|
||||||
{
|
|
||||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
|
||||||
}
|
|
||||||
Ok(Device::Cpu)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
|
||||||
p: P,
|
|
||||||
resize_longest: Option<usize>,
|
|
||||||
) -> Result<(Tensor, usize, usize)> {
|
|
||||||
let img = image::ImageReader::open(p)?
|
|
||||||
.decode()
|
|
||||||
.map_err(candle_core::Error::wrap)?;
|
|
||||||
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
|
||||||
let img = match resize_longest {
|
|
||||||
None => img,
|
|
||||||
Some(resize_longest) => {
|
|
||||||
let (height, width) = (img.height(), img.width());
|
|
||||||
let resize_longest = resize_longest as u32;
|
|
||||||
let (height, width) = if height < width {
|
|
||||||
let h = (resize_longest * height) / width;
|
|
||||||
(h, resize_longest)
|
|
||||||
} else {
|
|
||||||
let w = (resize_longest * width) / height;
|
|
||||||
(resize_longest, w)
|
|
||||||
};
|
|
||||||
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
|
||||||
let img = img.to_rgb8();
|
|
||||||
let data = img.into_raw();
|
|
||||||
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
|
||||||
Ok((data, initial_h, initial_w))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
|
||||||
p: P,
|
|
||||||
width: usize,
|
|
||||||
height: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let img = image::ImageReader::open(p)?
|
|
||||||
.decode()
|
|
||||||
.map_err(candle_core::Error::wrap)?
|
|
||||||
.resize_to_fill(
|
|
||||||
width as u32,
|
|
||||||
height as u32,
|
|
||||||
image::imageops::FilterType::Triangle,
|
|
||||||
);
|
|
||||||
let img = img.to_rgb8();
|
|
||||||
let data = img.into_raw();
|
|
||||||
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
|
||||||
/// (c, height, width).
|
|
||||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
|
||||||
let p = p.as_ref();
|
|
||||||
let (channel, height, width) = img.dims3()?;
|
|
||||||
if channel != 3 {
|
|
||||||
candle_core::bail!("save_image expects an input of shape (3, height, width)")
|
|
||||||
}
|
|
||||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
|
||||||
let pixels = img.to_vec1::<u8>()?;
|
|
||||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
|
||||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
|
||||||
Some(image) => image,
|
|
||||||
None => candle_core::bail!("error saving image {p:?}"),
|
|
||||||
};
|
|
||||||
image.save(p).map_err(candle_core::Error::wrap)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save_image_resize<P: AsRef<std::path::Path>>(
|
|
||||||
img: &Tensor,
|
|
||||||
p: P,
|
|
||||||
h: usize,
|
|
||||||
w: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
let p = p.as_ref();
|
|
||||||
let (channel, height, width) = img.dims3()?;
|
|
||||||
if channel != 3 {
|
|
||||||
candle_core::bail!("save_image expects an input of shape (3, height, width)")
|
|
||||||
}
|
|
||||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
|
||||||
let pixels = img.to_vec1::<u8>()?;
|
|
||||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
|
||||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
|
||||||
Some(image) => image,
|
|
||||||
None => candle_core::bail!("error saving image {p:?}"),
|
|
||||||
};
|
|
||||||
let image = image::DynamicImage::from(image);
|
|
||||||
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
|
|
||||||
image.save(p).map_err(candle_core::Error::wrap)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Loads the safetensors files for a model from the hub based on a json index file.
|
|
||||||
pub fn hub_load_safetensors(
|
|
||||||
repo: &hf_hub::api::sync::ApiRepo,
|
|
||||||
json_file: &str,
|
|
||||||
) -> Result<Vec<std::path::PathBuf>> {
|
|
||||||
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
|
|
||||||
let json_file = std::fs::File::open(json_file)?;
|
|
||||||
let json: serde_json::Value =
|
|
||||||
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
|
|
||||||
let weight_map = match json.get("weight_map") {
|
|
||||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
|
||||||
Some(serde_json::Value::Object(map)) => map,
|
|
||||||
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
|
|
||||||
};
|
|
||||||
let mut safetensors_files = std::collections::HashSet::new();
|
|
||||||
for value in weight_map.values() {
|
|
||||||
if let Some(file) = value.as_str() {
|
|
||||||
safetensors_files.insert(file.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let safetensors_files = safetensors_files
|
|
||||||
.iter()
|
|
||||||
.map(|v| repo.get(v).map_err(candle_core::Error::wrap))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Ok(safetensors_files)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
|
||||||
path: P,
|
|
||||||
json_file: &str,
|
|
||||||
) -> Result<Vec<std::path::PathBuf>> {
|
|
||||||
let path = path.as_ref();
|
|
||||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
|
||||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
|
||||||
let weight_map = match json.get("weight_map") {
|
|
||||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
|
||||||
Some(serde_json::Value::Object(map)) => map,
|
|
||||||
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
|
|
||||||
};
|
|
||||||
let mut safetensors_files = std::collections::HashSet::new();
|
|
||||||
for value in weight_map.values() {
|
|
||||||
if let Some(file) = value.as_str() {
|
|
||||||
safetensors_files.insert(file);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let safetensors_files: Vec<_> = safetensors_files
|
|
||||||
.into_iter()
|
|
||||||
.map(|v| path.join(v))
|
|
||||||
.collect();
|
|
||||||
Ok(safetensors_files)
|
|
||||||
}
|
|
@@ -1,3 +0,0 @@
|
|||||||
#!/usr/bin/env sh
|
|
||||||
|
|
||||||
cargo run -p legacy-inference-engine --release -- --prompt 'Name the 16th President of the USA.' --which 3-1b-it
|
|
@@ -1,67 +0,0 @@
|
|||||||
use legacy_inference_engine::model::{Model, Which};
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_which_to_model_id() {
|
|
||||||
// Test a few representative model variants
|
|
||||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
|
||||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
|
||||||
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
|
|
||||||
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
|
|
||||||
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
|
|
||||||
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_which_is_instruct_model() {
|
|
||||||
// Test base models (should return false)
|
|
||||||
assert!(!Which::Base2B.is_instruct_model());
|
|
||||||
assert!(!Which::Base7B.is_instruct_model());
|
|
||||||
assert!(!Which::CodeBase2B.is_instruct_model());
|
|
||||||
assert!(!Which::CodeBase7B.is_instruct_model());
|
|
||||||
assert!(!Which::BaseV2_2B.is_instruct_model());
|
|
||||||
assert!(!Which::BaseV2_9B.is_instruct_model());
|
|
||||||
assert!(!Which::BaseV3_1B.is_instruct_model());
|
|
||||||
|
|
||||||
// Test instruct models (should return true)
|
|
||||||
assert!(Which::Instruct2B.is_instruct_model());
|
|
||||||
assert!(Which::Instruct7B.is_instruct_model());
|
|
||||||
assert!(Which::InstructV1_1_2B.is_instruct_model());
|
|
||||||
assert!(Which::InstructV1_1_7B.is_instruct_model());
|
|
||||||
assert!(Which::CodeInstruct2B.is_instruct_model());
|
|
||||||
assert!(Which::CodeInstruct7B.is_instruct_model());
|
|
||||||
assert!(Which::InstructV2_2B.is_instruct_model());
|
|
||||||
assert!(Which::InstructV2_9B.is_instruct_model());
|
|
||||||
assert!(Which::InstructV3_1B.is_instruct_model());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_which_is_v3_model() {
|
|
||||||
// Test non-v3 models (should return false)
|
|
||||||
assert!(!Which::Base2B.is_v3_model());
|
|
||||||
assert!(!Which::Base7B.is_v3_model());
|
|
||||||
assert!(!Which::Instruct2B.is_v3_model());
|
|
||||||
assert!(!Which::Instruct7B.is_v3_model());
|
|
||||||
assert!(!Which::InstructV1_1_2B.is_v3_model());
|
|
||||||
assert!(!Which::InstructV1_1_7B.is_v3_model());
|
|
||||||
assert!(!Which::CodeBase2B.is_v3_model());
|
|
||||||
assert!(!Which::CodeBase7B.is_v3_model());
|
|
||||||
assert!(!Which::CodeInstruct2B.is_v3_model());
|
|
||||||
assert!(!Which::CodeInstruct7B.is_v3_model());
|
|
||||||
assert!(!Which::BaseV2_2B.is_v3_model());
|
|
||||||
assert!(!Which::InstructV2_2B.is_v3_model());
|
|
||||||
assert!(!Which::BaseV2_9B.is_v3_model());
|
|
||||||
assert!(!Which::InstructV2_9B.is_v3_model());
|
|
||||||
|
|
||||||
// Test v3 models (should return true)
|
|
||||||
assert!(Which::BaseV3_1B.is_v3_model());
|
|
||||||
assert!(Which::InstructV3_1B.is_v3_model());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: Testing the Model enum's forward method would require creating actual model instances,
|
|
||||||
// which is complex and would require loading model weights. This is better suited for
|
|
||||||
// integration tests or mocking the models.
|
|
||||||
}
|
|
@@ -1,101 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use legacy_inference_engine::model::Which;
|
|
||||||
use legacy_inference_engine::token_output_stream::TokenOutputStream;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
// Helper function to create a simple tokenizer for testing
|
|
||||||
fn create_test_tokenizer() -> Result<Tokenizer> {
|
|
||||||
// Create a simple tokenizer from the pretrained model
|
|
||||||
// This uses the tokenizer from the Hugging Face hub
|
|
||||||
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
|
||||||
Ok(tokenizer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the Which enum's to_model_id method
|
|
||||||
#[test]
|
|
||||||
fn test_which_model_id() {
|
|
||||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
|
||||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the Which enum's is_instruct_model method
|
|
||||||
#[test]
|
|
||||||
fn test_which_is_instruct() {
|
|
||||||
assert!(!Which::Base2B.is_instruct_model());
|
|
||||||
assert!(Which::Instruct7B.is_instruct_model());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the Which enum's is_v3_model method
|
|
||||||
#[test]
|
|
||||||
fn test_which_is_v3() {
|
|
||||||
assert!(!Which::Base2B.is_v3_model());
|
|
||||||
assert!(Which::BaseV3_1B.is_v3_model());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the TokenOutputStream functionality
|
|
||||||
#[test]
|
|
||||||
fn test_token_output_stream() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Test encoding and decoding
|
|
||||||
let text = "Hello, world!";
|
|
||||||
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
|
|
||||||
let token_ids = encoded.get_ids();
|
|
||||||
|
|
||||||
// Add tokens one by one
|
|
||||||
for &token_id in token_ids {
|
|
||||||
token_stream.next_token(token_id)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode all and check
|
|
||||||
let decoded = token_stream.decode_all()?;
|
|
||||||
assert_eq!(decoded.trim(), text);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the LogitsProcessor
|
|
||||||
#[test]
|
|
||||||
fn test_logits_processor() -> Result<()> {
|
|
||||||
// Create a LogitsProcessor with default settings
|
|
||||||
let seed = 42;
|
|
||||||
let temp = Some(0.8);
|
|
||||||
let top_p = Some(0.9);
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
|
|
||||||
// Create a simple logits tensor
|
|
||||||
// In a real test, we would create a tensor with known values and verify
|
|
||||||
// that sampling produces expected results
|
|
||||||
|
|
||||||
// For now, we'll just verify that the LogitsProcessor can be created
|
|
||||||
assert!(true);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the TextGeneration constructor
|
|
||||||
#[test]
|
|
||||||
fn test_text_generation_constructor() -> Result<()> {
|
|
||||||
// We can't easily create a Model instance for testing,
|
|
||||||
// but we can test that the constructor compiles and the types are correct
|
|
||||||
|
|
||||||
// In a real test with a mock Model, we would:
|
|
||||||
// 1. Create a mock model
|
|
||||||
// 2. Create a tokenizer
|
|
||||||
// 3. Call TextGeneration::new
|
|
||||||
// 4. Verify the properties of the created instance
|
|
||||||
|
|
||||||
// For now, we'll just verify that the code compiles
|
|
||||||
assert!(true);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: Testing the actual text generation functionality would require
|
|
||||||
// integration tests with real models, which is beyond the scope of these unit tests.
|
|
||||||
// The tests above focus on the components that can be tested in isolation.
|
|
||||||
}
|
|
@@ -1,129 +0,0 @@
|
|||||||
use legacy_inference_engine::token_output_stream::TokenOutputStream;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use anyhow::Result;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
// Helper function to create a simple tokenizer for testing
|
|
||||||
fn create_test_tokenizer() -> Result<Tokenizer> {
|
|
||||||
// Create a simple tokenizer from the pretrained model
|
|
||||||
// This uses the tokenizer from the Hugging Face hub
|
|
||||||
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
|
||||||
Ok(tokenizer)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_new_token_output_stream() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Check that the token stream was created successfully
|
|
||||||
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_clear() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Add a token
|
|
||||||
let token_id = token_stream.get_token("<eos>").unwrap();
|
|
||||||
token_stream.next_token(token_id)?;
|
|
||||||
|
|
||||||
// Clear the stream
|
|
||||||
token_stream.clear();
|
|
||||||
|
|
||||||
// Check that the stream is empty by trying to decode all
|
|
||||||
let decoded = token_stream.decode_all()?;
|
|
||||||
assert_eq!(decoded, "");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_get_token() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Get a token that should exist
|
|
||||||
let eos_token = token_stream.get_token("<eos>");
|
|
||||||
assert!(eos_token.is_some());
|
|
||||||
|
|
||||||
// Get a token that shouldn't exist
|
|
||||||
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
|
|
||||||
assert!(nonexistent_token.is_none());
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_token_and_decode() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Get some tokens
|
|
||||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
|
||||||
let token_ids = hello_tokens.get_ids();
|
|
||||||
|
|
||||||
// Add tokens one by one
|
|
||||||
let mut output = String::new();
|
|
||||||
for &token_id in token_ids {
|
|
||||||
if let Some(text) = token_stream.next_token(token_id)? {
|
|
||||||
output.push_str(&text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get any remaining text
|
|
||||||
if let Some(rest) = token_stream.decode_rest()? {
|
|
||||||
output.push_str(&rest);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the output
|
|
||||||
assert!(!output.is_empty());
|
|
||||||
assert_eq!(output.trim(), "Hello world");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_decode_all() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Get some tokens
|
|
||||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
|
||||||
let token_ids = hello_tokens.get_ids();
|
|
||||||
|
|
||||||
// Add tokens one by one
|
|
||||||
for &token_id in token_ids {
|
|
||||||
token_stream.next_token(token_id)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode all
|
|
||||||
let decoded = token_stream.decode_all()?;
|
|
||||||
|
|
||||||
// Check the output
|
|
||||||
assert_eq!(decoded.trim(), "Hello world");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_into_inner() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Get the inner tokenizer
|
|
||||||
let inner_tokenizer = token_stream.into_inner();
|
|
||||||
|
|
||||||
// Check that the inner tokenizer works
|
|
||||||
let encoded = inner_tokenizer.encode("Test", true).unwrap();
|
|
||||||
assert!(encoded.get_ids().len() > 0);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
Reference in New Issue
Block a user