File size: 1,533 Bytes
219ad87
c65cfb2
 
629dd02
 
eb9fa2c
629dd02
7dea212
629dd02
 
 
 
 
c65cfb2
aca3716
c65cfb2
 
 
ff9863c
629dd02
c65cfb2
aca3716
c65cfb2
 
aca3716
c65cfb2
 
 
 
 
 
 
 
 
219ad87
 
 
 
 
 
 
aca3716
629dd02
 
 
aca3716
 
629dd02
aca3716
 
c65cfb2
629dd02
c65cfb2
 
629dd02
c65cfb2
219ad87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import login
import os

print("Google Gemma 2 Chatbot is starting...")

# read access token from environment variable
access_token = os.getenv('HF_TOKEN')
login(access_token)

model_id = "google/gemma-2-9b-it"

print("Model loading started")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
print("Model loading completed")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Selected device:", device)

app = FastAPI()


@app.get('/')
def home():
    return {"hello": "Bitfumes"}


@app.post('/ask')
async def ask(request: Request):
    data = await request.json()
    prompt = data.get("prompt")
    if not prompt:
        return {"error": "Prompt is missing"}

    print("Device of the model:", model.device)
    messages = [
        {"role": "user", "content": f"{prompt}"},
    ]
    print("Messages:", messages)
    print("Tokenizer process started")
    input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")
    print("Tokenizer process completed")

    print("Model process started")
    outputs = model.generate(**input_ids, max_new_tokens=256)

    print("Tokenizer decode process started")
    answer = tokenizer.decode(outputs[0]).split("<end_of_turn>")[1].strip()

    return {"answer": answer}