chat client only displays available models

This commit is contained in:
geoffsee
2025-09-01 22:29:54 -04:00
parent 545e0c9831
commit 2deecb5e51
20 changed files with 3314 additions and 484 deletions

View File

@@ -10,15 +10,15 @@ edition = "2021"
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
candle-examples = { git = "https://github.com/huggingface/candle.git" }
hf-hub = "0.4"
tokenizers = "0.21"
tokenizers = "0.22.0"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0"
tracing = "0.1"
tracing-chrome = "0.7"
tracing-subscriber = "0.3"
utils = {path = "../utils"}
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }

View File

@@ -10,16 +10,17 @@ use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API
use candle_core::{utils, DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use tokenizers::Tokenizer;
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
use tokenizers::Tokenizer;
use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
pub enum WhichModel {
@@ -85,9 +86,9 @@ pub struct TextGeneration {
fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if utils::cuda_is_available() {
} else if candle_core::utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if utils::metal_is_available() {
} else if candle_core::utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
@@ -98,7 +99,7 @@ impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
tokenizer: tokenizers::Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
@@ -262,10 +263,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
utils::with_avx(),
utils::with_neon(),
utils::with_simd128(),
utils::with_f16c()
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
let device = device(cfg.cpu)?;
@@ -318,7 +319,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let config_filename = repo.get("config.json")?;
let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("Retrieved files in {:?}", start.elapsed());