iSpr commited on
Commit
22b6cc3
ยท
1 Parent(s): 4788c51

Delete app_2020.py

Browse files
Files changed (1) hide show
  1. app_2020.py +0 -140
app_2020.py DELETED
@@ -1,140 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
-
4
- # ๋ชจ๋ธ ์ค€๋น„ํ•˜๊ธฐ
5
- from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
6
- import numpy as np
7
- import pandas as pd
8
- import torch
9
- import os
10
-
11
- # ์ œ๋ชฉ ์ž…๋ ฅ
12
- st.header('ํ•œ๊ตญํ‘œ์ค€์‚ฐ์—…๋ถ„๋ฅ˜ ์ž๋™์ฝ”๋”ฉ ์„œ๋น„์Šค')
13
-
14
- # ์žฌ๋กœ๋“œ ์•ˆํ•˜๋„๋ก
15
- @st.experimental_memo(max_entries=20)
16
- def md_loading():
17
- ## cpu
18
- # device = torch.device('cpu')
19
-
20
- tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
21
- model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large', num_labels=493)
22
-
23
- model_checkpoint = 'base1_43_11.bin'
24
- project_path = './'
25
- output_model_file = os.path.join(project_path, model_checkpoint)
26
-
27
- model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))
28
-
29
- ################################## label tbl ์ˆ˜์ •
30
- label_tbl = np.load('./label_table.npy')
31
- loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
32
-
33
- print('ready')
34
-
35
- return tokenizer, model, label_tbl, loc_tbl
36
-
37
- # ๋ชจ๋ธ ๋กœ๋“œ
38
- tokenizer, model, label_tbl, loc_tbl = md_loading()
39
-
40
-
41
- # ํ…์ŠคํŠธ input ๋ฐ•์Šค
42
- # business = st.text_input('์‚ฌ์—…์ฒด๋ช…', '์ถฉ์ฒญ์ง€๋ฐฉํ†ต๊ณ„์ฒญ').replace(',', '')
43
- # business_work = st.text_input('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ', 'ํ†ต๊ณ„์„œ๋น„์Šค ์ œ๊ณต ๋ฐ ์ง€์—ญํ†ต๊ณ„ ํ—ˆ๋ธŒ').replace(',', '')
44
- # work_department = st.text_input('๊ทผ๋ฌด๋ถ€์„œ', '์ง€์—ญํ†ต๊ณ„๊ณผ').replace(',', '')
45
- # work_position = st.text_input('์ง์ฑ…', '์ฃผ๋ฌด๊ด€').replace(',', '')
46
- # what_do_i = st.text_input('๋‚ด๊ฐ€ ํ•˜๋Š” ์ผ', 'ํ†ต๊ณ„๋ฐ์ดํ„ฐ์„ผํ„ฐ ์šด์˜').replace(',', '')
47
-
48
- input_box = st.text_input()
49
-
50
- # md_input: ๋ชจ๋ธ์— ์ž…๋ ฅํ•  input ๊ฐ’ ์ •์˜
51
- md_input = input_box
52
-
53
- ## ์ž„์‹œ ํ™•์ธ
54
- # st.write(md_input)
55
-
56
- # ๋ฒ„ํŠผ
57
- if st.button('ํ™•์ธ'):
58
- ## ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ˆ˜ํ–‰์‚ฌํ•ญ
59
- ### ๋ชจ๋ธ ์‹คํ–‰
60
- query_tokens = md_input
61
-
62
- input_ids = np.zeros(shape=[1, 64])
63
- attention_mask = np.zeros(shape=[1, 64])
64
-
65
- # seq = '[CLS] '
66
- # try:
67
- # for i in range(5):
68
- # seq += query_tokens[i] + ' '
69
- # except:
70
- # None
71
-
72
- seq = query_tokens
73
-
74
- tokens = tokenizer.tokenize(seq)
75
- ids = tokenizer.convert_tokens_to_ids(tokens)
76
-
77
- length = len(ids)
78
- if length > 64:
79
- length = 64
80
-
81
- for i in range(length):
82
- input_ids[0, i] = ids[i]
83
- attention_mask[0, i] = 1
84
-
85
- input_ids = torch.from_numpy(input_ids).type(torch.long)
86
- attention_mask = torch.from_numpy(attention_mask).type(torch.long)
87
-
88
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None)
89
- logits = outputs.logits
90
-
91
- # # ๋‹จ๋… ์˜ˆ์ธก ์‹œ
92
- # arg_idx = torch.argmax(logits, dim=1)
93
- # print('arg_idx:', arg_idx)
94
-
95
- # num_ans = label_tbl[arg_idx]
96
- # str_ans = loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == num_ans].values
97
-
98
- # ์ƒ์œ„ k๋ฒˆ์งธ๊นŒ์ง€ ์˜ˆ์ธก ์‹œ
99
- k = 10
100
- topk_idx = torch.topk(logits.flatten(), k).indices
101
-
102
- num_ans_topk = label_tbl[topk_idx]
103
- str_ans_topk = [loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == k] for k in num_ans_topk]
104
-
105
- # print(num_ans, str_ans)
106
- # print(num_ans_topk)
107
-
108
- # print('์‚ฌ์—…์ฒด๋ช…:', query_tokens[0])
109
- # print('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ:', query_tokens[1])
110
- # print('๊ทผ๋ฌด๋ถ€์„œ:', query_tokens[2])
111
- # print('์ง์ฑ…:', query_tokens[3])
112
- # print('๋‚ด๊ฐ€ ํ•˜๋Š”์ผ:', query_tokens[4])
113
- # print('์‚ฐ์—…์ฝ”๋“œ ๋ฐ ๋ถ„๋ฅ˜:', num_ans, str_ans)
114
-
115
- # ans = ''
116
- # ans1, ans2, ans3 = '', '', ''
117
-
118
- ## ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ’ ์ถœ๋ ฅ
119
- # st.write("์‚ฐ์—…์ฝ”๋“œ ๋ฐ ๋ถ„๋ฅ˜:", num_ans, str_ans[0])
120
- # st.write("์„ธ๋ถ„๋ฅ˜ ์ฝ”๋“œ")
121
- # for i in range(k):
122
- # st.write(str(i+1) + '์ˆœ์œ„:', num_ans_topk[i], str_ans_topk[i].iloc[0])
123
-
124
- # print(num_ans)
125
- # print(str_ans, type(str_ans))
126
-
127
- str_ans_topk_list = []
128
- for i in range(k):
129
- str_ans_topk_list.append(str_ans_topk[i].iloc[0])
130
-
131
- # print(str_ans_topk_list)
132
-
133
- ans_topk_df = pd.DataFrame({
134
- 'NO': range(1, k+1),
135
- '์„ธ๋ถ„๋ฅ˜ ์ฝ”๋“œ': num_ans_topk,
136
- '์„ธ๋ถ„๋ฅ˜ ๋ช…์นญ': str_ans_topk_list
137
- })
138
- ans_topk_df = ans_topk_df.set_index('NO')
139
-
140
- st.dataframe(ans_topk_df)