iSpr commited on
Commit
569997f
ยท
1 Parent(s): 3296632

model change

Files changed (1) hide show
  1. app.py +126 -37
app.py CHANGED
@@ -1,13 +1,20 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import sentencepiece
 
4
  # ๋ชจ๋ธ ์ค€๋น„ํ•˜๊ธฐ
5
  from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
 
 
6
  import numpy as np
7
  import pandas as pd
8
  import torch
9
  import os
10
 
 
 
 
 
11
  # ์ œ๋ชฉ ์ž…๋ ฅ
12
  st.header('ํ•œ๊ตญํ‘œ์ค€์‚ฐ์—…๋ถ„๋ฅ˜ ์ž๋™์ฝ”๋”ฉ ์„œ๋น„์Šค')
13
 
@@ -18,38 +25,89 @@ def md_loading():
18
  # device = torch.device('cpu')
19
 
20
  tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
 
21
  model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large', num_labels=493)
22
-
23
- model_checkpoint = 'base1_43_11.bin'
24
  project_path = './'
25
  output_model_file = os.path.join(project_path, model_checkpoint)
26
- ckpt = torch.load(output_model_file, map_location=torch.device('cpu'))
27
 
28
  model.load_state_dict(ckpt['model_state_dict'])
29
-
30
- ################################## label tbl ์ˆ˜์ •
 
 
 
31
  label_tbl = np.load('./label_table.npy')
32
  loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
33
 
34
  print('ready')
35
 
36
- return tokenizer, model, label_tbl, loc_tbl
37
 
38
  # ๋ชจ๋ธ ๋กœ๋“œ
39
- tokenizer, model, label_tbl, loc_tbl = md_loading()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  # ํ…์ŠคํŠธ input ๋ฐ•์Šค
43
- # business = st.text_input('์‚ฌ์—…์ฒด๋ช…', '์ถฉ์ฒญ์ง€๋ฐฉํ†ต๊ณ„์ฒญ').replace(',', '')
44
- # business_work = st.text_input('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ', 'ํ†ต๊ณ„์„œ๋น„์Šค ์ œ๊ณต ๋ฐ ์ง€์—ญํ†ต๊ณ„ ํ—ˆ๋ธŒ').replace(',', '')
45
- # work_department = st.text_input('๊ทผ๋ฌด๋ถ€์„œ', '์ง€์—ญํ†ต๊ณ„๊ณผ').replace(',', '')
46
- # work_position = st.text_input('์ง์ฑ…', '์ฃผ๋ฌด๊ด€').replace(',', '')
47
- # what_do_i = st.text_input('๋‚ด๊ฐ€ ํ•˜๋Š” ์ผ', 'ํ†ต๊ณ„๋ฐ์ดํ„ฐ์„ผํ„ฐ ์šด์˜').replace(',', '')
48
 
49
- input_box = st.text_input('์ž…๋ ฅ')
50
 
51
- # md_input: ๋ชจ๋ธ์— ์ž…๋ ฅํ•  input ๊ฐ’ ์ •์˜
52
- md_input = input_box
 
 
 
 
 
 
 
 
53
 
54
  ## ์ž„์‹œ ํ™•์ธ
55
  # st.write(md_input)
@@ -57,37 +115,68 @@ md_input = input_box
57
  # ๋ฒ„ํŠผ
58
  if st.button('ํ™•์ธ'):
59
  ## ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ˆ˜ํ–‰์‚ฌํ•ญ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  ### ๋ชจ๋ธ ์‹คํ–‰
61
- query_tokens = md_input
62
 
63
- input_ids = np.zeros(shape=[1, 64])
64
- attention_mask = np.zeros(shape=[1, 64])
65
 
66
- # seq = '[CLS] '
67
- # try:
68
- # for i in range(5):
69
- # seq += query_tokens[i] + ' '
70
- # except:
71
- # None
72
 
73
- seq = query_tokens
 
 
 
74
 
75
- tokens = tokenizer.tokenize(seq)
76
- ids = tokenizer.convert_tokens_to_ids(tokens)
77
 
