hv68's picture
Update app.py
b22c372
import streamlit as st
from annotated_text import annotated_text
import os
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from transformers import DistilBertModel
from transformers import AutoTokenizer
import lightning.pytorch as pl
class Classifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.ln1 = torch.nn.Linear(512*768, 3)
# self.ln2 = torch.nn.Linear(1000, 3 )
self.criterion = nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch
with torch.no_grad():
x = get_bert()(input_ids = x[:,:512], attention_mask = x[:,512:]).last_hidden_state.reshape(-1, 512*768)
x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768)
x = self.ln1(x)
# x = self.ln2(x)
loss = self.criterion(x, y)
self.log("my_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def preprocess(self, x):
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)
return tokenizer(x, padding='max_length', return_tensors="pt")
def forward(self, x):
print("here!", self.ln1.type)
with torch.no_grad():
x = get_bert()(**x).last_hidden_state.reshape(-1, 512*768)
x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768)
x = self.ln1(x)
# x = self.ln2(x)
return x
@st.cache
def get_bert():
return DistilBertModel.from_pretrained("distilbert-base-uncased")
@st.cache
def get_classifier():
os.system('gdown 1GxhHvg3lwlGpA7So06v3l43U8pSASy9L')
return Classifier.load_from_checkpoint(f"{os.getcwd()}/model_params")
def get_annotated_text(text):
model = get_classifier()
text = text.split(".")
l = []
for i in text:
if i.strip(' ') == '':
continue
c = model(model.preprocess([i])).argmax()
print("class : ", c)
if c == 0:
l.append((i, "Leadership"))
if c == 1:
l.append((i, "Diversity"))
if c == 2:
l.append((i, "Integrity"))
l.append(".")
return tuple(l)
st.title("Code of Conduct Classifier")
input_text = st.text_area("enter code of conduct text" )
st.title("annotated text")
print(input_text)
annotated_text(*get_annotated_text(input_text))