ducdatit2002's picture
Upload folder using huggingface_hub
e09333c verified
raw
history blame
8.2 kB
# demo_phobert_gradio.py
# -*- coding: utf-8 -*-
import gradio as gr
import torch
import re
import json
import emoji
import numpy as np
from underthesea import word_tokenize
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForSequenceClassification
)
###############################################################################
# TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN
###############################################################################
emoji_mapping = {
"😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]",
"🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]",
"🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]",
"😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]",
"🤑": "[satisfaction]",
"🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]",
"😏": "[sarcasm]",
"😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]",
"😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]",
"😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]",
"🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]",
"🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]",
"😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]",
"😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]",
"😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]",
"😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]",
"😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]"
}
###############################################################################
# HÀM XỬ LÝ (COPY TỪ FILE TRAIN)
###############################################################################
def replace_emojis(sentence, emoji_mapping):
processed_sentence = []
for char in sentence:
if char in emoji_mapping:
processed_sentence.append(emoji_mapping[char])
elif not emoji.is_emoji(char):
processed_sentence.append(char)
return ''.join(processed_sentence)
def remove_profanity(sentence):
profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"]
words = sentence.split()
filtered = [w for w in words if w.lower() not in profane_words]
return ' '.join(filtered)
def remove_special_characters(sentence):
return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)
def normalize_whitespace(sentence):
return ' '.join(sentence.split())
def remove_repeated_characters(sentence):
return re.sub(r"(.)\1{2,}", r"\1", sentence)
def replace_numbers(sentence):
return re.sub(r"\d+", "[number]", sentence)
def tokenize_underthesea(sentence):
tokens = word_tokenize(sentence)
return " ".join(tokens)
# Nếu có abbreviations.json, bạn load. Nếu không thì để rỗng.
try:
with open("abbreviations.json", "r", encoding="utf-8") as f:
abbreviations = json.load(f)
except:
abbreviations = {}
def preprocess_sentence(sentence):
# hạ thấp
sentence = sentence.lower()
# thay thế emoji
sentence = replace_emojis(sentence, emoji_mapping)
# loại bỏ từ nhạy cảm
sentence = remove_profanity(sentence)
# bỏ ký tự đặc biệt
sentence = remove_special_characters(sentence)
# chuẩn hoá khoảng trắng
sentence = normalize_whitespace(sentence)
# thay thế viết tắt
words = sentence.split()
replaced = []
for w in words:
if w in abbreviations:
replaced.append(" ".join(abbreviations[w]))
else:
replaced.append(w)
sentence = " ".join(replaced)
# bỏ bớt kí tự lặp
sentence = remove_repeated_characters(sentence)
# thay số thành [number]
sentence = replace_numbers(sentence)
# tokenize tiếng Việt
sentence = tokenize_underthesea(sentence)
return sentence
###############################################################################
# LOAD CHECKPOINT
###############################################################################
checkpoint_dir = "/home/datpham/datpham/thesis-ngtram/phobert_results/checkpoint-17350" # Đường dẫn đến folder checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading config...")
config = AutoConfig.from_pretrained(checkpoint_dir)
# Mapping id to label theo thứ tự bạn cung cấp
custom_id2label = {
0: 'Anger',
1: 'Disgust',
2: 'Enjoyment',
3: 'Fear',
4: 'Other',
5: 'Sadness',
6: 'Surprise'
}
# Kiểm tra và sử dụng custom_id2label nếu config.id2label không đúng
if hasattr(config, "id2label") and config.id2label:
# Nếu config.id2label chứa 'LABEL_x', sử dụng custom mapping
if all(label.startswith("LABEL_") for label in config.id2label.values()):
id2label = custom_id2label
else:
id2label = {int(k): v for k, v in config.id2label.items()}
else:
id2label = custom_id2label # Sử dụng mapping mặc định nếu config không có id2label
print("id2label loaded:", id2label)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
model.to(device)
model.eval()
###############################################################################
# HÀM PREDICT
###############################################################################
# Mapping từ label đến thông điệp tương ứng
label2message = {
'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.',
'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.',
'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!',
'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.',
'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.',
'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.',
'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.'
}
def predict_text(text: str) -> str:
"""Tiền xử lý, token hoá và chạy model => trả về label và thông điệp."""
text_proc = preprocess_sentence(text)
inputs = tokenizer(
[text_proc],
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
pred_id = outputs.logits.argmax(dim=-1).item()
if pred_id in id2label:
label = id2label[pred_id]
message = label2message.get(label, "")
if message:
return f"Dự đoán cảm xúc: {label}. {message}"
else:
return f"Dự đoán cảm xúc: {label}."
else:
return f"Nhãn không xác định (id={pred_id})"
###############################################################################
# GRADIO APP
###############################################################################
def run_demo(input_text):
predicted_emotion = predict_text(input_text)
return predicted_emotion
demo = gr.Interface(
fn=run_demo,
inputs=gr.Textbox(lines=3, label="Nhập câu tiếng Việt"),
outputs=gr.Textbox(label="Kết quả"),
title="PhoBERT Emotion Classification",
description="Nhập vào 1 câu tiếng Việt để dự đoán cảm xúc."
)
if __name__ == "__main__":
demo.launch(share=True)