File size: 8,201 Bytes
e09333c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# demo_phobert_gradio.py
# -*- coding: utf-8 -*-

import gradio as gr
import torch
import re
import json
import emoji
import numpy as np
from underthesea import word_tokenize

from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification
)

###############################################################################
# TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN
###############################################################################
emoji_mapping = {
    "😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]",
    "🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]",
    "🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]",
    "😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]",
    "🤑": "[satisfaction]",
    "🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]",
    "😏": "[sarcasm]",
    "😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]",
    "😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]",
    "😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]",
    "🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]",
    "🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]",
    "😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]",
    "😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]",
    "😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]",
    "😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]",
    "😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]"
}

###############################################################################
# HÀM XỬ LÝ (COPY TỪ FILE TRAIN)
###############################################################################
def replace_emojis(sentence, emoji_mapping):
    processed_sentence = []
    for char in sentence:
        if char in emoji_mapping:
            processed_sentence.append(emoji_mapping[char])
        elif not emoji.is_emoji(char):
            processed_sentence.append(char)
    return ''.join(processed_sentence)

def remove_profanity(sentence):
    profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"]
    words = sentence.split()
    filtered = [w for w in words if w.lower() not in profane_words]
    return ' '.join(filtered)

def remove_special_characters(sentence):
    return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)

def normalize_whitespace(sentence):
    return ' '.join(sentence.split())

def remove_repeated_characters(sentence):
    return re.sub(r"(.)\1{2,}", r"\1", sentence)

def replace_numbers(sentence):
    return re.sub(r"\d+", "[number]", sentence)

def tokenize_underthesea(sentence):
    tokens = word_tokenize(sentence)
    return " ".join(tokens)

# Nếu có abbreviations.json, bạn load. Nếu không thì để rỗng.
try:
    with open("abbreviations.json", "r", encoding="utf-8") as f:
        abbreviations = json.load(f)
except:
    abbreviations = {}

def preprocess_sentence(sentence):
    # hạ thấp
    sentence = sentence.lower()
    # thay thế emoji
    sentence = replace_emojis(sentence, emoji_mapping)
    # loại bỏ từ nhạy cảm
    sentence = remove_profanity(sentence)
    # bỏ ký tự đặc biệt
    sentence = remove_special_characters(sentence)
    # chuẩn hoá khoảng trắng
    sentence = normalize_whitespace(sentence)
    # thay thế viết tắt
    words = sentence.split()
    replaced = []
    for w in words:
        if w in abbreviations:
            replaced.append(" ".join(abbreviations[w]))
        else:
            replaced.append(w)
    sentence = " ".join(replaced)
    # bỏ bớt kí tự lặp
    sentence = remove_repeated_characters(sentence)
    # thay số thành [number]
    sentence = replace_numbers(sentence)
    # tokenize tiếng Việt
    sentence = tokenize_underthesea(sentence)
    return sentence

###############################################################################
# LOAD CHECKPOINT
###############################################################################
checkpoint_dir = "/home/datpham/datpham/thesis-ngtram/phobert_results/checkpoint-17350"  # Đường dẫn đến folder checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading config...")
config = AutoConfig.from_pretrained(checkpoint_dir)

# Mapping id to label theo thứ tự bạn cung cấp
custom_id2label = {
    0: 'Anger',
    1: 'Disgust',
    2: 'Enjoyment',
    3: 'Fear',
    4: 'Other',
    5: 'Sadness',
    6: 'Surprise'
}

# Kiểm tra và sử dụng custom_id2label nếu config.id2label không đúng
if hasattr(config, "id2label") and config.id2label:
    # Nếu config.id2label chứa 'LABEL_x', sử dụng custom mapping
    if all(label.startswith("LABEL_") for label in config.id2label.values()):
        id2label = custom_id2label
    else:
        id2label = {int(k): v for k, v in config.id2label.items()}
else:
    id2label = custom_id2label  # Sử dụng mapping mặc định nếu config không có id2label

print("id2label loaded:", id2label)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)

print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
model.to(device)
model.eval()

###############################################################################
# HÀM PREDICT
###############################################################################
# Mapping từ label đến thông điệp tương ứng
label2message = {
    'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.',
    'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.',
    'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!',
    'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.',
    'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.',
    'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.',
    'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.'
}

def predict_text(text: str) -> str:
    """Tiền xử lý, token hoá và chạy model => trả về label và thông điệp."""
    text_proc = preprocess_sentence(text)
    inputs = tokenizer(
        [text_proc],
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        pred_id = outputs.logits.argmax(dim=-1).item()

    if pred_id in id2label:
        label = id2label[pred_id]
        message = label2message.get(label, "")
        if message:
            return f"Dự đoán cảm xúc: {label}. {message}"
        else:
            return f"Dự đoán cảm xúc: {label}."
    else:
        return f"Nhãn không xác định (id={pred_id})"

###############################################################################
# GRADIO APP
###############################################################################
def run_demo(input_text):
    predicted_emotion = predict_text(input_text)
    return predicted_emotion

demo = gr.Interface(
    fn=run_demo,
    inputs=gr.Textbox(lines=3, label="Nhập câu tiếng Việt"),
    outputs=gr.Textbox(label="Kết quả"),
    title="PhoBERT Emotion Classification",
    description="Nhập vào 1 câu tiếng Việt để dự đoán cảm xúc."
)

if __name__ == "__main__":
    demo.launch(share=True)