File size: 4,841 Bytes
fb17746
c402b7d
 
 
 
 
fb17746
c402b7d
f611cc3
c402b7d
 
 
061fd19
f611cc3
c402b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f611cc3
c402b7d
 
 
f611cc3
c402b7d
f611cc3
c402b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
061fd19
c402b7d
 
 
 
 
 
061fd19
 
c402b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f611cc3
c402b7d
f611cc3
c402b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Provides a robust, asynchronous SentimentAnalyzer class.

This module communicates with the Hugging Face Inference API to perform sentiment
analysis without requiring local model downloads. It's designed for use within
an asynchronous application like FastAPI.
"""
import asyncio
import logging
import os
from typing import TypedDict

import httpx

# --- Configuration & Models ---

# Configure logging for this module
logger = logging.getLogger(__name__)

# Define the expected structure of a result payload for type hinting
class SentimentResult(TypedDict):
    id: int
    text: str
    result: dict[str, str | float]


# --- Main Class: SentimentAnalyzer ---
class SentimentAnalyzer:
    """
    Manages sentiment analysis requests to the Hugging Face Inference API.

    This class handles asynchronous API communication, manages a result queue for
    Server-Sent Events (SSE), and encapsulates all related state and logic.
    """

    HF_API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english"

    def __init__(self, client: httpx.AsyncClient, api_token: str | None = None):
        """
        Initializes the SentimentAnalyzer.

        Args:
            client: An instance of httpx.AsyncClient for making API calls.
            api_token: The Hugging Face API token.
        """
        self.client = client
        self.api_token = api_token or os.getenv("HF_API_TOKEN")
        
        if not self.api_token:
            raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
            
        self.headers = {"Authorization": f"Bearer {self.api_token}"}
        
        # A queue is the ideal structure for a producer-consumer pattern,
        # where the API endpoint is the producer and SSE streamers are consumers.
        self.result_queue: asyncio.Queue[SentimentResult] = asyncio.Queue()

    async def compute_and_publish(self, text: str, request_id: int) -> None:
        """
        Performs sentiment analysis via an external API and places the result
        into a queue for consumption by SSE streams.

        Args:
            text: The input text to analyze.
            request_id: A unique identifier for this request.
        """
        analysis_result = {"label": "ERROR", "score": 0.0, "error": "Unknown failure"}
        try:
            response = await self.client.post(
                self.HF_API_URL,
                headers=self.headers,
                json={"inputs": text, "options": {"wait_for_model": True}},
                timeout=20.0
            )
            response.raise_for_status()
            data = response.json()

            # Validate the expected response structure from the Inference API
            if isinstance(data, list) and data and isinstance(data[0], list) and data[0]:
                # The model returns a list containing a list of results
                res = data[0][0]
                analysis_result = {"label": res.get("label"), "score": round(res.get("score", 0.0), 4)}
                logger.info("βœ… Sentiment computed for request #%d", request_id)
            else:
                raise ValueError(f"Unexpected API response format: {data}")

        except httpx.HTTPStatusError as e:
            error_msg = f"API returned status {e.response.status_code}"
            logger.error("❌ Sentiment API error for request #%d: %s", request_id, error_msg)
            analysis_result["error"] = error_msg
        except httpx.RequestError as e:
            error_msg = f"Network request failed: {e}"
            logger.error("❌ Sentiment network error for request #%d: %s", request_id, error_msg)
            analysis_result["error"] = error_msg
        except (ValueError, KeyError) as e:
            error_msg = f"Failed to parse API response: {e}"
            logger.error("❌ Sentiment parsing error for request #%d: %s", request_id, error_msg)
            analysis_result["error"] = error_msg

        # Always publish a result to the queue, even if it's an error state
        payload: SentimentResult = {
            "id": request_id,
            "text": text,
            "result": analysis_result
        }
        await self.result_queue.put(payload)

    async def stream_results(self) -> SentimentResult:
        """
        An async generator that yields new results as they become available.
        This is the consumer part of the pattern.
        """
        while True:
            try:
                # This efficiently waits until an item is available in the queue
                result = await self.result_queue.get()
                yield result
                self.result_queue.task_done()
            except asyncio.CancelledError:
                logger.info("Result stream has been cancelled.")
                break