beam-app / app.py
Greums's picture
keep warm for 5 minutes
f48cd65
from threading import Thread
import torch
from beam import Image, Volume, GpuType, asgi
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from transformers import (
AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer,
PreTrainedTokenizerFast, PreTrainedModel, StoppingCriteriaList
)
from utils import MaxPostsStoppingCriteria, Body, fallback
SETTINGS = {
"model_name": "Error410/JVCGPT-Medium",
"beam_volume_path": "./cached_models",
}
# @see https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation
DEFAULTS = {
"max_length": 2048, # 512
"temperature": 0.9, # 1
"top_p": 1, # 0.95
"top_k": 0, # 40
"repetition_penalty": 1.0, # 1.0
"no_repeat_ngram_size": 0, # 0
"do_sample": True, # True
}
def load_models():
tokenizer = AutoTokenizer.from_pretrained(
SETTINGS["model_name"],
cache_dir=SETTINGS["beam_volume_path"]
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
SETTINGS["model_name"],
device_map="auto",
torch_dtype=torch.float16,
cache_dir=SETTINGS["beam_volume_path"],
)
return model, tokenizer
def stream(model: PreTrainedModel, tokenizer: PreTrainedTokenizerFast, body: Body):
generate_args = {
"max_length": fallback(body.max_length, DEFAULTS["max_length"]),
"temperature": fallback(body.temperature, DEFAULTS["temperature"]),
"top_p": fallback(body.top_p, DEFAULTS["top_p"]),
"top_k": fallback(body.top_k, DEFAULTS["top_k"]),
"repetition_penalty": fallback(body.repetition_penalty, DEFAULTS["repetition_penalty"]),
"no_repeat_ngram_size": fallback(body.no_repeat_ngram_size, DEFAULTS["no_repeat_ngram_size"]),
"do_sample": fallback(body.do_sample, DEFAULTS["do_sample"]),
"use_cache": True,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
}
inputs = tokenizer(body.prompt, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to("cuda")
attention_mask = inputs["attention_mask"].to("cuda")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=240)
# with torch.no_grad(): # seems to be useless
thread = Thread(
target=model.generate,
kwargs={
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList([MaxPostsStoppingCriteria(tokenizer, body.posts_count)]),
**generate_args,
}
)
thread.start()
for token in streamer:
yield token
# if len(token) > 0:
# yield f"DATA {token}"
#
# yield "EOS"
@asgi(
name="jvcgpt",
on_start=load_models,
cpu=2.0,
memory="16Gi",
gpu=GpuType.A100_40,
gpu_count=1,
timeout=5*60, # Time for loading the model and run the server
keep_warm_seconds=5*60,
image=Image(
python_version="python3.12",
python_packages=[
"fastapi",
"torch",
"transformers",
"accelerate",
"huggingface_hub[hf-transfer]",
],
env_vars=["HF_HUB_ENABLE_HF_TRANSFER=1"],
),
volumes=[
Volume(
name="cached_models",
mount_path=SETTINGS["beam_volume_path"],
)
],
)
def server(context):
model, tokenizer = context.on_start_value
app = FastAPI()
@app.post("/stream")
async def stream_endpoint(body: Body) -> StreamingResponse:
return StreamingResponse(
stream(model, tokenizer, body),
media_type='text/event-stream',
headers={"Cache-Control": "no-cache"},
)
return app