mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -18,6 +18,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", features = ["
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
||||
[target.'cfg(not(target_os = "macos"))'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
|
@@ -5,4 +5,3 @@ pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
|
||||
// Re-export constants and types that might be needed
|
||||
pub const EOS_TOKEN: &str = "</s>";
|
||||
|
||||
|
@@ -1,14 +1,14 @@
|
||||
use crate::EOS_TOKEN;
|
||||
use anyhow::{bail, Error as E};
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||
use candle_transformers::models::llama as model;
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||
use clap::ValueEnum;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use std::sync::mpsc::{self, Receiver};
|
||||
use clap::ValueEnum;
|
||||
use crate::{EOS_TOKEN};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
pub enum WhichModel {
|
||||
@@ -81,8 +81,8 @@ impl Default for LlamaInferenceConfig {
|
||||
max_tokens: 512,
|
||||
|
||||
// Performance flags
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: true, // great speed boost if supported
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: true, // great speed boost if supported
|
||||
|
||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||
dtype: Some("bf16".to_string()),
|
||||
@@ -98,8 +98,6 @@ impl Default for LlamaInferenceConfig {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
@@ -112,7 +110,6 @@ fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn hub_load_safetensors(
|
||||
api: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
@@ -171,7 +168,7 @@ pub fn run_llama_inference(
|
||||
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
.to_string()
|
||||
.to_string()
|
||||
});
|
||||
println!("Loading model: {}", model_id);
|
||||
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||
@@ -334,4 +331,3 @@ pub fn run_llama_inference(
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
|
@@ -88,7 +88,6 @@ impl Into<LlamaInferenceConfig> for Args {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn run_cli() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let cfg = args.into();
|
||||
@@ -106,4 +105,4 @@ pub fn run_cli() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -2,8 +2,8 @@
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod llama_cli;
|
||||
mod llama_api;
|
||||
mod llama_cli;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
@@ -14,7 +14,6 @@ use crate::llama_cli::run_cli;
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
|
||||
|
||||
fn main() -> Result<()> {
|
||||
run_cli()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user