78
- length = len(ids)
79
- if length > 64:
80
- length = 64
 
 
81
 
82
- for i in range(length):
83
- input_ids[0, i] = ids[i]
84
- attention_mask[0, i] = 1
85
 
86
- input_ids = torch.from_numpy(input_ids).type(torch.long)
87
- attention_mask = torch.from_numpy(attention_mask).type(torch.long)
88
 
89
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None)
90
- logits = outputs.logits
91
 
92
  # # ๋‹จ๋… ์˜ˆ์ธก ์‹œ
93
  # arg_idx = torch.argmax(logits, dim=1)
 
1
  import streamlit as st
2
  import pandas as pd
3
  import sentencepiece
4
+
5
  # ๋ชจ๋ธ ์ค€๋น„ํ•˜๊ธฐ
6
  from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from base_data_loader import TVT_Dataset
9
  import numpy as np
10
  import pandas as pd
11
  import torch
12
  import os
13
 
14
+ # [theme]
15
+ # base="dark"
16
+ # primaryColor="purple"
17
+
18
  # ์ œ๋ชฉ ์ž…๋ ฅ
19
  st.header('ํ•œ๊ตญํ‘œ์ค€์‚ฐ์—…๋ถ„๋ฅ˜ ์ž๋™์ฝ”๋”ฉ ์„œ๋น„์Šค')
20
 
 
25
  # device = torch.device('cpu')
26
 
27
  tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
28
+
29
  model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large', num_labels=493)
30
+
31
+ model_checkpoint = 'base3_44_7.bin'
32
  project_path = './'
33
  output_model_file = os.path.join(project_path, model_checkpoint)
34
+ ckpt = torch.load(output_model_file)
35
 
36
  model.load_state_dict(ckpt['model_state_dict'])
37
+
38
+ device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu")
39
+
40
+ model.to(device)
41
+
42
  label_tbl = np.load('./label_table.npy')
43
  loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
44
 
45
  print('ready')
46
 
47
+ return tokenizer, model, label_tbl, loc_tbl, device
48
 
49
  # ๋ชจ๋ธ ๋กœ๋“œ
