File size: 3,261 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include "lm/common/model_buffer.hh"
#include "util/exception.hh"
#include "util/file_stream.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/stream/io.hh"
#include "util/stream/multi_stream.hh"

#include <boost/lexical_cast.hpp>

namespace lm {

namespace {
const char kMetadataHeader[] = "KenLM intermediate binary file";
} // namespace

ModelBuffer::ModelBuffer(StringPiece file_base, bool keep_buffer, bool output_q)
  : file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer), output_q_(output_q),
    vocab_file_(keep_buffer ? util::CreateOrThrow((file_base_ + ".vocab").c_str()) : util::MakeTemp(file_base_)) {}
  
ModelBuffer::ModelBuffer(StringPiece file_base)
  : file_base_(file_base.data(), file_base.size()), keep_buffer_(false) {
  const std::string full_name = file_base_ + ".kenlm_intermediate";
  util::FilePiece in(full_name.c_str());
  StringPiece token = in.ReadLine();
  UTIL_THROW_IF2(token != kMetadataHeader, "File " << full_name << " begins with \"" << token << "\" not " << kMetadataHeader);

  token = in.ReadDelimited();
  UTIL_THROW_IF2(token != "Counts", "Expected Counts, got \"" << token << "\" in " << full_name);
  char got;
  while ((got = in.get()) == ' ') {
    counts_.push_back(in.ReadULong());
  }
  UTIL_THROW_IF2(got != '\n', "Expected newline at end of counts.");

  token = in.ReadDelimited();
  UTIL_THROW_IF2(token != "Payload", "Expected Payload, got \"" << token << "\" in " << full_name);
  token = in.ReadDelimited();
  if (token == "q") {
    output_q_ = true;
  } else if (token == "pb") {
    output_q_ = false;
  } else {
    UTIL_THROW(util::Exception, "Unknown payload " << token);
  }

  vocab_file_.reset(util::OpenReadOrThrow((file_base_ + ".vocab").c_str()));

  files_.Init(counts_.size());
  for (unsigned long i = 0; i < counts_.size(); ++i) {
    files_.push_back(util::OpenReadOrThrow((file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()));
  }
}

void ModelBuffer::Sink(util::stream::Chains &chains, const std::vector<uint64_t> &counts) {
  counts_ = counts;
  // Open files.
  files_.Init(chains.size());
  for (std::size_t i = 0; i < chains.size(); ++i) {
    if (keep_buffer_) {
      files_.push_back(util::CreateOrThrow(
            (file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()
            ));
    } else {
      files_.push_back(util::MakeTemp(file_base_));
    }
    chains[i] >> util::stream::Write(files_.back().get());
  }
  if (keep_buffer_) {
    util::scoped_fd metadata(util::CreateOrThrow((file_base_ + ".kenlm_intermediate").c_str()));
    util::FileStream meta(metadata.get(), 200);
    meta << kMetadataHeader << "\nCounts";
    for (std::vector<uint64_t>::const_iterator i = counts_.begin(); i != counts_.end(); ++i) {
      meta << ' ' << *i;
    }
    meta << "\nPayload " << (output_q_ ? "q" : "pb") << '\n';
  }
}

void ModelBuffer::Source(util::stream::Chains &chains) {
  assert(chains.size() <= files_.size());
  for (unsigned int i = 0; i < chains.size(); ++i) {
    chains[i] >> util::stream::PRead(files_[i].get());
  }
}

void ModelBuffer::Source(std::size_t order_minus_1, util::stream::Chain &chain) {
  chain >> util::stream::PRead(files_[order_minus_1].get());
}

} // namespace