File size: 3,662 Bytes
e3d46c8
 
8158997
 
70318b5
cbac022
6cf9356
e3d46c8
0788ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56b99e8
0788ae6
e3d46c8
56b99e8
0788ae6
a85495d
 
 
 
5320ec6
ae0b0ef
5320ec6
 
6cf9356
2823250
093cd61
68d6aa3
 
 
 
 
8158997
68d6aa3
 
 
 
 
8158997
68d6aa3
 
 
 
b28d4fd
68d6aa3
 
 
b28d4fd
bceabb4
 
035678c
bceabb4
68d6aa3
bceabb4
 
 
68d6aa3
c85cea8
0788ae6
ed6dd13
f7ac541
23bc4d7
 
bceabb4
23f016c
7619618
68d6aa3
 
 
0788ae6
8158997
9fb8c95
1af96b9
f5a8947
5e15a3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from transformers import pipeline
import torch
import matplotlib.pyplot as plt
import numpy as np
#from PIL import Image


#pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
#text = st.text_area('Please type/copy/paste the Dutch article')

#labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis'
# 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech']

#if text:
#   out = pipe(text)
#   st.json(out)
   
   
   # load tokenizer and model, create trainer
  #model_name = "RuudVelo/dutch_news_classifier_bert_finetuned"
  #tokenizer = AutoTokenizer.from_pretrained(model_name)
  #model = AutoModelForSequenceClassification.from_pretrained(model_name)
  #trainer = Trainer(model=model)  
  #print(filename, type(filename))
  #print(filename.name)
  
from transformers import BertForSequenceClassification, BertTokenizer

model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
#from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")


# Title
st.title("Dutch news article classification")

st.write("This app classifies a Dutch news article into one of 9 pre-defined* article categories")

#image = Image.open('dataset-cover_articles.jpg')
st.image('dataset-cover_articles.jpeg', width=150)

text = st.text_area('Please type/copy/paste text of the Dutch article and click Submit')

#if text:
#   encoding = tokenizer(text, return_tensors="pt")
#   outputs = model(**encoding)
#   predictions = outputs.logits.argmax(-1)
#   probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
   
##   fig = plt.figure()
#   ax = fig.add_axes([0,0,1,1])
#   labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
# 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
#   probs_plot = probabilities[0].cpu().detach().numpy()
   
#   ax.barh(labels_plot,probs_plot )
#   st.pyplot(fig)


#input = st.text_input('Context')

if st.button('Submit'):
    with st.spinner('Generating a response...'):
        encoding = tokenizer(text, return_tensors="pt")
        outputs = model(**encoding)
        predictions = outputs.logits.argmax(-1)
        number = predictions[0].cpu().detach().numpy()
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
   
        fig = plt.figure()
        ax = fig.add_axes([0,0,1,1])
        labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
        probs_plot = probabilities[0].cpu().detach().numpy()*100
   
        ax.barh(labels_plot,probs_plot)
        ax.set_title("Predicted article category probability")
        ax.set_xlabel("Probability")
        ax.set_ylabel("Predicted category")
        st.pyplot(fig)
        
        st.write('Category: {} | Probability: {:.1f}%'.format(number,(probs_plot[predictions])*1))
   #     output = genQuestion(option, input)
   #     print(output)
   # st.write(output)   
#encoding = tokenizer(text, return_tensors="pt")
#import numpy as np
st.write("* The predefined categories are Binnenland, Buitenland, Cultuur & Media, Economie , Koningshuis, Opmerkelijk, Politiek, 'Regionaal nieuws en Tech")
st.write("The model for this app has been trained using data from Dutch news articles published by NOS. For more information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles")
#st.write('\n')
st.write('The model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_classifier_bert_finetuned')