Spaces:
Runtime error
Runtime error
base3_44
Browse filesmodel change
app.py
CHANGED
@@ -1,13 +1,20 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import sentencepiece
|
|
|
4 |
# ๋ชจ๋ธ ์ค๋นํ๊ธฐ
|
5 |
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
|
|
|
|
6 |
import numpy as np
|
7 |
import pandas as pd
|
8 |
import torch
|
9 |
import os
|
10 |
|
|
|
|
|
|
|
|
|
11 |
# ์ ๋ชฉ ์
๋ ฅ
|
12 |
st.header('ํ๊ตญํ์ค์ฐ์
๋ถ๋ฅ ์๋์ฝ๋ฉ ์๋น์ค')
|
13 |
|
@@ -18,38 +25,89 @@ def md_loading():
|
|
18 |
# device = torch.device('cpu')
|
19 |
|
20 |
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
|
|
21 |
model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large', num_labels=493)
|
22 |
-
|
23 |
-
model_checkpoint = '
|
24 |
project_path = './'
|
25 |
output_model_file = os.path.join(project_path, model_checkpoint)
|
26 |
-
ckpt = torch.load(output_model_file
|
27 |
|
28 |
model.load_state_dict(ckpt['model_state_dict'])
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
31 |
label_tbl = np.load('./label_table.npy')
|
32 |
loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
|
33 |
|
34 |
print('ready')
|
35 |
|
36 |
-
return tokenizer, model, label_tbl, loc_tbl
|
37 |
|
38 |
# ๋ชจ๋ธ ๋ก๋
|
39 |
-
tokenizer, model, label_tbl, loc_tbl = md_loading()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
# ํ
์คํธ input ๋ฐ์ค
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
|
49 |
-
input_box = st.text_input('์
๋ ฅ')
|
50 |
|
51 |
-
#
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
## ์์ ํ์ธ
|
55 |
# st.write(md_input)
|
@@ -57,37 +115,68 @@ md_input = input_box
|
|
57 |
# ๋ฒํผ
|
58 |
if st.button('ํ์ธ'):
|
59 |
## ๋ฒํผ ํด๋ฆญ ์ ์ํ์ฌํญ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
### ๋ชจ๋ธ ์คํ
|
61 |
-
query_tokens = md_input
|
62 |
|
63 |
-
input_ids = np.zeros(shape=[1, 64])
|
64 |
-
attention_mask = np.zeros(shape=[1, 64])
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
#
|
71 |
-
|
72 |
|
73 |
-
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
input_ids[0, i] = ids[i]
|
84 |
-
attention_mask[0, i] = 1
|
85 |
|
86 |
-
|
87 |
-
|
88 |
|
89 |
-
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None)
|
90 |
-
logits = outputs.logits
|
91 |
|
92 |
# # ๋จ๋
์์ธก ์
|
93 |
# arg_idx = torch.argmax(logits, dim=1)
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import sentencepiece
|
4 |
+
|
5 |
# ๋ชจ๋ธ ์ค๋นํ๊ธฐ
|
6 |
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
from base_data_loader import TVT_Dataset
|
9 |
import numpy as np
|
10 |
import pandas as pd
|
11 |
import torch
|
12 |
import os
|
13 |
|
14 |
+
# [theme]
|
15 |
+
# base="dark"
|
16 |
+
# primaryColor="purple"
|
17 |
+
|
18 |
# ์ ๋ชฉ ์
๋ ฅ
|
19 |
st.header('ํ๊ตญํ์ค์ฐ์
๋ถ๋ฅ ์๋์ฝ๋ฉ ์๋น์ค')
|
20 |
|
|
|
25 |
# device = torch.device('cpu')
|
26 |
|
27 |
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
28 |
+
|
29 |
model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large', num_labels=493)
|
30 |
+
|
31 |
+
model_checkpoint = 'base3_44_7.bin'
|
32 |
project_path = './'
|
33 |
output_model_file = os.path.join(project_path, model_checkpoint)
|
34 |
+
ckpt = torch.load(output_model_file)
|
35 |
|
36 |
model.load_state_dict(ckpt['model_state_dict'])
|
37 |
+
|
38 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu")
|
39 |
+
|
40 |
+
model.to(device)
|
41 |
+
|
42 |
label_tbl = np.load('./label_table.npy')
|
43 |
loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
|
44 |
|
45 |
print('ready')
|
46 |
|
47 |
+
return tokenizer, model, label_tbl, loc_tbl, device
|
48 |
|
49 |
# ๋ชจ๋ธ ๋ก๋
|
50 |
+
tokenizer, model, label_tbl, loc_tbl, device = md_loading()
|
51 |
+
|
52 |
+
|
53 |
+
# ๋ฐ์ดํฐ ์
์ค๋น์ฉ
|
54 |
+
max_len = 64 # 64
|
55 |
+
|
56 |
+
class TVT_Dataset(Dataset):
|
57 |
+
|
58 |
+
def __init__(self, df):
|
59 |
+
self.df_data = df
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
|
63 |
+
# ๋ฐ์ดํฐํ๋ ์ ์นผ๋ผ ๋ค๊ณ ์ค๊ธฐ
|
64 |
+
# sentence = self.df_data.loc[index, 'text']
|
65 |
+
sentence = self.df_data.loc[index, ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']]
|
66 |
+
|
67 |
+
encoded_dict = tokenizer(
|
68 |
+
' <s> '.join(sentence.to_list()),
|
69 |
+
add_special_tokens = True,
|
70 |
+
max_length = max_len,
|
71 |
+
padding='max_length',
|
72 |
+
truncation=True,
|
73 |
+
return_attention_mask = True,
|
74 |
+
return_tensors = 'pt')
|
75 |
+
|
76 |
+
|
77 |
+
padded_token_list = encoded_dict['input_ids'][0]
|
78 |
+
att_mask = encoded_dict['attention_mask'][0]
|
79 |
+
|
80 |
+
# ์ซ์๋ก ๋ณํ๋ label์ ํ
์๋ก ๋ณํ
|
81 |
+
# target = torch.tensor(self.df_data.loc[index, 'NEW_CD'])
|
82 |
+
# input_ids, attention_mask, label์ ํ๋์ ์ธํ์ผ๋ก ๋ฌถ์
|
83 |
+
# sample = (padded_token_list, att_mask, target)
|
84 |
+
sample = (padded_token_list, att_mask)
|
85 |
+
|
86 |
+
return sample
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.df_data)
|
90 |
+
|
91 |
|
92 |
|
93 |
# ํ
์คํธ input ๋ฐ์ค
|
94 |
+
business = st.text_input('์ฌ์
์ฒด๋ช
').replace(',', '')
|
95 |
+
business_work = st.text_input('์ฌ์
์ฒด ํ๋์ผ').replace(',', '')
|
96 |
+
what_do_i = st.text_input('๋ด๊ฐ ํ๋ ์ผ').replace(',', '')
|
97 |
+
work_position = st.text_input('์ง์ฑ
').replace(',', '')
|
98 |
+
work_department = st.text_input('๊ทผ๋ฌด๋ถ์').replace(',', '')
|
99 |
|
|
|
100 |
|
101 |
+
# data ์ค๋น
|
102 |
+
|
103 |
+
# test dataset์ ๋ง๋ค์ด์ค๋๋ค.
|
104 |
+
input_col_type = ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM', 'NEW_CD']
|
105 |
+
|
106 |
+
def preprocess_dataset(dataset):
|
107 |
+
dataset.reset_index(drop=True, inplace=True)
|
108 |
+
dataset.fillna('')
|
109 |
+
return dataset[input_col_type]
|
110 |
+
|
111 |
|
112 |
## ์์ ํ์ธ
|
113 |
# st.write(md_input)
|
|
|
115 |
# ๋ฒํผ
|
116 |
if st.button('ํ์ธ'):
|
117 |
## ๋ฒํผ ํด๋ฆญ ์ ์ํ์ฌํญ
|
118 |
+
|
119 |
+
### ๋ฐ์ดํฐ ์ค๋น
|
120 |
+
|
121 |
+
# md_input: ๋ชจ๋ธ์ ์
๋ ฅํ input ๊ฐ ์ ์
|
122 |
+
# md_input = '|'.join([business, business_work, what_do_i, work_position, work_department])
|
123 |
+
md_input = [business, business_work, what_do_i, work_position, work_department]
|
124 |
+
|
125 |
+
test_dataset = pd.DataFrame({
|
126 |
+
input_col_type[0]: md_input[0],
|
127 |
+
input_col_type[1]: md_input[1],
|
128 |
+
input_col_type[2]: md_input[2],
|
129 |
+
input_col_type[3]: md_input[3],
|
130 |
+
input_col_type[4]: md_input[4]
|
131 |
+
})
|
132 |
+
|
133 |
+
# test_dataset = pd.read_csv(DATA_IN_PATH + test_set_name, sep='|', na_filter=False)
|
134 |
+
|
135 |
+
test_dataset = preprocess_dataset(test_dataset)
|
136 |
+
|
137 |
+
print(len(test_dataset))
|
138 |
+
print(test_dataset)
|
139 |
+
|
140 |
+
print('base_data_loader ์ฌ์ฉ ์์ ์ ')
|
141 |
+
test_data = TVT_Dataset(test_dataset)
|
142 |
+
|
143 |
+
train_batch_size = 48
|
144 |
+
|
145 |
+
# batch_size ๋งํผ ๋ฐ์ดํฐ ๋ถํ
|
146 |
+
test_dataloader = DataLoader(test_data,
|
147 |
+
batch_size=train_batch_size,
|
148 |
+
shuffle=False)
|
149 |
+
|
150 |
+
|
151 |
### ๋ชจ๋ธ ์คํ
|
|
|
152 |
|
|
|
|
|
153 |
|
154 |
+
# Put model in evaluation mode
|
155 |
+
model.eval()
|
156 |
+
model.zero_grad()
|
157 |
+
|
158 |
+
# Tracking variables
|
159 |
+
predictions , true_labels = [], []
|
160 |
|
161 |
+
# Predict
|
162 |
+
for batch in range(test_dataloader):
|
163 |
+
# Add batch to GPU
|
164 |
+
batch = tuple(t.to(device) for t in batch)
|
165 |
|
166 |
+
# Unpack the inputs from our dataloader
|
167 |
+
test_input_ids, test_attention_mask = batch
|
168 |
|
169 |
+
# Telling the model not to compute or store gradients, saving memory and
|
170 |
+
# speeding up prediction
|
171 |
+
with torch.no_grad():
|
172 |
+
# Forward pass, calculate logit predictions
|
173 |
+
outputs = model(test_input_ids, token_type_ids=None, attention_mask=test_attention_mask)
|
174 |
|
175 |
+
logits = outputs.logits
|
|
|
|
|
176 |
|
177 |
+
# Move logits and labels to CPU
|
178 |
+
logits = logits.detach().cpu().numpy()
|
179 |
|
|
|
|
|
180 |
|
181 |
# # ๋จ๋
์์ธก ์
|
182 |
# arg_idx = torch.argmax(logits, dim=1)
|