|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <deque> |
|
|
|
#include "PhraseDecoder.h" |
|
#include "moses/StaticData.h" |
|
|
|
using namespace std; |
|
|
|
namespace Moses |
|
{ |
|
|
|
PhraseDecoder::PhraseDecoder( |
|
PhraseDictionaryCompact &phraseDictionary, |
|
const std::vector<FactorType>* input, |
|
const std::vector<FactorType>* output, |
|
size_t numScoreComponent |
|
|
|
) |
|
: m_coding(None), m_numScoreComponent(numScoreComponent), |
|
m_containsAlignmentInfo(true), m_maxRank(0), |
|
m_symbolTree(0), m_multipleScoreTrees(false), |
|
m_scoreTrees(1), m_alignTree(0), |
|
m_phraseDictionary(phraseDictionary), m_input(input), m_output(output), |
|
|
|
m_separator(" ||| ") |
|
{ } |
|
|
|
PhraseDecoder::~PhraseDecoder() |
|
{ |
|
if(m_symbolTree) |
|
delete m_symbolTree; |
|
|
|
for(size_t i = 0; i < m_scoreTrees.size(); i++) |
|
if(m_scoreTrees[i]) |
|
delete m_scoreTrees[i]; |
|
|
|
if(m_alignTree) |
|
delete m_alignTree; |
|
} |
|
|
|
inline unsigned PhraseDecoder::GetSourceSymbolId(std::string& symbol) |
|
{ |
|
boost::unordered_map<std::string, unsigned>::iterator it |
|
= m_sourceSymbolsMap.find(symbol); |
|
if(it != m_sourceSymbolsMap.end()) |
|
return it->second; |
|
|
|
size_t idx = m_sourceSymbols.find(symbol); |
|
m_sourceSymbolsMap[symbol] = idx; |
|
return idx; |
|
} |
|
|
|
inline std::string PhraseDecoder::GetTargetSymbol(unsigned idx) const |
|
{ |
|
if(idx < m_targetSymbols.size()) |
|
return m_targetSymbols[idx]; |
|
return std::string("##ERROR##"); |
|
} |
|
|
|
inline size_t PhraseDecoder::GetREncType(unsigned encodedSymbol) |
|
{ |
|
return (encodedSymbol >> 30) + 1; |
|
} |
|
|
|
inline size_t PhraseDecoder::GetPREncType(unsigned encodedSymbol) |
|
{ |
|
return (encodedSymbol >> 31) + 1; |
|
} |
|
|
|
inline unsigned PhraseDecoder::GetTranslation(unsigned srcIdx, size_t rank) |
|
{ |
|
size_t srcTrgIdx = m_lexicalTableIndex[srcIdx]; |
|
return m_lexicalTable[srcTrgIdx + rank].second; |
|
} |
|
|
|
size_t PhraseDecoder::GetMaxSourcePhraseLength() |
|
{ |
|
return m_maxPhraseLength; |
|
} |
|
|
|
inline unsigned PhraseDecoder::DecodeREncSymbol1(unsigned encodedSymbol) |
|
{ |
|
return encodedSymbol &= ~(3 << 30); |
|
} |
|
|
|
inline unsigned PhraseDecoder::DecodeREncSymbol2Rank(unsigned encodedSymbol) |
|
{ |
|
return encodedSymbol &= ~(255 << 24); |
|
} |
|
|
|
inline unsigned PhraseDecoder::DecodeREncSymbol2Position(unsigned encodedSymbol) |
|
{ |
|
encodedSymbol &= ~(3 << 30); |
|
encodedSymbol >>= 24; |
|
return encodedSymbol; |
|
} |
|
|
|
inline unsigned PhraseDecoder::DecodeREncSymbol3(unsigned encodedSymbol) |
|
{ |
|
return encodedSymbol &= ~(3 << 30); |
|
} |
|
|
|
inline unsigned PhraseDecoder::DecodePREncSymbol1(unsigned encodedSymbol) |
|
{ |
|
return encodedSymbol &= ~(1 << 31); |
|
} |
|
|
|
inline int PhraseDecoder::DecodePREncSymbol2Left(unsigned encodedSymbol) |
|
{ |
|
return ((encodedSymbol >> 25) & 63) - 32; |
|
} |
|
|
|
inline int PhraseDecoder::DecodePREncSymbol2Right(unsigned encodedSymbol) |
|
{ |
|
return ((encodedSymbol >> 19) & 63) - 32; |
|
} |
|
|
|
inline unsigned PhraseDecoder::DecodePREncSymbol2Rank(unsigned encodedSymbol) |
|
{ |
|
return (encodedSymbol & 524287); |
|
} |
|
|
|
size_t PhraseDecoder::Load(std::FILE* in) |
|
{ |
|
size_t start = std::ftell(in); |
|
size_t read = 0; |
|
|
|
read += std::fread(&m_coding, sizeof(m_coding), 1, in); |
|
read += std::fread(&m_numScoreComponent, sizeof(m_numScoreComponent), 1, in); |
|
read += std::fread(&m_containsAlignmentInfo, sizeof(m_containsAlignmentInfo), 1, in); |
|
read += std::fread(&m_maxRank, sizeof(m_maxRank), 1, in); |
|
read += std::fread(&m_maxPhraseLength, sizeof(m_maxPhraseLength), 1, in); |
|
|
|
if(m_coding == REnc) { |
|
m_sourceSymbols.load(in); |
|
|
|
size_t size; |
|
read += std::fread(&size, sizeof(size_t), 1, in); |
|
m_lexicalTableIndex.resize(size); |
|
read += std::fread(&m_lexicalTableIndex[0], sizeof(size_t), size, in); |
|
|
|
read += std::fread(&size, sizeof(size_t), 1, in); |
|
m_lexicalTable.resize(size); |
|
read += std::fread(&m_lexicalTable[0], sizeof(SrcTrg), size, in); |
|
} |
|
|
|
m_targetSymbols.load(in); |
|
|
|
m_symbolTree = new CanonicalHuffman<unsigned>(in); |
|
|
|
read += std::fread(&m_multipleScoreTrees, sizeof(m_multipleScoreTrees), 1, in); |
|
if(m_multipleScoreTrees) { |
|
m_scoreTrees.resize(m_numScoreComponent); |
|
for(size_t i = 0; i < m_numScoreComponent; i++) |
|
m_scoreTrees[i] = new CanonicalHuffman<float>(in); |
|
} else { |
|
m_scoreTrees.resize(1); |
|
m_scoreTrees[0] = new CanonicalHuffman<float>(in); |
|
} |
|
|
|
if(m_containsAlignmentInfo) |
|
m_alignTree = new CanonicalHuffman<AlignPoint>(in); |
|
|
|
size_t end = std::ftell(in); |
|
return end - start; |
|
} |
|
|
|
std::string PhraseDecoder::MakeSourceKey(std::string &source) |
|
{ |
|
return source + m_separator; |
|
} |
|
|
|
TargetPhraseVectorPtr PhraseDecoder::CreateTargetPhraseCollection(const Phrase &sourcePhrase, bool topLevel, bool eval) |
|
{ |
|
|
|
|
|
|
|
TargetPhraseVectorPtr tpv(new TargetPhraseVector()); |
|
size_t bitsLeft = 0; |
|
|
|
if(m_coding == PREnc) { |
|
std::pair<TargetPhraseVectorPtr, size_t> cachedPhraseColl |
|
= m_decodingCache.Retrieve(sourcePhrase); |
|
|
|
|
|
if(cachedPhraseColl.first != NULL && (!topLevel || cachedPhraseColl.second == 0)) |
|
return cachedPhraseColl.first; |
|
|
|
|
|
else if(cachedPhraseColl.first != NULL) { |
|
bitsLeft = cachedPhraseColl.second; |
|
tpv->resize(cachedPhraseColl.first->size()); |
|
std::copy(cachedPhraseColl.first->begin(), |
|
cachedPhraseColl.first->end(), |
|
tpv->begin()); |
|
} |
|
} |
|
|
|
|
|
std::string sourcePhraseString = sourcePhrase.GetStringRep(*m_input); |
|
size_t sourcePhraseId = m_phraseDictionary.m_hash[MakeSourceKey(sourcePhraseString)]; |
|
|
|
|
|
|
|
|
|
|
|
if(sourcePhraseId != m_phraseDictionary.m_hash.GetSize()) { |
|
|
|
std::string encodedPhraseCollection; |
|
if(m_phraseDictionary.m_inMemory) |
|
encodedPhraseCollection = m_phraseDictionary.m_targetPhrasesMemory[sourcePhraseId].str(); |
|
else |
|
encodedPhraseCollection = m_phraseDictionary.m_targetPhrasesMapped[sourcePhraseId].str(); |
|
|
|
BitWrapper<> encodedBitStream(encodedPhraseCollection); |
|
if(m_coding == PREnc && bitsLeft) |
|
encodedBitStream.SeekFromEnd(bitsLeft); |
|
|
|
|
|
TargetPhraseVectorPtr decodedPhraseColl = |
|
DecodeCollection(tpv, encodedBitStream, sourcePhrase, topLevel, eval); |
|
|
|
return decodedPhraseColl; |
|
} else |
|
return TargetPhraseVectorPtr(); |
|
} |
|
|
|
TargetPhraseVectorPtr PhraseDecoder::DecodeCollection( |
|
TargetPhraseVectorPtr tpv, BitWrapper<> &encodedBitStream, |
|
const Phrase &sourcePhrase, bool topLevel, bool eval) |
|
{ |
|
|
|
bool extending = tpv->size(); |
|
size_t bitsLeft = encodedBitStream.TellFromEnd(); |
|
|
|
typedef std::pair<size_t, size_t> AlignPointSizeT; |
|
|
|
std::vector<int> sourceWords; |
|
if(m_coding == REnc) { |
|
for(size_t i = 0; i < sourcePhrase.GetSize(); i++) { |
|
std::string sourceWord |
|
= sourcePhrase.GetWord(i).GetString(*m_input, false); |
|
unsigned idx = GetSourceSymbolId(sourceWord); |
|
sourceWords.push_back(idx); |
|
} |
|
} |
|
|
|
unsigned phraseStopSymbol = 0; |
|
AlignPoint alignStopSymbol(-1, -1); |
|
|
|
std::vector<float> scores; |
|
std::set<AlignPointSizeT> alignment; |
|
|
|
enum DecodeState { New, Symbol, Score, Alignment, Add } state = New; |
|
|
|
size_t srcSize = sourcePhrase.GetSize(); |
|
|
|
TargetPhrase* targetPhrase = NULL; |
|
while(encodedBitStream.TellFromEnd()) { |
|
|
|
if(state == New) { |
|
|
|
tpv->push_back(TargetPhrase()); |
|
targetPhrase = &tpv->back(); |
|
|
|
alignment.clear(); |
|
scores.clear(); |
|
|
|
state = Symbol; |
|
} |
|
|
|
if(state == Symbol) { |
|
unsigned symbol = m_symbolTree->Read(encodedBitStream); |
|
if(symbol == phraseStopSymbol) { |
|
state = Score; |
|
} else { |
|
if(m_coding == REnc) { |
|
std::string wordString; |
|
size_t type = GetREncType(symbol); |
|
|
|
if(type == 1) { |
|
unsigned decodedSymbol = DecodeREncSymbol1(symbol); |
|
wordString = GetTargetSymbol(decodedSymbol); |
|
} else if (type == 2) { |
|
size_t rank = DecodeREncSymbol2Rank(symbol); |
|
size_t srcPos = DecodeREncSymbol2Position(symbol); |
|
|
|
if(srcPos >= sourceWords.size()) |
|
return TargetPhraseVectorPtr(); |
|
|
|
wordString = GetTargetSymbol(GetTranslation(sourceWords[srcPos], rank)); |
|
if(m_phraseDictionary.m_useAlignmentInfo) { |
|
size_t trgPos = targetPhrase->GetSize(); |
|
alignment.insert(AlignPoint(srcPos, trgPos)); |
|
} |
|
} else if(type == 3) { |
|
size_t rank = DecodeREncSymbol3(symbol); |
|
size_t srcPos = targetPhrase->GetSize(); |
|
|
|
if(srcPos >= sourceWords.size()) |
|
return TargetPhraseVectorPtr(); |
|
|
|
wordString = GetTargetSymbol(GetTranslation(sourceWords[srcPos], rank)); |
|
if(m_phraseDictionary.m_useAlignmentInfo) { |
|
size_t trgPos = srcPos; |
|
alignment.insert(AlignPoint(srcPos, trgPos)); |
|
} |
|
} |
|
|
|
Word word; |
|
word.CreateFromString(Output, *m_output, wordString, false); |
|
targetPhrase->AddWord(word); |
|
} else if(m_coding == PREnc) { |
|
|
|
if(GetPREncType(symbol) == 1) { |
|
unsigned decodedSymbol = DecodePREncSymbol1(symbol); |
|
|
|
Word word; |
|
word.CreateFromString(Output, *m_output, |
|
GetTargetSymbol(decodedSymbol), false); |
|
targetPhrase->AddWord(word); |
|
} |
|
|
|
else { |
|
int left = DecodePREncSymbol2Left(symbol); |
|
int right = DecodePREncSymbol2Right(symbol); |
|
unsigned rank = DecodePREncSymbol2Rank(symbol); |
|
|
|
int srcStart = left + targetPhrase->GetSize(); |
|
int srcEnd = srcSize - right - 1; |
|
|
|
|
|
if(0 > srcStart || srcStart > srcEnd || unsigned(srcEnd) >= srcSize) |
|
return TargetPhraseVectorPtr(); |
|
|
|
|
|
if(m_maxRank && rank > m_maxRank) |
|
return TargetPhraseVectorPtr(); |
|
|
|
|
|
TargetPhraseVectorPtr subTpv = tpv; |
|
|
|
|
|
if(unsigned(srcEnd - srcStart + 1) != srcSize) { |
|
Phrase subPhrase = sourcePhrase.GetSubString(Range(srcStart, srcEnd)); |
|
subTpv = CreateTargetPhraseCollection(subPhrase, false); |
|
} else { |
|
|
|
if(rank >= tpv->size()-1) |
|
return TargetPhraseVectorPtr(); |
|
} |
|
|
|
|
|
if(subTpv != NULL && rank < subTpv->size()) { |
|
|
|
TargetPhrase& subTp = subTpv->at(rank); |
|
if(m_phraseDictionary.m_useAlignmentInfo) { |
|
|
|
for(AlignmentInfo::const_iterator it = subTp.GetAlignTerm().begin(); |
|
it != subTp.GetAlignTerm().end(); it++) { |
|
alignment.insert(AlignPointSizeT(srcStart + it->first, |
|
targetPhrase->GetSize() + it->second)); |
|
} |
|
} |
|
targetPhrase->Append(subTp); |
|
} else |
|
return TargetPhraseVectorPtr(); |
|
} |
|
} else { |
|
Word word; |
|
word.CreateFromString(Output, *m_output, |
|
GetTargetSymbol(symbol), false); |
|
targetPhrase->AddWord(word); |
|
} |
|
} |
|
} else if(state == Score) { |
|
size_t idx = m_multipleScoreTrees ? scores.size() : 0; |
|
float score = m_scoreTrees[idx]->Read(encodedBitStream); |
|
scores.push_back(score); |
|
|
|
if(scores.size() == m_numScoreComponent) { |
|
targetPhrase->GetScoreBreakdown().Assign(&m_phraseDictionary, scores); |
|
|
|
if(m_containsAlignmentInfo) |
|
state = Alignment; |
|
else |
|
state = Add; |
|
} |
|
} else if(state == Alignment) { |
|
AlignPoint alignPoint = m_alignTree->Read(encodedBitStream); |
|
if(alignPoint == alignStopSymbol) { |
|
state = Add; |
|
} else { |
|
if(m_phraseDictionary.m_useAlignmentInfo) |
|
alignment.insert(AlignPointSizeT(alignPoint)); |
|
} |
|
} |
|
|
|
if(state == Add) { |
|
if(m_phraseDictionary.m_useAlignmentInfo) { |
|
size_t sourceSize = sourcePhrase.GetSize(); |
|
size_t targetSize = targetPhrase->GetSize(); |
|
for(std::set<AlignPointSizeT>::iterator it = alignment.begin(); it != alignment.end(); it++) { |
|
if(it->first >= sourceSize || it->second >= targetSize) |
|
return TargetPhraseVectorPtr(); |
|
} |
|
targetPhrase->SetAlignTerm(alignment); |
|
} |
|
|
|
if(eval) { |
|
targetPhrase->EvaluateInIsolation(sourcePhrase, m_phraseDictionary.GetFeaturesToApply()); |
|
} |
|
|
|
if(m_coding == PREnc) { |
|
if(!m_maxRank || tpv->size() <= m_maxRank) |
|
bitsLeft = encodedBitStream.TellFromEnd(); |
|
|
|
if(!topLevel && m_maxRank && tpv->size() >= m_maxRank) |
|
break; |
|
} |
|
|
|
if(encodedBitStream.TellFromEnd() <= 8) |
|
break; |
|
|
|
state = New; |
|
} |
|
} |
|
|
|
if(m_coding == PREnc && !extending) { |
|
bitsLeft = bitsLeft > 8 ? bitsLeft : 0; |
|
m_decodingCache.Cache(sourcePhrase, tpv, bitsLeft, m_maxRank); |
|
} |
|
|
|
return tpv; |
|
} |
|
|
|
void PhraseDecoder::PruneCache() |
|
{ |
|
m_decodingCache.Prune(); |
|
} |
|
|
|
} |
|
|