Phoenix21's picture
Update app.py
23cc740 verified
raw
history blame
1.59 kB
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
import uvicorn
from huggingface_hub import login
# Authenticate with Hugging Face Hub
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
class Query(BaseModel):
text: str
app = FastAPI(title="Financial Chatbot API")
# Load base model
base_model_name = "meta-llama/Meta-Llama-3-8B" # Update this if different base model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
trust_remote_code=True
)
# Load adapter from your checkpoint
peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
model = PeftModel.from_pretrained(model, peft_model_id)
# Load tokenizer from base model
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Rest of your code remains the same...
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
)
@app.post("/generate")
def generate(query: Query):
prompt = f"Question: {query.text}\nAnswer: "
response = chat_pipe(prompt)[0]["generated_text"]
return {"response": response}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)