// Copyright 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.! #ifndef MODEL_INTERFACE_H_ #define MODEL_INTERFACE_H_ #include #include #include #include #include #include "common.h" #include "normalizer.h" #include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/strings/string_view.h" #include "third_party/darts_clone/darts.h" #include "util.h" namespace sentencepiece { // "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"] std::vector SplitIntoWords( absl::string_view text, bool treat_ws_as_suffix = false, bool allow_ws_only_pieces = false); // Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>"). std::string ByteToPiece(unsigned char c); // Converts piece to byte (e.g., "<0x3A>" -> 58). Returns -1 if `piece` is not // a valid byte piece. int PieceToByte(absl::string_view piece); using EncodeResult = std::vector>; using NBestEncodeResult = std::vector>; class ModelProto; // Underlying model interface. // Given a normalized string, returns a sequence of sentence pieces with ids. class ModelInterface { public: using PieceToIdMap = absl::flat_hash_map; // string_util::string_view_hash>; absl::string_view unk_piece() const; absl::string_view bos_piece() const; absl::string_view eos_piece() const; absl::string_view pad_piece() const; // `model_proto` should not be deleted until ModelInterface is destroyed. explicit ModelInterface(const ModelProto &model_proto); ModelInterface() {} virtual ~ModelInterface(); // Returns Status. // Encode/Decode functions are valid only when status is OK. virtual util::Status status() const { return status_; } virtual const ModelProto &model_proto() const { return *model_proto_; } virtual const normalizer::PrefixMatcher *prefix_matcher() const { return matcher_.get(); } // Given a normalized string, returns a sequence of sentence pieces with ids. // The concatenation of pieces must be the same as `normalized`. virtual EncodeResult Encode(absl::string_view normalized) const = 0; // The same as above, but returns nbest result with score. virtual NBestEncodeResult NBestEncode(absl::string_view normalized, int nbest_size) const { LOG(ERROR) << "Not implemented."; return NBestEncodeResult(); } virtual EncodeResult SampleEncode(absl::string_view normalized, float alpha) const { LOG(ERROR) << "Not implemented."; return EncodeResult(); } // Sample `samples` many tokenisations from the segmentation lattice // If `wor` is true, the samples are taken without replacement, and the scores // are the inclusion probabilities of the elements in the sample; otherwise // the samples are taken with replacement and the scores are the log-probs of // sample elements // If `include_best` is true, the best tokenisation is always included in the // sample, and the remaining elements are sampled excluding the best. virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized, float alpha, int samples, bool wor, bool include_best) const { LOG(ERROR) << "Not implemented."; return {{EncodeResult(), 0.0}}; } // Calculates the entropy of the segmentation lattice with inverse temperature // `alpha`. Uses a novel dynamic program to calculate the entropy. virtual float CalculateEntropy(absl::string_view normalized, float alpha) const { LOG(ERROR) << "Not implemented."; return 0.0; } // Return true if SampleEncode returns a valid result. virtual bool IsSampleEncodeAvailable() const { return false; } // Return true if NBestEncode returns a valid result. virtual bool IsNBestEncodeAvailable() const { return false; } // Return true if SampleEncodeAndScore returns a valid result. virtual bool IsSampleEncodeAndScoreAvailable() const { return false; } // Return true if CalculateEntropy returns a valid result. virtual bool IsCalculateEntropyAvailable() const { return false; } // Returns the vocab id of `piece`. // Returns UNK(0) if `piece` is unknown virtual int PieceToId(absl::string_view piece) const; // Returns the string representation of vocab with `id`. // id must be 0 <= id < GetPieceSize(). virtual const std::string &IdToPiece(int id) const { return model_proto_->pieces(id).piece(); } // Returns the size of sentence pieces, which is the same // as the size of vocabulary for NMT. virtual int GetPieceSize() const { if (!model_proto_) return 0; return model_proto_->pieces_size(); } // Returns the score of `id`. // Score represents a log probability of the piece. // We can roughly estimate the unigram frequency of the piece. virtual float GetScore(int id) const { return model_proto_->pieces(id).score(); } // Returns true if `id` is unknown symbol. virtual bool IsUnknown(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::UNKNOWN); } // Returns true if `id` is control symbol. virtual bool IsControl(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::CONTROL); } // Returns true if `id` is unused symbol. virtual bool IsUnused(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::UNUSED); } // Returns true if `id` is user defined symbol. virtual bool IsUserDefined(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::USER_DEFINED); } // Returns true if `id` is byte symbol. virtual bool IsByte(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::BYTE); } virtual bool ByteFallbackEnabled() const { return model_proto_ && model_proto_->trainer_spec().byte_fallback(); } // Verifies if the `expected` and `actual` outputs are equivalent. `expected` // and `actual` are sentence pieces joined by space (` `). Normally it means // that the two strings are identical. In some model, due to float rounding // errors, the strings may not be identical, but they may be still equivalent // provided their scores are close enough (by some espilon). virtual bool VerifyOutputsEquivalent(absl::string_view expected, absl::string_view actual) const { return expected == actual; } protected: void InitializePieces(); // Non-virtual (inlined) implementation for faster execution. inline float GetScoreInlined(int id) const { return model_proto_->pieces(id).score(); } inline bool IsUnknownInlined(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::UNKNOWN); } inline bool IsControlInlined(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::CONTROL); } inline bool IsUnusedInlined(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::UNUSED); } inline bool IsUserDefinedInlined(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::USER_DEFINED); } inline bool IsByteInlined(int id) const { return (model_proto_->pieces(id).type() == ModelProto::SentencePiece::BYTE); } const ModelProto *model_proto_ = nullptr; // PrefixMatcher for user defined symbols. std::unique_ptr matcher_; // piece -> id map for normal pieces PieceToIdMap pieces_; // piece -> id map for control, unknown, and byte pieces PieceToIdMap reserved_id_map_; // unknown id. int unk_id_ = 0; // status. util::Status status_; }; } // namespace sentencepiece #endif // MODEL_INTERFACE_H_