File size: 2,763 Bytes
e3d46c8
 
8158997
 
e3d46c8
0788ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56b99e8
0788ae6
e3d46c8
56b99e8
0788ae6
a85495d
 
 
 
68d6aa3
093cd61
68d6aa3
 
 
 
 
8158997
68d6aa3
 
 
 
 
8158997
68d6aa3
 
 
 
 
 
 
 
9299174
bceabb4
 
 
68d6aa3
bceabb4
 
 
68d6aa3
bceabb4
0788ae6
bceabb4
 
68d6aa3
 
 
0788ae6
8158997
 
 
 
 
 
 
0788ae6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
from transformers import pipeline
import torch
import matplotlib.pyplot as plt

#pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
#text = st.text_area('Please type/copy/paste the Dutch article')

#labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis'
# 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech']

#if text:
#   out = pipe(text)
#   st.json(out)
   
   
   # load tokenizer and model, create trainer
  #model_name = "RuudVelo/dutch_news_classifier_bert_finetuned"
  #tokenizer = AutoTokenizer.from_pretrained(model_name)
  #model = AutoModelForSequenceClassification.from_pretrained(model_name)
  #trainer = Trainer(model=model)  
  #print(filename, type(filename))
  #print(filename.name)
  
from transformers import BertForSequenceClassification, BertTokenizer

model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
#from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")


# Title
st.title("Dutch news article classification")

#text = st.text_area('Please type/copy/paste text of the Dutch article')

#if text:
#   encoding = tokenizer(text, return_tensors="pt")
#   outputs = model(**encoding)
#   predictions = outputs.logits.argmax(-1)
#   probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
   
##   fig = plt.figure()
#   ax = fig.add_axes([0,0,1,1])
#   labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
# 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
#   probs_plot = probabilities[0].cpu().detach().numpy()
   
#   ax.barh(labels_plot,probs_plot )
#   st.pyplot(fig)


input = st.text_input('Context')

if st.button('Submit'):
    with st.spinner('Generating a response...'):
        encoding = tokenizer(input, return_tensors="pt")
        outputs = model(**encoding)
        predictions = outputs.logits.argmax(-1)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
   
        fig = plt.figure()
        ax = fig.add_axes([0,0,1,1])
        labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
        probs_plot = probabilities[0].cpu().detach().numpy()
   
        ax.barh(labels_plot,probs_plot )
        st.pyplot(fig)
   #     output = genQuestion(option, input)
   #     print(output)
   # st.write(output)   
#encoding = tokenizer(text, return_tensors="pt")
#import numpy as np

#arr = np.random.normal(1, 1, size=100)
#fig, ax = plt.subplots()
#ax.hist(arr, bins=20)

#st.pyplot(fig)

# forward pass
#outputs = model(**encoding)
#predictions = outputs.logits.argmax(-1)