Andrei-Iulian SĂCELEANU commited on
Commit
db8dccd
1 Parent(s): e0a167e

added all methods for text

Browse files
.gitattributes CHANGED
@@ -39,3 +39,7 @@ checkpoints/mixmatch.index filter=lfs diff=lfs merge=lfs -text
39
  checkpoints/fixmatch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
40
  checkpoints/fixmatch_tune.index filter=lfs diff=lfs merge=lfs -text
41
  checkpoints/freematch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
39
  checkpoints/fixmatch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
40
  checkpoints/fixmatch_tune.index filter=lfs diff=lfs merge=lfs -text
41
  checkpoints/freematch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
42
+ checkpoints/label_prop.index filter=lfs diff=lfs merge=lfs -text
43
+ checkpoints/contrastive.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
44
+ checkpoints/contrastive.index filter=lfs diff=lfs merge=lfs -text
45
+ checkpoints/label_prop.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -50,6 +50,16 @@ def ssl_predict(in_text, model_type):
50
  model.cls_head.load_weights("./checkpoints/mixmatch")
51
  preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
52
 
 
 
 
 
 
 
 
 
 
 
53
  probs = list(preds[0].numpy())
54
 
55
  d = {}
 
50
  model.cls_head.load_weights("./checkpoints/mixmatch")
51
  preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
52
 
53
+ elif model_type == "contrastive_reg":
54
+ model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
55
+ model.cls_head.load_weights("./checkpoints/contrastive")
56
+ preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
57
+
58
+ elif model_type == "label_propagation":
59
+ model = LPModel()
60
+ model.cls_head.load_weights("./checkpoints/label_prop")
61
+ preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
62
+
63
  probs = list(preds[0].numpy())
64
 
65
  d = {}
checkpoints/contrastive.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1cbd1032502d2c8ec25b9c2ec989223d36027381c66ddf32746f3772af03ec8
3
+ size 461147140
checkpoints/contrastive.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:638d167c9d26ba37d1ba2d76ae7863e0edbc3c7b98213f53b9b310613c6e8598
3
+ size 14764
checkpoints/label_prop.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33c1a8e6cada47450ebf64d466cbd5de26ba15c9d2bfec0a0ed607eaad0d92c1
3
+ size 461147088
checkpoints/label_prop.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7847020dcdc7c07eee2ca947869fa9763f07e9aec75bf07574c0c9dc167ca9e
3
+ size 14764
models.py CHANGED
@@ -61,4 +61,25 @@ class MixMatch(tf.keras.Model):
61
  embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
62
  augs = self.augment(embeds,training=training)
63
 
64
- return self.cls_head(augs,training=training)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
62
  augs = self.augment(embeds,training=training)
63
 
64
+ return self.cls_head(augs,training=training)
65
+
66
+ class LPModel(tf.keras.Model):
67
+ """label propagation"""
68
+ def __init__(self,bert_model="readerbench/RoBERT-base",num_classes=4,**kwargs):
69
+ super(LPModel,self).__init__(**kwargs)
70
+ self.bert = TFAutoModel.from_pretrained(bert_model)
71
+ self.num_classes = num_classes
72
+
73
+ self.cls_head = tf.keras.Sequential([
74
+ tf.keras.layers.Dense(256,activation="relu"),
75
+ tf.keras.layers.Dropout(0.2),
76
+ tf.keras.layers.Dense(64,activation="relu"),
77
+ tf.keras.layers.Dense(self.num_classes, activation="softmax")
78
+ ])
79
+
80
+ def call(self, inputs, training):
81
+ ids, mask = inputs
82
+
83
+ embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
84
+
85
+ return self.cls_head(embeds, training=training)