File size: 2,383 Bytes
de0cb94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9ac1af
de0cb94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from dataclasses import dataclass
from typing import List, Dict, Any

import requests


class ClassificationError(Exception):
    pass


@dataclass
class Classification:
    entity: str
    start: int
    end: int

    def dict(self) -> Dict[str, Any]:
        return {
            'entity': self.entity,
            'start': self.start,
            'end': self.end
        }


class Classificator:

    def __init__(self, config: Dict[str, Any]):
        """
        Initialize the classificator with the given configuration
        """
        self._config = config

    def classify(self, text: str) -> List[Classification]:
        raw_data = self.send_request(text)
        return self.post_process(raw_data)

    def send_request(self, text: str) -> List[Dict[str, Any]]:
        """
            Process the text and return a list of dictionaries with the following keys
        """

        headers = {
            'Authorization': self._config['auth_endpoint_token'],
            'Content-Type': 'application/json',
        }
        try:
            response = requests.post(self._config['endpoint_url'], headers=headers, json={'inputs': text})
            return response.json()
        except Exception:
            raise ClassificationError('Classification failed')

    @staticmethod
    def post_process(raw_data: List[Dict[str, Any]]) -> List[Classification]:
        """
            Process the raw data and return a list of dictionaries with the following keys

            raw_data is a list of dictionaries with the following keys
            {'entity': 'B-Evaluation', 'score': 0.86011535, 'index': 1, 'word': 'Things', 'start': 0, 'end': 6}

            result is a list of classifications with the following keys
            Classification(entity='Evaluation', start=0, end=6)
        """
        classifications = []

        current_entity = None
        for item in raw_data:
            if current_entity is None or current_entity != item['entity'][2:]:
                current_entity = item['entity'][2:]
                classifications.append(
                    Classification(
                        entity=current_entity,
                        start=item['start'],
                        end=item['end']
                    )
                )
            else:
                classifications[-1].end = item['end']
        return classifications