File size: 7,087 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#!/usr/bin/env python

import gzip
import os
import re
import numpy as np
import sys

from bleu import BleuScorer
from coll import OrderedDict
# Edit to set moses python path
sys.path.append(os.path.dirname(__file__) + "/../python")
import moses.dictree as binpt

class DataFormatException(Exception):
  pass

class Hypothesis:
  def __init__(self,text,fv,segments=False):
    self.alignment = [] #only stored for segmented hypos
    self.tokens = [] #only stored for segmented hypos
    if not segments:
      self.text = text
      # Triples of (source-start, source-end, target-end) where segments end positions
      # are 1 beyond the last token
    else:
      # recover segmentation
      self.tokens = []
      align_re = re.compile("\|(\d+)-(\d+)\|")
      for token in text.split():
        match = align_re.match(token)
        if match:
          self.alignment.append\
            ((int(match.group(1)), 1+int(match.group(2)), len(self.tokens)))
        else:
          self.tokens.append(token)
      self.text = " ".join(self.tokens)
      if not self.alignment:
        raise DataFormatException("Expected segmentation information not found in nbest")

        
    self.fv = np.array(fv)
    self.score = 0

  def __str__(self):
    return "{text=%s fv=%s score=%5.4f}" % (self.text, str(self.fv), self.score)

class NBestList:
  def __init__(self,id):
    self.id = id
    self.hyps = []

# Maps feature ids (short feature names) to their values
_feature_index = {}
def set_feature_start(name,index):
  indexes = _feature_index.get(name, [index,0])
  indexes[0] = index
  _feature_index[name] = indexes

def set_feature_end(name,index):
  indexes = _feature_index.get(name, [0,index])
  indexes[1] = index
  _feature_index[name] = indexes

def get_feature_index(name):
  return _feature_index.get(name, [0,0])

def get_nbests(nbest_file, segments=False):
  """Iterate through nbest lists"""
  if nbest_file.endswith("gz"):
    fh = gzip.GzipFile(nbest_file)
  else:
    fh = open(nbest_file)
  lineno = 0
  nbest = None
  for line in fh:
    fields = line.split(" ||| ")
    if len(fields) != 4:
      raise DataFormatException("nbest(%d): %s" % (lineno,line))
    (id, text, scores, total) = fields
    if nbest and nbest.id != id:
      yield nbest
      nbest = None
    if not nbest:
      nbest = NBestList(id)
    fv = []
    score_name = None
    for score in scores.split():
      if score.endswith(":"): 
        score = score[:-1]
        if score_name:
          set_feature_end(score_name,len(fv))
        score_name = score
        set_feature_start(score_name,len(fv))
      else:
        fv.append(float(score))
    if score_name: set_feature_end(score_name,len(fv))
    hyp = Hypothesis(text[:-1],fv,segments)
    nbest.hyps.append(hyp)
  if nbest:
    yield nbest

def get_scores(score_data_file):
  """Iterate through the score data, returning a set of scores for each sentence"""
  scorer = BleuScorer()
  fh = open(score_data_file)
  lineno = 0
  score_vectors = None
  for line in fh:
    if line.startswith("SCORES_TXT_BEGIN"):
      score_vectors = []
    elif line.startswith("SCORES_TXT_END"):
      scores = [scorer.score(score_vector) for score_vector in score_vectors]
      yield scores
    else:
      score_vectors.append([float(i) for i in line[:-1].split()])
  

def get_scored_nbests(nbest_file, score_data_file, input_file, segments=False):
  score_gen = get_scores(score_data_file)
  input_gen = None
  if input_file: input_gen =  open(input_file)
  try:
    for nbest in get_nbests(nbest_file, segments=segments):
      scores = score_gen.next()
      if len(scores) != len(nbest.hyps):
        raise DataFormatException("Length of nbest %s does not match score list (%d != %d)" %
          (nbest.id,len(nbest.hyps), len(scores)))
      input_line = None
      if input_gen:
        input_line = input_gen.next()[:-1]
      for hyp,score in zip(nbest.hyps, scores):
        hyp.score = score
        hyp.input_line = input_line
      yield nbest
  
  except StopIteration:
    raise DataFormatException("Score file shorter than nbest list file")

class PhraseCache:
  """An LRU cache for ttable lookups"""
  def __init__(self, max_size):
    self.max_size = max_size
    self.pairs_to_scores = OrderedDict()

  def get(self, source, target):
    key = (source,target)
    scores = self.pairs_to_scores.get(key,None)
    if scores:
      # cache hit - update access time
      del self.pairs_to_scores[key]
      self.pairs_to_scores[key] = scores
    return scores

  def add(self,source,target,scores):
    key = (source,target)
    self.pairs_to_scores[key] = scores
    while len(self.pairs_to_scores) > self.max_size:
      self.pairs_to_scores.popitem(last=False)

# 
# Should I store full lists of options, or just phrase pairs?
# Should probably store phrase-pairs, but may want to add
# high scoring pairs (say, 20?) when I load the translations
# of a given phrase
#

class CachedPhraseTable:
  def __init__(self,ttable_file,nscores=5,cache_size=20000):
    wa = False
    if binpt.PhraseDictionaryTree.canLoad(ttable_file,True):
      # assume word alignment is included
      wa = True
    self.ttable = binpt.PhraseDictionaryTree(ttable_file,nscores = nscores,wa = wa, tableLimit=0)
    self.cache = PhraseCache(cache_size)
    self.nscores = nscores

  def get_scores(self,phrase):
    source = " ".join(phrase[0])
    target_tuple = tuple(phrase[1])
    target = " ".join(target_tuple)
    scores = self.cache.get(source,target)
    if not scores:
      # cache miss
      scores = [0] * (self.nscores-1) # ignore penalty
      entries = self.ttable.query(source, converter=None)
      # find correct target
      for entry in entries:
        if entry.rhs  == target_tuple:
          scores = entry.scores[:-1]
          break
      #print "QUERY",source,"|||",target,"|||",scores
      self.cache.add(source,target,scores)
    #else:
    #  print "CACHE",source,"|||",target,"|||",scores
    return scores
 

class MosesPhraseScorer:
  def __init__(self,ttable_files, cache_size=20000):
    self.ttables = []
    for ttable_file in ttable_files:
      self.ttables.append(CachedPhraseTable(ttable_file, cache_size=cache_size))
    
  def add_scores(self, hyp):
    """Add the phrase scores to a hypothesis"""
    # Collect up the phrase pairs
    phrases = []
    source_tokens = hyp.input_line.split()
    tgt_st = 0
    if not hyp.alignment:
      raise DataFormatException("Alignments missing from: " + str(hyp))
    for src_st,src_end,tgt_end in hyp.alignment:
      phrases.append((source_tokens[src_st:src_end], hyp.tokens[tgt_st:tgt_end]))
      tgt_st = tgt_end
    # Look up the scores
    phrase_scores = []
    for ttable in self.ttables:
      phrase_scores.append([])
      for phrase in phrases:
        phrase_scores[-1].append(ttable.get_scores(phrase))
#    phrase_scores = np.array(phrase_scores)
#    eps = np.exp(-100)
#    phrase_scores[phrase_scores<eps]=eps
    floor = np.exp(-100)
    phrase_scores = np.clip(np.array(phrase_scores), floor, np.inf)
    hyp.phrase_scores = phrase_scores