mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
chat client only displays available models
This commit is contained in:
@@ -10,15 +10,15 @@ edition = "2021"
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
||||
hf-hub = "0.4"
|
||||
tokenizers = "0.21"
|
||||
tokenizers = "0.22.0"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive", "string"] }
|
||||
serde_json = "1.0"
|
||||
tracing = "0.1"
|
||||
tracing-chrome = "0.7"
|
||||
tracing-subscriber = "0.3"
|
||||
utils = {path = "../utils"}
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
@@ -10,16 +10,17 @@ use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use clap::ValueEnum;
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::thread;
|
||||
use tokenizers::Tokenizer;
|
||||
use utils::hub_load_safetensors;
|
||||
use utils::token_output_stream::TokenOutputStream;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
pub enum WhichModel {
|
||||
@@ -85,9 +86,9 @@ pub struct TextGeneration {
|
||||
fn device(cpu: bool) -> Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if utils::cuda_is_available() {
|
||||
} else if candle_core::utils::cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if utils::metal_is_available() {
|
||||
} else if candle_core::utils::metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
@@ -98,7 +99,7 @@ impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
@@ -262,10 +263,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
utils::with_avx(),
|
||||
utils::with_neon(),
|
||||
utils::with_simd128(),
|
||||
utils::with_f16c()
|
||||
candle_core::utils::with_avx(),
|
||||
candle_core::utils::with_neon(),
|
||||
candle_core::utils::with_simd128(),
|
||||
candle_core::utils::with_f16c()
|
||||
);
|
||||
|
||||
let device = device(cfg.cpu)?;
|
||||
@@ -318,7 +319,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("Retrieved files in {:?}", start.elapsed());
|
||||
|
||||
|
Reference in New Issue
Block a user