CryptoSentinel_AI / app /sentiment.py
mgbam's picture
Update app/sentiment.py
c402b7d verified
raw
history blame
4.84 kB
"""
Provides a robust, asynchronous SentimentAnalyzer class.
This module communicates with the Hugging Face Inference API to perform sentiment
analysis without requiring local model downloads. It's designed for use within
an asynchronous application like FastAPI.
"""
import asyncio
import logging
import os
from typing import TypedDict
import httpx
# --- Configuration & Models ---
# Configure logging for this module
logger = logging.getLogger(__name__)
# Define the expected structure of a result payload for type hinting
class SentimentResult(TypedDict):
id: int
text: str
result: dict[str, str | float]
# --- Main Class: SentimentAnalyzer ---
class SentimentAnalyzer:
"""
Manages sentiment analysis requests to the Hugging Face Inference API.
This class handles asynchronous API communication, manages a result queue for
Server-Sent Events (SSE), and encapsulates all related state and logic.
"""
HF_API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english"
def __init__(self, client: httpx.AsyncClient, api_token: str | None = None):
"""
Initializes the SentimentAnalyzer.
Args:
client: An instance of httpx.AsyncClient for making API calls.
api_token: The Hugging Face API token.
"""
self.client = client
self.api_token = api_token or os.getenv("HF_API_TOKEN")
if not self.api_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
self.headers = {"Authorization": f"Bearer {self.api_token}"}
# A queue is the ideal structure for a producer-consumer pattern,
# where the API endpoint is the producer and SSE streamers are consumers.
self.result_queue: asyncio.Queue[SentimentResult] = asyncio.Queue()
async def compute_and_publish(self, text: str, request_id: int) -> None:
"""
Performs sentiment analysis via an external API and places the result
into a queue for consumption by SSE streams.
Args:
text: The input text to analyze.
request_id: A unique identifier for this request.
"""
analysis_result = {"label": "ERROR", "score": 0.0, "error": "Unknown failure"}
try:
response = await self.client.post(
self.HF_API_URL,
headers=self.headers,
json={"inputs": text, "options": {"wait_for_model": True}},
timeout=20.0
)
response.raise_for_status()
data = response.json()
# Validate the expected response structure from the Inference API
if isinstance(data, list) and data and isinstance(data[0], list) and data[0]:
# The model returns a list containing a list of results
res = data[0][0]
analysis_result = {"label": res.get("label"), "score": round(res.get("score", 0.0), 4)}
logger.info("βœ… Sentiment computed for request #%d", request_id)
else:
raise ValueError(f"Unexpected API response format: {data}")
except httpx.HTTPStatusError as e:
error_msg = f"API returned status {e.response.status_code}"
logger.error("❌ Sentiment API error for request #%d: %s", request_id, error_msg)
analysis_result["error"] = error_msg
except httpx.RequestError as e:
error_msg = f"Network request failed: {e}"
logger.error("❌ Sentiment network error for request #%d: %s", request_id, error_msg)
analysis_result["error"] = error_msg
except (ValueError, KeyError) as e:
error_msg = f"Failed to parse API response: {e}"
logger.error("❌ Sentiment parsing error for request #%d: %s", request_id, error_msg)
analysis_result["error"] = error_msg
# Always publish a result to the queue, even if it's an error state
payload: SentimentResult = {
"id": request_id,
"text": text,
"result": analysis_result
}
await self.result_queue.put(payload)
async def stream_results(self) -> SentimentResult:
"""
An async generator that yields new results as they become available.
This is the consumer part of the pattern.
"""
while True:
try:
# This efficiently waits until an item is available in the queue
result = await self.result_queue.get()
yield result
self.result_queue.task_done()
except asyncio.CancelledError:
logger.info("Result stream has been cancelled.")
break