iSpr commited on
Commit
5183c55
ยท
1 Parent(s): 22b6cc3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ # ๋ชจ๋ธ ์ค€๋น„ํ•˜๊ธฐ
5
+ from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import os
10
+
11
+ # ์ œ๋ชฉ ์ž…๋ ฅ
12
+ st.header('ํ•œ๊ตญํ‘œ์ค€์‚ฐ์—…๋ถ„๋ฅ˜ ์ž๋™์ฝ”๋”ฉ ์„œ๋น„์Šค')
13
+
14
+ # ์žฌ๋กœ๋“œ ์•ˆํ•˜๋„๋ก
15
+ @st.experimental_memo(max_entries=20)
16
+ def md_loading():
17
+ ## cpu
18
+ # device = torch.device('cpu')
19
+
20
+ tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
21
+ model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large', num_labels=493)
22
+
23
+ model_checkpoint = 'base1_43_11.bin'
24
+ project_path = './'
25
+ output_model_file = os.path.join(project_path, model_checkpoint)
26
+
27
+ model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))
28
+
29
+ ################################## label tbl ์ˆ˜์ •
30
+ label_tbl = np.load('./label_table.npy')
31
+ loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
32
+
33
+ print('ready')
34
+
35
+ return tokenizer, model, label_tbl, loc_tbl
36
+
37
+ # ๋ชจ๋ธ ๋กœ๋“œ
38
+ tokenizer, model, label_tbl, loc_tbl = md_loading()
39
+
40
+
41
+ # ํ…์ŠคํŠธ input ๋ฐ•์Šค
42
+ # business = st.text_input('์‚ฌ์—…์ฒด๋ช…', '์ถฉ์ฒญ์ง€๋ฐฉํ†ต๊ณ„์ฒญ').replace(',', '')
43
+ # business_work = st.text_input('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ', 'ํ†ต๊ณ„์„œ๋น„์Šค ์ œ๊ณต ๋ฐ ์ง€์—ญํ†ต๊ณ„ ํ—ˆ๋ธŒ').replace(',', '')
44
+ # work_department = st.text_input('๊ทผ๋ฌด๋ถ€์„œ', '์ง€์—ญํ†ต๊ณ„๊ณผ').replace(',', '')
45
+ # work_position = st.text_input('์ง์ฑ…', '์ฃผ๋ฌด๊ด€').replace(',', '')
46
+ # what_do_i = st.text_input('๋‚ด๊ฐ€ ํ•˜๋Š” ์ผ', 'ํ†ต๊ณ„๋ฐ์ดํ„ฐ์„ผํ„ฐ ์šด์˜').replace(',', '')
47
+
48
+ input_box = st.text_input()
49
+
50
+ # md_input: ๋ชจ๋ธ์— ์ž…๋ ฅํ•  input ๊ฐ’ ์ •์˜
51
+ md_input = input_box
52
+
53
+ ## ์ž„์‹œ ํ™•์ธ
54
+ # st.write(md_input)
55
+
56
+ # ๋ฒ„ํŠผ
57
+ if st.button('ํ™•์ธ'):
58
+ ## ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ˆ˜ํ–‰์‚ฌํ•ญ
59
+ ### ๋ชจ๋ธ ์‹คํ–‰
60
+ query_tokens = md_input
61
+
62
+ input_ids = np.zeros(shape=[1, 64])
63
+ attention_mask = np.zeros(shape=[1, 64])
64
+
65
+ # seq = '[CLS] '
66
+ # try:
67
+ # for i in range(5):
68
+ # seq += query_tokens[i] + ' '
69
+ # except:
70
+ # None
71
+
72
+ seq = query_tokens
73
+
74
+ tokens = tokenizer.tokenize(seq)
75
+ ids = tokenizer.convert_tokens_to_ids(tokens)
76
+
77
+ length = len(ids)
78
+ if length > 64:
79
+ length = 64
80
+
81
+ for i in range(length):
82
+ input_ids[0, i] = ids[i]
83
+ attention_mask[0, i] = 1
84
+
85
+ input_ids = torch.from_numpy(input_ids).type(torch.long)
86
+ attention_mask = torch.from_numpy(attention_mask).type(torch.long)
87
+
88
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None)
89
+ logits = outputs.logits
90
+
91
+ # # ๋‹จ๋… ์˜ˆ์ธก ์‹œ
92
+ # arg_idx = torch.argmax(logits, dim=1)
93
+ # print('arg_idx:', arg_idx)
94
+
95
+ # num_ans = label_tbl[arg_idx]
96
+ # str_ans = loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == num_ans].values
97
+
98
+ # ์ƒ์œ„ k๋ฒˆ์งธ๊นŒ์ง€ ์˜ˆ์ธก ์‹œ
99
+ k = 10
100
+ topk_idx = torch.topk(logits.flatten(), k).indices
101
+
102
+ num_ans_topk = label_tbl[topk_idx]
103
+ str_ans_topk = [loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == k] for k in num_ans_topk]
104
+
105
+ # print(num_ans, str_ans)
106
+ # print(num_ans_topk)
107
+
108
+ # print('์‚ฌ์—…์ฒด๋ช…:', query_tokens[0])
109
+ # print('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ:', query_tokens[1])
110
+ # print('๊ทผ๋ฌด๋ถ€์„œ:', query_tokens[2])
111
+ # print('์ง์ฑ…:', query_tokens[3])
112
+ # print('๋‚ด๊ฐ€ ํ•˜๋Š”์ผ:', query_tokens[4])
113
+ # print('์‚ฐ์—…์ฝ”๋“œ ๋ฐ ๋ถ„๋ฅ˜:', num_ans, str_ans)
114
+
115
+ # ans = ''
116
+ # ans1, ans2, ans3 = '', '', ''
117
+
118
+ ## ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ’ ์ถœ๋ ฅ
119
+ # st.write("์‚ฐ์—…์ฝ”๋“œ ๋ฐ ๋ถ„๋ฅ˜:", num_ans, str_ans[0])
120
+ # st.write("์„ธ๋ถ„๋ฅ˜ ์ฝ”๋“œ")
121
+ # for i in range(k):
122
+ # st.write(str(i+1) + '์ˆœ์œ„:', num_ans_topk[i], str_ans_topk[i].iloc[0])
123
+
124
+ # print(num_ans)
125
+ # print(str_ans, type(str_ans))
126
+
127
+ str_ans_topk_list = []
128
+ for i in range(k):
129
+ str_ans_topk_list.append(str_ans_topk[i].iloc[0])
130
+
131
+ # print(str_ans_topk_list)
132
+
133
+ ans_topk_df = pd.DataFrame({
134
+ 'NO': range(1, k+1),
135
+ '์„ธ๋ถ„๋ฅ˜ ์ฝ”๋“œ': num_ans_topk,
136
+ '์„ธ๋ถ„๋ฅ˜ ๋ช…์นญ': str_ans_topk_list
137
+ })
138
+ ans_topk_df = ans_topk_df.set_index('NO')
139
+
140
+ st.dataframe(ans_topk_df)