Spaces:
Running
Running
""" | |
Provides a robust, asynchronous SentimentAnalyzer class. | |
This module communicates with the Hugging Face Inference API to perform sentiment | |
analysis without requiring local model downloads. It's designed for use within | |
an asynchronous application like FastAPI. | |
""" | |
import asyncio | |
import logging | |
import os | |
# ==================================================================== | |
# FINAL FIX APPLIED HERE | |
# ==================================================================== | |
# Import Optional and Union for Python 3.9 compatibility. | |
from typing import TypedDict, Union, Optional | |
# ==================================================================== | |
import httpx | |
# --- Configuration & Models --- | |
# Configure logging for this module | |
logger = logging.getLogger(__name__) | |
# Define the expected structure of a result payload for type hinting | |
class SentimentResult(TypedDict): | |
id: int | |
text: str | |
# Using Union for Python 3.9 compatibility | |
result: dict[str, Union[str, float]] | |
# --- Main Class: SentimentAnalyzer --- | |
class SentimentAnalyzer: | |
""" | |
Manages sentiment analysis requests to the Hugging Face Inference API. | |
This class handles asynchronous API communication, manages a result queue for | |
Server-Sent Events (SSE), and encapsulates all related state and logic. | |
""" | |
HF_API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english" | |
# ==================================================================== | |
# FINAL FIX APPLIED HERE | |
# ==================================================================== | |
# Changed `str | None` to `Optional[str]` for Python 3.9 compatibility. | |
def __init__(self, client: httpx.AsyncClient, api_token: Optional[str] = None): | |
# ==================================================================== | |
""" | |
Initializes the SentimentAnalyzer. | |
Args: | |
client: An instance of httpx.AsyncClient for making API calls. | |
api_token: The Hugging Face API token. | |
""" | |
self.client = client | |
self.api_token = api_token or os.getenv("HF_API_TOKEN") | |
if not self.api_token: | |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") | |
self.headers = {"Authorization": f"Bearer {self.api_token}"} | |
# A queue is the ideal structure for a producer-consumer pattern, | |
# where the API endpoint is the producer and SSE streamers are consumers. | |
self.result_queue: asyncio.Queue[SentimentResult] = asyncio.Queue() | |
async def compute_and_publish(self, text: str, request_id: int) -> None: | |
""" | |
Performs sentiment analysis via an external API and places the result | |
into a queue for consumption by SSE streams. | |
Args: | |
text: The input text to analyze. | |
request_id: A unique identifier for this request. | |
""" | |
analysis_result: dict[str, Union[str, float]] = {"label": "ERROR", "score": 0.0, "error": "Unknown failure"} | |
try: | |
response = await self.client.post( | |
self.HF_API_URL, | |
headers=self.headers, | |
json={"inputs": text, "options": {"wait_for_model": True}}, | |
timeout=20.0 | |
) | |
response.raise_for_status() | |
data = response.json() | |
# Validate the expected response structure from the Inference API | |
if isinstance(data, list) and data and isinstance(data[0], list) and data[0]: | |
# The model returns a list containing a list of results | |
res = data[0][0] | |
analysis_result = {"label": res.get("label"), "score": round(res.get("score", 0.0), 4)} | |
logger.info("β Sentiment computed for request #%d", request_id) | |
else: | |
raise ValueError(f"Unexpected API response format: {data}") | |
except httpx.HTTPStatusError as e: | |
error_msg = f"API returned status {e.response.status_code}" | |
logger.error("β Sentiment API error for request #%d: %s", request_id, error_msg) | |
analysis_result["error"] = error_msg | |
except httpx.RequestError as e: | |
error_msg = f"Network request failed: {e}" | |
logger.error("β Sentiment network error for request #%d: %s", request_id, error_msg) | |
analysis_result["error"] = error_msg | |
except (ValueError, KeyError) as e: | |
error_msg = f"Failed to parse API response: {e}" | |
logger.error("β Sentiment parsing error for request #%d: %s", request_id, error_msg) | |
analysis_result["error"] = error_msg | |
# Always publish a result to the queue, even if it's an error state | |
payload: SentimentResult = { | |
"id": request_id, | |
"text": text, | |
"result": analysis_result | |
} | |
await self.result_queue.put(payload) | |
async def stream_results(self): # Type hint removed for simplicity, was -> SentimentResult | |
""" | |
An async generator that yields new results as they become available. | |
This is the consumer part of the pattern. | |
""" | |
while True: | |
try: | |
# This efficiently waits until an item is available in the queue | |
result = await self.result_queue.get() | |
yield result | |
self.result_queue.task_done() | |
except asyncio.CancelledError: | |
logger.info("Result stream has been cancelled.") | |
break |