assignment / app /main.py
sameer9's picture
Add FastAPI app and Docker configuration
e5b2387
raw
history blame
1.55 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from app.model_utils import load_model_and_tokenizer, generate_summary
from app.classifier import train_classifier, classify_text
app = FastAPI()
# Load model and tokenizer for the /rag endpoint
model_name = "sshleifer/distilbart-cnn-6-6" # Example model
model, tokenizer = load_model_and_tokenizer(model_name)
# Dummy data and classifier for the /classification endpoint
dummy_data = [
("I feel very sad and hopeless.", "Depression"),
("I have trouble sleeping at night.", "Insomnia"),
("I am constantly worrying about everything.", "Anxiety"),
("I feel energetic and happy.", "Happiness"),
("My mood swings a lot and I feel irritable.", "Mood Disorder")
]
classifier, vectorizer = train_classifier(dummy_data)
class Prompt(BaseModel):
prompt: str
class ClassificationInput(BaseModel):
data: str
@app.post("/rag")
def rag_endpoint(prompt: Prompt):
try:
summary = generate_summary(prompt.prompt, model, tokenizer)
return {"summary": summary}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classification")
def classification_endpoint(input: ClassificationInput):
try:
category = classify_text(input.data, classifier, vectorizer)
return {"category": category}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)