|
from copy import deepcopy |
|
import torch |
|
import dgl |
|
import stanza |
|
import networkx as nx |
|
|
|
class Sentence2GraphParser: |
|
def __init__(self, language='zh', use_gpu=False, download=False): |
|
self.language = language |
|
if download: |
|
self.stanza_parser = stanza.Pipeline(lang=language, use_gpu=use_gpu) |
|
else: |
|
self.stanza_parser = stanza.Pipeline(lang=language, use_gpu=use_gpu, download_method=None) |
|
|
|
def parse(self, clean_sentence=None, words=None, ph_words=None): |
|
if self.language == 'zh': |
|
assert words is not None and ph_words is not None |
|
ret = self._parse_zh(words, ph_words) |
|
elif self.language == 'en': |
|
assert clean_sentence is not None |
|
ret = self._parse_en(clean_sentence) |
|
else: |
|
raise NotImplementedError |
|
return ret |
|
|
|
def _parse_zh(self, words, ph_words, enable_backward_edge=True, enable_recur_edge=True, |
|
enable_inter_sentence_edge=True, sequential_edge=False): |
|
""" |
|
words: <List of str>, each character in chinese is one item |
|
ph_words: <List of str>, each character in chinese is one item, represented by the phoneme |
|
Example: |
|
text1 = '宝马配挂跛骡鞍,貂蝉怨枕董翁榻.' |
|
words = ['<BOS>', '宝', '马', '配', '挂', '跛', '骡', '鞍', ',' |
|
, '貂', '蝉', '怨', '枕', '董', '翁', '榻', '<EOS>'] |
|
ph_words = ['<BOS>', 'b_ao3_|', 'm_a3_#', 'p_ei4_|', 'g_ua4_#', |
|
'b_o3_#', 'l_uo2_|', 'an1', ',', 'd_iao1_|', |
|
'ch_an2_#', 'van4_#', 'zh_en3_#', 'd_ong3_|', 'ueng1_#', 't_a4', '<EOS>'] |
|
""" |
|
words, ph_words = words[1:-1], ph_words[1:-1] |
|
for i, p_w in enumerate(ph_words): |
|
if p_w == ',': |
|
|
|
|
|
words[i], ph_words[i] = ',', ',' |
|
tmp_words = deepcopy(words) |
|
num_added_space = 0 |
|
for i, p_w in enumerate(ph_words): |
|
if p_w.endswith("#"): |
|
|
|
tmp_words.insert(num_added_space + i + 1, " ") |
|
num_added_space += 1 |
|
if p_w in [',', ',']: |
|
|
|
tmp_words.insert(num_added_space + i + 1, " ") |
|
tmp_words.insert(num_added_space + i, " ") |
|
num_added_space += 2 |
|
clean_text = ''.join(tmp_words).strip() |
|
parser_out = self.stanza_parser(clean_text) |
|
|
|
idx_to_word = {i + 1: w for i, w in enumerate(words)} |
|
|
|
vocab_nodes = {} |
|
vocab_idx_offset = 0 |
|
for sentence in parser_out.sentences: |
|
num_nodes_in_current_sentence = 0 |
|
for vocab_node in sentence.words: |
|
num_nodes_in_current_sentence += 1 |
|
vocab_idx = vocab_node.id + vocab_idx_offset |
|
vocab_text = vocab_node.text.replace(" ", "") |
|
vocab_nodes[vocab_idx] = vocab_text |
|
vocab_idx_offset += num_nodes_in_current_sentence |
|
|
|
|
|
vocab_to_word = {} |
|
current_word_idx = 1 |
|
for vocab_i in vocab_nodes.keys(): |
|
vocab_to_word[vocab_i] = [] |
|
for w_in_vocab_i in vocab_nodes[vocab_i]: |
|
if w_in_vocab_i != idx_to_word[current_word_idx]: |
|
raise ValueError("Word Mismatch!") |
|
vocab_to_word[vocab_i].append(current_word_idx) |
|
current_word_idx += 1 |
|
|
|
|
|
if len(parser_out.sentences) > 5: |
|
print("Detect more than 5 input sentence! pls check whether the sentence is too long!") |
|
vocab_level_source_id, vocab_level_dest_id = [], [] |
|
vocab_level_edge_types = [] |
|
sentences_heads = [] |
|
vocab_id_offset = 0 |
|
|
|
for s in parser_out.sentences: |
|
for w in s.words: |
|
w_idx = w.id + vocab_id_offset |
|
w_dest_idx = w.head + vocab_id_offset |
|
if w.head == 0: |
|
sentences_heads.append(w_idx) |
|
continue |
|
vocab_level_source_id.append(w_idx) |
|
vocab_level_dest_id.append(w_dest_idx) |
|
vocab_id_offset += len(s.words) |
|
vocab_level_edge_types += [0] * len(vocab_level_source_id) |
|
num_vocab = vocab_id_offset |
|
|
|
|
|
if enable_backward_edge: |
|
back_source, back_dest = deepcopy(vocab_level_dest_id), deepcopy(vocab_level_source_id) |
|
vocab_level_source_id += back_source |
|
vocab_level_dest_id += back_dest |
|
vocab_level_edge_types += [1] * len(back_source) |
|
|
|
|
|
inter_sentence_source, inter_sentence_dest = [], [] |
|
if enable_inter_sentence_edge and len(sentences_heads) > 1: |
|
def get_full_graph_edges(nodes): |
|
tmp_edges = [] |
|
for i, node_i in enumerate(nodes): |
|
for j, node_j in enumerate(nodes): |
|
if i == j: |
|
continue |
|
tmp_edges.append((node_i, node_j)) |
|
return tmp_edges |
|
|
|
tmp_edges = get_full_graph_edges(sentences_heads) |
|
for (source, dest) in tmp_edges: |
|
inter_sentence_source.append(source) |
|
inter_sentence_dest.append(dest) |
|
vocab_level_source_id += inter_sentence_source |
|
vocab_level_dest_id += inter_sentence_dest |
|
vocab_level_edge_types += [3] * len(inter_sentence_source) |
|
|
|
if sequential_edge: |
|
seq_source, seq_dest = list(range(1, num_vocab)) + list(range(num_vocab, 0, -1)), \ |
|
list(range(2, num_vocab + 1)) + list(range(num_vocab - 1, -1, -1)) |
|
vocab_level_source_id += seq_source |
|
vocab_level_dest_id += seq_dest |
|
vocab_level_edge_types += [4] * (num_vocab - 1) + [5] * (num_vocab - 1) |
|
|
|
|
|
num_word = len(words) |
|
source_id, dest_id, edge_types = [], [], [] |
|
for (vocab_start, vocab_end, vocab_edge_type) in zip(vocab_level_source_id, vocab_level_dest_id, |
|
vocab_level_edge_types): |
|
|
|
word_start = min(vocab_to_word[vocab_start]) |
|
word_end = min(vocab_to_word[vocab_end]) |
|
source_id.append(word_start) |
|
dest_id.append(word_end) |
|
edge_types.append(vocab_edge_type) |
|
|
|
|
|
for word_indices_in_v in vocab_to_word.values(): |
|
for i, word_idx in enumerate(word_indices_in_v): |
|
if i + 1 < len(word_indices_in_v): |
|
source_id.append(word_idx) |
|
dest_id.append(word_idx + 1) |
|
edge_types.append(4) |
|
if i - 1 >= 0: |
|
source_id.append(word_idx) |
|
dest_id.append(word_idx - 1) |
|
edge_types.append(5) |
|
|
|
|
|
if enable_recur_edge: |
|
recur_source, recur_dest = list(range(1, num_word + 1)), list(range(1, num_word + 1)) |
|
source_id += recur_source |
|
dest_id += recur_dest |
|
edge_types += [2] * len(recur_source) |
|
|
|
|
|
source_id += [0, num_word + 1, 1, num_word] |
|
dest_id += [1, num_word, 0, num_word + 1] |
|
edge_types += [4, 4, 5, 5] |
|
|
|
edges = (torch.LongTensor(source_id), torch.LongTensor(dest_id)) |
|
dgl_graph = dgl.graph(edges) |
|
assert dgl_graph.num_edges() == len(edge_types) |
|
return dgl_graph, torch.LongTensor(edge_types) |
|
|
|
def _parse_en(self, clean_sentence, enable_backward_edge=True, enable_recur_edge=True, |
|
enable_inter_sentence_edge=True, sequential_edge=False, consider_bos_for_index=True): |
|
""" |
|
clean_sentence: <str>, each word or punctuation should be separated by one blank. |
|
""" |
|
edge_types = [] |
|
clean_sentence = clean_sentence.strip() |
|
if clean_sentence.endswith((" .", " ,", " ;", " :", " ?", " !")): |
|
clean_sentence = clean_sentence[:-2] |
|
if clean_sentence.startswith(". "): |
|
clean_sentence = clean_sentence[2:] |
|
parser_out = self.stanza_parser(clean_sentence) |
|
if len(parser_out.sentences) > 5: |
|
print("Detect more than 5 input sentence! pls check whether the sentence is too long!") |
|
print(clean_sentence) |
|
source_id, dest_id = [], [] |
|
sentences_heads = [] |
|
word_id_offset = 0 |
|
|
|
for s in parser_out.sentences: |
|
for w in s.words: |
|
w_idx = w.id + word_id_offset |
|
w_dest_idx = w.head + word_id_offset |
|
if w.head == 0: |
|
sentences_heads.append(w_idx) |
|
continue |
|
source_id.append(w_idx) |
|
dest_id.append(w_dest_idx) |
|
word_id_offset += len(s.words) |
|
num_word = word_id_offset |
|
edge_types += [0] * len(source_id) |
|
|
|
|
|
if enable_backward_edge: |
|
back_source, back_dest = deepcopy(dest_id), deepcopy(source_id) |
|
source_id += back_source |
|
dest_id += back_dest |
|
edge_types += [1] * len(back_source) |
|
|
|
|
|
if enable_recur_edge: |
|
recur_source, recur_dest = list(range(1, num_word + 1)), list(range(1, num_word + 1)) |
|
source_id += recur_source |
|
dest_id += recur_dest |
|
edge_types += [2] * len(recur_source) |
|
|
|
|
|
inter_sentence_source, inter_sentence_dest = [], [] |
|
if enable_inter_sentence_edge and len(sentences_heads) > 1: |
|
def get_full_graph_edges(nodes): |
|
tmp_edges = [] |
|
for i, node_i in enumerate(nodes): |
|
for j, node_j in enumerate(nodes): |
|
if i == j: |
|
continue |
|
tmp_edges.append((node_i, node_j)) |
|
return tmp_edges |
|
|
|
tmp_edges = get_full_graph_edges(sentences_heads) |
|
for (source, dest) in tmp_edges: |
|
inter_sentence_source.append(source) |
|
inter_sentence_dest.append(dest) |
|
source_id += inter_sentence_source |
|
dest_id += inter_sentence_dest |
|
edge_types += [3] * len(inter_sentence_source) |
|
|
|
|
|
source_id += [0, num_word + 1, 1, num_word] |
|
dest_id += [1, num_word, 0, num_word + 1] |
|
edge_types += [4, 4, 5, 5] |
|
|
|
|
|
if sequential_edge: |
|
seq_source, seq_dest = list(range(1, num_word)) + list(range(num_word, 0, -1)), \ |
|
list(range(2, num_word + 1)) + list(range(num_word - 1, -1, -1)) |
|
source_id += seq_source |
|
dest_id += seq_dest |
|
edge_types += [4] * (num_word - 1) + [5] * (num_word - 1) |
|
if consider_bos_for_index: |
|
edges = (torch.LongTensor(source_id), torch.LongTensor(dest_id)) |
|
else: |
|
edges = (torch.LongTensor(source_id) - 1, torch.LongTensor(dest_id) - 1) |
|
dgl_graph = dgl.graph(edges) |
|
assert dgl_graph.num_edges() == len(edge_types) |
|
return dgl_graph, torch.LongTensor(edge_types) |
|
|
|
|
|
def plot_dgl_sentence_graph(dgl_graph, labels): |
|
""" |
|
labels = {idx: word for idx,word in enumerate(sentence.split(" ")) } |
|
""" |
|
import matplotlib.pyplot as plt |
|
nx_graph = dgl_graph.to_networkx() |
|
pos = nx.random_layout(nx_graph) |
|
nx.draw(nx_graph, pos, with_labels=False) |
|
nx.draw_networkx_labels(nx_graph, pos, labels) |
|
plt.show() |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
parser = Sentence2GraphParser("zh") |
|
text1 = '宝马配挂跛骡鞍,貂蝉怨枕董翁榻.' |
|
words = ['<BOS>', '宝', '马', '配', '挂', '跛', '骡', '鞍', ',', '貂', '蝉', '怨', '枕', '董', '翁', '榻', '<EOS>'] |
|
ph_words = ['<BOS>', 'b_ao3_|', 'm_a3_#', 'p_ei4_|', 'g_ua4_#', 'b_o3_#', 'l_uo2_|', 'an1', ',', 'd_iao1_|', |
|
'ch_an2_#', 'van4_#', 'zh_en3_#', 'd_ong3_|', 'ueng1_#', 't_a4', '<EOS>'] |
|
graph1, etypes1 = parser.parse(text1, words, ph_words) |
|
plot_dgl_sentence_graph(graph1, {i: w for i, w in enumerate(ph_words)}) |
|
|
|
|
|
parser = Sentence2GraphParser("en") |
|
text2 = "I love you . You love me . Mixue ice-scream and tea ." |
|
graph2, etypes2 = parser.parse(text2) |
|
plot_dgl_sentence_graph(graph2, {i: w for i, w in enumerate(("<BOS> " + text2 + " <EOS>").split(" "))}) |
|
|
|
|