tharu22 commited on
Commit
ed1f32e
·
1 Parent(s): 563d648
Files changed (1) hide show
  1. main.py +16 -54
main.py CHANGED
@@ -1,73 +1,35 @@
1
  # app/main.py
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
- import numpy as np
5
- from services.sms_service import classify_sms, load_trained_model
6
- from schemas.input_schemas import CosineSimilarityInput, CosineSimilarityOutput
7
- from schemas.input_schemas import EmbeddingInput, EmbeddingOutput
8
 
9
- # Initialize FastAPI
10
  app = FastAPI()
11
 
12
- # Load the models from the 'models' folder
13
- model, vectorizer = load_trained_model()
14
-
15
- # Function to compute cosine similarity
16
- def cosine_similarity(vec1, vec2):
17
- """
18
- Compute cosine similarity between two vectors.
19
- """
20
- norm1 = np.linalg.norm(vec1)
21
- norm2 = np.linalg.norm(vec2)
22
- if norm1 == 0 or norm2 == 0:
23
- return 0.0 # Prevent division by zero
24
- return np.dot(vec1, vec2) / (norm1 * norm2)
25
-
26
  # 🚀 1️⃣ Homepage Endpoint
27
  @app.get("/")
28
  async def home():
29
  return {"message": "Welcome to SMS Classification API"}
30
 
31
- # 📩 2️⃣ SMS Classification Endpoint
32
- class MessageInput(BaseModel):
33
- message: str
34
-
35
- @app.post("/predict_label/")
36
- async def classify_sms_endpoint(input_data: MessageInput):
37
- """
38
- Classify an SMS as either 'Transaction' or 'Offer'.
39
- """
40
  try:
41
- return classify_sms(input_data.message, model, vectorizer)
42
  except Exception as e:
43
- raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
44
 
45
- # 🔢 3️⃣ Cosine Similarity Endpoint
46
- @app.post("/cosine_similarity/", response_model=CosineSimilarityOutput)
47
- async def compute_similarity(input_data: CosineSimilarityInput):
48
- """
49
- Compute cosine similarity between two input texts.
50
- """
51
  try:
52
- # Transform the input texts using the TF-IDF vectorizer
53
- text1_vectorized = vectorizer.transform([input_data.text1])
54
- text2_vectorized = vectorizer.transform([input_data.text2])
55
-
56
- # Compute the cosine similarity between the two text embeddings
57
- similarity = cosine_similarity(text1_vectorized.toarray(), text2_vectorized.toarray())
58
- return CosineSimilarityOutput(cosine_similarity=round(float(similarity), 4))
59
  except Exception as e:
60
- raise HTTPException(status_code=500, detail=f"Error computing similarity: {str(e)}")
61
 
62
- # 🧠 4️⃣ Get Embedding of Text Message
63
- @app.post("/get_embedding/", response_model=EmbeddingOutput)
64
- async def get_embedding(input_data: EmbeddingInput):
65
- """
66
- Get the embedding (vector representation) of an input text message.
67
- """
68
  try:
69
- # Transform the input text using the TF-IDF vectorizer
70
- text_embedding = vectorizer.transform([input_data.message]).toarray().tolist()
71
- return EmbeddingOutput(embedding=text_embedding[0])
72
  except Exception as e:
73
- raise HTTPException(status_code=500, detail=f"Error generating embedding: {str(e)}")
 
1
  # app/main.py
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
+ from services.sms_service import predict_label, compute_cosine_similarity, compute_embeddings
 
 
 
5
 
 
6
  app = FastAPI()
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # 🚀 1️⃣ Homepage Endpoint
9
  @app.get("/")
10
  async def home():
11
  return {"message": "Welcome to SMS Classification API"}
12
 
13
+ # 🔢 2️⃣ Cosine Similarity Endpoint
14
+ @app.post("/cosine_similarity")
15
+ async def get_cosine_similarity(input_data: BaseModel):
 
 
 
 
 
 
16
  try:
17
+ return await compute_cosine_similarity(input_data.text1, input_data.text2)
18
  except Exception as e:
19
+ raise HTTPException(status_code=500, detail=f"Error computing similarity: {str(e)}")
20
 
21
+ # 📩 3️⃣ SMS Classification Endpoint
22
+ @app.post("/predict_label")
23
+ async def classify_message(input_data: BaseModel):
 
 
 
24
  try:
25
+ return await predict_label(input_data.message)
 
 
 
 
 
 
26
  except Exception as e:
27
+ raise HTTPException(status_code=500, detail=f"Error predicting label: {str(e)}")
28
 
29
+ # 📊 4️⃣ Text Embedding Endpoint
30
+ @app.post("/compute_embeddings")
31
+ async def get_embeddings(input_data: BaseModel):
 
 
 
32
  try:
33
+ return await compute_embeddings(input_data.message)
 
 
34
  except Exception as e:
35
+ raise HTTPException(status_code=500, detail=f"Error computing embeddings: {str(e)}")