iSpr commited on
Commit
bd9ed9e
·
1 Parent(s): cc5a679

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -36,19 +36,19 @@ def md_loading():
36
  # model.load_state_dict(ckpt['model_state_dict'])
37
 
38
  # device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu")
39
- device = torch.device("cpu")
40
 
41
- model.to(device)
42
 
43
  label_tbl = np.load('./label_table.npy')
44
  loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
45
 
46
  print('ready')
47
 
48
- return tokenizer, model, label_tbl, loc_tbl, device
49
 
50
  # 모델 로드
51
- tokenizer, model, label_tbl, loc_tbl, device = md_loading()
52
 
53
 
54
  # 데이터 셋 준비용
@@ -162,7 +162,7 @@ if st.button('확인'):
162
  # Predict
163
  for batch in range(test_dataloader):
164
  # Add batch to GPU
165
- batch = tuple(t.to(device) for t in batch)
166
 
167
  # Unpack the inputs from our dataloader
168
  test_input_ids, test_attention_mask = batch
 
36
  # model.load_state_dict(ckpt['model_state_dict'])
37
 
38
  # device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu")
39
+ # device = torch.device("cpu")
40
 
41
+ # model.to(device)
42
 
43
  label_tbl = np.load('./label_table.npy')
44
  loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
45
 
46
  print('ready')
47
 
48
+ return tokenizer, model, label_tbl, loc_tbl
49
 
50
  # 모델 로드
51
+ tokenizer, model, label_tbl, loc_tbl = md_loading()
52
 
53
 
54
  # 데이터 셋 준비용
 
162
  # Predict
163
  for batch in range(test_dataloader):
164
  # Add batch to GPU
165
+ # batch = tuple(t.to(device) for t in batch)
166
 
167
  # Unpack the inputs from our dataloader
168
  test_input_ids, test_attention_mask = batch