|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "PhraseNode.h" |
|
#include "OnDiskWrapper.h" |
|
#include "TargetPhraseCollection.h" |
|
#include "SourcePhrase.h" |
|
#include "moses/Util.h" |
|
#include "util/exception.hh" |
|
|
|
using namespace std; |
|
|
|
namespace OnDiskPt |
|
{ |
|
|
|
size_t PhraseNode::GetNodeSize(size_t numChildren, size_t wordSize, size_t countSize) |
|
{ |
|
size_t ret = sizeof(uint64_t) * 2 |
|
+ (wordSize + sizeof(uint64_t)) * numChildren |
|
+ sizeof(float) * countSize; |
|
return ret; |
|
} |
|
|
|
PhraseNode::PhraseNode() |
|
: m_value(0) |
|
,m_currChild(NULL) |
|
,m_saved(false) |
|
,m_memLoad(NULL) |
|
{ |
|
} |
|
|
|
PhraseNode::PhraseNode(uint64_t filePos, OnDiskWrapper &onDiskWrapper) |
|
:m_counts(onDiskWrapper.GetNumCounts()) |
|
{ |
|
|
|
m_filePos = filePos; |
|
|
|
size_t countSize = onDiskWrapper.GetNumCounts(); |
|
|
|
std::fstream &file = onDiskWrapper.GetFileSource(); |
|
file.seekg(filePos); |
|
assert(filePos == (uint64_t)file.tellg()); |
|
|
|
file.read((char*) &m_numChildrenLoad, sizeof(uint64_t)); |
|
|
|
size_t memAlloc = GetNodeSize(m_numChildrenLoad, onDiskWrapper.GetSourceWordSize(), countSize); |
|
m_memLoad = (char*) malloc(memAlloc); |
|
|
|
|
|
file.seekg(filePos); |
|
assert(filePos == (uint64_t)file.tellg()); |
|
|
|
|
|
file.read(m_memLoad, memAlloc); |
|
assert(filePos + memAlloc == (uint64_t)file.tellg()); |
|
|
|
|
|
m_value = ((uint64_t*)m_memLoad)[1]; |
|
|
|
|
|
float *memFloat = (float*) (m_memLoad + sizeof(uint64_t) * 2); |
|
|
|
assert(countSize == 1); |
|
m_counts[0] = memFloat[0]; |
|
|
|
m_memLoadLast = m_memLoad + memAlloc; |
|
} |
|
|
|
PhraseNode::~PhraseNode() |
|
{ |
|
free(m_memLoad); |
|
} |
|
|
|
float PhraseNode::GetCount(size_t ind) const |
|
{ |
|
return m_counts[ind]; |
|
} |
|
|
|
void PhraseNode::Save(OnDiskWrapper &onDiskWrapper, size_t pos, size_t tableLimit) |
|
{ |
|
UTIL_THROW_IF2(m_saved, "Already saved"); |
|
|
|
|
|
m_targetPhraseColl.Sort(tableLimit); |
|
m_targetPhraseColl.Save(onDiskWrapper); |
|
m_value = m_targetPhraseColl.GetFilePos(); |
|
|
|
size_t numCounts = onDiskWrapper.GetNumCounts(); |
|
|
|
size_t memAlloc = GetNodeSize(GetSize(), onDiskWrapper.GetSourceWordSize(), numCounts); |
|
char *mem = (char*) malloc(memAlloc); |
|
|
|
|
|
size_t memUsed = 0; |
|
uint64_t *memArray = (uint64_t*) mem; |
|
memArray[0] = GetSize(); |
|
memArray[1] = m_value; |
|
memUsed += 2 * sizeof(uint64_t); |
|
|
|
|
|
float *memFloat = (float*) (mem + memUsed); |
|
UTIL_THROW_IF2(numCounts != 1, "Can only store 1 phrase count"); |
|
memFloat[0] = (m_counts.size() == 0) ? DEFAULT_COUNT : m_counts[0]; |
|
memUsed += sizeof(float) * numCounts; |
|
|
|
|
|
ChildColl::iterator iter; |
|
for (iter = m_children.begin(); iter != m_children.end(); ++iter) { |
|
const Word &childWord = iter->first; |
|
PhraseNode &childNode = iter->second; |
|
|
|
|
|
if (!childNode.Saved()) |
|
childNode.Save(onDiskWrapper, pos + 1, tableLimit); |
|
|
|
char *currMem = mem + memUsed; |
|
size_t wordMemUsed = childWord.WriteToMemory(currMem); |
|
memUsed += wordMemUsed; |
|
|
|
uint64_t *memArray = (uint64_t*) (mem + memUsed); |
|
memArray[0] = childNode.GetFilePos(); |
|
memUsed += sizeof(uint64_t); |
|
|
|
} |
|
|
|
|
|
|
|
assert(memUsed == memAlloc); |
|
|
|
std::fstream &file = onDiskWrapper.GetFileSource(); |
|
m_filePos = file.tellp(); |
|
file.seekp(0, ios::end); |
|
file.write(mem, memUsed); |
|
|
|
uint64_t endPos = file.tellp(); |
|
assert(m_filePos + memUsed == endPos); |
|
|
|
free(mem); |
|
|
|
m_children.clear(); |
|
m_saved = true; |
|
} |
|
|
|
void PhraseNode::AddTargetPhrase(const SourcePhrase &sourcePhrase, TargetPhrase *targetPhrase |
|
, OnDiskWrapper &onDiskWrapper, size_t tableLimit |
|
, const std::vector<float> &counts, OnDiskPt::PhrasePtr spShort) |
|
{ |
|
AddTargetPhrase(0, sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, counts, spShort); |
|
} |
|
|
|
void PhraseNode::AddTargetPhrase(size_t pos, const SourcePhrase &sourcePhrase |
|
, TargetPhrase *targetPhrase, OnDiskWrapper &onDiskWrapper |
|
, size_t tableLimit, const std::vector<float> &counts, OnDiskPt::PhrasePtr spShort) |
|
{ |
|
size_t phraseSize = sourcePhrase.GetSize(); |
|
if (pos < phraseSize) { |
|
const Word &word = sourcePhrase.GetWord(pos); |
|
|
|
PhraseNode &node = m_children[word]; |
|
if (m_currChild != &node) { |
|
|
|
node.SetPos(pos); |
|
|
|
if (m_currChild) { |
|
m_currChild->Save(onDiskWrapper, pos, tableLimit); |
|
} |
|
|
|
m_currChild = &node; |
|
} |
|
|
|
|
|
node.AddTargetPhrase(pos + 1, sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, counts, spShort); |
|
} else { |
|
|
|
m_counts = counts; |
|
targetPhrase->SetSourcePhrase(spShort); |
|
m_targetPhraseColl.AddTargetPhrase(targetPhrase); |
|
} |
|
} |
|
|
|
const PhraseNode *PhraseNode::GetChild(const Word &wordSought, OnDiskWrapper &onDiskWrapper) const |
|
{ |
|
const PhraseNode *ret = NULL; |
|
|
|
int l = 0; |
|
int r = m_numChildrenLoad - 1; |
|
int x; |
|
|
|
while (r >= l) { |
|
x = (l + r) / 2; |
|
|
|
Word wordFound; |
|
uint64_t childFilePos; |
|
GetChild(wordFound, childFilePos, x, onDiskWrapper); |
|
|
|
if (wordSought == wordFound) { |
|
ret = new PhraseNode(childFilePos, onDiskWrapper); |
|
break; |
|
} |
|
if (wordSought < wordFound) |
|
r = x - 1; |
|
else |
|
l = x + 1; |
|
} |
|
|
|
return ret; |
|
} |
|
|
|
void PhraseNode::GetChild(Word &wordFound, uint64_t &childFilePos, size_t ind, OnDiskWrapper &onDiskWrapper) const |
|
{ |
|
|
|
size_t wordSize = onDiskWrapper.GetSourceWordSize(); |
|
size_t childSize = wordSize + sizeof(uint64_t); |
|
|
|
char *currMem = m_memLoad |
|
+ sizeof(uint64_t) * 2 |
|
+ sizeof(float) * onDiskWrapper.GetNumCounts() |
|
+ childSize * ind; |
|
|
|
size_t memRead = ReadChild(wordFound, childFilePos, currMem); |
|
assert(memRead == childSize); |
|
} |
|
|
|
size_t PhraseNode::ReadChild(Word &wordFound, uint64_t &childFilePos, const char *mem) const |
|
{ |
|
size_t memRead = wordFound.ReadFromMemory(mem); |
|
|
|
const char *currMem = mem + memRead; |
|
uint64_t *memArray = (uint64_t*) (currMem); |
|
childFilePos = memArray[0]; |
|
|
|
memRead += sizeof(uint64_t); |
|
return memRead; |
|
} |
|
|
|
TargetPhraseCollection::shared_ptr |
|
PhraseNode:: |
|
GetTargetPhraseCollection(size_t tableLimit, OnDiskWrapper &onDiskWrapper) const |
|
{ |
|
TargetPhraseCollection::shared_ptr ret(new TargetPhraseCollection); |
|
if (m_value > 0) ret->ReadFromFile(tableLimit, m_value, onDiskWrapper); |
|
return ret; |
|
} |
|
|
|
std::ostream& operator<<(std::ostream &out, const PhraseNode &node) |
|
{ |
|
out << "node (" << node.GetFilePos() << "," << node.GetValue() << "," << node.m_pos << ")"; |
|
return out; |
|
} |
|
|
|
} |
|
|
|
|