|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <algorithm> |
|
#include "StaticData.h" |
|
#include "ChartHypothesisCollection.h" |
|
#include "ChartHypothesis.h" |
|
#include "ChartManager.h" |
|
#include "HypergraphOutput.h" |
|
#include "util/exception.hh" |
|
#include "parameters/AllOptions.h" |
|
|
|
using namespace std; |
|
using namespace Moses; |
|
|
|
namespace Moses |
|
{ |
|
|
|
ChartHypothesisCollection::ChartHypothesisCollection(AllOptions const& opts) |
|
{ |
|
|
|
|
|
m_beamWidth = opts.search.beam_width; |
|
m_maxHypoStackSize = opts.search.stack_size; |
|
m_nBestIsEnabled = opts.nbest.enabled; |
|
m_bestScore = -std::numeric_limits<float>::infinity(); |
|
} |
|
|
|
ChartHypothesisCollection::~ChartHypothesisCollection() |
|
{ |
|
HCType::iterator iter; |
|
for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) { |
|
ChartHypothesis *hypo = *iter; |
|
delete hypo; |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool ChartHypothesisCollection::AddHypothesis(ChartHypothesis *hypo, ChartManager &manager) |
|
{ |
|
if (hypo->GetFutureScore() == - std::numeric_limits<float>::infinity()) { |
|
manager.GetSentenceStats().AddDiscarded(); |
|
VERBOSE(3,"discarded, -inf score" << std::endl); |
|
delete hypo; |
|
return false; |
|
} |
|
|
|
if (hypo->GetFutureScore() < m_bestScore + m_beamWidth) { |
|
|
|
manager.GetSentenceStats().AddDiscarded(); |
|
VERBOSE(3,"discarded, too bad for stack" << std::endl); |
|
delete hypo; |
|
return false; |
|
} |
|
|
|
|
|
std::pair<HCType::iterator, bool> addRet = Add(hypo, manager); |
|
|
|
|
|
if (addRet.second) { |
|
|
|
return true; |
|
} |
|
|
|
|
|
HCType::iterator &iterExisting = addRet.first; |
|
ChartHypothesis *hypoExisting = *iterExisting; |
|
UTIL_THROW_IF2(iterExisting == m_hypos.end(), |
|
"Adding a hypothesis should have returned a valid iterator"); |
|
|
|
|
|
|
|
|
|
|
|
if (hypo->GetFutureScore() > hypoExisting->GetFutureScore()) { |
|
|
|
VERBOSE(3,"better than matching hyp " << hypoExisting->GetId() << ", recombining, "); |
|
if (m_nBestIsEnabled) { |
|
hypo->AddArc(hypoExisting); |
|
Detach(iterExisting); |
|
} else { |
|
Remove(iterExisting); |
|
} |
|
|
|
bool added = Add(hypo, manager).second; |
|
if (!added) { |
|
iterExisting = m_hypos.find(hypo); |
|
UTIL_THROW2("Offending hypo = " << **iterExisting); |
|
} |
|
return false; |
|
} else { |
|
|
|
VERBOSE(3,"worse than matching hyp " << hypoExisting->GetId() << ", recombining" << std::endl) |
|
if (m_nBestIsEnabled) { |
|
hypoExisting->AddArc(hypo); |
|
} else { |
|
delete hypo; |
|
} |
|
return false; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
pair<ChartHypothesisCollection::HCType::iterator, bool> ChartHypothesisCollection::Add(ChartHypothesis *hypo, ChartManager &manager) |
|
{ |
|
std::pair<HCType::iterator, bool> ret = m_hypos.insert(hypo); |
|
if (ret.second) { |
|
|
|
VERBOSE(3,"added hyp to stack"); |
|
|
|
|
|
if (hypo->GetFutureScore() > m_bestScore) { |
|
VERBOSE(3,", best on stack"); |
|
m_bestScore = hypo->GetFutureScore(); |
|
} |
|
|
|
|
|
VERBOSE(3,", now size " << m_hypos.size()); |
|
if (m_hypos.size() > 2*m_maxHypoStackSize-1) { |
|
PruneToSize(manager); |
|
} else { |
|
VERBOSE(3,std::endl); |
|
} |
|
} |
|
|
|
return ret; |
|
} |
|
|
|
|
|
|
|
|
|
void ChartHypothesisCollection::Detach(const HCType::iterator &iter) |
|
{ |
|
m_hypos.erase(iter); |
|
} |
|
|
|
|
|
|
|
void ChartHypothesisCollection::Remove(const HCType::iterator &iter) |
|
{ |
|
ChartHypothesis *h = *iter; |
|
Detach(iter); |
|
delete h; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void ChartHypothesisCollection::PruneToSize(ChartManager &manager) |
|
{ |
|
if (m_maxHypoStackSize == 0) return; |
|
|
|
if (GetSize() > m_maxHypoStackSize) { |
|
priority_queue<float> bestScores; |
|
|
|
|
|
|
|
HCType::iterator iter = m_hypos.begin(); |
|
float score = 0; |
|
while (iter != m_hypos.end()) { |
|
ChartHypothesis *hypo = *iter; |
|
score = hypo->GetFutureScore(); |
|
if (score > m_bestScore+m_beamWidth) { |
|
bestScores.push(score); |
|
} |
|
++iter; |
|
} |
|
|
|
|
|
|
|
size_t minNewSizeHeapSize = m_maxHypoStackSize > bestScores.size() ? bestScores.size() : m_maxHypoStackSize; |
|
for (size_t i = 1 ; i < minNewSizeHeapSize ; i++) |
|
bestScores.pop(); |
|
|
|
|
|
float scoreThreshold = bestScores.top(); |
|
|
|
|
|
iter = m_hypos.begin(); |
|
while (iter != m_hypos.end()) { |
|
ChartHypothesis *hypo = *iter; |
|
float score = hypo->GetFutureScore(); |
|
if (score < scoreThreshold) { |
|
HCType::iterator iterRemove = iter++; |
|
Remove(iterRemove); |
|
manager.GetSentenceStats().AddPruning(); |
|
} else { |
|
++iter; |
|
} |
|
} |
|
VERBOSE(3,", pruned to size " << m_hypos.size() << endl); |
|
|
|
IFVERBOSE(3) { |
|
TRACE_ERR("stack now contains: "); |
|
for(iter = m_hypos.begin(); iter != m_hypos.end(); iter++) { |
|
ChartHypothesis *hypo = *iter; |
|
TRACE_ERR( hypo->GetId() << " (" << hypo->GetFutureScore() << ") "); |
|
} |
|
TRACE_ERR( endl); |
|
} |
|
|
|
|
|
if (m_hypos.size() > m_maxHypoStackSize * 2) { |
|
std::vector<ChartHypothesis*> hyposOrdered; |
|
|
|
|
|
std::copy(m_hypos.begin(), m_hypos.end(), std::inserter(hyposOrdered, hyposOrdered.end())); |
|
std::sort(hyposOrdered.begin(), hyposOrdered.end(), ChartHypothesisScoreOrderer()); |
|
|
|
|
|
std::vector<ChartHypothesis*>::iterator iter; |
|
for (iter = hyposOrdered.begin() + (m_maxHypoStackSize * 2); iter != hyposOrdered.end(); ++iter) { |
|
ChartHypothesis *hypo = *iter; |
|
HCType::iterator iterFindHypo = m_hypos.find(hypo); |
|
UTIL_THROW_IF2(iterFindHypo == m_hypos.end(), |
|
"Adding a hypothesis should have returned a valid iterator"); |
|
|
|
Remove(iterFindHypo); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
void ChartHypothesisCollection::SortHypotheses() |
|
{ |
|
UTIL_THROW_IF2(!m_hyposOrdered.empty(), "Hypotheses already sorted"); |
|
if (!m_hypos.empty()) { |
|
|
|
|
|
|
|
m_hyposOrdered.reserve(m_hypos.size()); |
|
std::copy(m_hypos.begin(), m_hypos.end(), back_inserter(m_hyposOrdered)); |
|
std::sort(m_hyposOrdered.begin(), m_hyposOrdered.end(), ChartHypothesisScoreOrderer()); |
|
} |
|
} |
|
|
|
|
|
void ChartHypothesisCollection::CleanupArcList() |
|
{ |
|
HCType::iterator iter; |
|
for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) { |
|
ChartHypothesis *mainHypo = *iter; |
|
mainHypo->CleanupArcList(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ChartHypothesisCollection::WriteSearchGraph(const ChartSearchGraphWriter& writer, const std::map<unsigned, bool> &reachable) const |
|
{ |
|
writer.WriteHypos(*this,reachable); |
|
} |
|
|
|
std::ostream& operator<<(std::ostream &out, const ChartHypothesisCollection &coll) |
|
{ |
|
HypoList::const_iterator iterInside; |
|
for (iterInside = coll.m_hyposOrdered.begin(); iterInside != coll.m_hyposOrdered.end(); ++iterInside) { |
|
const ChartHypothesis &hypo = **iterInside; |
|
out << hypo << endl; |
|
} |
|
|
|
return out; |
|
} |
|
|
|
|
|
} |
|
|