File size: 8,201 Bytes
e09333c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# demo_phobert_gradio.py
# -*- coding: utf-8 -*-
import gradio as gr
import torch
import re
import json
import emoji
import numpy as np
from underthesea import word_tokenize
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForSequenceClassification
)
###############################################################################
# TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN
###############################################################################
emoji_mapping = {
"😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]",
"🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]",
"🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]",
"😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]",
"🤑": "[satisfaction]",
"🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]",
"😏": "[sarcasm]",
"😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]",
"😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]",
"😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]",
"🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]",
"🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]",
"😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]",
"😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]",
"😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]",
"😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]",
"😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]"
}
###############################################################################
# HÀM XỬ LÝ (COPY TỪ FILE TRAIN)
###############################################################################
def replace_emojis(sentence, emoji_mapping):
processed_sentence = []
for char in sentence:
if char in emoji_mapping:
processed_sentence.append(emoji_mapping[char])
elif not emoji.is_emoji(char):
processed_sentence.append(char)
return ''.join(processed_sentence)
def remove_profanity(sentence):
profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"]
words = sentence.split()
filtered = [w for w in words if w.lower() not in profane_words]
return ' '.join(filtered)
def remove_special_characters(sentence):
return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)
def normalize_whitespace(sentence):
return ' '.join(sentence.split())
def remove_repeated_characters(sentence):
return re.sub(r"(.)\1{2,}", r"\1", sentence)
def replace_numbers(sentence):
return re.sub(r"\d+", "[number]", sentence)
def tokenize_underthesea(sentence):
tokens = word_tokenize(sentence)
return " ".join(tokens)
# Nếu có abbreviations.json, bạn load. Nếu không thì để rỗng.
try:
with open("abbreviations.json", "r", encoding="utf-8") as f:
abbreviations = json.load(f)
except:
abbreviations = {}
def preprocess_sentence(sentence):
# hạ thấp
sentence = sentence.lower()
# thay thế emoji
sentence = replace_emojis(sentence, emoji_mapping)
# loại bỏ từ nhạy cảm
sentence = remove_profanity(sentence)
# bỏ ký tự đặc biệt
sentence = remove_special_characters(sentence)
# chuẩn hoá khoảng trắng
sentence = normalize_whitespace(sentence)
# thay thế viết tắt
words = sentence.split()
replaced = []
for w in words:
if w in abbreviations:
replaced.append(" ".join(abbreviations[w]))
else:
replaced.append(w)
sentence = " ".join(replaced)
# bỏ bớt kí tự lặp
sentence = remove_repeated_characters(sentence)
# thay số thành [number]
sentence = replace_numbers(sentence)
# tokenize tiếng Việt
sentence = tokenize_underthesea(sentence)
return sentence
###############################################################################
# LOAD CHECKPOINT
###############################################################################
checkpoint_dir = "/home/datpham/datpham/thesis-ngtram/phobert_results/checkpoint-17350" # Đường dẫn đến folder checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading config...")
config = AutoConfig.from_pretrained(checkpoint_dir)
# Mapping id to label theo thứ tự bạn cung cấp
custom_id2label = {
0: 'Anger',
1: 'Disgust',
2: 'Enjoyment',
3: 'Fear',
4: 'Other',
5: 'Sadness',
6: 'Surprise'
}
# Kiểm tra và sử dụng custom_id2label nếu config.id2label không đúng
if hasattr(config, "id2label") and config.id2label:
# Nếu config.id2label chứa 'LABEL_x', sử dụng custom mapping
if all(label.startswith("LABEL_") for label in config.id2label.values()):
id2label = custom_id2label
else:
id2label = {int(k): v for k, v in config.id2label.items()}
else:
id2label = custom_id2label # Sử dụng mapping mặc định nếu config không có id2label
print("id2label loaded:", id2label)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
model.to(device)
model.eval()
###############################################################################
# HÀM PREDICT
###############################################################################
# Mapping từ label đến thông điệp tương ứng
label2message = {
'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.',
'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.',
'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!',
'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.',
'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.',
'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.',
'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.'
}
def predict_text(text: str) -> str:
"""Tiền xử lý, token hoá và chạy model => trả về label và thông điệp."""
text_proc = preprocess_sentence(text)
inputs = tokenizer(
[text_proc],
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
pred_id = outputs.logits.argmax(dim=-1).item()
if pred_id in id2label:
label = id2label[pred_id]
message = label2message.get(label, "")
if message:
return f"Dự đoán cảm xúc: {label}. {message}"
else:
return f"Dự đoán cảm xúc: {label}."
else:
return f"Nhãn không xác định (id={pred_id})"
###############################################################################
# GRADIO APP
###############################################################################
def run_demo(input_text):
predicted_emotion = predict_text(input_text)
return predicted_emotion
demo = gr.Interface(
fn=run_demo,
inputs=gr.Textbox(lines=3, label="Nhập câu tiếng Việt"),
outputs=gr.Textbox(label="Kết quả"),
title="PhoBERT Emotion Classification",
description="Nhập vào 1 câu tiếng Việt để dự đoán cảm xúc."
)
if __name__ == "__main__":
demo.launch(share=True)
|