Spaces:
Runtime error
Runtime error
File size: 2,713 Bytes
f99e419 02b4c7b 8ce0ba9 f423eb3 25b91cd bd8563e 8ce0ba9 02b4c7b f423eb3 13dc3e9 f423eb3 759408c 02b4c7b 759408c 02b4c7b 759408c 0bc8a9d 20165cd 0bc8a9d f2d6701 877da1f bd8563e efec77b f423eb3 0bc8a9d f423eb3 cc93917 0bc8a9d 18340a0 877da1f bd8563e 18340a0 0bc8a9d f423eb3 759408c f423eb3 759408c f423eb3 77485e2 f423eb3 220d69e 759408c 0bc8a9d f423eb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import datetime
import json
import subprocess
import torch
def log(msg):
print(str(datetime.datetime.now()) + ': ' + str(msg), flush=True)
def get_prompt(user_query: str, functions: list = []) -> str:
"""
Generates a conversation prompt based on the user's query and a list of functions.
Parameters:
- user_query (str): The user's query.
- functions (list): A list of functions to include in the prompt.
Returns:
- str: The formatted conversation prompt.
"""
if len(functions) == 0:
return f"USER: <<question>> {user_query}\nASSISTANT: "
functions_string = json.dumps(functions)
return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "
device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
log('Device: ' + device)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
result = subprocess.run('cat /etc/os-release && pwd && ls -lH && find /.cache/huggingface/hub && find /.cache/gorilla', shell=True, capture_output=True, text=True)
log(result.stdout)
model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
log('AutoTokenizer.from_pretrained ...')
tokenizer = AutoTokenizer.from_pretrained(model_id)
log('AutoModelForCausalLM.from_pretrained ...')
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
result = subprocess.run('pwd && ls -lH && find /.cache/huggingface/hub && find /.cache/gorilla', shell=True, capture_output=True, text=True)
log(result.stdout)
log('model.to(device) ...')
model.to(device)
log('FastAPI setup ...')
app = FastAPI()
@app.post("/query_gorilla")
async def query_gorilla(req: Request):
body = await req.body()
parsedBody = json.loads(body)
log(parsedBody['query'])
log(parsedBody['functions'])
log('Generate prompt and obtain model output')
prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions'])
log('Pipeline setup ...')
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=128,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
log('Get answer ...')
output = pipe(prompt)
return {
"val": output
}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")
log('Initialization done.')
|