Spaces:
Paused
Paused
Andrei-Iulian SĂCELEANU
commited on
Commit
•
db8dccd
1
Parent(s):
e0a167e
added all methods for text
Browse files- .gitattributes +4 -0
- app.py +10 -0
- checkpoints/contrastive.data-00000-of-00001 +3 -0
- checkpoints/contrastive.index +3 -0
- checkpoints/label_prop.data-00000-of-00001 +3 -0
- checkpoints/label_prop.index +3 -0
- models.py +22 -1
.gitattributes
CHANGED
@@ -39,3 +39,7 @@ checkpoints/mixmatch.index filter=lfs diff=lfs merge=lfs -text
|
|
39 |
checkpoints/fixmatch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
40 |
checkpoints/fixmatch_tune.index filter=lfs diff=lfs merge=lfs -text
|
41 |
checkpoints/freematch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
39 |
checkpoints/fixmatch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
40 |
checkpoints/fixmatch_tune.index filter=lfs diff=lfs merge=lfs -text
|
41 |
checkpoints/freematch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
42 |
+
checkpoints/label_prop.index filter=lfs diff=lfs merge=lfs -text
|
43 |
+
checkpoints/contrastive.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
44 |
+
checkpoints/contrastive.index filter=lfs diff=lfs merge=lfs -text
|
45 |
+
checkpoints/label_prop.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -50,6 +50,16 @@ def ssl_predict(in_text, model_type):
|
|
50 |
model.cls_head.load_weights("./checkpoints/mixmatch")
|
51 |
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
probs = list(preds[0].numpy())
|
54 |
|
55 |
d = {}
|
|
|
50 |
model.cls_head.load_weights("./checkpoints/mixmatch")
|
51 |
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
52 |
|
53 |
+
elif model_type == "contrastive_reg":
|
54 |
+
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
|
55 |
+
model.cls_head.load_weights("./checkpoints/contrastive")
|
56 |
+
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
57 |
+
|
58 |
+
elif model_type == "label_propagation":
|
59 |
+
model = LPModel()
|
60 |
+
model.cls_head.load_weights("./checkpoints/label_prop")
|
61 |
+
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
62 |
+
|
63 |
probs = list(preds[0].numpy())
|
64 |
|
65 |
d = {}
|
checkpoints/contrastive.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1cbd1032502d2c8ec25b9c2ec989223d36027381c66ddf32746f3772af03ec8
|
3 |
+
size 461147140
|
checkpoints/contrastive.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:638d167c9d26ba37d1ba2d76ae7863e0edbc3c7b98213f53b9b310613c6e8598
|
3 |
+
size 14764
|
checkpoints/label_prop.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33c1a8e6cada47450ebf64d466cbd5de26ba15c9d2bfec0a0ed607eaad0d92c1
|
3 |
+
size 461147088
|
checkpoints/label_prop.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7847020dcdc7c07eee2ca947869fa9763f07e9aec75bf07574c0c9dc167ca9e
|
3 |
+
size 14764
|
models.py
CHANGED
@@ -61,4 +61,25 @@ class MixMatch(tf.keras.Model):
|
|
61 |
embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
|
62 |
augs = self.augment(embeds,training=training)
|
63 |
|
64 |
-
return self.cls_head(augs,training=training)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
|
62 |
augs = self.augment(embeds,training=training)
|
63 |
|
64 |
+
return self.cls_head(augs,training=training)
|
65 |
+
|
66 |
+
class LPModel(tf.keras.Model):
|
67 |
+
"""label propagation"""
|
68 |
+
def __init__(self,bert_model="readerbench/RoBERT-base",num_classes=4,**kwargs):
|
69 |
+
super(LPModel,self).__init__(**kwargs)
|
70 |
+
self.bert = TFAutoModel.from_pretrained(bert_model)
|
71 |
+
self.num_classes = num_classes
|
72 |
+
|
73 |
+
self.cls_head = tf.keras.Sequential([
|
74 |
+
tf.keras.layers.Dense(256,activation="relu"),
|
75 |
+
tf.keras.layers.Dropout(0.2),
|
76 |
+
tf.keras.layers.Dense(64,activation="relu"),
|
77 |
+
tf.keras.layers.Dense(self.num_classes, activation="softmax")
|
78 |
+
])
|
79 |
+
|
80 |
+
def call(self, inputs, training):
|
81 |
+
ids, mask = inputs
|
82 |
+
|
83 |
+
embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
|
84 |
+
|
85 |
+
return self.cls_head(embeds, training=training)
|