nielklug's picture
init
6ed21b9
raw
history blame
3.44 kB
import sys
from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader
"""
This functions is used to replace the leaves of parse trees in a file of text string form by the leaves from
a file of conll format.
It takes the file for the parse trees of text string form and the file of conll file, which have to correspond
with each other.
It creates a new file containing the replaced parse tree in text string format.
"""
def replace_leaves(parse_path, conll_path, output_path):
# extract new leaves form the conll file
with open(conll_path, 'r', encoding='utf-8') as f_conll:
leaves = []
current_leaves = []
for line in f_conll.readlines():
if line == '\n':
leaves.append(current_leaves)
current_leaves = []
else:
leaf = line.split()[1].strip()
current_leaves.append(leaf)
# read the original parse tree
reader = BracketParseCorpusReader('', [parse_path])
trees = reader.parsed_sents()
assert len(trees) == len(leaves), "The number of trees and leaves is not matched."
with open(output_path, 'w', encoding='utf-8') as f_output:
for i, (tree, current_leaves) in enumerate(zip(trees, leaves)):
leaf_positions = tree.treepositions('leaves')
assert len(leaf_positions) == len(current_leaves), f"The number of leaves is not matched at position {i}:\
{len(leaf_positions)} vs {len(current_leaves)} \n{tree.leaves()}\n{current_leaves}"
for j, (pos, leaf) in enumerate(zip(leaf_positions, current_leaves)):
tree[pos] = leaf
f_output.write('{}\n'.format(tree.pformat(margin=1e100)))
def replace_labels(parse_path, conll_path, output_path):
# extract new leaves form the conll file
with open(conll_path, 'r', encoding='utf-8') as f_conll:
leaves = []
current_leaves = []
for line in f_conll.readlines():
if line == '\n':
leaves.append(current_leaves)
current_leaves = []
else:
leaf = line.split()[1].strip()
current_leaves.append(leaf)
# read the original parse tree
reader = BracketParseCorpusReader('', [parse_path])
trees = reader.parsed_sents()
assert len(trees) == len(leaves), "The number of trees and leaves is not matched."
with open(output_path, 'w', encoding='utf-8') as f_output:
for i, (tree, current_leaves) in enumerate(zip(trees, leaves)):
leaf_positions = tree.treepositions('leaves')
assert len(leaf_positions) == len(current_leaves), f"The number of leaves is not matched at position {i}:\
{len(leaf_positions)} vs {len(current_leaves)} \n{tree.leaves()}\n{current_leaves}"
for j, (pos, leaf) in enumerate(zip(leaf_positions, current_leaves)):
tree[pos[:-1]].set_label(leaf)
f_output.write('{}\n'.format(tree.pformat(margin=1e100)))
"""
For example:
cd schmid/MHG-Parser/self-attentive-parser-master
python src/utils.py data/mhg/MHG.parses data/mhg/MHG.mapped data/mhg/MHG_retag.parses
"""
if __name__=='__main__':
assert len(sys.argv) == 4, "Wrong number of input file paths"
parse_path, conll_path, output_path = sys.argv[1:]
# replace_leaves(parse_path, conll_path, output_path)
replace_labels(parse_path, conll_path, output_path)