import torch import torch.nn as nn from functools import partial, cache from argparse import Namespace from typing import List, Tuple, Dict, Union, Optional from itertools import chain import random from typing import Literal from transformers import T5Tokenizer class Graph(): """ A graph class. :param g: A list of tuples, where each tuple is a triple (head, r, tail). """ def __init__( self, g: List[Tuple[str,str,str]] = [] ): self.g = g self.concepts = self.get_concepts() # list of all concepts in the graph self.relations = self.get_relations() # list of all relations in the graph self.relations_multiple = self.get_relations_multiple() # list of all relations in the graph, including duplicate relations @property def g(self) -> List[Tuple[str,str,str]]: return self._g @g.setter def g(self, g: List[Tuple[str,str,str]]): self._g = g def num_triplets(self) -> int: """ Get the number of triplets in the graph. """ return len(self.g) def get_concepts(self) -> List[str]: """ Get the concepts in the graph. """ concepts = list(set([triplet[i] for triplet in self.g for i in [0, 2]])) concepts.sort() # not necessary but makes debugging easier return concepts def get_relations(self) -> List[str]: """ Get the relations in the graph. """ relations = list(set(self.get_relations_multiple())) relations.sort() # not necessary but makes debugging easier return relations def get_relations_multiple(self) -> List[str]: """ Get the relations in the graph, including duplicate relations. """ relations = [triplet[1] for triplet in self.g] return relations def __str__(self): out_str = '\n'.join([str(triplet) for triplet in self.g]) return out_str class Data(Namespace): def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) def get_dummy_graph(num_triplets:int=3) -> Graph: g = [ ("dog", "IsA", "animal"), ("cat", "IsA", "animal"), ("black poodle", "IsA", "dog"), ("black cat", "IsA", "cat"), ] assert num_triplets <=4, "num_triplets must be <= 4" g = g[:num_triplets] g = Graph(g) return g def r2nl(r: str) -> str: """ Convert a relation to a natural language string. Can be used to implement necessary changes in the data. """ return r def _get_str2tok(g:Graph, tokenizer: T5Tokenizer) -> dict[str, list[int]]: """ Get a dictionary that maps strings to tokens. """ # tokenize concepts and relations c_tok = tokenizer([r2nl(c) for c in g.concepts], padding=False)['input_ids'] r_tok = tokenizer([r2nl(r) for r in g.relations], padding=False)['input_ids'] tokens = c_tok + r_tok node_names = g.concepts + g.relations # these are not necessarily all nodes in the Levi Graph, as relations can occur more than once assert len(tokens) == len(node_names), f"{len(tokens) = }, {len(node_names) = }" # remove end-of-sequence token tokens = [toks[:-1] if toks[-1] == tokenizer.eos_token_id else toks for toks in tokens] # create a dictionary mapping concepts and relations to their tokenized forms str2tok = {node: tok for node, tok in zip(node_names, tokens)} str2tok[''] = [tokenizer.eos_token_id] return str2tok def _get_graphT5_input_sequence(g:Graph, str2tok:dict, use_eos:bool) -> Tuple[list, dict]: # get input sequence (i.e. sequence that will be fed into the model for this graph) all_nodes = g.relations_multiple + g.concepts # list of all concepts and relations that will be in the final sequence (i.e. all nodes of the Levi Graph) # the order of nodes is first all relations (in the order that they appear in g.g), and then all concepts (in alphabetical order. though here the order is not important) if use_eos: all_nodes.append('') all_tokens = [str2tok[node] for node in all_nodes] # list of length #nodes, where each element is a list of token ids indices = {node: [] for node in all_nodes} # dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts and are as long as the number of occurances of the relation in the graph for relations. # WARNING: this assumes that concepts and realtions have different names. This not always the case for REBEL. For concept_indices this is fixed. num_relation_tokens = sum([len(token) for token in all_tokens[:len(g.relations_multiple)]]) # number of tokens that are relations num_concept_tokens = sum([len(token) for token in all_tokens[len(g.relations_multiple):len(g.relations_multiple)+len(g.concepts)]]) # number of tokens that are concepts num_eos_tokens = 1 if use_eos else 0 is_concept = torch.tensor([False] * num_relation_tokens + [True] * num_concept_tokens + [False] * num_eos_tokens, dtype=torch.bool) # tensor of length #nodes, where each element is True if the node is a concept and False if it is a relation index_counter = 0 assert len(all_nodes) == len(all_tokens), (all_nodes, all_tokens) for node, token in zip(all_nodes, all_tokens): indices[node].append((index_counter, index_counter + len(token))) # assert is_concept[index_counter:index_counter+len(token)].all() == (node in g.concepts), f"{is_concept = }, {node = }, {g.concepts = }, {index_counter = }, {len(token) = }, {is_concept[index_counter:index_counter+len(token)] = }" index_counter += len(token) concept_indices = {node: [indices[node][-1]] for node in g.concepts} # [-1] and reput in list in case relations have the same name as a concept (concepts are put in last). sequence = torch.tensor(list(chain.from_iterable(all_tokens)), dtype=torch.long) sequence = sequence.unsqueeze(0) # add batch dimension is_concept = is_concept.unsqueeze(0) # add batch dimension return sequence, indices, is_concept, concept_indices def _get_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]: ### get relative position of each node in the sequence, as well as the sparsity mask ### # initialize relative position matrix) relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long) # initialize sparsity mask sparsity_mask = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool) # initialize use_additional_bucket use_additional_bucket = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool) # relative positions / sparsity within each node for start, end in chain.from_iterable(indices.values()): relative_position[start:end, start:end] = _get_relative_position(end-start) sparsity_mask[start:end, start:end] = True # relative position between nodes of the same triplet relation_counter = {relation: 0 for relation in g.relations} # dictionary mapping each relation to the number of times it has already appeared in the graph for triplet in g.g: pos_h = indices[triplet[0]][0] # position of head; tuple (start_index, end_index) pos_r = indices[triplet[1]][relation_counter[triplet[1]]] # position of relation; tuple (start_index, end_index) pos_t = indices[triplet[2]][0] # position of tail; tuple (start_index, end_index) l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0] # length (i.e. number of tokens) of head and relation # iterate over all combinations of tokens in each triplet. This implementation is not very elegant, but it is sufficiently fast. for ih, ph in enumerate(range(pos_h[0], pos_h[1])): # iterate over all head tokens for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens relative_position[ph, pr] = l_h - ih + ir relative_position[pr, ph] = - (l_h - ih + ir) sparsity_mask[ph, pr] = True sparsity_mask[pr, ph] = True for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens relative_position[ph, pt] = l_h - ih + l_r + it relative_position[pt, ph] = - (l_h - ih + l_r + it) sparsity_mask[ph, pt] = True sparsity_mask[pt, ph] = True for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens relative_position[pr, pt] = l_r - ir + it relative_position[pt, pr] = - (l_r - ir + it) sparsity_mask[pr, pt] = True sparsity_mask[pt, pr] = True relation_counter[triplet[1]] += 1 # next time when that relation comes, then the next tokens will be used if use_eos: assert len(indices['']) == 1, f"{indices[''] = } should have length 1" pos_eos = indices[''][0] # position of head; tuple (start_index, end_index) assert pos_eos[0] + 1 == pos_eos[1], pos_eos pos_eos = pos_eos[0] # position of eos token if eos == 'bidirectional': relative_position[:, pos_eos] = +1e6 relative_position[pos_eos, :] = -1e6 relative_position[pos_eos, pos_eos] = 0 sparsity_mask[:, pos_eos] = True sparsity_mask[pos_eos, :] = True elif eos == 'unidirectional': relative_position[:, pos_eos] = 1e6 relative_position[pos_eos, pos_eos] = 0 sparsity_mask[pos_eos, :] = False # no messages from eos to other tokens sparsity_mask[:, pos_eos] = True else: raise ValueError(f'{eos = } is not a valid option.') relative_position = relative_position.unsqueeze(0) # add batch dimension sparsity_mask = sparsity_mask.unsqueeze(0) # add batch dimension use_additional_bucket = use_additional_bucket.unsqueeze(0) # add batch dimension return relative_position, sparsity_mask, use_additional_bucket def _get_global_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]: ### get relative position of each node in the sequence, as well as the sparsity mask ### # initialize relative position matrix) # relative_position = torch.ones(size=(sequence_length, sequence_length), dtype=torch.long) * 1e6 # technically should be float('inf'), but it does not matter relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long) # initialize sparsity mask sparsity_mask = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool) # could switch to None, but then code has to be updated accordingly (in particular get_batch) # initialize use_additional_bucket use_additional_bucket = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool) # relative positions / sparsity within each node for start, end in chain.from_iterable(indices.values()): relative_position[start:end, start:end] = _get_relative_position(end-start) use_additional_bucket[start:end, start:end] = False # relative position between nodes of the same triplet relation_counter = {relation: 0 for relation in g.relations} # dictionary mapping each relation to the number of times it has already appeared in the graph for triplet in g.g: pos_h = indices[triplet[0]][0] # position of head; tuple (start_index, end_index) pos_r = indices[triplet[1]][relation_counter[triplet[1]]] # position of relation; tuple (start_index, end_index) pos_t = indices[triplet[2]][0] # position of tail; tuple (start_index, end_index) l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0] # length (i.e. number of tokens) of head and relation # iterate over all combinations of tokens in each triplet. This implementation is not very elegant, but it works. for ih, ph in enumerate(range(pos_h[0], pos_h[1])): # iterate over all head tokens for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens relative_position[ph, pr] = l_h - ih + ir relative_position[pr, ph] = - (l_h - ih + ir) use_additional_bucket[ph, pr] = False use_additional_bucket[pr, ph] = False for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens relative_position[ph, pt] = l_h - ih + l_r + it relative_position[pt, ph] = - (l_h - ih + l_r + it) use_additional_bucket[ph, pt] = False use_additional_bucket[pt, ph] = False for ir, pr in enumerate(range(pos_r[0], pos_r[1])): # iterate over all relation tokens for it, pt in enumerate(range(pos_t[0], pos_t[1])): # iterate over all tail tokens relative_position[pr, pt] = l_r - ir + it relative_position[pt, pr] = - (l_r - ir + it) use_additional_bucket[pr, pt] = False use_additional_bucket[pt, pr] = False relation_counter[triplet[1]] += 1 # next time when that relation comes, then the next tokens will be used if use_eos: assert len(indices['']) == 1, f"{indices[''] = } should have length 1" pos_eos = indices[''][0] # position of head; tuple (start_index, end_index) assert pos_eos[0] + 1 == pos_eos[1], pos_eos pos_eos = pos_eos[0] # position of eos token if eos == 'bidirectional': relative_position[:, pos_eos] = +1e6 relative_position[pos_eos, :] = -1e6 relative_position[pos_eos, pos_eos] = 0 sparsity_mask[:, pos_eos] = True sparsity_mask[pos_eos, :] = True use_additional_bucket[:, pos_eos] = False use_additional_bucket[pos_eos, :] = False elif eos == 'unidirectional': relative_position[:, pos_eos] = 1e6 relative_position[pos_eos, pos_eos] = 0 sparsity_mask[pos_eos, :] = False # no messages from eos to other tokens sparsity_mask[:, pos_eos] = True use_additional_bucket[:, pos_eos] = False use_additional_bucket[pos_eos, :] = False else: raise ValueError(f'{eos = } is not a valid option.') relative_position = relative_position.unsqueeze(0) # add batch dimension sparsity_mask = sparsity_mask.unsqueeze(0) # add batch dimension use_additional_bucket = use_additional_bucket.unsqueeze(0) # add batch dimension return relative_position, sparsity_mask, use_additional_bucket def graph_to_graphT5(g:Graph, tokenizer:T5Tokenizer, how:str, eos:str)->Data: """ Convert a graph to a graphT5 input. :param g: graph :param tokenizer: tokenizer :param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively. :param eos: end-of-sequence token. Can be `False` for not using an eos token. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph, with a relative position of positive infinity (from node to eos) or negative infinity (from eos to node). `unidirectional` means that the eos token is connected to every node in the graph with a relative position of positive infinity (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM """ if not isinstance(g, Graph): g = Graph(g) eos = str(eos) assert eos in ['False', 'bidirectional', 'unidirectional'], f"{eos = } must be either 'False', 'bidirectional', or 'unidirectional'" use_eos:bool = eos != 'False' str2tok = _get_str2tok(g, tokenizer) # get a dictionary mapping concepts and relations to their tokenized forms sequence, indices, is_concept, concept_indices = _get_graphT5_input_sequence(g, str2tok, use_eos) # get input sequence (i.e. sequence that will be fed into the model for this graph sequence_length = sequence.shape[1] if how == 'local': relative_position, sparsity_mask, use_additional_bucket = _get_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos) num_additional_buckets = 0 # lGLM does not use additional buckets elif how == 'global': relative_position, sparsity_mask, use_additional_bucket = _get_global_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos) num_additional_buckets = 1 # gGLM uses 1 additional bucket for long-ranged G2G connections else: raise ValueError(f"how must be either 'local' or 'global', but is {how}") input_ids = sequence data = Data(input_ids=input_ids, relative_position=relative_position, sparsity_mask=sparsity_mask, use_additional_bucket=use_additional_bucket, indices=indices, is_concept=is_concept, concept_indices=concept_indices, num_additional_buckets=num_additional_buckets) return data @cache def _get_relative_position(size): return torch.tensor([[i - j for i in range(size)] for j in range(size)], dtype=torch.long) def get_embedding( sequence_embedding: torch.Tensor, indices: Dict[str, List[Tuple[int, int]]], concept: str, embedding_aggregation: str = "mean", ): """ Returns the embedding of a concept. :param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size) :param indices: dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts. :param concept: the concept for which the embedding should be returned :param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept. :return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size). """ assert concept in indices.keys(), f"{concept = } is not a node in the graph. {indices = }" assert len(indices[concept]) == 1, f"{concept = } is not a concept, as concepts occur only once in the graph. {indices = }" start, end = indices[concept][0] sequence_embedding = sequence_embedding[start:end, :] if embedding_aggregation == "mean": return torch.mean(sequence_embedding, dim=0, keepdim=True) elif embedding_aggregation == "seq": return sequence_embedding else: raise NotImplementedError(f"{embedding_aggregation = } is not supported. Use either 'mean' or 'seq'.") def add_text_to_graph_data(data, text, tokenizer, use_text): if use_text in {'False', '', False, None}: return None text_seq = torch.tensor(tokenizer(text, padding=False)['input_ids']).unsqueeze(0) new_input_ids = torch.cat([data.input_ids, text_seq], dim=1) old_seq_len = data.input_ids.shape[1] text_seq_len = text_seq.shape[1] new_seq_len = new_input_ids.shape[1] new_is_graph = torch.zeros(size=(1, new_seq_len), dtype=torch.bool) new_is_graph[:, :old_seq_len] = True if data.relative_position is None: # sequence transformer assert data.sparsity_mask is None assert data.use_additional_bucket is None data.input_ids = new_input_ids data.is_graph = new_is_graph return None new_relative_position = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.relative_position.dtype) new_relative_position[:, :old_seq_len, :old_seq_len] = data.relative_position new_relative_position[:, old_seq_len:, old_seq_len:] = _get_relative_position(text_seq_len) new_sparsity_mask = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.sparsity_mask.dtype) new_sparsity_mask[:, :old_seq_len, :old_seq_len] = data.sparsity_mask new_sparsity_mask[:, old_seq_len:, old_seq_len:] = True new_use_additional_bucket = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.use_additional_bucket.dtype) new_use_additional_bucket[:, :old_seq_len, :old_seq_len] = data.use_additional_bucket new_use_additional_bucket[:, old_seq_len:, old_seq_len:] = False # could change that if we want T2T and local G2G relations to be learned separately if use_text in {'FullyConnected', True}: new_sparsity_mask[:, old_seq_len:, :old_seq_len] = True new_sparsity_mask[:, :old_seq_len, old_seq_len:] = True new_use_additional_bucket[:, old_seq_len:, :old_seq_len] = True new_use_additional_bucket[:, :old_seq_len, old_seq_len:] = True new_relative_position[:, old_seq_len:, :old_seq_len] = data.num_additional_buckets new_relative_position[:, :old_seq_len, old_seq_len:] = data.num_additional_buckets + 1 new_num_additional_buckets = data.num_additional_buckets + 2 else: raise ValueError(f"unknown use_text {use_text} (type {type(use_text)})") data.input_ids = new_input_ids data.relative_position = new_relative_position data.sparsity_mask = new_sparsity_mask data.use_additional_bucket = new_use_additional_bucket data.num_additional_buckets = new_num_additional_buckets data.is_graph = new_is_graph return None class DataProcessor(): @staticmethod def encode_graph(tokenizer, g:Union[Graph,list[tuple[str,str,str]]], text:Optional[str]=None, how:Literal['global', 'local']='global', eos:str="False")->Data: """ convert graph to suitable input for the model. :param tokenizer: tokenizer :param g: graph :param text: text to add to the graph. Can be None if no text should be added. :param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively. :param eos: end-of-sequence token. Can be `False` for not using an eos token. This is the method used in the paper. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph. `unidirectional` means that the eos token is connected to every node in the graph (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM :return: Data object """ if not isinstance(g, Graph): g = Graph(g) data = graph_to_graphT5(g, tokenizer, how, eos) if text is not None: add_text_to_graph_data(data, text, tokenizer, use_text=True) return data @staticmethod def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', return_attention_mask:bool=False, **kwargs)->dict: """ converts list of data instances to batched inputs for GLM forward call. :param datas: list of Data instances :param tokenizer: tokenizer :param max_seq_len: maximum sequence length :param device: device :param return_attention_mask: whether to return attention mask. The attention mask is not used by the GLM encoder, but the decoder needs it to mask out padding tokens in cross attention. :return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket' """ current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances]) if max_seq_len is None: max_seq_len = current_max_seq_len else: max_seq_len = min(max_seq_len, current_max_seq_len) if data_instances[0].relative_position is None: assert data_instances[0].sparsity_mask is None assert data_instances[0].use_additional_bucket is None is_sequence_transformer = True else: assert data_instances[0].sparsity_mask is not None assert data_instances[0].use_additional_bucket is not None is_sequence_transformer = False # intialize tensors input_ids = torch.ones((len(data_instances), max_seq_len), dtype=torch.long, device=device) * tokenizer.pad_token_id if is_sequence_transformer: relative_position = None sparsity_mask = None use_additional_bucket = None else: relative_position = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.long, device=device) sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device) use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device) if return_attention_mask: attention_mask = torch.zeros((len(data_instances), max_seq_len), dtype=torch.bool, device=device) # fill tensors for i, data in enumerate(data_instances): instance_len = min(data.input_ids.shape[1], max_seq_len) input_ids[i, :instance_len] = data.input_ids[:, :instance_len] if not is_sequence_transformer: relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len] sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len] use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len] if return_attention_mask: attention_mask[i, :instance_len] = 1 model_input = { 'input_ids': input_ids, 'relative_position': relative_position, 'sparsity_mask': sparsity_mask, 'use_additional_bucket': use_additional_bucket, **kwargs } if return_attention_mask: return model_input, attention_mask return model_input @staticmethod def get_embedding(sequence_embedding:torch.Tensor, indices:Dict[str,List[Tuple[int, int]]], concept:str, embedding_aggregation:str="mean"): """ Returns embedding of a concept. :param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size) :param indices: dictionary mapping each node to its start- and end-index in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts. indices is part of the Data object. :param concept: the concept for which the embedding should be returned. :param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept. :return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size). """ return get_embedding(sequence_embedding, indices, concept, embedding_aggregation)