gyesibiney commited on
Commit
b7ac8b4
·
1 Parent(s): 1c1ef2c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -18
main.py CHANGED
@@ -1,6 +1,4 @@
1
- #from fastapi import FastAPI, HTTPException, Query
2
- #import pandas as pd
3
- from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
6
 
@@ -14,32 +12,33 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  # Create a sentiment analysis pipeline
15
  sentiment = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
16
 
 
 
 
 
 
 
17
  # Define a request body model
18
  class SentimentRequest(BaseModel):
19
  text: str
20
 
21
  # Define a response model
22
  class SentimentResponse(BaseModel):
23
- sentiment: str
24
  score: float
25
 
26
- # Create an endpoint for sentiment analysis
27
- @app.post("/sentiment/")
28
- async def analyze_sentiment(request: SentimentRequest):
29
- input_text = request.text
30
- result = sentiment(input_text)
31
  sentiment_label = result[0]["label"]
32
  sentiment_score = result[0]["score"]
33
-
34
- if sentiment_label == "LABEL_1":
35
- sentiment_label = "positive"
36
- elif sentiment_label == "LABEL_0":
37
- sentiment_label = "neutral"
38
- else:
39
- sentiment_label = "negative"
40
-
41
- return SentimentResponse(sentiment=sentiment_label.capitalize(), score=sentiment_score)
42
 
43
  if __name__ == "__main__":
44
  import uvicorn
45
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
+ from fastapi import FastAPI, Query
 
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
 
 
12
  # Create a sentiment analysis pipeline
13
  sentiment = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
14
 
15
+ # Create a dictionary to map sentiment labels to binary values
16
+ sentiment_label_mapping = {
17
+ "LABEL_1": 1, # Positive
18
+ "LABEL_0": 0, # Negative
19
+ }
20
+
21
  # Define a request body model
22
  class SentimentRequest(BaseModel):
23
  text: str
24
 
25
  # Define a response model
26
  class SentimentResponse(BaseModel):
27
+ sentiment: int # 1 for positive, 0 for negative
28
  score: float
29
 
30
+ # Create an endpoint for sentiment analysis with query parameter
31
+ @app.get("/sentiment/")
32
+ async def analyze_sentiment(text: str = Query(..., description="Input text for sentiment analysis")):
33
+ result = sentiment(text)
 
34
  sentiment_label = result[0]["label"]
35
  sentiment_score = result[0]["score"]
36
+
37
+ sentiment_value = sentiment_label_mapping.get(sentiment_label, -1) # Default to -1 for unknown labels
38
+
39
+ return SentimentResponse(sentiment=sentiment_value, score=sentiment_score)
 
 
 
 
 
40
 
41
  if __name__ == "__main__":
42
  import uvicorn
43
  uvicorn.run(app, host="0.0.0.0", port=8000)
44
+