Spaces:
Sleeping
Sleeping
File size: 1,533 Bytes
219ad87 c65cfb2 629dd02 eb9fa2c 629dd02 7dea212 629dd02 c65cfb2 aca3716 c65cfb2 ff9863c 629dd02 c65cfb2 aca3716 c65cfb2 aca3716 c65cfb2 219ad87 aca3716 629dd02 aca3716 629dd02 aca3716 c65cfb2 629dd02 c65cfb2 629dd02 c65cfb2 219ad87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import login
import os
print("Google Gemma 2 Chatbot is starting...")
# read access token from environment variable
access_token = os.getenv('HF_TOKEN')
login(access_token)
model_id = "google/gemma-2-9b-it"
print("Model loading started")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
print("Model loading completed")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Selected device:", device)
app = FastAPI()
@app.get('/')
def home():
return {"hello": "Bitfumes"}
@app.post('/ask')
async def ask(request: Request):
data = await request.json()
prompt = data.get("prompt")
if not prompt:
return {"error": "Prompt is missing"}
print("Device of the model:", model.device)
messages = [
{"role": "user", "content": f"{prompt}"},
]
print("Messages:", messages)
print("Tokenizer process started")
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")
print("Tokenizer process completed")
print("Model process started")
outputs = model.generate(**input_ids, max_new_tokens=256)
print("Tokenizer decode process started")
answer = tokenizer.decode(outputs[0]).split("<end_of_turn>")[1].strip()
return {"answer": answer} |