50
+ tokenizer, model, label_tbl, loc_tbl, device = md_loading()
51
+
52
+
53
+ # ๋ฐ์ดํ„ฐ ์…‹ ์ค€๋น„์šฉ
54
+ max_len = 64 # 64
55
+
56
+ class TVT_Dataset(Dataset):
57
+
58
+ def __init__(self, df):
59
+ self.df_data = df
60
+
61
+ def __getitem__(self, index):
62
+
63
+ # ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„ ์นผ๋Ÿผ ๋“ค๊ณ ์˜ค๊ธฐ
64
+ # sentence = self.df_data.loc[index, 'text']
65
+ sentence = self.df_data.loc[index, ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']]
66
+
67
+ encoded_dict = tokenizer(
68
+ ' <s> '.join(sentence.to_list()),
69
+ add_special_tokens = True,
70
+ max_length = max_len,
71
+ padding='max_length',
72
+ truncation=True,
73
+ return_attention_mask = True,
74
+ return_tensors = 'pt')
75
+
76
+
77
+ padded_token_list = encoded_dict['input_ids'][0]
78
+ att_mask = encoded_dict['attention_mask'][0]
79
+
80
+ # ์ˆซ์ž๋กœ ๋ณ€ํ™˜๋œ label์„ ํ…์„œ๋กœ ๋ณ€ํ™˜
81
+ # target = torch.tensor(self.df_data.loc[index, 'NEW_CD'])
82
+ # input_ids, attention_mask, label์„ ํ•˜๋‚˜์˜ ์ธํ’‹์œผ๋กœ ๋ฌถ์Œ
83
+ # sample = (padded_token_list, att_mask, target)
84
+ sample = (padded_token_list, att_mask)
85
+
86
+ return sample
87
+
88
+ def __len__(self):
89
+ return len(self.df_data)
90
+
91
 
92
 
93
  # ํ…์ŠคํŠธ input ๋ฐ•์Šค
94
+ business = st.text_input('์‚ฌ์—…์ฒด๋ช…').replace(',', '')
95
+ business_work = st.text_input('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ').replace(',', '')
96
+ what_do_i = st.text_input('๋‚ด๊ฐ€ ํ•˜๋Š” ์ผ').replace(',', '')
97
+ work_position = st.text_input('์ง์ฑ…').replace(',', '')
98
+ work_department = st.text_input('๊ทผ๋ฌด๋ถ€์„œ').replace(',', '')
99
 
 
100
 
101
+ # data ์ค€๋น„
102
+
103
+ # test dataset์„ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.
104
+ input_col_type = ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM', 'NEW_CD']
105
+
106
+ def preprocess_dataset(dataset):
107
+ dataset.reset_index(drop=True, inplace=True)
108
+ dataset.fillna('')
109
+ return dataset[input_col_type]
110
+
111
 
112
  ## ์ž„์‹œ ํ™•์ธ
113
  # st.write(md_input)
 
115
  # ๋ฒ„ํŠผ
116
  if st.button('ํ™•์ธ'):
117
  ## ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ˆ˜ํ–‰์‚ฌํ•ญ
118
+
119
+ ### ๋ฐ์ดํ„ฐ ์ค€๋น„
120
+
121
+ # md_input: ๋ชจ๋ธ์— ์ž…๋ ฅํ•  input ๊ฐ’ ์ •์˜
122
+ # md_input = '|'.join([business, business_work, what_do_i, work_position, work_department])
123
+ md_input = [business, business_work, what_do_i, work_position, work_department]
124
+
125
+ test_dataset = pd.DataFrame({
126
+ input_col_type[0]: md_input[0],
127
+ input_col_type[1]: md_input[1],
128
+ input_col_type[2]: md_input[2],
129
+ input_col_type[3]: md_input[3],
130
+ input_col_type[4]: md_input[4]
131
+ })
132
+
133
+ # test_dataset = pd.read_csv(DATA_IN_PATH + test_set_name, sep='|', na_filter=False)
134
+
135
+ test_dataset = preprocess_dataset(test_dataset)
136
+
137
+ print(len(test_dataset))
138
+ print(test_dataset)
139
+
140
+ print('base_data_loader ์‚ฌ์šฉ ์‹œ์ ์ ')
141
+ test_data = TVT_Dataset(test_dataset)
142
+
143
+ train_batch_size = 48
144
+
145
+ # batch_size ๋งŒํผ ๋ฐ์ดํ„ฐ ๋ถ„ํ• 
146
+ test_dataloader = DataLoader(test_data,
147
+ batch_size=train_batch_size,
148
+ shuffle=False)
149
+
150
+
151
  ### ๋ชจ๋ธ ์‹คํ–‰
 
152
 
 
 
153
 
154
+ # Put model in evaluation mode
155
+ model.eval()
156
+ model.zero_grad()
157
+
158
+ # Tracking variables
159
+ predictions , true_labels = [], []
160
 
161
+ # Predict
162
+ for batch in range(test_dataloader):
163
+ # Add batch to GPU
164
+ batch = tuple(t.to(device) for t in batch)
165
 
166
+ # Unpack the inputs from our dataloader
167
+ test_input_ids, test_attention_mask = batch
168
 
169
+ # Telling the model not to compute or store gradients, saving memory and
170
+ # speeding up prediction
171
+ with torch.no_grad():
172
+ # Forward pass, calculate logit predictions
173
+ outputs = model(test_input_ids, token_type_ids=None, attention_mask=test_attention_mask)
174
 
175
+ logits = outputs.logits
 
 
176
 
177
+ # Move logits and labels to CPU
178
+ logits = logits.detach().cpu().numpy()
179
 
 
 
180
 
181
  # # ๋‹จ๋… ์˜ˆ์ธก ์‹œ
182
  # arg_idx = torch.argmax(logits, dim=1)