cleanup, add ci

This commit is contained in:
geoffsee
2025-08-31 10:31:07 -04:00
parent 419e1c2ea7
commit f5d2a85f2e
42 changed files with 1740 additions and 705 deletions

View File

@@ -18,6 +18,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", features = ["
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[target.'cfg(not(target_os = "macos"))'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]

View File

@@ -5,4 +5,3 @@ pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed
pub const EOS_TOKEN: &str = "</s>";

View File

@@ -1,14 +1,14 @@
use crate::EOS_TOKEN;
use anyhow::{bail, Error as E};
use candle_core::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::{Llama, LlamaConfig};
use candle_transformers::models::llama as model;
use candle_transformers::models::llama::{Llama, LlamaConfig};
use clap::ValueEnum;
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use std::sync::mpsc::{self, Receiver};
use clap::ValueEnum;
use crate::{EOS_TOKEN};
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum WhichModel {
@@ -81,8 +81,8 @@ impl Default for LlamaInferenceConfig {
max_tokens: 512,
// Performance flags
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: true, // great speed boost if supported
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: true, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
dtype: Some("bf16".to_string()),
@@ -98,8 +98,6 @@ impl Default for LlamaInferenceConfig {
}
}
fn device(cpu: bool) -> anyhow::Result<Device> {
if cpu {
Ok(Device::Cpu)
@@ -112,7 +110,6 @@ fn device(cpu: bool) -> anyhow::Result<Device> {
}
}
fn hub_load_safetensors(
api: &hf_hub::api::sync::ApiRepo,
json_file: &str,
@@ -171,7 +168,7 @@ pub fn run_llama_inference(
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
.to_string()
.to_string()
});
println!("Loading model: {}", model_id);
let revision = cfg.revision.clone().unwrap_or("main".to_string());
@@ -334,4 +331,3 @@ pub fn run_llama_inference(
Ok(rx)
}

View File

@@ -88,7 +88,6 @@ impl Into<LlamaInferenceConfig> for Args {
}
}
pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = args.into();
@@ -106,4 +105,4 @@ pub fn run_cli() -> anyhow::Result<()> {
}
}
Ok(())
}
}

View File

@@ -2,8 +2,8 @@
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod llama_cli;
mod llama_api;
mod llama_cli;
use anyhow::Result;
use clap::{Parser, ValueEnum};
@@ -14,7 +14,6 @@ use crate::llama_cli::run_cli;
const EOS_TOKEN: &str = "</s>";
fn main() -> Result<()> {
run_cli()
}
}