alex6095 commited on
Commit
3f5668b
·
1 Parent(s): 9ce43b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -6,23 +6,10 @@ import streamlit as st
6
  from transformers import DistilBertModel
7
  from tokenization_kobert import KoBertTokenizer
8
 
9
- @st.cache(allow_output_mutation=True)
10
- def get_model():
11
- bert_model = DistilBertModel.from_pretrained('monologg/distilkobert')
12
- tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
13
-
14
- model = SanctiMoly(freeze_bert=False)
15
- checkpoint = torch.load("./model.pt", map_location=device)
16
- model.load_state_dict(checkpoint['model_state_dict'])
17
-
18
- return model, tokenizer
19
-
20
- model, tokenizer = get_model()
21
-
22
  class SanctiMoly(nn.Module):
23
  """ Holy Moly News BERT """
24
 
25
- def __init__(self, freeze_bert = True):
26
  super(SanctiMoly, self).__init__()
27
  self.encoder = bert_model
28
  # FC-BN-Tanh
@@ -57,6 +44,21 @@ class SanctiMoly(nn.Module):
57
  # print(output.shape)
58
  return output
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  class RegexSubstitution(object):
62
  """Regex substitution class for transform"""
 
6
  from transformers import DistilBertModel
7
  from tokenization_kobert import KoBertTokenizer
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class SanctiMoly(nn.Module):
10
  """ Holy Moly News BERT """
11
 
12
+ def __init__(self, bert_model, freeze_bert = True):
13
  super(SanctiMoly, self).__init__()
14
  self.encoder = bert_model
15
  # FC-BN-Tanh
 
44
  # print(output.shape)
45
  return output
46
 
47
+ @st.cache(allow_output_mutation=True)
48
+ def get_model():
49
+ bert_model = DistilBertModel.from_pretrained('monologg/distilkobert')
50
+ tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
51
+
52
+ model = SanctiMoly(bert_model, freeze_bert=False)
53
+ checkpoint = torch.load("./model.pt", map_location=device)
54
+ model.load_state_dict(checkpoint['model_state_dict'])
55
+
56
+ return model, tokenizer
57
+
58
+ model, tokenizer = get_model()
59
+
60
+
61
+
62
 
63
  class RegexSubstitution(object):
64
  """Regex substitution class for transform"""