|
import gradio as gr |
|
import torch |
|
import re |
|
from underthesea import word_tokenize |
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("VietTung04/videberta-base-topic-classification") |
|
model = AutoModelForSequenceClassification.from_pretrained("VietTung04/videberta-base-topic-classification") |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
def preprocess_fn(text): |
|
stopword_path = 'vietnamese-stopwords.txt' |
|
|
|
with open(stopword_path, 'r', encoding='utf-8') as file: |
|
stopwords = file.read().splitlines() |
|
|
|
def remove_stopwords(tokens): |
|
return [word for word in tokens if word not in stopwords] |
|
|
|
text = re.sub(r'http\S+', ' ', text) |
|
text = re.sub(r'#\w+', ' ', text) |
|
text = re.sub(r'@\w+', ' ', text) |
|
text = re.sub(r'\d+', ' ', text) |
|
text = re.sub(r'[^\w\sđĐàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆìÌỉỈĩĨíÍịỊòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰỳỲỷỶỹỸýÝỵỴ]', ' ', text) |
|
|
|
tokens = word_tokenize(' '.join(text.split()).lower()) |
|
|
|
|
|
tokens = remove_stopwords(tokens) |
|
|
|
return ' '.join(tokens) |
|
|
|
def predict_topic(text): |
|
inputs = tokenizer( |
|
preprocess_fn(text), |
|
truncation=True, |
|
padding='max_length', |
|
max_length=512, |
|
add_special_tokens=True, |
|
return_tensors='pt' |
|
) |
|
inputs = {key: value.to(device) for key, value in inputs.items()} |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0] |
|
|
|
|
|
top3_indices = probabilities.argsort()[-3:][::-1] |
|
top3_probabilities = probabilities[top3_indices] |
|
top3_classes = [model.config.id2label[idx] for idx in top3_indices] |
|
|
|
return {top3_classes[i]: float(top3_probabilities[i]) for i in range(3)} |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_topic, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), |
|
outputs=gr.Label(num_top_classes=3), |
|
title="Text Classification", |
|
description="Enter text to classify it into different categories and get the probability for each class." |
|
) |
|
|
|
|
|
iface.launch() |