IBounhas commited on
Commit
f45632f
·
1 Parent(s): 8cba4ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -59,20 +59,20 @@ with open("model.pt", "wb") as f:
59
  f.write(response.content)
60
  f.close()
61
 
62
- # model_load_path = "model.pt"
63
- # model = BertForSTS()
64
- # model.load_state_dict(torch.load(model_load_path))
65
- # model.to(device)
66
-
67
- # def predict_similarity(sentence_pair):
68
- # test_input = tokenizer(sentence_pair, padding='max_length', max_length = param_max_length, truncation=True, return_tensors="pt").to(device)
69
- # test_input['input_ids'] = test_input['input_ids']
70
- # test_input['attention_mask'] = test_input['attention_mask']
71
- # del test_input['token_type_ids']
72
- # output = model(test_input)
73
- # sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()*2-1
74
-
75
- # return sim
76
 
77
  # Create a Gradio interface with a text input zone
78
  iface = gr.Interface(
 
59
  f.write(response.content)
60
  f.close()
61
 
62
+ model_load_path = "model.pt"
63
+ model = BertForSTS()
64
+ model.load_state_dict(torch.load(model_load_path))
65
+ model.to(device)
66
+
67
+ def predict_similarity(sentence_pair):
68
+ test_input = tokenizer(sentence_pair, padding='max_length', max_length = param_max_length, truncation=True, return_tensors="pt").to(device)
69
+ test_input['input_ids'] = test_input['input_ids']
70
+ test_input['attention_mask'] = test_input['attention_mask']
71
+ del test_input['token_type_ids']
72
+ output = model(test_input)
73
+ sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()*2-1
74
+
75
+ return sim
76
 
77
  # Create a Gradio interface with a text input zone
78
  iface = gr.Interface(