isThisYouLLM commited on
Commit
923779f
·
verified ·
1 Parent(s): 484c277

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -12
app.py CHANGED
@@ -108,13 +108,13 @@ def main():
108
  parser = argparse.ArgumentParser()
109
  parser.add_argument('--model_name', type=str, default="Salesforce/codet5p-770m")
110
  parser.add_argument('--path_checkpoint1', type=str, default="checkpoint.bin")
111
-
112
  args = parser.parse_args()
113
 
114
 
115
  model_name = args.model_name
116
  checkpoint1 = args.path_checkpoint1
117
-
118
 
119
  DEVICE = "cpu"
120
 
@@ -138,16 +138,37 @@ def main():
138
 
139
 
140
 
141
- model.load_state_dict(torch.load(checkpoint1,map_location='cpu'))
142
- model = model.eval()
143
- st.title("Human-AI stylometer - Multilingual_multiprovenance")
144
- text = st.text_area("insert your code here")
145
- button = st.button("send")
146
- if button or text:
147
- input = tokenizer([text])
148
- out= model(torch.tensor(input.input_ids),torch.tensor(input.attention_mask))
149
- #st.json(out)
150
- st.write(out["my_class"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
 
 
108
  parser = argparse.ArgumentParser()
109
  parser.add_argument('--model_name', type=str, default="Salesforce/codet5p-770m")
110
  parser.add_argument('--path_checkpoint1', type=str, default="checkpoint.bin")
111
+ parser.add_argument('--path_checkpoint2', type=str, default="multilingual_standard.bin")
112
  args = parser.parse_args()
113
 
114
 
115
  model_name = args.model_name
116
  checkpoint1 = args.path_checkpoint1
117
+ checkpoint2 = args.path_checkpoint2
118
 
119
  DEVICE = "cpu"
120
 
 
138
 
139
 
140
 
141
+
142
+
143
+ selected = option_menu(
144
+ menu_title="Choose your model",
145
+ options=["Multilingual_multiprovenance","Multilingual_standard" ],
146
+ default_index=0,
147
+ orientation="horizontal",
148
+ )
149
+
150
+ if selected=="Multilingual_standard":
151
+ model.load_state_dict(torch.load(checkpoint2,map_location='cpu'))
152
+ model = model.eval()
153
+ st.title("Human-AI stylometer - Multilingual_standard")
154
+ text = st.text_area("insert your code here")
155
+ button = st.button("send")
156
+ if button or text:
157
+ input = tokenizer([text])
158
+ out= model(torch.tensor(input.input_ids),torch.tensor(input.attention_mask))
159
+ #st.json(out)
160
+ st.write(out["my_class"])
161
+ else:
162
+ model.load_state_dict(torch.load(checkpoint1,map_location='cpu'))
163
+ model = model.eval()
164
+ st.title("Human-AI stylometer - Multilingual_multiprovenance")
165
+ text = st.text_area("insert your code here")
166
+ button = st.button("send")
167
+ if button or text:
168
+ input = tokenizer([text])
169
+ out= model(torch.tensor(input.input_ids),torch.tensor(input.attention_mask))
170
+ #st.json(out)
171
+ st.write(out["my_class"])
172
 
173
 
174