mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
fixes issue with model selection
This commit is contained in:
@@ -12,7 +12,7 @@ AI inference Server with OpenAI-compatible API (Limited Features)
|
||||
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
|
||||
|
||||
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
|
||||
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
||||
Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
||||
|
||||
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
|
||||
|
||||
|
4
bun.lock
4
bun.lock
@@ -4,7 +4,7 @@
|
||||
"": {
|
||||
"name": "predict-otron-9000",
|
||||
},
|
||||
"crates/cli/package": {
|
||||
"integration/cli/package": {
|
||||
"name": "cli",
|
||||
"dependencies": {
|
||||
"install": "^0.13.0",
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
},
|
||||
"packages": {
|
||||
"cli": ["cli@workspace:crates/cli/package"],
|
||||
"cli": ["cli@workspace:integration/cli/package"],
|
||||
|
||||
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||
|
||||
|
@@ -365,7 +365,7 @@ fn ChatPage() -> impl IntoView {
|
||||
|
||||
// State for available models and selected model
|
||||
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
|
||||
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model
|
||||
let selected_model = RwSignal::new(String::from("")); // Default model
|
||||
|
||||
// State for streaming response
|
||||
let streaming_content = RwSignal::new(String::new());
|
||||
@@ -382,6 +382,7 @@ fn ChatPage() -> impl IntoView {
|
||||
match fetch_models().await {
|
||||
Ok(models) => {
|
||||
available_models.set(models);
|
||||
selected_model.set(String::from("gemma-3-1b-it"));
|
||||
}
|
||||
Err(error) => {
|
||||
console::log_1(&format!("Failed to fetch models: {}", error).into());
|
||||
|
@@ -7,6 +7,7 @@ use axum::{
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use std::convert::Infallible;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
@@ -20,7 +21,7 @@ use crate::openai_types::{
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use embeddings_engine::models_list;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
use serde_json::Value;
|
||||
// -------------------------
|
||||
@@ -35,12 +36,13 @@ pub enum ModelType {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub model_type: ModelType,
|
||||
pub model_type: Option<ModelType>,
|
||||
pub model_id: String,
|
||||
pub gemma_config: Option<GemmaInferenceConfig>,
|
||||
pub llama_config: Option<LlamaInferenceConfig>,
|
||||
}
|
||||
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
// Configure a default model to prevent 503 errors from the chat-ui
|
||||
@@ -48,12 +50,12 @@ impl Default for AppState {
|
||||
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||
model: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Self {
|
||||
model_type: ModelType::Gemma,
|
||||
model_type: None,
|
||||
model_id: default_model_id,
|
||||
gemma_config: Some(gemma_config),
|
||||
llama_config: None,
|
||||
@@ -84,7 +86,9 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
|
||||
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
|
||||
"gemma-3-1b" => Some(Which::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
|
||||
"llama-3.2-1b" => Some(Which::Llama32_1B),
|
||||
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
|
||||
"llama-3.2-3b" => Some(Which::Llama32_3B),
|
||||
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
|
||||
_ => None,
|
||||
}
|
||||
@@ -190,7 +194,21 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
// Get streaming receiver based on model type
|
||||
let rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let mut config = LlamaInferenceConfig::default();
|
||||
let llama_model = match which_model {
|
||||
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
@@ -201,14 +219,35 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
))?
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = if which_model.is_v3_model() {
|
||||
gemma_runner::WhichModel::InstructV3_1B
|
||||
} else {
|
||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||
let gemma_model = match which_model {
|
||||
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: gemma_model,
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
@@ -348,7 +387,21 @@ async fn handle_streaming_request(
|
||||
// Get streaming receiver based on model type
|
||||
let model_rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let mut config = LlamaInferenceConfig::default();
|
||||
let llama_model = match which_model {
|
||||
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
@@ -364,14 +417,35 @@ async fn handle_streaming_request(
|
||||
}
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = if which_model.is_v3_model() {
|
||||
gemma_runner::WhichModel::InstructV3_1B
|
||||
} else {
|
||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||
let gemma_model = match which_model {
|
||||
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: gemma_model,
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
|
@@ -1,13 +1,8 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use clap::ValueEnum;
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
@@ -21,8 +16,10 @@ use std::thread;
|
||||
use tokenizers::Tokenizer;
|
||||
use utils::hub_load_safetensors;
|
||||
use utils::token_output_stream::TokenOutputStream;
|
||||
use std::str::FromStr;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "gemma-2b")]
|
||||
Base2B,
|
||||
@@ -58,6 +55,56 @@ pub enum WhichModel {
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
impl FromStr for WhichModel {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"gemma-2b" => Ok(Self::Base2B),
|
||||
"gemma-7b" => Ok(Self::Base7B),
|
||||
"gemma-2b-it" => Ok(Self::Instruct2B),
|
||||
"gemma-7b-it" => Ok(Self::Instruct7B),
|
||||
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
|
||||
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
|
||||
"codegemma-2b" => Ok(Self::CodeBase2B),
|
||||
"codegemma-7b" => Ok(Self::CodeBase7B),
|
||||
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
|
||||
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
|
||||
"gemma-2-2b" => Ok(Self::BaseV2_2B),
|
||||
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
|
||||
"gemma-2-9b" => Ok(Self::BaseV2_9B),
|
||||
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
|
||||
"gemma-3-1b" => Ok(Self::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
|
||||
_ => Err(format!("Unknown model: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for WhichModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let name = match self {
|
||||
Self::Base2B => "gemma-2b",
|
||||
Self::Base7B => "gemma-7b",
|
||||
Self::Instruct2B => "gemma-2b-it",
|
||||
Self::Instruct7B => "gemma-7b-it",
|
||||
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Self::CodeBase2B => "codegemma-2b",
|
||||
Self::CodeBase7B => "codegemma-7b",
|
||||
Self::CodeInstruct2B => "codegemma-2b-it",
|
||||
Self::CodeInstruct7B => "codegemma-7b-it",
|
||||
Self::BaseV2_2B => "gemma-2-2b",
|
||||
Self::InstructV2_2B => "gemma-2-2b-it",
|
||||
Self::BaseV2_9B => "gemma-2-9b",
|
||||
Self::InstructV2_9B => "gemma-2-9b-it",
|
||||
Self::BaseV3_1B => "gemma-3-1b",
|
||||
Self::InstructV3_1B => "gemma-3-1b-it",
|
||||
};
|
||||
write!(f, "{}", name)
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
@@ -145,7 +192,7 @@ impl TextGeneration {
|
||||
// Make sure stdout isn't holding anything (if caller also prints).
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let mut _generated_tokens = 0usize;
|
||||
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
@@ -183,7 +230,7 @@ impl TextGeneration {
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
_generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
@@ -210,7 +257,7 @@ impl TextGeneration {
|
||||
pub struct GemmaInferenceConfig {
|
||||
pub tracing: bool,
|
||||
pub prompt: String,
|
||||
pub model: WhichModel,
|
||||
pub model: Option<WhichModel>,
|
||||
pub cpu: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
@@ -229,7 +276,7 @@ impl Default for GemmaInferenceConfig {
|
||||
Self {
|
||||
tracing: false,
|
||||
prompt: "Hello".to_string(),
|
||||
model: WhichModel::InstructV2_2B,
|
||||
model: Some(WhichModel::InstructV2_2B),
|
||||
cpu: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
@@ -286,28 +333,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
}
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
println!("Raw model string: {:?}", cfg.model_id);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = cfg.model_id.unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Base2B => "google/gemma-2b",
|
||||
WhichModel::Base7B => "google/gemma-7b",
|
||||
WhichModel::Instruct2B => "google/gemma-2b-it",
|
||||
WhichModel::Instruct7B => "google/gemma-7b-it",
|
||||
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
|
||||
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
|
||||
WhichModel::CodeBase2B => "google/codegemma-2b",
|
||||
WhichModel::CodeBase7B => "google/codegemma-7b",
|
||||
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
|
||||
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
|
||||
WhichModel::BaseV2_2B => "google/gemma-2-2b",
|
||||
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
|
||||
WhichModel::BaseV2_9B => "google/gemma-2-9b",
|
||||
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
|
||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||
Some(WhichModel::Base2B) => "google/gemma-2b",
|
||||
Some(WhichModel::Base7B) => "google/gemma-7b",
|
||||
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
|
||||
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
|
||||
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
|
||||
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
|
||||
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
|
||||
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
|
||||
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
|
||||
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
|
||||
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
|
||||
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
|
||||
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
|
||||
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
|
||||
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
|
||||
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
|
||||
None => "google/gemma-2-2b-it", // default fallback
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
@@ -318,7 +367,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?],
|
||||
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("Retrieved files in {:?}", start.elapsed());
|
||||
@@ -329,29 +378,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
|
||||
let model: Model = match cfg.model {
|
||||
WhichModel::Base2B
|
||||
| WhichModel::Base7B
|
||||
| WhichModel::Instruct2B
|
||||
| WhichModel::Instruct7B
|
||||
| WhichModel::InstructV1_1_2B
|
||||
| WhichModel::InstructV1_1_7B
|
||||
| WhichModel::CodeBase2B
|
||||
| WhichModel::CodeBase7B
|
||||
| WhichModel::CodeInstruct2B
|
||||
| WhichModel::CodeInstruct7B => {
|
||||
Some(WhichModel::Base2B)
|
||||
| Some(WhichModel::Base7B)
|
||||
| Some(WhichModel::Instruct2B)
|
||||
| Some(WhichModel::Instruct7B)
|
||||
| Some(WhichModel::InstructV1_1_2B)
|
||||
| Some(WhichModel::InstructV1_1_7B)
|
||||
| Some(WhichModel::CodeBase2B)
|
||||
| Some(WhichModel::CodeBase7B)
|
||||
| Some(WhichModel::CodeInstruct2B)
|
||||
| Some(WhichModel::CodeInstruct7B) => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
WhichModel::BaseV2_2B
|
||||
| WhichModel::InstructV2_2B
|
||||
| WhichModel::BaseV2_9B
|
||||
| WhichModel::InstructV2_9B => {
|
||||
Some(WhichModel::BaseV2_2B)
|
||||
| Some(WhichModel::InstructV2_2B)
|
||||
| Some(WhichModel::BaseV2_9B)
|
||||
| Some(WhichModel::InstructV2_9B)
|
||||
| None => { // default to V2 model
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
@@ -371,7 +421,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
);
|
||||
|
||||
let prompt = match cfg.model {
|
||||
WhichModel::InstructV3_1B => {
|
||||
Some(WhichModel::InstructV3_1B) => {
|
||||
format!(
|
||||
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
||||
cfg.prompt
|
||||
|
@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl LlamaInferenceConfig {
|
||||
pub fn new(model: WhichModel) -> Self {
|
||||
Self {
|
||||
prompt: String::new(),
|
||||
model,
|
||||
cpu: false,
|
||||
temperature: 1.0,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
seed: 42,
|
||||
max_tokens: 512,
|
||||
no_kv_cache: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
revision: None,
|
||||
use_flash_attn: true,
|
||||
repeat_penalty: 1.1,
|
||||
repeat_last_n: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Default for LlamaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
Reference in New Issue
Block a user