|
from fastapi import FastAPI, HTTPException |
|
from services.sms_service import predict_label, compute_cosine_similarity, compute_embeddings |
|
from schemas.input_schemas import CosineSimilarityInput, MessageInput, EmbeddingInput |
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.get("/") |
|
async def home(): |
|
return {"message": "Welcome to SMS Classification API"} |
|
|
|
|
|
@app.post("/cosine_similarity") |
|
async def get_cosine_similarity(input_data: CosineSimilarityInput): |
|
try: |
|
return await compute_cosine_similarity(input_data.text1, input_data.text2) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error computing similarity: {str(e)}") |
|
|
|
|
|
@app.post("/predict_label") |
|
async def classify_message(input_data: MessageInput): |
|
try: |
|
return await predict_label(input_data.message) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error predicting label: {str(e)}") |
|
|
|
|
|
@app.post("/compute_embeddings") |
|
async def get_embeddings(input_data: EmbeddingInput): |
|
try: |
|
return await compute_embeddings(input_data.message) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error computing embeddings: {str(e)}") |
|
|