Phoenix21's picture
Update app.py
1ebaed9 verified
import os
import torch
import json
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
import uvicorn
from huggingface_hub import login, hf_hub_download
# Authenticate with Hugging Face Hub using the HF_TOKEN environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
# Define a Pydantic model for request validation
class Query(BaseModel):
text: str
app = FastAPI(title="Financial Chatbot API")
# Load the base model from Meta-Llama
base_model_name = "meta-llama/Llama-3.2-3B"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
trust_remote_code=True
)
# Load the tokenizer from the base model
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Load the finetuned adapter using PEFT with handling for eva_config
peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
try:
# Try direct loading first
model = PeftModel.from_pretrained(base_model, peft_model_id)
except TypeError as e:
if "eva_config" in str(e):
print("Handling eva_config compatibility issue...")
# Download config but handle it manually
config_path = hf_hub_download(repo_id=peft_model_id, filename="adapter_config.json")
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Remove the problematic parameter
if 'eva_config' in config_dict:
del config_dict['eva_config']
# Save modified config
modified_config_path = "modified_adapter_config.json"
with open(modified_config_path, 'w') as f:
json.dump(config_dict, f)
# Load the config from the modified file
config = PeftConfig.from_json_file(modified_config_path)
# Ensure the config has the correct path
config._name_or_path = peft_model_id
# Now load with the modified config
model = PeftModel.from_pretrained(base_model, peft_model_id, config=config)
else:
# If it's a different error, raise it
raise
# Create a text-generation pipeline using the loaded model and tokenizer
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
)
@app.post("/generate")
def generate(query: Query):
prompt = f"Question: {query.text}\nAnswer: "
response = chat_pipe(prompt)[0]["generated_text"]
# Extract just the answer part from the response
if "Answer: " in response:
response = response.split("Answer: ", 1)[1]
return {"response": response}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)