Phoenix21's picture
Update app.py
bc3abd0 verified
raw
history blame
1.7 kB
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import uvicorn
from huggingface_hub import login
# Authenticate with Hugging Face Hub using the HF_TOKEN environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
# Define a Pydantic model for request validation
class Query(BaseModel):
text: str
app = FastAPI(title="Financial Chatbot API")
# Load the base model from Meta-Llama
base_model_name = "meta-llama/Llama-3.2-3B"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
trust_remote_code=True
)
# Load the finetuned adapter using PEFT
peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
model = PeftModel.from_pretrained(base_model, peft_model_id)
# Load the tokenizer from the base model
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Create a text-generation pipeline using the loaded model and tokenizer
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
)
@app.post("/generate")
def generate(query: Query):
prompt = f"Question: {query.text}\nAnswer: "
response = chat_pipe(prompt)[0]["generated_text"]
return {"response": response}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)