Upload main.rs
Browse files
main.rs
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#[cfg(feature = "mkl")]
|
2 |
+
extern crate intel_mkl_src;
|
3 |
+
|
4 |
+
#[cfg(feature = "accelerate")]
|
5 |
+
extern crate accelerate_src;
|
6 |
+
use std::io::Write;
|
7 |
+
use std::path::PathBuf;
|
8 |
+
|
9 |
+
use actix_web::{post, web, App, HttpResponse, HttpServer, Responder};
|
10 |
+
use serde::{Deserialize, Serialize};
|
11 |
+
|
12 |
+
use candle_transformers::models::quantized_t5 as t5;
|
13 |
+
|
14 |
+
use anyhow::{Error as E, Result};
|
15 |
+
use candle_core::{Device, Tensor};
|
16 |
+
use candle_transformers::generation::LogitsProcessor;
|
17 |
+
use clap::{Parser, ValueEnum};
|
18 |
+
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
|
19 |
+
use tokenizers::Tokenizer;
|
20 |
+
|
21 |
+
#[derive(Clone, Debug, Copy, ValueEnum)]
|
22 |
+
enum Which {
|
23 |
+
T5Small,
|
24 |
+
FlanT5Small,
|
25 |
+
FlanT5Base,
|
26 |
+
FlanT5Large,
|
27 |
+
FlanT5Xl,
|
28 |
+
FlanT5Xxl,
|
29 |
+
}
|
30 |
+
|
31 |
+
#[derive(Parser, Debug, Clone)]
|
32 |
+
#[command(author, version, about, long_about = None)]
|
33 |
+
|
34 |
+
struct Args {
|
35 |
+
/// Enable tracing (generates a trace-timestamp.json file).
|
36 |
+
#[arg(long)]
|
37 |
+
tracing: bool,
|
38 |
+
|
39 |
+
/// The model repository to use on the HuggingFace hub.
|
40 |
+
#[arg(long)]
|
41 |
+
model_id: Option<String>,
|
42 |
+
|
43 |
+
#[arg(long)]
|
44 |
+
revision: Option<String>,
|
45 |
+
|
46 |
+
#[arg(long)]
|
47 |
+
weight_file: Option<String>,
|
48 |
+
|
49 |
+
#[arg(long)]
|
50 |
+
config_file: Option<String>,
|
51 |
+
|
52 |
+
// Enable/disable decoding.
|
53 |
+
#[arg(long, default_value = "false")]
|
54 |
+
disable_cache: bool,
|
55 |
+
|
56 |
+
/// Use this prompt, otherwise compute sentence similarities.
|
57 |
+
// #[arg(long)]
|
58 |
+
// prompt: Option<String>,
|
59 |
+
|
60 |
+
/// The temperature used to generate samples.
|
61 |
+
#[arg(long, default_value_t = 0.8)]
|
62 |
+
temperature: f64,
|
63 |
+
|
64 |
+
/// Nucleus sampling probability cutoff.
|
65 |
+
#[arg(long)]
|
66 |
+
top_p: Option<f64>,
|
67 |
+
|
68 |
+
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
69 |
+
#[arg(long, default_value_t = 1.1)]
|
70 |
+
repeat_penalty: f32,
|
71 |
+
|
72 |
+
/// The context size to consider for the repeat penalty.
|
73 |
+
#[arg(long, default_value_t = 64)]
|
74 |
+
repeat_last_n: usize,
|
75 |
+
|
76 |
+
/// The model size to use.
|
77 |
+
#[arg(long, default_value = "flan-t5-xl")]
|
78 |
+
which: Which,
|
79 |
+
}
|
80 |
+
|
81 |
+
struct T5ModelBuilder {
|
82 |
+
device: Device,
|
83 |
+
config: t5::Config,
|
84 |
+
weights_filename: PathBuf,
|
85 |
+
}
|
86 |
+
|
87 |
+
impl T5ModelBuilder {
|
88 |
+
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
89 |
+
let device = Device::Cpu;
|
90 |
+
let default_model = "deepfile/flan-t5-xl-gguf".to_string();
|
91 |
+
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
92 |
+
(Some(model_id), Some(revision)) => (model_id, revision),
|
93 |
+
(Some(model_id), None) => (model_id, "main".to_string()),
|
94 |
+
(None, Some(revision)) => (default_model, revision),
|
95 |
+
(None, None) => (default_model, "main".to_string()),
|
96 |
+
};
|
97 |
+
|
98 |
+
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
99 |
+
let api = Api::new()?;
|
100 |
+
let api = api.repo(repo);
|
101 |
+
let config_filename = match &args.config_file {
|
102 |
+
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
103 |
+
None => match args.which {
|
104 |
+
Which::T5Small => api.get("config.json")?,
|
105 |
+
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
106 |
+
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
107 |
+
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
108 |
+
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
109 |
+
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
110 |
+
},
|
111 |
+
};
|
112 |
+
let tokenizer_filename = api.get("tokenizer.json")?;
|
113 |
+
let weights_filename = match &args.weight_file {
|
114 |
+
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
115 |
+
None => match args.which {
|
116 |
+
Which::T5Small => api.get("model.gguf")?,
|
117 |
+
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
118 |
+
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
119 |
+
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
120 |
+
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
121 |
+
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
122 |
+
},
|
123 |
+
};
|
124 |
+
|
125 |
+
let config = std::fs::read_to_string(config_filename)?;
|
126 |
+
let mut config: t5::Config = serde_json::from_str(&config)?;
|
127 |
+
config.use_cache = !args.disable_cache;
|
128 |
+
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
129 |
+
Ok((
|
130 |
+
Self {
|
131 |
+
device,
|
132 |
+
config,
|
133 |
+
weights_filename,
|
134 |
+
},
|
135 |
+
tokenizer,
|
136 |
+
))
|
137 |
+
}
|
138 |
+
|
139 |
+
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
140 |
+
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
141 |
+
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
142 |
+
}
|
143 |
+
|
144 |
+
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
|
145 |
+
let local_filename = std::path::PathBuf::from(filename);
|
146 |
+
if local_filename.exists() {
|
147 |
+
Ok(local_filename)
|
148 |
+
} else {
|
149 |
+
Ok(api.get(filename)?)
|
150 |
+
}
|
151 |
+
}
|
152 |
+
}
|
153 |
+
fn generate_answer(_prompt: String, args: &Args) -> Result<String> {
|
154 |
+
|
155 |
+
let mut generated_text = String::new();
|
156 |
+
|
157 |
+
let (_builder, mut _tokenizer) = T5ModelBuilder::load(&args)?;
|
158 |
+
let device = &_builder.device;
|
159 |
+
let _tokenizer = _tokenizer
|
160 |
+
.with_padding(None)
|
161 |
+
.with_truncation(None)
|
162 |
+
.map_err(E::msg)?;
|
163 |
+
let _tokens = _tokenizer
|
164 |
+
.encode(_prompt, true)
|
165 |
+
.map_err(E::msg)?
|
166 |
+
.get_ids()
|
167 |
+
.to_vec();
|
168 |
+
let input_token_ids = Tensor::new(&_tokens[..], device)?.unsqueeze(0)?;
|
169 |
+
let mut model = _builder.build_model()?;
|
170 |
+
let mut output_token_ids = [_builder.config.pad_token_id as u32].to_vec();
|
171 |
+
let temperature = 0.8f64;
|
172 |
+
|
173 |
+
let mut logits_processor = LogitsProcessor::new(299792458, Some(temperature), None);
|
174 |
+
let encoder_output = model.encode(&input_token_ids)?;
|
175 |
+
|
176 |
+
let start = std::time::Instant::now();
|
177 |
+
|
178 |
+
for index in 0.. {
|
179 |
+
|
180 |
+
if output_token_ids.len() > 512 {
|
181 |
+
break;
|
182 |
+
}
|
183 |
+
let decoder_token_ids = if index == 0 || !_builder.config.use_cache {
|
184 |
+
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
185 |
+
} else {
|
186 |
+
let last_token = *output_token_ids.last().unwrap();
|
187 |
+
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
188 |
+
};
|
189 |
+
let logits = model
|
190 |
+
.decode(&decoder_token_ids, &encoder_output)?
|
191 |
+
.squeeze(0)?;
|
192 |
+
let logits = if args.repeat_penalty == 1. {
|
193 |
+
logits
|
194 |
+
} else {
|
195 |
+
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
196 |
+
candle_transformers::utils::apply_repeat_penalty(
|
197 |
+
&logits,
|
198 |
+
args.repeat_penalty,
|
199 |
+
&output_token_ids[start_at..],
|
200 |
+
)?
|
201 |
+
};
|
202 |
+
|
203 |
+
let next_token_id = logits_processor.sample(&logits)?;
|
204 |
+
if next_token_id as usize == _builder.config.eos_token_id {
|
205 |
+
break;
|
206 |
+
}
|
207 |
+
output_token_ids.push(next_token_id);
|
208 |
+
if let Some(text) = _tokenizer.id_to_token(next_token_id) {
|
209 |
+
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
210 |
+
generated_text.push_str(&text);
|
211 |
+
print!("{}", text);
|
212 |
+
std::io::stdout().flush()?;
|
213 |
+
}
|
214 |
+
}
|
215 |
+
let dt = start.elapsed();
|
216 |
+
println!(
|
217 |
+
"\n{} tokens generated ({:.2} token/s)\n",
|
218 |
+
output_token_ids.len(),
|
219 |
+
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
220 |
+
);
|
221 |
+
|
222 |
+
Ok(generated_text)
|
223 |
+
}
|
224 |
+
|
225 |
+
// request struct
|
226 |
+
#[derive(Deserialize)]
|
227 |
+
struct Request {
|
228 |
+
prompt: String,
|
229 |
+
}
|
230 |
+
|
231 |
+
#[derive(Serialize)]
|
232 |
+
struct Response {
|
233 |
+
answer: String,
|
234 |
+
}
|
235 |
+
|
236 |
+
#[post("/generate")]
|
237 |
+
async fn generate(req_body: web::Json<Request>) -> impl Responder {
|
238 |
+
let args = Args::parse();
|
239 |
+
let generated_answer = generate_answer(req_body.prompt.clone(), &args);
|
240 |
+
HttpResponse::Ok().json(Response {
|
241 |
+
answer: generated_answer.unwrap(),
|
242 |
+
})
|
243 |
+
}
|
244 |
+
|
245 |
+
#[actix_web::main]
|
246 |
+
async fn main() -> std::io::Result<()> {
|
247 |
+
println!("Starting server at: http://localhost:7000");
|
248 |
+
HttpServer::new(|| App::new().service(generate))
|
249 |
+
.bind("localhost:7000")?
|
250 |
+
.run()
|
251 |
+
.await
|
252 |
+
}
|