|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Add flexibility scores to a phrase table half. |
|
|
|
You usually don't have to call this script directly; to add flexibility |
|
scores to your model, run train-model.perl with the option |
|
"--flexibility-score" (will only affect steps 5 and 6). |
|
|
|
Usage: |
|
python flexibility_score.py extract.context(.inv).sorted \ |
|
[--Inverse] [--Hierarchical] < phrasetable > output_file |
|
""" |
|
|
|
from __future__ import division |
|
from __future__ import unicode_literals |
|
|
|
import sys |
|
import gzip |
|
from collections import defaultdict |
|
|
|
|
|
class FlexScore: |
|
|
|
def __init__(self, inverted, hierarchical): |
|
self.inverted = inverted |
|
self.hierarchical = hierarchical |
|
|
|
def store_pt(self, obj): |
|
"""Store line in dictionary. |
|
|
|
If we work with inverted phrase table, swap the two phrases. |
|
""" |
|
src, target = obj[0], obj[1] |
|
|
|
if self.inverted: |
|
src, target = target, src |
|
|
|
self.phrase_pairs[src][target] = obj |
|
|
|
def update_contextcounts(self, obj): |
|
"""count the number of contexts a phrase pair occurs in""" |
|
src, target = obj[0], obj[1] |
|
self.context_counts[src][target] += 1 |
|
if obj[-1].startswith(b'<'): |
|
self.context_counts_l[src][target] += 1 |
|
elif obj[-1].startswith(b'>'): |
|
self.context_counts_r[src][target] += 1 |
|
elif obj[-1].startswith(b'v'): |
|
self.context_counts_d[src][target] += 1 |
|
else: |
|
sys.stderr.write( |
|
b"\nERROR in line: {0}\n".format(b' ||| '.join(obj))) |
|
sys.stderr.write( |
|
b"ERROR: expecting one of '<, >, v' as context marker " |
|
"in context extract file.\n") |
|
raise ValueError |
|
|
|
def traverse_incrementally(self, phrasetable, flexfile): |
|
"""Traverse phrase table and phrase extract file (with context |
|
information) incrementally without storing all in memory. |
|
""" |
|
|
|
increment = b'' |
|
old_increment = 1 |
|
stack = [''] * 2 |
|
|
|
|
|
sort_pt = 0 |
|
if self.inverted: |
|
sort_pt = 1 |
|
|
|
while old_increment != increment: |
|
|
|
old_increment = increment |
|
|
|
self.phrase_pairs = defaultdict(dict) |
|
self.context_counts = defaultdict(lambda: defaultdict(int)) |
|
self.context_counts_l = defaultdict(lambda: defaultdict(int)) |
|
self.context_counts_r = defaultdict(lambda: defaultdict(int)) |
|
self.context_counts_d = defaultdict(lambda: defaultdict(int)) |
|
|
|
if stack[0]: |
|
self.store_pt(stack[0]) |
|
stack[0] = b'' |
|
|
|
if stack[1]: |
|
self.update_contextcounts(stack[1]) |
|
stack[1] = b'' |
|
|
|
for line in phrasetable: |
|
line = line.rstrip().split(b' ||| ') |
|
if line[sort_pt] != increment: |
|
increment = line[sort_pt] |
|
stack[0] = line |
|
break |
|
else: |
|
self.store_pt(line) |
|
|
|
for line in flexfile: |
|
line = line.rstrip().split(b' ||| ') |
|
if line[0] + b' |' <= old_increment + b' |': |
|
self.update_contextcounts(line) |
|
|
|
else: |
|
stack[1] = line |
|
break |
|
|
|
yield 1 |
|
|
|
def main(self, phrasetable, flexfile, output_object): |
|
|
|
i = 0 |
|
sys.stderr.write( |
|
"Incrementally loading phrase table " |
|
"and adding flexibility score...") |
|
for block in self.traverse_incrementally(phrasetable, flexfile): |
|
|
|
self.flexprob_l = normalize(self.context_counts_l) |
|
self.flexprob_r = normalize(self.context_counts_r) |
|
self.flexprob_d = normalize(self.context_counts_d) |
|
|
|
|
|
sortkey = lambda x: x + b' |' |
|
for src in sorted(self.phrase_pairs, key=sortkey): |
|
for target in sorted(self.phrase_pairs[src], key=sortkey): |
|
|
|
if i % 1000000 == 0: |
|
sys.stderr.write('.') |
|
i += 1 |
|
|
|
outline = self.write_phrase_table(src, target) |
|
output_object.write(outline) |
|
sys.stderr.write('done\n') |
|
|
|
def write_phrase_table(self, src, target): |
|
|
|
line = self.phrase_pairs[src][target] |
|
flexscore_l = b"{0:.6g}".format(self.flexprob_l[src][target]) |
|
flexscore_r = b"{0:.6g}".format(self.flexprob_r[src][target]) |
|
line[3] += b' ' + flexscore_l + b' ' + flexscore_r |
|
|
|
if self.hierarchical: |
|
try: |
|
flexscore_d = b"{0:.6g}".format(self.flexprob_d[src][target]) |
|
except KeyError: |
|
flexscore_d = b"1" |
|
line[3] += b' ' + flexscore_d |
|
|
|
return b' ||| '.join(line) + b'\n' |
|
|
|
|
|
def normalize(d): |
|
|
|
out_dict = defaultdict(dict) |
|
|
|
for src in d: |
|
total = sum(d[src].values()) |
|
|
|
for target in d[src]: |
|
out_dict[src][target] = d[src][target] / total |
|
|
|
return out_dict |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
if len(sys.argv) < 1: |
|
sys.stderr.write( |
|
"Usage: " |
|
"python flexibility_score.py extract.context(.inv).sorted " |
|
"[--Inverse] [--Hierarchical] < phrasetable > output_file\n") |
|
exit() |
|
|
|
flexfile = sys.argv[1] |
|
if '--Inverse' in sys.argv: |
|
inverted = True |
|
else: |
|
inverted = False |
|
|
|
if '--Hierarchical' in sys.argv: |
|
hierarchical = True |
|
else: |
|
hierarchical = False |
|
|
|
FS = FlexScore(inverted, hierarchical) |
|
FS.main(sys.stdin, gzip.open(flexfile, 'r'), sys.stdout) |
|
|