mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -3,16 +3,14 @@ name = "gemma-runner"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
hf-hub = "0.4"
|
||||
tokenizers = "0.21"
|
||||
anyhow = "1.0"
|
||||
@@ -22,6 +20,12 @@ tracing = "0.1"
|
||||
tracing-chrome = "0.7"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
|
@@ -4,10 +4,10 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::ValueEnum;
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use clap::ValueEnum;
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
@@ -119,7 +119,12 @@ impl TextGeneration {
|
||||
|
||||
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
||||
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
||||
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> {
|
||||
fn run_stream(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
sample_len: usize,
|
||||
tx: Sender<Result<String>>,
|
||||
) -> Result<()> {
|
||||
self.tokenizer.clear();
|
||||
|
||||
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
||||
@@ -303,7 +308,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||
}
|
||||
.to_string()
|
||||
.to_string()
|
||||
});
|
||||
|
||||
println!("Loading model: {}", &model_id);
|
||||
@@ -337,7 +342,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => {
|
||||
WhichModel::BaseV2_2B
|
||||
| WhichModel::InstructV2_2B
|
||||
| WhichModel::BaseV2_9B
|
||||
| WhichModel::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
use std::io::Write;
|
||||
use clap::Parser;
|
||||
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
||||
@@ -94,4 +94,4 @@ pub fn run_cli() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -2,8 +2,8 @@
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod gemma_cli;
|
||||
mod gemma_api;
|
||||
mod gemma_cli;
|
||||
|
||||
use anyhow::Error;
|
||||
use clap::{Parser, ValueEnum};
|
||||
@@ -14,4 +14,4 @@ use std::io::Write;
|
||||
/// just a placeholder, not used for anything
|
||||
fn main() -> std::result::Result<(), Error> {
|
||||
run_cli()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user