File size: 3,128 Bytes
885b434 1af1aed 885b434 5a63293 cb34ab7 5a63293 885b434 4c9facd 0da98ac 4c9facd 885b434 82c6a1e b9bec37 1dd5bbf b9bec37 885b434 1dd5bbf 1af1aed 1dd5bbf f0e5035 82c6a1e f0e5035 82c6a1e 1dd5bbf f0e5035 82c6a1e 1af1aed f0e5035 1dd5bbf 9a2fed2 7fee8c9 15a27ae 0da98ac 7fee8c9 15a27ae 7fee8c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import csv
MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News")
model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News")
categories = {
"arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang",
"crime_law_and_justice": "Bosenyi, molao le bosiamisi",
"disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso",
"economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete",
"education": "Thuto",
"environment": "Tikologo",
"health": "Boitekanelo",
"politics": "Dipolotiki",
"religion_and_belief": "Bodumedi le tumelo",
"society": "Setšhaba"
}
def prediction(news):
clasifer = pipeline("text-classification", tokenizer=tokenizer, model=model, return_all_scores=True)
preds = clasifer(news)
preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
return preds_dict
def file_prediction(file):
if file.name.endswith('.csv'):
file.seek(0)
reader = csv.reader(file.read().decode('utf-8').splitlines())
news_list = [row[0] for row in reader if row]
else:
file.seek(0)
file_content = file.read().decode('utf-8')
news_list = file_content.splitlines()
results = []
for news in news_list:
if news.strip():
pred = prediction(news)
results.append({"news": news, "predictions": pred})
output = "News Article | Predictions\n"
output += "-" * 50 + "\n"
for result in results:
output += f"{result['news']} | {result['predictions']}\n"
return output
gradio_ui = gr.Interface(
fn=prediction,
title="Setswana News Classification",
description=f"Enter Setswana news article to see the category of the news.\n For this classification, the {MODEL_URL} model was used.",
inputs=gr.Textbox(lines=10, label="Paste some Setswana news here"),
outputs=gr.Label(num_top_classes=5, label="News categories probabilities"),
theme="default",
article="<p style='text-align: center'>For our other AI works: <a href='https://www.kodiks.com/ai_solutions.html' target='_blank'>https://www.kodiks.com/ai_solutions.html</a> | <a href='https://twitter.com/KodiksBilisim' target='_blank'>Contact us</a></p>",
)
gradio_file_ui = gr.Interface(
fn=file_prediction,
title="Upload File for Setswana News Classification",
description=f"Upload a text or CSV file with Setswana news articles. The first column in the CSV should contain the news text.",
inputs=gr.File(label="Upload text or CSV file"),
outputs=gr.Dataframe(headers=["News Text", "Category Predictions"], label="Predictions from file"),
theme="default"
)
gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
gradio_combined_ui.launch()
|