JohnDoee commited on
Commit
b147674
·
1 Parent(s): ada2814

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +38 -10
main.py CHANGED
@@ -1,29 +1,57 @@
1
  import os
2
- from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import pipeline
 
5
 
6
  # Set custom cache directory to avoid permission issues
7
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
8
 
9
  app = FastAPI()
10
 
11
- # Explicitly specify a model (avoid default selection)
12
- MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
13
- sentiment_pipeline = pipeline("sentiment-analysis", model=MODEL_NAME)
14
 
15
  class SentimentRequest(BaseModel):
16
  text: str
17
 
18
  class SentimentResponse(BaseModel):
19
- label: str
20
- score: float
 
 
 
 
 
 
 
 
21
 
22
  @app.get("/")
23
  def home():
24
  return {"message": "Sentiment Analysis API is running!"}
25
 
26
- @app.post("/predict/", response_model=SentimentResponse)
27
- def predict(request: SentimentRequest):
28
- result = sentiment_pipeline(request.text)
29
- return SentimentResponse(label=result[0]['label'], score=result[0]['score'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import pipeline
5
+ import langdetect
6
 
7
  # Set custom cache directory to avoid permission issues
8
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
9
 
10
  app = FastAPI()
11
 
12
+ # Load sentiment analysis models
13
+ multilingual_model = pipeline("sentiment-analysis", model="tabularisai/multilingual-sentiment-analysis")
14
+ english_model = pipeline("sentiment-analysis", model="siebert/sentiment-roberta-large-english")
15
 
16
  class SentimentRequest(BaseModel):
17
  text: str
18
 
19
  class SentimentResponse(BaseModel):
20
+ original_text: str
21
+ language_detected: str
22
+ sentiment: str
23
+ confidence_score: float
24
+
25
+ def detect_language(text: str) -> str:
26
+ try:
27
+ return langdetect.detect(text)
28
+ except:
29
+ return "unknown"
30
 
31
  @app.get("/")
32
  def home():
33
  return {"message": "Sentiment Analysis API is running!"}
34
 
35
+ @app.post("/analyze/", response_model=SentimentResponse)
36
+ def analyze_sentiment(request: SentimentRequest):
37
+ if not request.text:
38
+ raise HTTPException(status_code=400, detail="No text provided")
39
+
40
+ text = request.text
41
+ language = detect_language(text)
42
+
43
+ # Choose the appropriate model based on language
44
+ if language == "en":
45
+ result = english_model(text)
46
+ else:
47
+ result = multilingual_model(text)
48
+
49
+ sentiment = result[0]["label"].lower()
50
+ score = result[0]["score"]
51
+
52
+ return SentimentResponse(
53
+ original_text=text,
54
+ language_detected=language,
55
+ sentiment=sentiment,
56
+ confidence_score=score
57
+ )