Spaces:
Sleeping
Sleeping
__all__ = ['is_flower', 'learn', 'classify_image', 'categories', 'image', 'label', 'examples', 'intf'] | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
import torch | |
model_name = "NbAiLab/nb-bert-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
first_model_path = "models/first_model" | |
first_model = AutoModelForSequenceClassification.from_pretrained(first_model_path) | |
second_model_path = "models/second_model" | |
second_model = AutoModelForSequenceClassification.from_pretrained(second_model_path) | |
def classify_text(test_text, selected_model): | |
if selected_model == 'Model 1': | |
model = first_model | |
elif selected_model == 'Model 2': | |
model = second_model | |
else: | |
raise ValueError("Invalid model selection") | |
inputs = tokenizer(test_text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1) | |
predicted_class = torch.argmax(probabilities, dim=1).item() | |
class_labels = model.config.id2label | |
predicted_label = class_labels[predicted_class] | |
probabilities = probabilities[0].tolist() | |
categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55'] | |
#category_probabilities = list(zip(categories, probabilities)) | |
#max_category = max(category_probabilities, key=lambda x: x[1]) | |
#print('The model predicts that this text lead would have a majority of readers in the target group', max_category[0]) | |
return dict(zip(categories, map(float,probabilities))) | |
# Cell | |
label = gr.outputs.Label() | |
categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55') | |
app_title = "Target group classifier" | |
examples = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'], | |
["Fotballstadion tok fyr i helgen", 'Model 2'], | |
["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'] | |
] | |
intf = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2'])], outputs=label, examples=examples, title=app_title) | |
intf.launch(inline=False) | |