cleanup, add ci

This commit is contained in:
geoffsee
2025-08-31 10:31:07 -04:00
parent 419e1c2ea7
commit f5d2a85f2e
42 changed files with 1740 additions and 705 deletions

View File

@@ -1,14 +1,14 @@
use crate::EOS_TOKEN;
use anyhow::{bail, Error as E};
use candle_core::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::{Llama, LlamaConfig};
use candle_transformers::models::llama as model;
use candle_transformers::models::llama::{Llama, LlamaConfig};
use clap::ValueEnum;
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use std::sync::mpsc::{self, Receiver};
use clap::ValueEnum;
use crate::{EOS_TOKEN};
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum WhichModel {
@@ -81,8 +81,8 @@ impl Default for LlamaInferenceConfig {
max_tokens: 512,
// Performance flags
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: true, // great speed boost if supported
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: true, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
dtype: Some("bf16".to_string()),
@@ -98,8 +98,6 @@ impl Default for LlamaInferenceConfig {
}
}
fn device(cpu: bool) -> anyhow::Result<Device> {
if cpu {
Ok(Device::Cpu)
@@ -112,7 +110,6 @@ fn device(cpu: bool) -> anyhow::Result<Device> {
}
}
fn hub_load_safetensors(
api: &hf_hub::api::sync::ApiRepo,
json_file: &str,
@@ -171,7 +168,7 @@ pub fn run_llama_inference(
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
.to_string()
.to_string()
});
println!("Loading model: {}", model_id);
let revision = cfg.revision.clone().unwrap_or("main".to_string());
@@ -334,4 +331,3 @@ pub fn run_llama_inference(
Ok(rx)
}