pyt5-base / README.md
mozharovsky's picture
Update README.md
17690fa
|
raw
history blame
4.11 kB

Python T5 base model

Pre-trained model on CodeSearchNet Python dataset using a span-masking objective. The training objective and model were introduced in this paper and first released in this repository. PyT5 model used git-t5 framework built on top of JAX/Flax to pre-train the model on a TPU v3-8 node.

How to use

You can use this model to denoise span-masked sequences. Note, that you'll need to add some boilerplate code for adding the noise to your sequences.

First, install the git-t5 pip package:

> pip install git-t5

Add the following code for encoding an input text:

from typing import Dict, Optional, Tuple

import numpy as np
import torch
from transformers import PreTrainedTokenizerBase

from git_t5.data import DataCollatorForT5MLM


def encode(
    tokenizer: PreTrainedTokenizerBase,
    text: str,
    noise_density: float = 0.15,
    mean_noise_span_length: float = 3.0,
    extra_tokens_per_span_inputs: int = 1,
    extra_tokens_per_span_targets: int = 1,
    seed: Optional[int] = None,
) -> Tuple[Dict[str, torch.Tensor], int]:
    def compute_lengths(tokens_length: int) -> Tuple[int, int]:
        num_noise_tokens = int(round(tokens_length * noise_density))
        num_nonnoise_tokens = tokens_length - num_noise_tokens
        num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
        # inputs contain all nonnoise tokens, sentinels for all noise spans
        # and one EOS token.
        return (
            num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1,
            num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1,
        )

    encoding = tokenizer(
        text,
        truncation=False,
        return_attention_mask=False,
        return_length=True,
    )

    input_length = encoding.pop("length")
    input_length = input_length[0]
    input_length, target_length = compute_lengths(input_length)

    np.random.seed(seed)

    data_collator = DataCollatorForT5MLM(
        tokenizer=tokenizer,
        noise_density=noise_density,
        mean_noise_span_length=mean_noise_span_length,
        input_length=input_length,
        target_length=target_length,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        decoder_start_token_id=tokenizer.pad_token_id,
        sentinel_token_id=tokenizer.convert_tokens_to_ids("<extra_id_0>"),
    )

    batch = data_collator([encoding])  # type: ignore
    batch = {key: torch.tensor(val) for key, val in batch.items()}

    return batch, target_length

Next, download the model and tokenizer:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, 

model = AutoModelForSeq2SeqLM.from_pretrained("formermagic/pyt5-base")

tokenizer = AutoTokenizer.from_pretrained("formermagic/pyt5-base")

Finally, encode your input and generate the output sequence:

text = """
def alias(self, annotationtype, set, fallback=False):
    if inspect.isclass(annotationtype): annotationtype = annotationtype.ANNOTATIONTYPE
    if annotationtype in self.set_alias and set in self.set_alias[annotationtype]:
        return self.set_alias[annotationtype][set]
    elif fallback:
        return set
    else:
        raise KeyError("No alias for set " + set)
"""

batch, max_length = encode(tokenizer, text, seed=22)
outputs = model.generate(batch["input_ids"], max_length=max_length, num_beams=1)
print(tokenizer.batch_decode(outputs[..., 1:]))
print(tokenizer.batch_decode(batch["labels"]))

You should see the following output:

['<extra_id_0>, fallback=<extra_id_1> inspect<extra_id_2>.set_alias<extra_id_3> return self.set<extra_id_4>) def fallback']
['<extra_id_0>, fallback=<extra_id_1> inspect<extra_id_2>.set_alias<extra_id_3> return self.set<extra_id_4>) </s></s>']

As you can see, the predicted result is very close to the target sequence.