Usage example added.
Browse files
README.md
CHANGED
@@ -13,23 +13,72 @@ For a detailed description and experimental results, please refer to the origina
|
|
13 |
This repository contains a small ELECTRA discriminator finetuned on a corpus of interactive fiction commands labelled with the WordNet synset offset of the verb in the sentence. The original dataset has been collected from the list of action in the walkthroughs for the game included in the [Jericho](https://github.com/microsoft/jericho) framework and manually annotated. For more information visit https://github.com/aporporato/electra and https://github.com/aporporato/jericho-corpora.
|
14 |
|
15 |
## How to use the discriminator in `transformers`
|
|
|
16 |
|
17 |
```python
|
18 |
-
|
19 |
-
import torch
|
20 |
|
21 |
-
|
22 |
-
|
|
|
23 |
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
[print("%7s" % int(prediction), end="") for prediction in predictions.squeeze().tolist()]
|
35 |
```
|
|
|
13 |
This repository contains a small ELECTRA discriminator finetuned on a corpus of interactive fiction commands labelled with the WordNet synset offset of the verb in the sentence. The original dataset has been collected from the list of action in the walkthroughs for the game included in the [Jericho](https://github.com/microsoft/jericho) framework and manually annotated. For more information visit https://github.com/aporporato/electra and https://github.com/aporporato/jericho-corpora.
|
14 |
|
15 |
## How to use the discriminator in `transformers`
|
16 |
+
(Heavily based on: https://github.com/huggingface/notebooks/blob/master/examples/text_classification-tf.ipynb)
|
17 |
|
18 |
```python
|
19 |
+
import math
|
|
|
20 |
|
21 |
+
import tensorflow as tf
|
22 |
+
from datasets import Dataset, ClassLabel, Features, Value
|
23 |
+
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, create_optimizer
|
24 |
|
25 |
+
# This example shows how this model can be used:
|
26 |
+
# you should finetune the model of your specific corpus if commands, bogger than this
|
27 |
+
dict_train = {
|
28 |
+
"idx": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"],
|
29 |
+
"sentence": ["e", "get pen", "drop book", "x paper", "i", "south", "get paper", "drop pen", "x book", "inventory",
|
30 |
+
"n", "get book", "drop paper", "examine Pen", "inv", "w"],
|
31 |
+
"label": ["v01835496", "v01214265", "v01977701", "v02131279", "v02472495", "v01835496", "v01214265", "v01977701",
|
32 |
+
"v02131279", "v02472495", "v01835496", "v01214265", "v01977701", "v02131279", "v02472495", "v01835496"]
|
33 |
+
}
|
34 |
|
35 |
+
num_labels = len(set(dict_train["label"]))
|
36 |
+
features = Features({'idx': Value('uint32'), 'sentence': Value('string'),
|
37 |
+
'label': ClassLabel(names=list(set(dict_train["label"])))})
|
|
|
38 |
|
39 |
+
raw_train_dataset = Dataset.from_dict(dict_train, features=features)
|
40 |
+
|
41 |
+
discriminator = TFAutoModelForSequenceClassification.from_pretrained("Aureliano/electra-if", num_labels=num_labels)
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained("Aureliano/electra-if")
|
43 |
+
|
44 |
+
tokenize_function = lambda example: tokenizer(example["sentence"], truncation=True)
|
45 |
+
|
46 |
+
pre_tokenizer_columns = set(raw_train_dataset.features)
|
47 |
+
train_dataset = raw_train_dataset.map(tokenize_function, batched=True)
|
48 |
+
tokenizer_columns = list(set(train_dataset.features) - pre_tokenizer_columns)
|
49 |
+
|
50 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
|
51 |
+
|
52 |
+
batch_size = 16
|
53 |
+
tf_train_dataset = train_dataset.to_tf_dataset(
|
54 |
+
columns=tokenizer_columns,
|
55 |
+
label_cols=["labels"],
|
56 |
+
shuffle=True,
|
57 |
+
batch_size=batch_size,
|
58 |
+
collate_fn=data_collator
|
59 |
+
)
|
60 |
+
|
61 |
+
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
62 |
+
num_epochs = 100
|
63 |
+
batches_per_epoch = math.ceil(len(train_dataset) / batch_size)
|
64 |
+
total_train_steps = int(batches_per_epoch * num_epochs)
|
65 |
+
|
66 |
+
optimizer, schedule = create_optimizer(
|
67 |
+
init_lr=1e-5, num_warmup_steps=1, num_train_steps=total_train_steps
|
68 |
+
)
|
69 |
+
|
70 |
+
discriminator.compile(optimizer=optimizer, loss=loss)
|
71 |
+
discriminator.fit(
|
72 |
+
tf_train_dataset,
|
73 |
+
epochs=num_epochs
|
74 |
+
)
|
75 |
+
|
76 |
+
text = "get lamp"
|
77 |
+
encoded_input = tokenizer(text, return_tensors='tf')
|
78 |
+
output = discriminator(encoded_input)
|
79 |
+
prediction = tf.nn.softmax(output["logits"][0], -1)
|
80 |
+
label = dict_train["label"][tf.math.argmax(prediction)]
|
81 |
+
print(text, ":", label)
|
82 |
+
# ideally [v01214265 -> take.v.04 -> "get into one's hands, take physically"], but probably only with a better dataset
|
83 |
|
|
|
84 |
```
|