|
|
|
|
|
|
|
import gradio as gr |
|
import torch |
|
import re |
|
import json |
|
import emoji |
|
import numpy as np |
|
from underthesea import word_tokenize |
|
|
|
from transformers import ( |
|
AutoConfig, |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification |
|
) |
|
|
|
|
|
|
|
|
|
emoji_mapping = { |
|
"😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]", |
|
"🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]", |
|
"🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]", |
|
"😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]", |
|
"🤑": "[satisfaction]", |
|
"🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]", |
|
"😏": "[sarcasm]", |
|
"😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]", |
|
"😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]", |
|
"😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]", |
|
"🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]", |
|
"🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]", |
|
"😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]", |
|
"😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]", |
|
"😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]", |
|
"😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]", |
|
"😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]" |
|
} |
|
|
|
|
|
|
|
|
|
def replace_emojis(sentence, emoji_mapping): |
|
processed_sentence = [] |
|
for char in sentence: |
|
if char in emoji_mapping: |
|
processed_sentence.append(emoji_mapping[char]) |
|
elif not emoji.is_emoji(char): |
|
processed_sentence.append(char) |
|
return ''.join(processed_sentence) |
|
|
|
def remove_profanity(sentence): |
|
profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"] |
|
words = sentence.split() |
|
filtered = [w for w in words if w.lower() not in profane_words] |
|
return ' '.join(filtered) |
|
|
|
def remove_special_characters(sentence): |
|
return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence) |
|
|
|
def normalize_whitespace(sentence): |
|
return ' '.join(sentence.split()) |
|
|
|
def remove_repeated_characters(sentence): |
|
return re.sub(r"(.)\1{2,}", r"\1", sentence) |
|
|
|
def replace_numbers(sentence): |
|
return re.sub(r"\d+", "[number]", sentence) |
|
|
|
def tokenize_underthesea(sentence): |
|
tokens = word_tokenize(sentence) |
|
return " ".join(tokens) |
|
|
|
|
|
try: |
|
with open("abbreviations.json", "r", encoding="utf-8") as f: |
|
abbreviations = json.load(f) |
|
except: |
|
abbreviations = {} |
|
|
|
def preprocess_sentence(sentence): |
|
|
|
sentence = sentence.lower() |
|
|
|
sentence = replace_emojis(sentence, emoji_mapping) |
|
|
|
sentence = remove_profanity(sentence) |
|
|
|
sentence = remove_special_characters(sentence) |
|
|
|
sentence = normalize_whitespace(sentence) |
|
|
|
words = sentence.split() |
|
replaced = [] |
|
for w in words: |
|
if w in abbreviations: |
|
replaced.append(" ".join(abbreviations[w])) |
|
else: |
|
replaced.append(w) |
|
sentence = " ".join(replaced) |
|
|
|
sentence = remove_repeated_characters(sentence) |
|
|
|
sentence = replace_numbers(sentence) |
|
|
|
sentence = tokenize_underthesea(sentence) |
|
return sentence |
|
|
|
|
|
|
|
|
|
checkpoint_dir = "/home/datpham/datpham/thesis-ngtram/phobert_results/checkpoint-17350" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
print("Loading config...") |
|
config = AutoConfig.from_pretrained(checkpoint_dir) |
|
|
|
|
|
custom_id2label = { |
|
0: 'Anger', |
|
1: 'Disgust', |
|
2: 'Enjoyment', |
|
3: 'Fear', |
|
4: 'Other', |
|
5: 'Sadness', |
|
6: 'Surprise' |
|
} |
|
|
|
|
|
if hasattr(config, "id2label") and config.id2label: |
|
|
|
if all(label.startswith("LABEL_") for label in config.id2label.values()): |
|
id2label = custom_id2label |
|
else: |
|
id2label = {int(k): v for k, v in config.id2label.items()} |
|
else: |
|
id2label = custom_id2label |
|
|
|
print("id2label loaded:", id2label) |
|
|
|
print("Loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) |
|
|
|
print("Loading model...") |
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
label2message = { |
|
'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.', |
|
'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.', |
|
'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!', |
|
'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.', |
|
'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.', |
|
'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.', |
|
'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.' |
|
} |
|
|
|
def predict_text(text: str) -> str: |
|
"""Tiền xử lý, token hoá và chạy model => trả về label và thông điệp.""" |
|
text_proc = preprocess_sentence(text) |
|
inputs = tokenizer( |
|
[text_proc], |
|
padding=True, |
|
truncation=True, |
|
max_length=256, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
pred_id = outputs.logits.argmax(dim=-1).item() |
|
|
|
if pred_id in id2label: |
|
label = id2label[pred_id] |
|
message = label2message.get(label, "") |
|
if message: |
|
return f"Dự đoán cảm xúc: {label}. {message}" |
|
else: |
|
return f"Dự đoán cảm xúc: {label}." |
|
else: |
|
return f"Nhãn không xác định (id={pred_id})" |
|
|
|
|
|
|
|
|
|
def run_demo(input_text): |
|
predicted_emotion = predict_text(input_text) |
|
return predicted_emotion |
|
|
|
demo = gr.Interface( |
|
fn=run_demo, |
|
inputs=gr.Textbox(lines=3, label="Nhập câu tiếng Việt"), |
|
outputs=gr.Textbox(label="Kết quả"), |
|
title="PhoBERT Emotion Classification", |
|
description="Nhập vào 1 câu tiếng Việt để dự đoán cảm xúc." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|