Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from app.model_utils import load_model_and_tokenizer, generate_summary | |
from app.classifier import train_classifier, classify_text | |
app = FastAPI() | |
# Load model and tokenizer for the /rag endpoint | |
model_name = "sshleifer/distilbart-cnn-6-6" # Example model | |
model, tokenizer = load_model_and_tokenizer(model_name) | |
# Dummy data and classifier for the /classification endpoint | |
dummy_data = [ | |
("I feel very sad and hopeless.", "Depression"), | |
("I have trouble sleeping at night.", "Insomnia"), | |
("I am constantly worrying about everything.", "Anxiety"), | |
("I feel energetic and happy.", "Happiness"), | |
("My mood swings a lot and I feel irritable.", "Mood Disorder") | |
] | |
classifier, vectorizer = train_classifier(dummy_data) | |
class Prompt(BaseModel): | |
prompt: str | |
class ClassificationInput(BaseModel): | |
data: str | |
def rag_endpoint(prompt: Prompt): | |
try: | |
summary = generate_summary(prompt.prompt, model, tokenizer) | |
return {"summary": summary} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def classification_endpoint(input: ClassificationInput): | |
try: | |
category = classify_text(input.data, classifier, vectorizer) | |
return {"category": category} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |