|
#include "search/vertex.hh" |
|
|
|
#include "search/context.hh" |
|
|
|
#include <boost/unordered_map.hpp> |
|
|
|
#include <algorithm> |
|
#include <functional> |
|
#include <cassert> |
|
|
|
namespace search { |
|
|
|
namespace { |
|
|
|
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); |
|
|
|
class DivideLeft { |
|
public: |
|
explicit DivideLeft(unsigned char index) |
|
: index_(index) {} |
|
|
|
uint64_t operator()(const lm::ngram::ChartState &state) const { |
|
return (index_ < state.left.length) ? |
|
state.left.pointers[index_] : |
|
(kCompleteAdd - state.left.full); |
|
} |
|
|
|
private: |
|
unsigned char index_; |
|
}; |
|
|
|
class DivideRight { |
|
public: |
|
explicit DivideRight(unsigned char index) |
|
: index_(index) {} |
|
|
|
uint64_t operator()(const lm::ngram::ChartState &state) const { |
|
return (index_ < state.right.length) ? |
|
static_cast<uint64_t>(state.right.words[index_]) : |
|
(kCompleteAdd - state.left.full); |
|
} |
|
|
|
private: |
|
unsigned char index_; |
|
}; |
|
|
|
template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) { |
|
|
|
typedef boost::unordered_map<uint64_t, std::size_t> Lookup; |
|
Lookup lookup; |
|
for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) { |
|
uint64_t key = divider(i->state); |
|
std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size()))); |
|
if (res.second) { |
|
extend.resize(extend.size() + 1); |
|
extend.back().AppendHypothesis(*i); |
|
} else { |
|
extend[res.first->second].AppendHypothesis(*i); |
|
} |
|
} |
|
|
|
} |
|
|
|
lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) { |
|
return right.words[index]; |
|
} |
|
|
|
uint64_t Identify(const lm::ngram::Left &left, unsigned char index) { |
|
return left.pointers[index]; |
|
} |
|
|
|
template <class Side> class DetermineSame { |
|
public: |
|
DetermineSame(const Side &side, unsigned char guaranteed) |
|
: side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {} |
|
|
|
void Consider(const Side &other) { |
|
if (shared_ != other.length) { |
|
complete_ = false; |
|
if (shared_ > other.length) |
|
shared_ = other.length; |
|
} |
|
for (unsigned char i = guaranteed_; i < shared_; ++i) { |
|
if (Identify(side_, i) != Identify(other, i)) { |
|
shared_ = i; |
|
complete_ = false; |
|
return; |
|
} |
|
} |
|
} |
|
|
|
unsigned char Shared() const { return shared_; } |
|
|
|
bool Complete() const { return complete_; } |
|
|
|
private: |
|
const Side &side_; |
|
unsigned char guaranteed_, shared_; |
|
bool complete_; |
|
}; |
|
|
|
|
|
|
|
const unsigned char kPolicyAlternate = 0; |
|
|
|
const unsigned char kPolicyOneLeft = 1; |
|
|
|
const unsigned char kPolicyOneRight = 2; |
|
|
|
|
|
|
|
} |
|
|
|
namespace { |
|
struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> { |
|
bool operator()(const HypoState &first, const HypoState &second) const { |
|
return first.score > second.score; |
|
} |
|
}; |
|
} |
|
|
|
void VertexNode::FinishRoot() { |
|
std::sort(hypos_.begin(), hypos_.end(), GreaterByScore()); |
|
extend_.clear(); |
|
|
|
state_.left.full = false; |
|
state_.left.length = 0; |
|
state_.right.length = 0; |
|
right_full_ = false; |
|
niceness_ = 0; |
|
policy_ = kPolicyAlternate; |
|
if (hypos_.size() == 1) { |
|
extend_.resize(1); |
|
extend_.front().AppendHypothesis(hypos_.front()); |
|
extend_.front().FinishedAppending(0, 0); |
|
} |
|
if (hypos_.empty()) { |
|
bound_ = -INFINITY; |
|
} else { |
|
bound_ = hypos_.front().score; |
|
} |
|
} |
|
|
|
void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) { |
|
assert(!hypos_.empty()); |
|
assert(extend_.empty()); |
|
bound_ = hypos_.front().score; |
|
state_ = hypos_.front().state; |
|
bool all_full = state_.left.full; |
|
bool all_non_full = !state_.left.full; |
|
DetermineSame<lm::ngram::Left> left(state_.left, common_left); |
|
DetermineSame<lm::ngram::Right> right(state_.right, common_right); |
|
for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) { |
|
all_full &= i->state.left.full; |
|
all_non_full &= !i->state.left.full; |
|
left.Consider(i->state.left); |
|
right.Consider(i->state.right); |
|
} |
|
state_.left.full = all_full && left.Complete(); |
|
right_full_ = all_full && right.Complete(); |
|
state_.left.length = left.Shared(); |
|
state_.right.length = right.Shared(); |
|
|
|
if (!all_full && !all_non_full) { |
|
policy_ = kPolicyAlternate; |
|
} else if (left.Complete()) { |
|
policy_ = kPolicyOneRight; |
|
} else if (right.Complete()) { |
|
policy_ = kPolicyOneLeft; |
|
} else { |
|
policy_ = kPolicyAlternate; |
|
} |
|
niceness_ = state_.left.length + state_.right.length; |
|
} |
|
|
|
void VertexNode::BuildExtend() { |
|
|
|
if (!extend_.empty()) return; |
|
|
|
if (hypos_.size() <= 1) return; |
|
bool left_branch = true; |
|
switch (policy_) { |
|
case kPolicyAlternate: |
|
left_branch = (state_.left.length <= state_.right.length); |
|
break; |
|
case kPolicyOneLeft: |
|
left_branch = true; |
|
break; |
|
case kPolicyOneRight: |
|
left_branch = false; |
|
break; |
|
} |
|
if (left_branch) { |
|
Split(DivideLeft(state_.left.length), hypos_, extend_); |
|
} else { |
|
Split(DivideRight(state_.right.length), hypos_, extend_); |
|
} |
|
for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) { |
|
|
|
i->FinishedAppending(state_.left.length, state_.right.length); |
|
} |
|
} |
|
|
|
} |
|
|