Spaces:
Running
Running
File size: 4,841 Bytes
fb17746 c402b7d fb17746 c402b7d f611cc3 c402b7d 061fd19 f611cc3 c402b7d f611cc3 c402b7d f611cc3 c402b7d f611cc3 c402b7d 061fd19 c402b7d 061fd19 c402b7d f611cc3 c402b7d f611cc3 c402b7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
"""
Provides a robust, asynchronous SentimentAnalyzer class.
This module communicates with the Hugging Face Inference API to perform sentiment
analysis without requiring local model downloads. It's designed for use within
an asynchronous application like FastAPI.
"""
import asyncio
import logging
import os
from typing import TypedDict
import httpx
# --- Configuration & Models ---
# Configure logging for this module
logger = logging.getLogger(__name__)
# Define the expected structure of a result payload for type hinting
class SentimentResult(TypedDict):
id: int
text: str
result: dict[str, str | float]
# --- Main Class: SentimentAnalyzer ---
class SentimentAnalyzer:
"""
Manages sentiment analysis requests to the Hugging Face Inference API.
This class handles asynchronous API communication, manages a result queue for
Server-Sent Events (SSE), and encapsulates all related state and logic.
"""
HF_API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english"
def __init__(self, client: httpx.AsyncClient, api_token: str | None = None):
"""
Initializes the SentimentAnalyzer.
Args:
client: An instance of httpx.AsyncClient for making API calls.
api_token: The Hugging Face API token.
"""
self.client = client
self.api_token = api_token or os.getenv("HF_API_TOKEN")
if not self.api_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
self.headers = {"Authorization": f"Bearer {self.api_token}"}
# A queue is the ideal structure for a producer-consumer pattern,
# where the API endpoint is the producer and SSE streamers are consumers.
self.result_queue: asyncio.Queue[SentimentResult] = asyncio.Queue()
async def compute_and_publish(self, text: str, request_id: int) -> None:
"""
Performs sentiment analysis via an external API and places the result
into a queue for consumption by SSE streams.
Args:
text: The input text to analyze.
request_id: A unique identifier for this request.
"""
analysis_result = {"label": "ERROR", "score": 0.0, "error": "Unknown failure"}
try:
response = await self.client.post(
self.HF_API_URL,
headers=self.headers,
json={"inputs": text, "options": {"wait_for_model": True}},
timeout=20.0
)
response.raise_for_status()
data = response.json()
# Validate the expected response structure from the Inference API
if isinstance(data, list) and data and isinstance(data[0], list) and data[0]:
# The model returns a list containing a list of results
res = data[0][0]
analysis_result = {"label": res.get("label"), "score": round(res.get("score", 0.0), 4)}
logger.info("β
Sentiment computed for request #%d", request_id)
else:
raise ValueError(f"Unexpected API response format: {data}")
except httpx.HTTPStatusError as e:
error_msg = f"API returned status {e.response.status_code}"
logger.error("β Sentiment API error for request #%d: %s", request_id, error_msg)
analysis_result["error"] = error_msg
except httpx.RequestError as e:
error_msg = f"Network request failed: {e}"
logger.error("β Sentiment network error for request #%d: %s", request_id, error_msg)
analysis_result["error"] = error_msg
except (ValueError, KeyError) as e:
error_msg = f"Failed to parse API response: {e}"
logger.error("β Sentiment parsing error for request #%d: %s", request_id, error_msg)
analysis_result["error"] = error_msg
# Always publish a result to the queue, even if it's an error state
payload: SentimentResult = {
"id": request_id,
"text": text,
"result": analysis_result
}
await self.result_queue.put(payload)
async def stream_results(self) -> SentimentResult:
"""
An async generator that yields new results as they become available.
This is the consumer part of the pattern.
"""
while True:
try:
# This efficiently waits until an item is available in the queue
result = await self.result_queue.get()
yield result
self.result_queue.task_done()
except asyncio.CancelledError:
logger.info("Result stream has been cancelled.")
break |