File size: 2,713 Bytes
f99e419
02b4c7b
 
8ce0ba9
f423eb3
25b91cd
bd8563e
8ce0ba9
02b4c7b
f423eb3
13dc3e9
f423eb3
759408c
 
 
02b4c7b
759408c
 
 
02b4c7b
759408c
 
 
 
 
 
 
0bc8a9d
 
20165cd
0bc8a9d
 
f2d6701
877da1f
bd8563e
efec77b
f423eb3
0bc8a9d
f423eb3
cc93917
0bc8a9d
18340a0
877da1f
bd8563e
18340a0
0bc8a9d
 
f423eb3
759408c
 
 
 
 
 
f423eb3
 
759408c
f423eb3
77485e2
f423eb3
 
 
 
 
 
 
 
 
 
 
 
220d69e
759408c
 
 
 
 
 
 
 
 
 
 
 
0bc8a9d
f423eb3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import datetime
import json
import subprocess
import torch

def log(msg):
    print(str(datetime.datetime.now()) + ': ' + str(msg), flush=True)

def get_prompt(user_query: str, functions: list = []) -> str:
    """
    Generates a conversation prompt based on the user's query and a list of functions.

    Parameters:
    - user_query (str): The user's query.
    - functions (list): A list of functions to include in the prompt.

    Returns:
    - str: The formatted conversation prompt.
    """
    if len(functions) == 0:
        return f"USER: <<question>> {user_query}\nASSISTANT: "
    functions_string = json.dumps(functions)
    return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "

device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
log('Device: ' + device)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

result = subprocess.run('cat /etc/os-release && pwd && ls -lH && find /.cache/huggingface/hub && find /.cache/gorilla', shell=True, capture_output=True, text=True)
log(result.stdout)

model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
log('AutoTokenizer.from_pretrained ...')
tokenizer = AutoTokenizer.from_pretrained(model_id)
log('AutoModelForCausalLM.from_pretrained ...')
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)

result = subprocess.run('pwd && ls -lH && find /.cache/huggingface/hub && find /.cache/gorilla', shell=True, capture_output=True, text=True)
log(result.stdout)

log('model.to(device) ...')
model.to(device)

log('FastAPI setup ...')
app = FastAPI()

@app.post("/query_gorilla")
async def query_gorilla(req: Request):
    body = await req.body()
    parsedBody = json.loads(body)
    log(parsedBody['query'])
    log(parsedBody['functions'])

    log('Generate prompt and obtain model output')
    prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions'])

    log('Pipeline setup ...')
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=128,
        batch_size=16,
        torch_dtype=torch_dtype,
        device=device,
    )

    log('Get answer ...')
    output = pipe(prompt)
    
    return {
      "val": output
    }

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")


log('Initialization done.')