Spaces:
Runtime error
Runtime error
import streamlit as st | |
from annotated_text import annotated_text | |
import os | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
import torch.optim as optim | |
from transformers import DistilBertModel | |
from transformers import AutoTokenizer | |
import lightning.pytorch as pl | |
class Classifier(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.ln1 = torch.nn.Linear(512*768, 3) | |
# self.ln2 = torch.nn.Linear(1000, 3 ) | |
self.criterion = nn.CrossEntropyLoss() | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
with torch.no_grad(): | |
x = get_bert()(input_ids = x[:,:512], attention_mask = x[:,512:]).last_hidden_state.reshape(-1, 512*768) | |
x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768) | |
x = self.ln1(x) | |
# x = self.ln2(x) | |
loss = self.criterion(x, y) | |
self.log("my_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
return loss | |
def configure_optimizers(self): | |
optimizer = optim.Adam(self.parameters(), lr=1e-3) | |
return optimizer | |
def preprocess(self, x): | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True) | |
return tokenizer(x, padding='max_length', return_tensors="pt") | |
def forward(self, x): | |
print("here!", self.ln1.type) | |
with torch.no_grad(): | |
x = get_bert()(**x).last_hidden_state.reshape(-1, 512*768) | |
x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768) | |
x = self.ln1(x) | |
# x = self.ln2(x) | |
return x | |
def get_bert(): | |
return DistilBertModel.from_pretrained("distilbert-base-uncased") | |
def get_classifier(): | |
os.system('gdown 1GxhHvg3lwlGpA7So06v3l43U8pSASy9L') | |
return Classifier.load_from_checkpoint(f"{os.getcwd()}/model_params") | |
def get_annotated_text(text): | |
model = get_classifier() | |
text = text.split(".") | |
l = [] | |
for i in text: | |
if i.strip(' ') == '': | |
continue | |
c = model(model.preprocess([i])).argmax() | |
print("class : ", c) | |
if c == 0: | |
l.append((i, "Leadership")) | |
if c == 1: | |
l.append((i, "Diversity")) | |
if c == 2: | |
l.append((i, "Integrity")) | |
l.append(".") | |
return tuple(l) | |
st.title("Code of Conduct Classifier") | |
input_text = st.text_area("enter code of conduct text" ) | |
st.title("annotated text") | |
print(input_text) | |
annotated_text(*get_annotated_text(input_text)) |