RuudVelo commited on
Commit
8158997
1 Parent(s): 0788ae6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -1
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import streamlit as st
2
  from transformers import pipeline
 
 
3
 
4
  #pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
5
  #text = st.text_area('Please type/copy/paste the Dutch article')
@@ -34,10 +36,28 @@ if text:
34
  encoding = tokenizer(text, return_tensors="pt")
35
  outputs = model(**encoding)
36
  predictions = outputs.logits.argmax(-1)
 
 
 
 
 
 
 
 
 
 
 
37
  #out = pipe(text)
38
- st.json(predictions)
39
 
40
  #encoding = tokenizer(text, return_tensors="pt")
 
 
 
 
 
 
 
41
 
42
  # forward pass
43
  #outputs = model(**encoding)
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
 
6
  #pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
7
  #text = st.text_area('Please type/copy/paste the Dutch article')
 
36
  encoding = tokenizer(text, return_tensors="pt")
37
  outputs = model(**encoding)
38
  predictions = outputs.logits.argmax(-1)
39
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
40
+
41
+ fig = plt.figure()
42
+ ax = fig.add_axes([0,0,1,1])
43
+ langs = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
44
+ 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
45
+ students = probabilities[0].cpu().detach().numpy()
46
+
47
+ ax.barh(langs,students)
48
+ st.pyplot(fig)
49
+ #plt.show()
50
  #out = pipe(text)
51
+ #st.json(predictions)
52
 
53
  #encoding = tokenizer(text, return_tensors="pt")
54
+ #import numpy as np
55
+
56
+ #arr = np.random.normal(1, 1, size=100)
57
+ #fig, ax = plt.subplots()
58
+ #ax.hist(arr, bins=20)
59
+
60
+ #st.pyplot(fig)
61
 
62
  # forward pass
63
  #outputs = model(**encoding)