Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import sentencepiece | |
# ๋ชจ๋ธ ์ค๋นํ๊ธฐ | |
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer | |
from torch.utils.data import DataLoader, Dataset | |
import numpy as np | |
import pandas as pd | |
import torch | |
import os | |
from tqdm import tqdm | |
# [theme] | |
# base="dark" | |
# primaryColor="purple" | |
# ์ ๋ชฉ ์ ๋ ฅ | |
st.header('ํ๊ตญํ์ค์ฐ์ ๋ถ๋ฅ ์๋์ฝ๋ฉ ์๋น์ค') | |
# ์ฌ๋ก๋ ์ํ๋๋ก | |
def md_loading(): | |
## cpu | |
device = torch.device("cpu") | |
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base') | |
model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-base', num_labels=493) | |
model_checkpoint = 'en_ko_4mix_proto.bin' | |
project_path = './' | |
output_model_file = os.path.join(project_path, model_checkpoint) | |
# model.load_state_dict(torch.load(output_model_file)) | |
model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu'))) | |
# ckpt = torch.load(output_model_file, map_location=torch.device('cpu')) | |
# model.load_state_dict(ckpt['model_state_dict']) | |
# device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu") | |
# device = torch.device("cpu") | |
model.to(device) | |
label_tbl = np.load('./label_table.npy') | |
loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8') | |
print('ready') | |
return tokenizer, model, label_tbl, loc_tbl, device | |
# ๋ชจ๋ธ ๋ก๋ | |
tokenizer, model, label_tbl, loc_tbl, device = md_loading() | |
# ๋ฐ์ดํฐ ์ ์ค๋น์ฉ | |
max_len = 64 # 64 | |
class TVT_Dataset(Dataset): | |
def __init__(self, df): | |
self.df_data = df | |
def __getitem__(self, index): | |
# ๋ฐ์ดํฐํ๋ ์ ์นผ๋ผ ๋ค๊ณ ์ค๊ธฐ | |
# sentence = self.df_data.loc[index, 'text'] | |
sentence = self.df_data.loc[index, ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']] | |
encoded_dict = tokenizer( | |
' <s> '.join(sentence.to_list()), | |
add_special_tokens = True, | |
max_length = max_len, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask = True, | |
return_tensors = 'pt') | |
padded_token_list = encoded_dict['input_ids'][0] | |
att_mask = encoded_dict['attention_mask'][0] | |
# ์ซ์๋ก ๋ณํ๋ label์ ํ ์๋ก ๋ณํ | |
# target = torch.tensor(self.df_data.loc[index, 'NEW_CD']) | |
# input_ids, attention_mask, label์ ํ๋์ ์ธํ์ผ๋ก ๋ฌถ์ | |
# sample = (padded_token_list, att_mask, target) | |
sample = (padded_token_list, att_mask) | |
return sample | |
def __len__(self): | |
return len(self.df_data) | |
# ํ ์คํธ input ๋ฐ์ค | |
business = st.text_input('') | |
# business_work = st.text_input('์ฌ์ ์ฒด ํ๋์ผ') | |
# work_department = st.text_input('๊ทผ๋ฌด๋ถ์') | |
# work_position = st.text_input('์ง์ฑ ') | |
# what_do_i = st.text_input('๋ด๊ฐ ํ๋ ์ผ') | |
business_work = '' | |
work_department = '' | |
work_position = '' | |
what_do_i = '' | |
# data ์ค๋น | |
# test dataset์ ๋ง๋ค์ด์ค๋๋ค. | |
input_col_type = ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM'] | |
def preprocess_dataset(dataset): | |
dataset.reset_index(drop=True, inplace=True) | |
dataset.fillna('') | |
return dataset[input_col_type] | |
## ์์ ํ์ธ | |
# st.write(md_input) | |
# ๋ฒํผ | |
if st.button('ํ์ธ'): | |
## ๋ฒํผ ํด๋ฆญ ์ ์ํ์ฌํญ | |
### ๋ฐ์ดํฐ ์ค๋น | |
# md_input: ๋ชจ๋ธ์ ์ ๋ ฅํ input ๊ฐ ์ ์ | |
# md_input = '|'.join([business, business_work, what_do_i, work_position, work_department]) | |
md_input = [str(business), str(business_work), str(what_do_i), str(work_position), str(work_department)] | |
test_dataset = pd.DataFrame({ | |
input_col_type[0]: md_input[0], | |
input_col_type[1]: md_input[1], | |
input_col_type[2]: md_input[2], | |
input_col_type[3]: md_input[3], | |
input_col_type[4]: md_input[4] | |
}, index=[0]) | |
# test_dataset = pd.read_csv(DATA_IN_PATH + test_set_name, sep='|', na_filter=False) | |
test_dataset.reset_index(inplace=True) | |
test_dataset = preprocess_dataset(test_dataset) | |
print(len(test_dataset)) | |
print(test_dataset) | |
print('base_data_loader ์ฌ์ฉ ์์ ์ ') | |
test_data = TVT_Dataset(test_dataset) | |
train_batch_size = 48 | |
# batch_size ๋งํผ ๋ฐ์ดํฐ ๋ถํ | |
test_dataloader = DataLoader(test_data, | |
batch_size=train_batch_size, | |
shuffle=False) | |
### ๋ชจ๋ธ ์คํ | |
# Put model in evaluation mode | |
model.eval() | |
model.zero_grad() | |
# Tracking variables | |
predictions , true_labels = [], [] | |
# Predict | |
for batch in tqdm(test_dataloader): | |
# Add batch to GPU | |
batch = tuple(t.to(device) for t in batch) | |
# Unpack the inputs from our dataloader | |
test_input_ids, test_attention_mask = batch | |
# Telling the model not to compute or store gradients, saving memory and | |
# speeding up prediction | |
with torch.no_grad(): | |
# Forward pass, calculate logit predictions | |
outputs = model(test_input_ids, token_type_ids=None, attention_mask=test_attention_mask) | |
logits = outputs.logits | |
# Move logits and labels to CPU | |
# logits = logits.detach().cpu().numpy() | |
pred_m = torch.nn.Softmax(dim=1) | |
pred_ = pred_m(logits) | |
# st.write(logits.size()) | |
# # ๋จ๋ ์์ธก ์ | |
# arg_idx = torch.argmax(logits, dim=1) | |
# print('arg_idx:', arg_idx) | |
# num_ans = label_tbl[arg_idx] | |
# str_ans = loc_tbl['ํญ๋ชฉ๋ช '][loc_tbl['์ฝ๋'] == num_ans].values | |
# ์์ k๋ฒ์งธ๊น์ง ์์ธก ์ | |
k = 10 | |
topk_idx = torch.topk(pred_.flatten(), k).indices | |
topk_values = torch.topk(pred_.flatten(), k).values | |
num_ans_topk = label_tbl[topk_idx] | |
str_ans_topk = [loc_tbl['ํญ๋ชฉ๋ช '][loc_tbl['์ฝ๋'] == k] for k in num_ans_topk] | |
percent_ans_topk = topk_values.numpy() | |
# st.write(sum(torch.topk(pred_.flatten(), 493).values.numpy())) | |
# print(num_ans, str_ans) | |
# print(num_ans_topk) | |
# print('์ฌ์ ์ฒด๋ช :', query_tokens[0]) | |
# print('์ฌ์ ์ฒด ํ๋์ผ:', query_tokens[1]) | |
# print('๊ทผ๋ฌด๋ถ์:', query_tokens[2]) | |
# print('์ง์ฑ :', query_tokens[3]) | |
# print('๋ด๊ฐ ํ๋์ผ:', query_tokens[4]) | |
# print('์ฐ์ ์ฝ๋ ๋ฐ ๋ถ๋ฅ:', num_ans, str_ans) | |
# ans = '' | |
# ans1, ans2, ans3 = '', '', '' | |
## ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ ์ถ๋ ฅ | |
# st.write("์ฐ์ ์ฝ๋ ๋ฐ ๋ถ๋ฅ:", num_ans, str_ans[0]) | |
# st.write("์ธ๋ถ๋ฅ ์ฝ๋") | |
# for i in range(k): | |
# st.write(str(i+1) + '์์:', num_ans_topk[i], str_ans_topk[i].iloc[0]) | |
# print(num_ans) | |
# print(str_ans, type(str_ans)) | |
str_ans_topk_list = [] | |
percent_ans_topk_list = [] | |
for i in range(k): | |
str_ans_topk_list.append(str_ans_topk[i].iloc[0]) | |
percent_ans_topk_list.append(percent_ans_topk[i]*100) | |
# print(str_ans_topk_list) | |
ans_topk_df = pd.DataFrame({ | |
'NO': range(1, k+1), | |
'์ธ๋ถ๋ฅ ์ฝ๋': num_ans_topk, | |
'์ธ๋ถ๋ฅ ๋ช ์นญ': str_ans_topk_list, | |
'ํ๋ฅ ': percent_ans_topk_list | |
}) | |
ans_topk_df = ans_topk_df.set_index('NO') | |
# ans_topk_df.style.bar(subset='ํ๋ฅ ', align='left', color='blue') | |
# ans_topk_df['ํ๋ฅ '].style.applymap(color='black', font_color='blue') | |
# st.dataframe(ans_topk_df) | |
# st.dataframe(ans_topk_df.style.bar(subset='ํ๋ฅ ', align='left', color='blue')) | |
st.write(ans_topk_df.style.bar(subset='ํ๋ฅ ', align='left', color='blue')) |