tharu22 commited on
Commit
ba19e1d
·
1 Parent(s): ed1f32e
Files changed (1) hide show
  1. main.py +4 -5
main.py CHANGED
@@ -1,7 +1,6 @@
1
- # app/main.py
2
  from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel
4
  from services.sms_service import predict_label, compute_cosine_similarity, compute_embeddings
 
5
 
6
  app = FastAPI()
7
 
@@ -12,7 +11,7 @@ async def home():
12
 
13
  # 🔢 2️⃣ Cosine Similarity Endpoint
14
  @app.post("/cosine_similarity")
15
- async def get_cosine_similarity(input_data: BaseModel):
16
  try:
17
  return await compute_cosine_similarity(input_data.text1, input_data.text2)
18
  except Exception as e:
@@ -20,7 +19,7 @@ async def get_cosine_similarity(input_data: BaseModel):
20
 
21
  # 📩 3️⃣ SMS Classification Endpoint
22
  @app.post("/predict_label")
23
- async def classify_message(input_data: BaseModel):
24
  try:
25
  return await predict_label(input_data.message)
26
  except Exception as e:
@@ -28,7 +27,7 @@ async def classify_message(input_data: BaseModel):
28
 
29
  # 📊 4️⃣ Text Embedding Endpoint
30
  @app.post("/compute_embeddings")
31
- async def get_embeddings(input_data: BaseModel):
32
  try:
33
  return await compute_embeddings(input_data.message)
34
  except Exception as e:
 
 
1
  from fastapi import FastAPI, HTTPException
 
2
  from services.sms_service import predict_label, compute_cosine_similarity, compute_embeddings
3
+ from schemas.input_schemas import CosineSimilarityInput, MessageInput, EmbeddingInput
4
 
5
  app = FastAPI()
6
 
 
11
 
12
  # 🔢 2️⃣ Cosine Similarity Endpoint
13
  @app.post("/cosine_similarity")
14
+ async def get_cosine_similarity(input_data: CosineSimilarityInput):
15
  try:
16
  return await compute_cosine_similarity(input_data.text1, input_data.text2)
17
  except Exception as e:
 
19
 
20
  # 📩 3️⃣ SMS Classification Endpoint
21
  @app.post("/predict_label")
22
+ async def classify_message(input_data: MessageInput):
23
  try:
24
  return await predict_label(input_data.message)
25
  except Exception as e:
 
27
 
28
  # 📊 4️⃣ Text Embedding Endpoint
29
  @app.post("/compute_embeddings")
30
+ async def get_embeddings(input_data: EmbeddingInput):
31
  try:
32
  return await compute_embeddings(input_data.message)
33
  except Exception as e: