File size: 3,160 Bytes
acb3380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os

from huggingface_hub.inference_api import (
    InferenceApi,  # type: ignore[import] # FIX ME
)

from ctm.messengers.messenger_base import BaseMessenger
from ctm.processors.processor_base import BaseProcessor


@BaseProcessor.register_processor("roberta_text_sentiment_processor")  # type: ignore[no-untyped-call] # FIX ME
class RobertaTextSentimentProcessor(BaseProcessor):
    def __init__(self, *args, **kwargs):  # type: ignore[no-untyped-def] # FIX ME
        self.init_processor()  # type: ignore[no-untyped-call] # FIX ME

    def init_processor(self):  # type: ignore[no-untyped-def] # FIX ME
        self.model = InferenceApi(
            token=os.environ["HF_TOKEN"],
            repo_id="cardiffnlp/twitter-roberta-base-sentiment-latest",
        )
        self.messenger = BaseMessenger("roberta_text_sentiment_messenger")  # type: ignore[no-untyped-call] # FIX ME
        return

    def update_info(self, feedback: str):  # type: ignore[no-untyped-def] # FIX ME
        self.messenger.add_assistant_message(feedback)

    def ask_info(  # type: ignore[override] # FIX ME
        self,
        query: str,
        context: str = None,  # type: ignore[assignment] # FIX ME
        image_path: str = None,  # type: ignore[assignment] # FIX ME
        audio_path: str = None,  # type: ignore[assignment] # FIX ME
        video_path: str = None,  # type: ignore[assignment] # FIX ME
    ) -> str:
        if self.messenger.check_iter_round_num() == 0:  # type: ignore[no-untyped-call] # FIX ME
            self.messenger.add_user_message(context)

        response = self.model(self.messenger.get_messages())  # type: ignore[no-untyped-call] # FIX ME
        results = response[0]
        # choose the label with the highest score
        pos_score = 0
        neg_score = 0
        neutral_score = 0
        for result in results:
            if result["label"] == "POSITIVE":
                pos_score = result["score"]
            elif result["label"] == "NEGATIVE":
                neg_score = result["score"]
            else:
                neutral_score = result["score"]
        if max(pos_score, neg_score, neutral_score) == pos_score:
            return "This text is positive."
        elif max(pos_score, neg_score, neutral_score) == neg_score:
            return "This text is negative."
        else:
            return "This text is neutral."


if __name__ == "__main__":
    processor = BaseProcessor("roberta_text_sentiment_processor")  # type: ignore[no-untyped-call] # FIX ME
    image_path = "../ctmai-test1.png"
    text: str = (
        "In a shocking turn of events, Hugging Face has released a new version of Transformers "
        "that brings several enhancements and bug fixes. Users are thrilled with the improvements "
        "and are finding the new version to be significantly better than the previous one. "
        "The Hugging Face team is thankful for the community's support and continues to work "
        "towards making the library the best it can be."
    )
    label = processor.ask_info(query=None, context=text, image_path=image_path)  # type: ignore[no-untyped-call] # FIX ME
    print(label)