# demo_phobert_gradio.py # -*- coding: utf-8 -*- import gradio as gr import torch import re import json import emoji import numpy as np from underthesea import word_tokenize from transformers import ( AutoConfig, AutoTokenizer, AutoModelForSequenceClassification ) ############################################################################### # TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN ############################################################################### emoji_mapping = { "😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]", "🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]", "🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]", "😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]", "🤑": "[satisfaction]", "🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]", "😏": "[sarcasm]", "😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]", "😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]", "😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]", "🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]", "🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]", "😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]", "😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]", "😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]", "😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]", "😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]" } ############################################################################### # HÀM XỬ LÝ (COPY TỪ FILE TRAIN) ############################################################################### def replace_emojis(sentence, emoji_mapping): processed_sentence = [] for char in sentence: if char in emoji_mapping: processed_sentence.append(emoji_mapping[char]) elif not emoji.is_emoji(char): processed_sentence.append(char) return ''.join(processed_sentence) def remove_profanity(sentence): profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"] words = sentence.split() filtered = [w for w in words if w.lower() not in profane_words] return ' '.join(filtered) def remove_special_characters(sentence): return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence) def normalize_whitespace(sentence): return ' '.join(sentence.split()) def remove_repeated_characters(sentence): return re.sub(r"(.)\1{2,}", r"\1", sentence) def replace_numbers(sentence): return re.sub(r"\d+", "[number]", sentence) def tokenize_underthesea(sentence): tokens = word_tokenize(sentence) return " ".join(tokens) # Nếu có abbreviations.json, bạn load. Nếu không thì để rỗng. try: with open("abbreviations.json", "r", encoding="utf-8") as f: abbreviations = json.load(f) except: abbreviations = {} def preprocess_sentence(sentence): # hạ thấp sentence = sentence.lower() # thay thế emoji sentence = replace_emojis(sentence, emoji_mapping) # loại bỏ từ nhạy cảm sentence = remove_profanity(sentence) # bỏ ký tự đặc biệt sentence = remove_special_characters(sentence) # chuẩn hoá khoảng trắng sentence = normalize_whitespace(sentence) # thay thế viết tắt words = sentence.split() replaced = [] for w in words: if w in abbreviations: replaced.append(" ".join(abbreviations[w])) else: replaced.append(w) sentence = " ".join(replaced) # bỏ bớt kí tự lặp sentence = remove_repeated_characters(sentence) # thay số thành [number] sentence = replace_numbers(sentence) # tokenize tiếng Việt sentence = tokenize_underthesea(sentence) return sentence ############################################################################### # LOAD CHECKPOINT ############################################################################### checkpoint_dir = "/home/datpham/datpham/thesis-ngtram/phobert_results/checkpoint-17350" # Đường dẫn đến folder checkpoint device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading config...") config = AutoConfig.from_pretrained(checkpoint_dir) # Mapping id to label theo thứ tự bạn cung cấp custom_id2label = { 0: 'Anger', 1: 'Disgust', 2: 'Enjoyment', 3: 'Fear', 4: 'Other', 5: 'Sadness', 6: 'Surprise' } # Kiểm tra và sử dụng custom_id2label nếu config.id2label không đúng if hasattr(config, "id2label") and config.id2label: # Nếu config.id2label chứa 'LABEL_x', sử dụng custom mapping if all(label.startswith("LABEL_") for label in config.id2label.values()): id2label = custom_id2label else: id2label = {int(k): v for k, v in config.id2label.items()} else: id2label = custom_id2label # Sử dụng mapping mặc định nếu config không có id2label print("id2label loaded:", id2label) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) print("Loading model...") model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config) model.to(device) model.eval() ############################################################################### # HÀM PREDICT ############################################################################### # Mapping từ label đến thông điệp tương ứng label2message = { 'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.', 'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.', 'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!', 'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.', 'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.', 'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.', 'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.' } def predict_text(text: str) -> str: """Tiền xử lý, token hoá và chạy model => trả về label và thông điệp.""" text_proc = preprocess_sentence(text) inputs = tokenizer( [text_proc], padding=True, truncation=True, max_length=256, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) pred_id = outputs.logits.argmax(dim=-1).item() if pred_id in id2label: label = id2label[pred_id] message = label2message.get(label, "") if message: return f"Dự đoán cảm xúc: {label}. {message}" else: return f"Dự đoán cảm xúc: {label}." else: return f"Nhãn không xác định (id={pred_id})" ############################################################################### # GRADIO APP ############################################################################### def run_demo(input_text): predicted_emotion = predict_text(input_text) return predicted_emotion demo = gr.Interface( fn=run_demo, inputs=gr.Textbox(lines=3, label="Nhập câu tiếng Việt"), outputs=gr.Textbox(label="Kết quả"), title="PhoBERT Emotion Classification", description="Nhập vào 1 câu tiếng Việt để dự đoán cảm xúc." ) if __name__ == "__main__": demo.launch(share=True)