Andrei-Iulian SĂCELEANU commited on
Commit
e0a167e
1 Parent(s): 9b0f71a

mixmatch fix2

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -34,17 +34,22 @@ def ssl_predict(in_text, model_type):
34
  return_tensors="tf"
35
  )
36
 
 
37
  if model_type == "fixmatch":
38
  model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
39
  model.load_weights("./checkpoints/fixmatch_tune")
 
 
40
  elif model_type == "freematch":
41
  model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
42
  model.cls_head.load_weights("./checkpoints/freematch_tune")
 
 
43
  elif model_type == "mixmatch":
44
  model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
45
  model.cls_head.load_weights("./checkpoints/mixmatch")
 
46
 
47
- preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
48
  probs = list(preds[0].numpy())
49
 
50
  d = {}
 
34
  return_tensors="tf"
35
  )
36
 
37
+ preds = None
38
  if model_type == "fixmatch":
39
  model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
40
  model.load_weights("./checkpoints/fixmatch_tune")
41
+ preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
42
+
43
  elif model_type == "freematch":
44
  model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
45
  model.cls_head.load_weights("./checkpoints/freematch_tune")
46
+ preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
47
+
48
  elif model_type == "mixmatch":
49
  model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
50
  model.cls_head.load_weights("./checkpoints/mixmatch")
51
+ preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
52
 
 
53
  probs = list(preds[0].numpy())
54
 
55
  d = {}