gorilla-test2 / main.py
gmerrill
update
220d69e
raw
history blame
2.71 kB
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import datetime
import json
import subprocess
import torch
def log(msg):
print(str(datetime.datetime.now()) + ': ' + str(msg), flush=True)
def get_prompt(user_query: str, functions: list = []) -> str:
"""
Generates a conversation prompt based on the user's query and a list of functions.
Parameters:
- user_query (str): The user's query.
- functions (list): A list of functions to include in the prompt.
Returns:
- str: The formatted conversation prompt.
"""
if len(functions) == 0:
return f"USER: <<question>> {user_query}\nASSISTANT: "
functions_string = json.dumps(functions)
return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "
device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
log('Device: ' + device)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
result = subprocess.run('cat /etc/os-release && pwd && ls -lH && find /.cache/huggingface/hub && find /.cache/gorilla', shell=True, capture_output=True, text=True)
log(result.stdout)
model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
log('AutoTokenizer.from_pretrained ...')
tokenizer = AutoTokenizer.from_pretrained(model_id)
log('AutoModelForCausalLM.from_pretrained ...')
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
result = subprocess.run('pwd && ls -lH && find /.cache/huggingface/hub && find /.cache/gorilla', shell=True, capture_output=True, text=True)
log(result.stdout)
log('model.to(device) ...')
model.to(device)
log('FastAPI setup ...')
app = FastAPI()
@app.post("/query_gorilla")
async def query_gorilla(req: Request):
body = await req.body()
parsedBody = json.loads(body)
log(parsedBody['query'])
log(parsedBody['functions'])
log('Generate prompt and obtain model output')
prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions'])
log('Pipeline setup ...')
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=128,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
log('Get answer ...')
output = pipe(prompt)
return {
"val": output
}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")
log('Initialization done.')