import streamlit as st from transformers import pipeline import torch import matplotlib.pyplot as plt #pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned") #text = st.text_area('Please type/copy/paste the Dutch article') #labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis' # 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech'] #if text: # out = pipe(text) # st.json(out) # load tokenizer and model, create trainer #model_name = "RuudVelo/dutch_news_classifier_bert_finetuned" #tokenizer = AutoTokenizer.from_pretrained(model_name) #model = AutoModelForSequenceClassification.from_pretrained(model_name) #trainer = Trainer(model=model) #print(filename, type(filename)) #print(filename.name) from transformers import BertForSequenceClassification, BertTokenizer model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") #from transformers import BertTokenizer tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") # Title st.title("Dutch news article classification") #text = st.text_area('Please type/copy/paste text of the Dutch article') #if text: # encoding = tokenizer(text, return_tensors="pt") # outputs = model(**encoding) # predictions = outputs.logits.argmax(-1) # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) ## fig = plt.figure() # ax = fig.add_axes([0,0,1,1]) # labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis', # 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech'] # probs_plot = probabilities[0].cpu().detach().numpy() # ax.barh(labels_plot,probs_plot ) # st.pyplot(fig) input = st.text_input('Context') if st.button('Submit'): with st.spinner('Generating a response...'): encoding = tokenizer(input, return_tensors="pt") outputs = model(**encoding) predictions = outputs.logits.argmax(-1) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) fig = plt.figure() ax = fig.add_axes([0,0,1,1]) labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis', 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech'] probs_plot = probabilities[0].cpu().detach().numpy() ax.barh(labels_plot,probs_plot ) st.pyplot(fig) # output = genQuestion(option, input) # print(output) # st.write(output) #encoding = tokenizer(text, return_tensors="pt") #import numpy as np #arr = np.random.normal(1, 1, size=100) #fig, ax = plt.subplots() #ax.hist(arr, bins=20) #st.pyplot(fig) # forward pass #outputs = model(**encoding) #predictions = outputs.logits.argmax(-1)