Spaces:
Runtime error
Runtime error
use crate::health::Health; | |
/// HTTP Server logic | |
use crate::infer::{InferError, InferResponse, InferStreamResponse}; | |
use crate::validation::ValidationError; | |
use crate::{ | |
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, | |
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, | |
StreamDetails, StreamResponse, Token, Validation, | |
}; | |
use axum::extract::Extension; | |
use axum::http::{HeaderMap, Method, StatusCode}; | |
use axum::response::sse::{Event, KeepAlive, Sse}; | |
use axum::response::{IntoResponse, Response}; | |
use axum::routing::{get, post}; | |
use axum::{http, Json, Router}; | |
use axum_tracing_opentelemetry::opentelemetry_tracing_layer; | |
use futures::stream::StreamExt; | |
use futures::Stream; | |
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; | |
use std::convert::Infallible; | |
use std::net::SocketAddr; | |
use std::sync::atomic::AtomicBool; | |
use std::sync::Arc; | |
use text_generation_client::{ShardInfo, ShardedClient}; | |
use tokenizers::Tokenizer; | |
use tokio::signal; | |
use tokio::time::Instant; | |
use tower_http::cors::{AllowOrigin, CorsLayer}; | |
use tracing::{info_span, instrument, Instrument}; | |
use utoipa::OpenApi; | |
use utoipa_swagger_ui::SwaggerUi; | |
/// Generate tokens if `stream == false` or a stream of token if `stream == true` | |
async fn compat_generate( | |
default_return_full_text: Extension<bool>, | |
infer: Extension<Infer>, | |
req: Json<CompatGenerateRequest>, | |
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { | |
let mut req = req.0; | |
// default return_full_text given the pipeline_tag | |
if req.parameters.return_full_text.is_none() { | |
req.parameters.return_full_text = Some(default_return_full_text.0) | |
} | |
// switch on stream | |
if req.stream { | |
Ok(generate_stream(infer, Json(req.into())) | |
.await | |
.into_response()) | |
} else { | |
let (headers, generation) = generate(infer, Json(req.into())).await?; | |
// wrap generation inside a Vec to match api-inference | |
Ok((headers, Json(vec![generation.0])).into_response()) | |
} | |
} | |
/// Text Generation Inference endpoint info | |
async fn get_model_info(info: Extension<Info>) -> Json<Info> { | |
Json(info.0) | |
} | |
/// Health check method | |
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { | |
match health.check().await { | |
true => Ok(()), | |
false => Err(( | |
StatusCode::SERVICE_UNAVAILABLE, | |
Json(ErrorResponse { | |
error: "unhealthy".to_string(), | |
error_type: "healthcheck".to_string(), | |
}), | |
)), | |
} | |
} | |
/// Generate tokens | |
async fn generate( | |
infer: Extension<Infer>, | |
req: Json<GenerateRequest>, | |
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { | |
let span = tracing::Span::current(); | |
let start_time = Instant::now(); | |
metrics::increment_counter!("tgi_request_count"); | |
let compute_characters = req.0.inputs.chars().count(); | |
let mut add_prompt = None; | |
if req.0.parameters.return_full_text.unwrap_or(false) { | |
add_prompt = Some(req.0.inputs.clone()); | |
} | |
let details = req.0.parameters.details; | |
// Inference | |
let (response, best_of_responses) = match req.0.parameters.best_of { | |
Some(best_of) if best_of > 1 => { | |
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; | |
(response, Some(best_of_responses)) | |
} | |
_ => (infer.generate(req.0).await?, None), | |
}; | |
// Token details | |
let details = match details { | |
true => { | |
// convert best_of_responses | |
let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| { | |
responses | |
.into_iter() | |
.map(|response: InferResponse| { | |
// Add prompt if return_full_text | |
let mut output_text = response.generated_text.text; | |
if let Some(prompt) = &add_prompt { | |
output_text = prompt.clone() + &output_text; | |
} | |
BestOfSequence { | |
generated_text: output_text, | |
finish_reason: FinishReason::from( | |
response.generated_text.finish_reason, | |
), | |
generated_tokens: response.generated_text.generated_tokens, | |
prefill: response.prefill, | |
tokens: response.tokens, | |
seed: response.generated_text.seed, | |
} | |
}) | |
.collect() | |
}); | |
Some(Details { | |
finish_reason: FinishReason::from(response.generated_text.finish_reason), | |
generated_tokens: response.generated_text.generated_tokens, | |
prefill: response.prefill, | |
tokens: response.tokens, | |
seed: response.generated_text.seed, | |
best_of_sequences, | |
}) | |
} | |
false => None, | |
}; | |
// Timings | |
let total_time = start_time.elapsed(); | |
let validation_time = response.queued - start_time; | |
let queue_time = response.start - response.queued; | |
let inference_time = Instant::now() - response.start; | |
let time_per_token = inference_time / response.generated_text.generated_tokens; | |
// Tracing metadata | |
span.record("total_time", format!("{total_time:?}")); | |
span.record("validation_time", format!("{validation_time:?}")); | |
span.record("queue_time", format!("{queue_time:?}")); | |
span.record("inference_time", format!("{inference_time:?}")); | |
span.record("time_per_token", format!("{time_per_token:?}")); | |
span.record("seed", format!("{:?}", response.generated_text.seed)); | |
// Headers | |
let mut headers = HeaderMap::new(); | |
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); | |
headers.insert( | |
"x-compute-time", | |
total_time.as_millis().to_string().parse().unwrap(), | |
); | |
headers.insert( | |
"x-compute-characters", | |
compute_characters.to_string().parse().unwrap(), | |
); | |
headers.insert( | |
"x-total-time", | |
total_time.as_millis().to_string().parse().unwrap(), | |
); | |
headers.insert( | |
"x-validation-time", | |
validation_time.as_millis().to_string().parse().unwrap(), | |
); | |
headers.insert( | |
"x-queue-time", | |
queue_time.as_millis().to_string().parse().unwrap(), | |
); | |
headers.insert( | |
"x-inference-time", | |
inference_time.as_millis().to_string().parse().unwrap(), | |
); | |
headers.insert( | |
"x-time-per-token", | |
time_per_token.as_millis().to_string().parse().unwrap(), | |
); | |
// Metrics | |
metrics::increment_counter!("tgi_request_success"); | |
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); | |
metrics::histogram!( | |
"tgi_request_validation_duration", | |
validation_time.as_secs_f64() | |
); | |
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); | |
metrics::histogram!( | |
"tgi_request_inference_duration", | |
inference_time.as_secs_f64() | |
); | |
metrics::histogram!( | |
"tgi_request_mean_time_per_token_duration", | |
time_per_token.as_secs_f64() | |
); | |
metrics::histogram!( | |
"tgi_request_generated_tokens", | |
response.generated_text.generated_tokens as f64 | |
); | |
// Send response | |
let mut output_text = response.generated_text.text; | |
if let Some(prompt) = add_prompt { | |
output_text = prompt + &output_text; | |
} | |
tracing::info!("Output: {}", output_text); | |
let response = GenerateResponse { | |
generated_text: output_text, | |
details, | |
}; | |
Ok((headers, Json(response))) | |
} | |
/// Generate a stream of token using Server-Sent Events | |
async fn generate_stream( | |
infer: Extension<Infer>, | |
req: Json<GenerateRequest>, | |
) -> ( | |
HeaderMap, | |
Sse<impl Stream<Item = Result<Event, Infallible>>>, | |
) { | |
let span = tracing::Span::current(); | |
let start_time = Instant::now(); | |
metrics::increment_counter!("tgi_request_count"); | |
let compute_characters = req.0.inputs.chars().count(); | |
let mut headers = HeaderMap::new(); | |
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); | |
headers.insert( | |
"x-compute-characters", | |
compute_characters.to_string().parse().unwrap(), | |
); | |
let stream = async_stream::stream! { | |
// Inference | |
let mut end_reached = false; | |
let mut error = false; | |
let mut add_prompt = None; | |
if req.0.parameters.return_full_text.unwrap_or(false) { | |
add_prompt = Some(req.0.inputs.clone()); | |
} | |
let details = req.0.parameters.details; | |
let best_of = req.0.parameters.best_of.unwrap_or(1); | |
if best_of == 1 { | |
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { | |
// Keep permit as long as generate_stream lives | |
Ok((_permit, mut response_stream)) => { | |
// Server-Sent Event stream | |
while let Some(response) = response_stream.next().await { | |
match response { | |
Ok(response) => { | |
match response { | |
// Prefill is ignored | |
InferStreamResponse::Prefill(_) => {} | |
// Yield event for every new token | |
InferStreamResponse::Token(token) => { | |
// StreamResponse | |
let stream_token = StreamResponse { | |
token, | |
generated_text: None, | |
details: None, | |
}; | |
yield Ok(Event::default().json_data(stream_token).unwrap()) | |
} | |
// Yield event for last token and compute timings | |
InferStreamResponse::End { | |
token, | |
generated_text, | |
start, | |
queued, | |
} => { | |
// Token details | |
let details = match details { | |
true => Some(StreamDetails { | |
finish_reason: FinishReason::from(generated_text.finish_reason), | |
generated_tokens: generated_text.generated_tokens, | |
seed: generated_text.seed, | |
}), | |
false => None, | |
}; | |
// Timings | |
let total_time = start_time.elapsed(); | |
let validation_time = queued - start_time; | |
let queue_time = start - queued; | |
let inference_time = Instant::now() - start; | |
let time_per_token = inference_time / generated_text.generated_tokens; | |
// Tracing metadata | |
span.record("total_time", format!("{total_time:?}")); | |
span.record("validation_time", format!("{validation_time:?}")); | |
span.record("queue_time", format!("{queue_time:?}")); | |
span.record("inference_time", format!("{inference_time:?}")); | |
span.record("time_per_token", format!("{time_per_token:?}")); | |
span.record("seed", format!("{:?}", generated_text.seed)); | |
// Metrics | |
metrics::increment_counter!("tgi_request_success"); | |
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); | |
metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); | |
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); | |
metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); | |
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); | |
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); | |
// StreamResponse | |
end_reached = true; | |
let mut output_text = generated_text.text; | |
if let Some(prompt) = add_prompt { | |
output_text = prompt + &output_text; | |
} | |
tracing::info!(parent: &span, "Output: {}", output_text); | |
let stream_token = StreamResponse { | |
token, | |
generated_text: Some(output_text), | |
details | |
}; | |
yield Ok(Event::default().json_data(stream_token).unwrap()); | |
break; | |
} | |
} | |
} | |
// yield error | |
Err(err) => { | |
error = true; | |
yield Ok(Event::from(err)); | |
break; | |
} | |
} | |
} | |
}, | |
// yield error | |
Err(err) => { | |
error = true; | |
yield Ok(Event::from(err)); | |
} | |
} | |
// Check if generation reached the end | |
// Skip if we already sent an error | |
if !end_reached && !error { | |
let err = InferError::IncompleteGeneration; | |
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); | |
tracing::error!("{err}"); | |
yield Ok(Event::from(err)); | |
} | |
} else { | |
let err = InferError::from(ValidationError::BestOfStream); | |
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); | |
tracing::error!("{err}"); | |
yield Ok(Event::from(err)); | |
} | |
}; | |
(headers, Sse::new(stream).keep_alive(KeepAlive::default())) | |
} | |
/// Prometheus metrics scrape endpoint | |
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String { | |
prom_handle.render() | |
} | |
/// Serving method | |
pub async fn run( | |
model_info: HubModelInfo, | |
shard_info: ShardInfo, | |
compat_return_full_text: bool, | |
max_concurrent_requests: usize, | |
max_best_of: usize, | |
max_stop_sequences: usize, | |
max_input_length: usize, | |
max_total_tokens: usize, | |
waiting_served_ratio: f32, | |
max_batch_total_tokens: u32, | |
max_waiting_tokens: usize, | |
client: ShardedClient, | |
tokenizer: Option<Tokenizer>, | |
validation_workers: usize, | |
addr: SocketAddr, | |
allow_origin: Option<AllowOrigin>, | |
) { | |
// OpenAPI documentation | |
struct ApiDoc; | |
// Create state | |
let validation = Validation::new( | |
validation_workers, | |
tokenizer, | |
max_best_of, | |
max_stop_sequences, | |
max_input_length, | |
max_total_tokens, | |
); | |
let generation_health = Arc::new(AtomicBool::new(false)); | |
let health_ext = Health::new(client.clone(), generation_health.clone()); | |
let infer = Infer::new( | |
client, | |
validation, | |
waiting_served_ratio, | |
max_batch_total_tokens, | |
max_waiting_tokens, | |
max_concurrent_requests, | |
shard_info.requires_padding, | |
generation_health, | |
); | |
// Duration buckets | |
let duration_matcher = Matcher::Suffix(String::from("duration")); | |
let n_duration_buckets = 35; | |
let mut duration_buckets = Vec::with_capacity(n_duration_buckets); | |
// Minimum duration in seconds | |
let mut value = 0.0001; | |
for _ in 0..n_duration_buckets { | |
// geometric sequence | |
value *= 1.5; | |
duration_buckets.push(value); | |
} | |
// Input Length buckets | |
let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length")); | |
let input_length_buckets: Vec<f64> = (0..100) | |
.map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64) | |
.collect(); | |
// Generated tokens buckets | |
let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens")); | |
let generated_tokens_buckets: Vec<f64> = (0..100) | |
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64) | |
.collect(); | |
// Input Length buckets | |
let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens")); | |
let max_new_tokens_buckets: Vec<f64> = (0..100) | |
.map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64) | |
.collect(); | |
// Batch size buckets | |
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); | |
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect(); | |
// Prometheus handler | |
let builder = PrometheusBuilder::new() | |
.set_buckets_for_metric(duration_matcher, &duration_buckets) | |
.unwrap() | |
.set_buckets_for_metric(input_length_matcher, &input_length_buckets) | |
.unwrap() | |
.set_buckets_for_metric(generated_tokens_matcher, &generated_tokens_buckets) | |
.unwrap() | |
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) | |
.unwrap() | |
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) | |
.unwrap(); | |
let prom_handle = builder | |
.install_recorder() | |
.expect("failed to install metrics recorder"); | |
// CORS layer | |
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any()); | |
let cors_layer = CorsLayer::new() | |
.allow_methods([Method::GET, Method::POST]) | |
.allow_headers([http::header::CONTENT_TYPE]) | |
.allow_origin(allow_origin); | |
// Endpoint info | |
let info = Info { | |
model_id: model_info.model_id, | |
model_sha: model_info.sha, | |
model_dtype: shard_info.dtype, | |
model_device_type: shard_info.device_type, | |
model_pipeline_tag: model_info.pipeline_tag, | |
max_concurrent_requests, | |
max_best_of, | |
max_stop_sequences, | |
max_input_length, | |
max_total_tokens, | |
waiting_served_ratio, | |
max_batch_total_tokens, | |
max_waiting_tokens, | |
validation_workers, | |
version: env!("CARGO_PKG_VERSION"), | |
sha: option_env!("VERGEN_GIT_SHA"), | |
docker_label: option_env!("DOCKER_LABEL"), | |
}; | |
// Create router | |
let app = Router::new() | |
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) | |
// Base routes | |
.route("/", post(compat_generate)) | |
.route("/info", get(get_model_info)) | |
.route("/generate", post(generate)) | |
.route("/generate_stream", post(generate_stream)) | |
// AWS Sagemaker route | |
.route("/invocations", post(compat_generate)) | |
// Base Health route | |
.route("/health", get(health)) | |
// Inference API health route | |
.route("/", get(health)) | |
// AWS Sagemaker health route | |
.route("/ping", get(health)) | |
// Prometheus metrics route | |
.route("/metrics", get(metrics)) | |
.layer(Extension(info)) | |
.layer(Extension(health_ext)) | |
.layer(Extension(compat_return_full_text)) | |
.layer(Extension(infer)) | |
.layer(Extension(prom_handle)) | |
.layer(opentelemetry_tracing_layer()) | |
.layer(cors_layer); | |
// Run server | |
axum::Server::bind(&addr) | |
.serve(app.into_make_service()) | |
// Wait until all requests are finished to shut down | |
.with_graceful_shutdown(shutdown_signal()) | |
.await | |
.unwrap(); | |
} | |
/// Shutdown signal handler | |
async fn shutdown_signal() { | |
let ctrl_c = async { | |
signal::ctrl_c() | |
.await | |
.expect("failed to install Ctrl+C handler"); | |
}; | |
let terminate = async { | |
signal::unix::signal(signal::unix::SignalKind::terminate()) | |
.expect("failed to install signal handler") | |
.recv() | |
.await; | |
}; | |
let terminate = std::future::pending::<()>(); | |
tokio::select! { | |
_ = ctrl_c => {}, | |
_ = terminate => {}, | |
} | |
tracing::info!("signal received, starting graceful shutdown"); | |
opentelemetry::global::shutdown_tracer_provider(); | |
} | |
impl From<i32> for FinishReason { | |
fn from(finish_reason: i32) -> Self { | |
let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap(); | |
match finish_reason { | |
text_generation_client::FinishReason::Length => FinishReason::Length, | |
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, | |
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence, | |
} | |
} | |
} | |
/// Convert to Axum supported formats | |
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { | |
fn from(err: InferError) -> Self { | |
let status_code = match err { | |
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY, | |
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, | |
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, | |
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, | |
}; | |
( | |
status_code, | |
Json(ErrorResponse { | |
error: err.to_string(), | |
error_type: err.error_type().to_string(), | |
}), | |
) | |
} | |
} | |
impl From<InferError> for Event { | |
fn from(err: InferError) -> Self { | |
Event::default() | |
.json_data(ErrorResponse { | |
error: err.to_string(), | |
error_type: err.error_type().to_string(), | |
}) | |
.unwrap() | |
} | |
} | |