#include "tfg2p.h" #include TFG2P::TFG2P() { G2P = nullptr; } TFG2P::TFG2P(const std::string &SavedModelFolder) { G2P = nullptr; Initialize(SavedModelFolder); } bool TFG2P::Initialize(const std::string &SavedModelFolder) { try { G2P = new Model(SavedModelFolder); } catch (...) { G2P = nullptr; return false; } return true; } TFTensor TFG2P::DoInference(const std::vector &InputIDs, float Temperature) { if (!G2P) throw std::invalid_argument("Tried to do inference on unloaded or invalid model!"); // Convenience reference so that we don't have to constantly derefer pointers. Model& Mdl = *G2P; // Convenience reference so that we don't have to constantly derefer pointers. Tensor input_ids{ Mdl,"serving_default_input_ids" }; Tensor input_len{Mdl,"serving_default_input_len"}; Tensor input_temp{Mdl,"serving_default_input_temperature"}; input_ids.set_data(InputIDs, std::vector{(int64_t)InputIDs.size()}); input_len.set_data(std::vector{(int32_t)InputIDs.size()}); input_temp.set_data(std::vector{Temperature}); std::vector Inputs {&input_ids,&input_len,&input_temp}; Tensor out_ids{ Mdl,"StatefulPartitionedCall" }; Mdl.run(Inputs, out_ids); TFTensor RetTensor = VoxUtil::CopyTensor(out_ids); return RetTensor; } TFG2P::~TFG2P() { if (G2P) delete G2P; }