import gradio as gr import torch import re from underthesea import word_tokenize # Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("VietTung04/videberta-base-topic-classification") model = AutoModelForSequenceClassification.from_pretrained("VietTung04/videberta-base-topic-classification") # Check if GPU is available and set the device accordingly device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) def preprocess_fn(text): stopword_path = 'vietnamese-stopwords.txt' with open(stopword_path, 'r', encoding='utf-8') as file: stopwords = file.read().splitlines() def remove_stopwords(tokens): return [word for word in tokens if word not in stopwords] text = re.sub(r'http\S+', ' ', text) # Remove URLs text = re.sub(r'#\w+', ' ', text) # Remove hashtags text = re.sub(r'@\w+', ' ', text) # Remove mentions text = re.sub(r'\d+', ' ', text) # Remove numbers text = re.sub(r'[^\w\sđĐàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆìÌỉỈĩĨíÍịỊòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰỳỲỷỶỹỸýÝỵỴ]', ' ', text) # Remove special characters # Tokenize Vietnamese text tokens = word_tokenize(' '.join(text.split()).lower()) # Remove stop words tokens = remove_stopwords(tokens) return ' '.join(tokens) def predict_topic(text): inputs = tokenizer( preprocess_fn(text), truncation=True, padding='max_length', max_length=512, add_special_tokens=True, return_tensors='pt' ) inputs = {key: value.to(device) for key, value in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0] # Get the top 3 classes top3_indices = probabilities.argsort()[-3:][::-1] top3_probabilities = probabilities[top3_indices] top3_classes = [model.config.id2label[idx] for idx in top3_indices] # Assuming your model has this attribute return {top3_classes[i]: float(top3_probabilities[i]) for i in range(3)} # Define the Gradio interface iface = gr.Interface( fn=predict_topic, inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), outputs=gr.Label(num_top_classes=3), title="Text Classification", description="Enter text to classify it into different categories and get the probability for each class." ) # Launch the interface iface.launch()