File size: 3,460 Bytes
e3d46c8 8158997 cbac022 6cf9356 e3d46c8 0788ae6 56b99e8 0788ae6 e3d46c8 56b99e8 0788ae6 a85495d 5320ec6 ae0b0ef 5320ec6 6cf9356 b28d4fd 093cd61 68d6aa3 8158997 68d6aa3 8158997 68d6aa3 b28d4fd 68d6aa3 b28d4fd bceabb4 68d6aa3 bceabb4 68d6aa3 bceabb4 0788ae6 ed6dd13 f7ac541 23bc4d7 bceabb4 68d6aa3 0788ae6 8158997 f5a8947 1af96b9 f5a8947 5e15a3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import streamlit as st
from transformers import pipeline
import torch
import matplotlib.pyplot as plt
#from PIL import Image
#pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
#text = st.text_area('Please type/copy/paste the Dutch article')
#labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis'
# 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech']
#if text:
# out = pipe(text)
# st.json(out)
# load tokenizer and model, create trainer
#model_name = "RuudVelo/dutch_news_classifier_bert_finetuned"
#tokenizer = AutoTokenizer.from_pretrained(model_name)
#model = AutoModelForSequenceClassification.from_pretrained(model_name)
#trainer = Trainer(model=model)
#print(filename, type(filename))
#print(filename.name)
from transformers import BertForSequenceClassification, BertTokenizer
model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
#from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
# Title
st.title("Dutch news article classification")
st.write("This app classifies a Dutch news article into one of 9 pre-defined* article categories")
#image = Image.open('dataset-cover_articles.jpg')
st.image('dataset-cover_articles.jpeg', width=150)
text = st.text_area('Please type/copy/paste text of the Dutch article')
#if text:
# encoding = tokenizer(text, return_tensors="pt")
# outputs = model(**encoding)
# predictions = outputs.logits.argmax(-1)
# probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
## fig = plt.figure()
# ax = fig.add_axes([0,0,1,1])
# labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
# 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
# probs_plot = probabilities[0].cpu().detach().numpy()
# ax.barh(labels_plot,probs_plot )
# st.pyplot(fig)
#input = st.text_input('Context')
if st.button('Submit'):
with st.spinner('Generating a response...'):
encoding = tokenizer(text, return_tensors="pt")
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
probs_plot = probabilities[0].cpu().detach().numpy()
ax.barh(labels_plot,probs_plot)
ax.set_title("Predicted article category probability")
ax.set_xlabel("Probability")
ax.set_ylabel("Predicted category")
st.pyplot(fig)
# output = genQuestion(option, input)
# print(output)
# st.write(output)
#encoding = tokenizer(text, return_tensors="pt")
#import numpy as np
st.write(" * The predefined categories are Binnenland, Buitenland, Cultuur & Media, Economie , Koningshuis, Opmerkelijk, Politiek, 'Regionaal nieuws en Tech")
st.write("The model for this app has been trained using data from Dutch news articles published by NOS. For more information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles")
#st.write('\n')
st.write('The model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_classifier_bert_finetuned')
|