|
|
|
#include "ProbingPT.h" |
|
#include "moses/StaticData.h" |
|
#include "moses/FactorCollection.h" |
|
#include "moses/TargetPhraseCollection.h" |
|
#include "moses/InputFileStream.h" |
|
#include "probingpt/querying.h" |
|
#include "probingpt/probing_hash_utils.h" |
|
|
|
using namespace std; |
|
|
|
namespace Moses |
|
{ |
|
ProbingPT::ProbingPT(const std::string &line) |
|
: PhraseDictionary(line,true) |
|
,m_engine(NULL) |
|
,load_method(util::POPULATE_OR_READ) |
|
{ |
|
ReadParameters(); |
|
|
|
assert(m_input.size() == 1); |
|
assert(m_output.size() == 1); |
|
} |
|
|
|
ProbingPT::~ProbingPT() |
|
{ |
|
delete m_engine; |
|
} |
|
|
|
void ProbingPT::Load(AllOptions::ptr const& opts) |
|
{ |
|
m_options = opts; |
|
SetFeaturesToApply(); |
|
|
|
m_engine = new probingpt::QueryEngine(m_filePath.c_str(), load_method); |
|
|
|
m_unkId = 456456546456; |
|
|
|
FactorCollection &vocab = FactorCollection::Instance(); |
|
|
|
|
|
const std::map<uint64_t, std::string> &sourceVocab = |
|
m_engine->getSourceVocab(); |
|
std::map<uint64_t, std::string>::const_iterator iterSource; |
|
for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end(); |
|
++iterSource) { |
|
string wordStr = iterSource->second; |
|
|
|
|
|
const Factor *factor = vocab.AddFactor(wordStr); |
|
|
|
uint64_t probingId = iterSource->first; |
|
size_t factorId = factor->GetId(); |
|
|
|
if (factorId >= m_sourceVocab.size()) { |
|
m_sourceVocab.resize(factorId + 1, m_unkId); |
|
} |
|
m_sourceVocab[factorId] = probingId; |
|
} |
|
|
|
|
|
InputFileStream targetVocabStrme(m_filePath + "/TargetVocab.dat"); |
|
string line; |
|
while (getline(targetVocabStrme, line)) { |
|
vector<string> toks = Tokenize(line, "\t"); |
|
UTIL_THROW_IF2(toks.size() != 2, string("Incorrect format:") + line + "\n"); |
|
|
|
|
|
|
|
const Factor *factor = vocab.AddFactor(toks[0]); |
|
uint32_t probingId = Scan<uint32_t>(toks[1]); |
|
|
|
if (probingId >= m_targetVocab.size()) { |
|
m_targetVocab.resize(probingId + 1); |
|
} |
|
|
|
m_targetVocab[probingId] = factor; |
|
} |
|
|
|
|
|
CreateAlignmentMap(m_filePath + "/Alignments.dat"); |
|
|
|
|
|
string filePath = m_filePath + "/TargetColl.dat"; |
|
file.open(filePath.c_str()); |
|
if (!file.is_open()) { |
|
throw "Couldn't open file "; |
|
} |
|
|
|
data = file.data(); |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
void ProbingPT::CreateAlignmentMap(const std::string path) |
|
{ |
|
const std::vector< std::vector<unsigned char> > &probingAlignColl = m_engine->getAlignments(); |
|
m_aligns.resize(probingAlignColl.size(), NULL); |
|
|
|
for (size_t i = 0; i < probingAlignColl.size(); ++i) { |
|
AlignmentInfo::CollType aligns; |
|
|
|
const std::vector<unsigned char> &probingAligns = probingAlignColl[i]; |
|
for (size_t j = 0; j < probingAligns.size(); j += 2) { |
|
size_t startPos = probingAligns[j]; |
|
size_t endPos = probingAligns[j+1]; |
|
|
|
aligns.insert(std::pair<size_t,size_t>(startPos, endPos)); |
|
} |
|
|
|
const AlignmentInfo *align = AlignmentInfoCollection::Instance().Add(aligns); |
|
m_aligns[i] = align; |
|
|
|
} |
|
} |
|
|
|
void ProbingPT::SetParameter(const std::string& key, const std::string& value) |
|
{ |
|
if (key == "load") { |
|
if (value == "lazy") { |
|
load_method = util::LAZY; |
|
} else if (value == "populate_or_lazy") { |
|
load_method = util::POPULATE_OR_LAZY; |
|
} else if (value == "populate_or_read" || value == "populate") { |
|
load_method = util::POPULATE_OR_READ; |
|
} else if (value == "read") { |
|
load_method = util::READ; |
|
} else if (value == "parallel_read") { |
|
load_method = util::PARALLEL_READ; |
|
} else { |
|
UTIL_THROW2("load method not supported" << value); |
|
} |
|
} else { |
|
PhraseDictionary::SetParameter(key, value); |
|
} |
|
|
|
} |
|
|
|
void ProbingPT::InitializeForInput(ttasksptr const& ttask) |
|
{ |
|
|
|
} |
|
|
|
void ProbingPT::GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const |
|
{ |
|
InputPathList::const_iterator iter; |
|
for (iter = inputPathQueue.begin(); iter != inputPathQueue.end(); ++iter) { |
|
InputPath &inputPath = **iter; |
|
const Phrase &sourcePhrase = inputPath.GetPhrase(); |
|
|
|
if (sourcePhrase.GetSize() > m_options->search.max_phrase_length) { |
|
continue; |
|
} |
|
|
|
TargetPhraseCollection::shared_ptr tpColl = CreateTargetPhrase(sourcePhrase); |
|
inputPath.SetTargetPhrases(*this, tpColl, NULL); |
|
} |
|
} |
|
|
|
TargetPhraseCollection::shared_ptr ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase) const |
|
{ |
|
|
|
assert(sourcePhrase.GetSize()); |
|
|
|
std::pair<bool, uint64_t> keyStruct = GetKey(sourcePhrase); |
|
if (!keyStruct.first) { |
|
return TargetPhraseCollection::shared_ptr(); |
|
} |
|
|
|
|
|
CachePb::const_iterator iter = m_cachePb.find(keyStruct.second); |
|
if (iter != m_cachePb.end()) { |
|
|
|
TargetPhraseCollection *tps = iter->second; |
|
return TargetPhraseCollection::shared_ptr(tps); |
|
} |
|
|
|
|
|
TargetPhraseCollection *tps = CreateTargetPhrases(sourcePhrase, |
|
keyStruct.second); |
|
return TargetPhraseCollection::shared_ptr(tps); |
|
} |
|
|
|
std::pair<bool, uint64_t> ProbingPT::GetKey(const Phrase &sourcePhrase) const |
|
{ |
|
std::pair<bool, uint64_t> ret; |
|
|
|
|
|
size_t sourceSize = sourcePhrase.GetSize(); |
|
assert(sourceSize); |
|
|
|
uint64_t probingSource[sourceSize]; |
|
GetSourceProbingIds(sourcePhrase, ret.first, probingSource); |
|
if (!ret.first) { |
|
|
|
|
|
} else { |
|
ret.second = m_engine->getKey(probingSource, sourceSize); |
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
void ProbingPT::GetSourceProbingIds(const Phrase &sourcePhrase, |
|
bool &ok, uint64_t probingSource[]) const |
|
{ |
|
|
|
size_t size = sourcePhrase.GetSize(); |
|
for (size_t i = 0; i < size; ++i) { |
|
const Word &word = sourcePhrase.GetWord(i); |
|
uint64_t probingId = GetSourceProbingId(word); |
|
if (probingId == m_unkId) { |
|
ok = false; |
|
return; |
|
} else { |
|
probingSource[i] = probingId; |
|
} |
|
} |
|
|
|
ok = true; |
|
} |
|
|
|
uint64_t ProbingPT::GetSourceProbingId(const Word &word) const |
|
{ |
|
uint64_t ret = 0; |
|
|
|
for (size_t i = 0; i < m_input.size(); ++i) { |
|
FactorType factorType = m_input[i]; |
|
const Factor *factor = word[factorType]; |
|
|
|
size_t factorId = factor->GetId(); |
|
if (factorId >= m_sourceVocab.size()) { |
|
return m_unkId; |
|
} |
|
ret += m_sourceVocab[factorId]; |
|
} |
|
|
|
return ret; |
|
} |
|
|
|
TargetPhraseCollection *ProbingPT::CreateTargetPhrases( |
|
const Phrase &sourcePhrase, uint64_t key) const |
|
{ |
|
TargetPhraseCollection *tps = NULL; |
|
|
|
|
|
std::pair<bool, uint64_t> query_result; |
|
query_result = m_engine->query(key); |
|
|
|
|
|
if (query_result.first) { |
|
const char *offset = data + query_result.second; |
|
uint64_t *numTP = (uint64_t*) offset; |
|
|
|
tps = new TargetPhraseCollection(); |
|
|
|
offset += sizeof(uint64_t); |
|
for (size_t i = 0; i < *numTP; ++i) { |
|
TargetPhrase *tp = CreateTargetPhrase(offset); |
|
assert(tp); |
|
tp->EvaluateInIsolation(sourcePhrase, GetFeaturesToApply()); |
|
|
|
tps->Add(tp); |
|
|
|
} |
|
|
|
tps->Prune(true, m_tableLimit); |
|
|
|
} |
|
|
|
return tps; |
|
|
|
} |
|
|
|
TargetPhrase *ProbingPT::CreateTargetPhrase( |
|
const char *&offset) const |
|
{ |
|
probingpt::TargetPhraseInfo *tpInfo = (probingpt::TargetPhraseInfo*) offset; |
|
size_t numRealWords = tpInfo->numWords / m_output.size(); |
|
|
|
TargetPhrase *tp = new TargetPhrase(this); |
|
|
|
offset += sizeof(probingpt::TargetPhraseInfo); |
|
|
|
|
|
float *scores = (float*) offset; |
|
|
|
size_t totalNumScores = m_engine->num_scores + m_engine->num_lex_scores; |
|
|
|
if (m_engine->logProb) { |
|
|
|
tp->GetScoreBreakdown().PlusEquals(this, scores); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} else { |
|
|
|
float logScores[totalNumScores]; |
|
for (size_t i = 0; i < totalNumScores; ++i) { |
|
logScores[i] = FloorScore(TransformScore(scores[i])); |
|
} |
|
|
|
|
|
tp->GetScoreBreakdown().PlusEquals(this, logScores); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
offset += sizeof(float) * totalNumScores; |
|
|
|
|
|
for (size_t targetPos = 0; targetPos < numRealWords; ++targetPos) { |
|
Word &word = tp->AddWord(); |
|
for (size_t i = 0; i < m_output.size(); ++i) { |
|
FactorType factorType = m_output[i]; |
|
|
|
uint32_t *probingId = (uint32_t*) offset; |
|
|
|
const Factor *factor = GetTargetFactor(*probingId); |
|
assert(factor); |
|
|
|
word[factorType] = factor; |
|
|
|
offset += sizeof(uint32_t); |
|
} |
|
} |
|
|
|
|
|
uint32_t alignTerm = tpInfo->alignTerm; |
|
|
|
UTIL_THROW_IF2(alignTerm >= m_aligns.size(), "Unknown alignInd"); |
|
tp->SetAlignTerm(m_aligns[alignTerm]); |
|
|
|
|
|
|
|
return tp; |
|
} |
|
|
|
|
|
|
|
|
|
ChartRuleLookupManager *ProbingPT::CreateRuleLookupManager( |
|
const ChartParser &, |
|
const ChartCellCollectionBase &, |
|
std::size_t) |
|
{ |
|
abort(); |
|
return NULL; |
|
} |
|
|
|
TO_STRING_BODY(ProbingPT); |
|
|
|
|
|
ostream& operator<<(ostream& out, const ProbingPT& phraseDict) |
|
{ |
|
return out; |
|
} |
|
|
|
} |
|
|