File size: 2,066 Bytes
e3d46c8
 
8158997
 
e3d46c8
0788ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56b99e8
0788ae6
e3d46c8
56b99e8
0788ae6
 
 
093cd61
e3d46c8
0788ae6
 
 
8158997
 
 
 
56b99e8
8158997
56b99e8
8158997
56b99e8
8158997
56b99e8
8158997
0788ae6
8158997
0788ae6
 
8158997
 
 
 
 
 
 
0788ae6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
from transformers import pipeline
import torch
import matplotlib.pyplot as plt

#pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
#text = st.text_area('Please type/copy/paste the Dutch article')

#labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis'
# 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech']

#if text:
#   out = pipe(text)
#   st.json(out)
   
   
   # load tokenizer and model, create trainer
  #model_name = "RuudVelo/dutch_news_classifier_bert_finetuned"
  #tokenizer = AutoTokenizer.from_pretrained(model_name)
  #model = AutoModelForSequenceClassification.from_pretrained(model_name)
  #trainer = Trainer(model=model)  
  #print(filename, type(filename))
  #print(filename.name)
  
from transformers import BertForSequenceClassification, BertTokenizer

model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
#from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")

#text = ["this is one sentence", "this is another sentence"]
text = st.text_area('Please type/copy/paste the Dutch article')

if text:
   encoding = tokenizer(text, return_tensors="pt")
   outputs = model(**encoding)
   predictions = outputs.logits.argmax(-1)
   probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
   
   fig = plt.figure()
   ax = fig.add_axes([0,0,1,1])
   labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
   probs_plot = probabilities[0].cpu().detach().numpy()
   
   ax.barh(labels_plot,probs_plot )
   st.pyplot(fig)
   st.json(predictions)
   #plt.show()
   #out = pipe(text)
   #st.json(predictions)
   
#encoding = tokenizer(text, return_tensors="pt")
#import numpy as np

#arr = np.random.normal(1, 1, size=100)
#fig, ax = plt.subplots()
#ax.hist(arr, bins=20)

#st.pyplot(fig)

# forward pass
#outputs = model(**encoding)
#predictions = outputs.logits.argmax(-1)