mgbam commited on
Commit
c402b7d
Β·
verified Β·
1 Parent(s): 0e87c05

Update app/sentiment.py

Browse files
Files changed (1) hide show
  1. app/sentiment.py +108 -36
app/sentiment.py CHANGED
@@ -1,49 +1,121 @@
1
  """
2
- Sentiment analysis module using Hugging Face Inference API to avoid local model downloads.
 
 
 
 
3
  """
4
- import os
5
- import hashlib
6
  import logging
7
- from functools import lru_cache
 
 
8
  import httpx
9
 
10
- # Environment variables (set HF_API_TOKEN in your Space's Settings)
11
- HF_API_TOKEN = os.getenv("HF_API_TOKEN", "")
12
- API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english"
13
- HEADERS = {"Authorization": f"Bearer {HF_API_TOKEN}"}
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # In-memory cache for latest sentiment
16
- class SentimentCache:
17
- latest_id: int = 0
18
- latest_result: dict = {}
19
 
20
- @classmethod
21
- def _hash(cls, text: str) -> str:
22
- return hashlib.sha256(text.encode()).hexdigest()
23
 
24
- @classmethod
25
- @lru_cache(maxsize=128)
26
- def _analyze(cls, text: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  try:
28
- response = httpx.post(API_URL, headers=HEADERS, json={"inputs": text}, timeout=20)
 
 
 
 
 
29
  response.raise_for_status()
30
  data = response.json()
31
- # Expecting list of {label, score}
32
- if isinstance(data, list) and data:
33
- return data[0]
34
- raise ValueError("Unexpected response format: %s" % data)
35
- except Exception as e:
36
- logging.error("❌ Sentiment API error: %s", e)
37
- return {"label": "ERROR", "score": 0.0}
38
-
39
- @classmethod
40
- def compute(cls, text: str):
41
- """Trigger sentiment inference via API and update latest result."""
42
- res = cls._analyze(text)
43
- cls.latest_id += 1
44
- cls.latest_result = {
 
 
 
 
 
 
 
 
 
 
 
 
45
  "text": text,
46
- "label": res.get("label"),
47
- "score": round(res.get("score", 0.0), 4)
48
  }
49
- logging.info("βœ… Sentiment computed: %s", cls.latest_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Provides a robust, asynchronous SentimentAnalyzer class.
3
+
4
+ This module communicates with the Hugging Face Inference API to perform sentiment
5
+ analysis without requiring local model downloads. It's designed for use within
6
+ an asynchronous application like FastAPI.
7
  """
8
+ import asyncio
 
9
  import logging
10
+ import os
11
+ from typing import TypedDict
12
+
13
  import httpx
14
 
15
+ # --- Configuration & Models ---
16
+
17
+ # Configure logging for this module
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Define the expected structure of a result payload for type hinting
21
+ class SentimentResult(TypedDict):
22
+ id: int
23
+ text: str
24
+ result: dict[str, str | float]
25
+
26
+
27
+ # --- Main Class: SentimentAnalyzer ---
28
+ class SentimentAnalyzer:
29
+ """
30
+ Manages sentiment analysis requests to the Hugging Face Inference API.
31
 
32
+ This class handles asynchronous API communication, manages a result queue for
33
+ Server-Sent Events (SSE), and encapsulates all related state and logic.
34
+ """
 
35
 
36
+ HF_API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english"
 
 
37
 
38
+ def __init__(self, client: httpx.AsyncClient, api_token: str | None = None):
39
+ """
40
+ Initializes the SentimentAnalyzer.
41
+
42
+ Args:
43
+ client: An instance of httpx.AsyncClient for making API calls.
44
+ api_token: The Hugging Face API token.
45
+ """
46
+ self.client = client
47
+ self.api_token = api_token or os.getenv("HF_API_TOKEN")
48
+
49
+ if not self.api_token:
50
+ raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
51
+
52
+ self.headers = {"Authorization": f"Bearer {self.api_token}"}
53
+
54
+ # A queue is the ideal structure for a producer-consumer pattern,
55
+ # where the API endpoint is the producer and SSE streamers are consumers.
56
+ self.result_queue: asyncio.Queue[SentimentResult] = asyncio.Queue()
57
+
58
+ async def compute_and_publish(self, text: str, request_id: int) -> None:
59
+ """
60
+ Performs sentiment analysis via an external API and places the result
61
+ into a queue for consumption by SSE streams.
62
+
63
+ Args:
64
+ text: The input text to analyze.
65
+ request_id: A unique identifier for this request.
66
+ """
67
+ analysis_result = {"label": "ERROR", "score": 0.0, "error": "Unknown failure"}
68
  try:
69
+ response = await self.client.post(
70
+ self.HF_API_URL,
71
+ headers=self.headers,
72
+ json={"inputs": text, "options": {"wait_for_model": True}},
73
+ timeout=20.0
74
+ )
75
  response.raise_for_status()
76
  data = response.json()
77
+
78
+ # Validate the expected response structure from the Inference API
79
+ if isinstance(data, list) and data and isinstance(data[0], list) and data[0]:
80
+ # The model returns a list containing a list of results
81
+ res = data[0][0]
82
+ analysis_result = {"label": res.get("label"), "score": round(res.get("score", 0.0), 4)}
83
+ logger.info("βœ… Sentiment computed for request #%d", request_id)
84
+ else:
85
+ raise ValueError(f"Unexpected API response format: {data}")
86
+
87
+ except httpx.HTTPStatusError as e:
88
+ error_msg = f"API returned status {e.response.status_code}"
89
+ logger.error("❌ Sentiment API error for request #%d: %s", request_id, error_msg)
90
+ analysis_result["error"] = error_msg
91
+ except httpx.RequestError as e:
92
+ error_msg = f"Network request failed: {e}"
93
+ logger.error("❌ Sentiment network error for request #%d: %s", request_id, error_msg)
94
+ analysis_result["error"] = error_msg
95
+ except (ValueError, KeyError) as e:
96
+ error_msg = f"Failed to parse API response: {e}"
97
+ logger.error("❌ Sentiment parsing error for request #%d: %s", request_id, error_msg)
98
+ analysis_result["error"] = error_msg
99
+
100
+ # Always publish a result to the queue, even if it's an error state
101
+ payload: SentimentResult = {
102
+ "id": request_id,
103
  "text": text,
104
+ "result": analysis_result
 
105
  }
106
+ await self.result_queue.put(payload)
107
+
108
+ async def stream_results(self) -> SentimentResult:
109
+ """
110
+ An async generator that yields new results as they become available.
111
+ This is the consumer part of the pattern.
112
+ """
113
+ while True:
114
+ try:
115
+ # This efficiently waits until an item is available in the queue
116
+ result = await self.result_queue.get()
117
+ yield result
118
+ self.result_queue.task_done()
119
+ except asyncio.CancelledError:
120
+ logger.info("Result stream has been cancelled.")
121
+ break