File size: 3,795 Bytes
d758c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn

from relogic.pretrainkit.models.relationalsemparse.relational_semparse import RelationalBartForTextToSQL
from relogic.logickit.dataflow.semtransparse.grammar.keywords import SKETCH_KEYWORDS, KEYWORDS
import logging
import os

logger = logging.getLogger(__name__)
WEIGHTS_NAME = "pytorch_model.bin"

class RelationalBARTParser(nn.Module):
  """
  output: tuple: (loss, ) in training
  """
  def __init__(self):
    super().__init__()
    self.bert: RelationalBartForTextToSQL = RelationalBartForTextToSQL.from_pretrained("facebook/bart-large")
    self.bert.model.encoder.use_relation_transformer = True


  def forward(self, *input, **kwargs):

    input_token_ids = kwargs.pop("input_ids")
    # column_spans = kwargs.pop("column_spans")
    input_padding_id = kwargs.pop("input_padding_id")
    attention_mask = (input_token_ids != input_padding_id).long()
    example_info_list= kwargs.pop("example_info_list")
    # relation_ids = None
    if self.training:
      label_ids = kwargs.pop("labels")
      label_padding_id = kwargs.pop("label_padding_id")
      # encoded = self.bert.encoder(input_token_ids)[0].contiguous()
      y_ids = label_ids[:, :-1].contiguous()
      lm_labels = label_ids[:, 1:].clone()
      lm_labels[label_ids[:, 1:] == label_padding_id] = -100
      outputs = self.bert(input_token_ids, example_info_list=example_info_list,
                          attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels, )
      return (outputs[0],)

    else:
      label_eos_id = kwargs.pop("label_eos_id")
      label_bos_id = kwargs.pop("label_bos_id")
      label_padding_id = kwargs.pop("label_padding_id")
      generated_ids = self.bert.generate(
        input_ids=input_token_ids,
        example_info_list=example_info_list,
        attention_mask=attention_mask,
        num_beams=1,
        max_length=30,
        length_penalty=2.0,
        early_stopping=True,
        use_cache=True,
        decoder_start_token_id=label_bos_id,
        eos_token_id=label_eos_id,
        pad_token_id=label_padding_id,
        vocab_size=len(KEYWORDS)
      )
      # raise NotImplementedError()
      # label_ids = kwargs.pop("label_ids")
      # label_padding_id = kwargs.pop("label_padding_id")
      # # encoded = self.bert.encoder(input_token_ids)[0].contiguous()
      # y_ids = label_ids[:, :-1].contiguous()
      # lm_labels = label_ids[:, 1:].clone()
      # lm_labels[label_ids[:, 1:] == label_padding_id] = -100
      # outputs = self.bert(input_token_ids, example_info_list=example_info_list,
      #                     attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels, )
      # generated_ids = outputs[-1]
      # raise NotImplementedError()
      return (torch.zeros(1), generated_ids)

  def save_pretrained(self, save_directory):
    """ Save a model and its configuration file to a directory, so that it
        can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.

        Arguments:
            save_directory: directory to which to save.
    """
    assert os.path.isdir(
      save_directory
    ), "Saving path should be a directory where the model and configuration can be saved"

    # Only save the model itself if we are using distributed training
    model_to_save = self.module if hasattr(self, "module") else self

    # Attach architecture to the config
    # model_to_save.config.architectures = [model_to_save.__class__.__name__]

    # If we save using the predefined names, we can load using `from_pretrained`
    output_model_file = os.path.join(save_directory, WEIGHTS_NAME)

    torch.save(model_to_save.state_dict(), output_model_file)

    logger.info("Model weights saved in {}".format(output_model_file))