Spaces:
Paused
Paused
Andrei-Iulian SĂCELEANU
commited on
Commit
•
e0a167e
1
Parent(s):
9b0f71a
mixmatch fix2
Browse files
app.py
CHANGED
@@ -34,17 +34,22 @@ def ssl_predict(in_text, model_type):
|
|
34 |
return_tensors="tf"
|
35 |
)
|
36 |
|
|
|
37 |
if model_type == "fixmatch":
|
38 |
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
|
39 |
model.load_weights("./checkpoints/fixmatch_tune")
|
|
|
|
|
40 |
elif model_type == "freematch":
|
41 |
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
|
42 |
model.cls_head.load_weights("./checkpoints/freematch_tune")
|
|
|
|
|
43 |
elif model_type == "mixmatch":
|
44 |
model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
|
45 |
model.cls_head.load_weights("./checkpoints/mixmatch")
|
|
|
46 |
|
47 |
-
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
48 |
probs = list(preds[0].numpy())
|
49 |
|
50 |
d = {}
|
|
|
34 |
return_tensors="tf"
|
35 |
)
|
36 |
|
37 |
+
preds = None
|
38 |
if model_type == "fixmatch":
|
39 |
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
|
40 |
model.load_weights("./checkpoints/fixmatch_tune")
|
41 |
+
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
42 |
+
|
43 |
elif model_type == "freematch":
|
44 |
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
|
45 |
model.cls_head.load_weights("./checkpoints/freematch_tune")
|
46 |
+
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
47 |
+
|
48 |
elif model_type == "mixmatch":
|
49 |
model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
|
50 |
model.cls_head.load_weights("./checkpoints/mixmatch")
|
51 |
+
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
52 |
|
|
|
53 |
probs = list(preds[0].numpy())
|
54 |
|
55 |
d = {}
|