Aureliano commited on
Commit
32192ff
·
1 Parent(s): 506ab66

Usage example added.

Browse files
Files changed (1) hide show
  1. README.md +61 -12
README.md CHANGED
@@ -13,23 +13,72 @@ For a detailed description and experimental results, please refer to the origina
13
  This repository contains a small ELECTRA discriminator finetuned on a corpus of interactive fiction commands labelled with the WordNet synset offset of the verb in the sentence. The original dataset has been collected from the list of action in the walkthroughs for the game included in the [Jericho](https://github.com/microsoft/jericho) framework and manually annotated. For more information visit https://github.com/aporporato/electra and https://github.com/aporporato/jericho-corpora.
14
 
15
  ## How to use the discriminator in `transformers`
 
16
 
17
  ```python
18
- from transformers import ElectraForPreTraining, ElectraTokenizerFast
19
- import torch
20
 
21
- discriminator = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
22
- tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-small-discriminator")
 
23
 
24
- sentence = "The quick brown fox jumps over the lazy dog"
25
- fake_sentence = "The quick brown fox fake over the lazy dog"
 
 
 
 
 
 
 
26
 
27
- fake_tokens = tokenizer.tokenize(fake_sentence)
28
- fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
29
- discriminator_outputs = discriminator(fake_inputs)
30
- predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
31
 
32
- [print("%7s" % token, end="") for token in fake_tokens]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- [print("%7s" % int(prediction), end="") for prediction in predictions.squeeze().tolist()]
35
  ```
 
13
  This repository contains a small ELECTRA discriminator finetuned on a corpus of interactive fiction commands labelled with the WordNet synset offset of the verb in the sentence. The original dataset has been collected from the list of action in the walkthroughs for the game included in the [Jericho](https://github.com/microsoft/jericho) framework and manually annotated. For more information visit https://github.com/aporporato/electra and https://github.com/aporporato/jericho-corpora.
14
 
15
  ## How to use the discriminator in `transformers`
16
+ (Heavily based on: https://github.com/huggingface/notebooks/blob/master/examples/text_classification-tf.ipynb)
17
 
18
  ```python
19
+ import math
 
20
 
21
+ import tensorflow as tf
22
+ from datasets import Dataset, ClassLabel, Features, Value
23
+ from transformers import TFAutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, create_optimizer
24
 
25
+ # This example shows how this model can be used:
26
+ # you should finetune the model of your specific corpus if commands, bogger than this
27
+ dict_train = {
28
+ "idx": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"],
29
+ "sentence": ["e", "get pen", "drop book", "x paper", "i", "south", "get paper", "drop pen", "x book", "inventory",
30
+ "n", "get book", "drop paper", "examine Pen", "inv", "w"],
31
+ "label": ["v01835496", "v01214265", "v01977701", "v02131279", "v02472495", "v01835496", "v01214265", "v01977701",
32
+ "v02131279", "v02472495", "v01835496", "v01214265", "v01977701", "v02131279", "v02472495", "v01835496"]
33
+ }
34
 
35
+ num_labels = len(set(dict_train["label"]))
36
+ features = Features({'idx': Value('uint32'), 'sentence': Value('string'),
37
+ 'label': ClassLabel(names=list(set(dict_train["label"])))})
 
38
 
39
+ raw_train_dataset = Dataset.from_dict(dict_train, features=features)
40
+
41
+ discriminator = TFAutoModelForSequenceClassification.from_pretrained("Aureliano/electra-if", num_labels=num_labels)
42
+ tokenizer = AutoTokenizer.from_pretrained("Aureliano/electra-if")
43
+
44
+ tokenize_function = lambda example: tokenizer(example["sentence"], truncation=True)
45
+
46
+ pre_tokenizer_columns = set(raw_train_dataset.features)
47
+ train_dataset = raw_train_dataset.map(tokenize_function, batched=True)
48
+ tokenizer_columns = list(set(train_dataset.features) - pre_tokenizer_columns)
49
+
50
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
51
+
52
+ batch_size = 16
53
+ tf_train_dataset = train_dataset.to_tf_dataset(
54
+ columns=tokenizer_columns,
55
+ label_cols=["labels"],
56
+ shuffle=True,
57
+ batch_size=batch_size,
58
+ collate_fn=data_collator
59
+ )
60
+
61
+ loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
62
+ num_epochs = 100
63
+ batches_per_epoch = math.ceil(len(train_dataset) / batch_size)
64
+ total_train_steps = int(batches_per_epoch * num_epochs)
65
+
66
+ optimizer, schedule = create_optimizer(
67
+ init_lr=1e-5, num_warmup_steps=1, num_train_steps=total_train_steps
68
+ )
69
+
70
+ discriminator.compile(optimizer=optimizer, loss=loss)
71
+ discriminator.fit(
72
+ tf_train_dataset,
73
+ epochs=num_epochs
74
+ )
75
+
76
+ text = "get lamp"
77
+ encoded_input = tokenizer(text, return_tensors='tf')
78
+ output = discriminator(encoded_input)
79
+ prediction = tf.nn.softmax(output["logits"][0], -1)
80
+ label = dict_train["label"][tf.math.argmax(prediction)]
81
+ print(text, ":", label)
82
+ # ideally [v01214265 -> take.v.04 -> "get into one's hands, take physically"], but probably only with a better dataset
83
 
 
84
  ```