|
|
|
#pragma once |
|
#include "utils/protomessage.h" |
|
namespace pblczero { |
|
|
|
|
|
class EngineVersion; |
|
class Weights; |
|
class Weights_Layer; |
|
class Weights_ConvBlock; |
|
class Weights_SEunit; |
|
class Weights_Residual; |
|
class Weights_Smolgen; |
|
class Weights_MHA; |
|
class Weights_FFN; |
|
class Weights_EncoderLayer; |
|
class TrainingParams; |
|
class NetworkFormat; |
|
class Format; |
|
class OnnxModel; |
|
class Net; |
|
enum NetworkFormat_InputFormat : int { |
|
NetworkFormat_InputFormat_INPUT_UNKNOWN = 0, |
|
NetworkFormat_InputFormat_INPUT_CLASSICAL_112_PLANE = 1, |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CASTLING_PLANE = 2, |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION = 3, |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_HECTOPLIES = 4, |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON = 132, |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_V2 = 5, |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON = 133, |
|
}; |
|
inline std::string NetworkFormat_InputFormat_Name(NetworkFormat_InputFormat val) { |
|
switch (val) { |
|
case NetworkFormat_InputFormat_INPUT_UNKNOWN: |
|
return "INPUT_UNKNOWN"; |
|
case NetworkFormat_InputFormat_INPUT_CLASSICAL_112_PLANE: |
|
return "INPUT_CLASSICAL_112_PLANE"; |
|
case NetworkFormat_InputFormat_INPUT_112_WITH_CASTLING_PLANE: |
|
return "INPUT_112_WITH_CASTLING_PLANE"; |
|
case NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION: |
|
return "INPUT_112_WITH_CANONICALIZATION"; |
|
case NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_HECTOPLIES: |
|
return "INPUT_112_WITH_CANONICALIZATION_HECTOPLIES"; |
|
case NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON: |
|
return "INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON"; |
|
case NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_V2: |
|
return "INPUT_112_WITH_CANONICALIZATION_V2"; |
|
case NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON: |
|
return "INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON"; |
|
}; |
|
return "InputFormat(" + std::to_string(val) + ")"; |
|
} |
|
enum NetworkFormat_OutputFormat : int { |
|
NetworkFormat_OutputFormat_OUTPUT_UNKNOWN = 0, |
|
NetworkFormat_OutputFormat_OUTPUT_CLASSICAL = 1, |
|
NetworkFormat_OutputFormat_OUTPUT_WDL = 2, |
|
}; |
|
inline std::string NetworkFormat_OutputFormat_Name(NetworkFormat_OutputFormat val) { |
|
switch (val) { |
|
case NetworkFormat_OutputFormat_OUTPUT_UNKNOWN: |
|
return "OUTPUT_UNKNOWN"; |
|
case NetworkFormat_OutputFormat_OUTPUT_CLASSICAL: |
|
return "OUTPUT_CLASSICAL"; |
|
case NetworkFormat_OutputFormat_OUTPUT_WDL: |
|
return "OUTPUT_WDL"; |
|
}; |
|
return "OutputFormat(" + std::to_string(val) + ")"; |
|
} |
|
enum NetworkFormat_NetworkStructure : int { |
|
NetworkFormat_NetworkStructure_NETWORK_UNKNOWN = 0, |
|
NetworkFormat_NetworkStructure_NETWORK_CLASSICAL = 1, |
|
NetworkFormat_NetworkStructure_NETWORK_SE = 2, |
|
NetworkFormat_NetworkStructure_NETWORK_CLASSICAL_WITH_HEADFORMAT = 3, |
|
NetworkFormat_NetworkStructure_NETWORK_SE_WITH_HEADFORMAT = 4, |
|
NetworkFormat_NetworkStructure_NETWORK_ONNX = 5, |
|
NetworkFormat_NetworkStructure_NETWORK_ATTENTIONBODY_WITH_HEADFORMAT = 6, |
|
}; |
|
inline std::string NetworkFormat_NetworkStructure_Name(NetworkFormat_NetworkStructure val) { |
|
switch (val) { |
|
case NetworkFormat_NetworkStructure_NETWORK_UNKNOWN: |
|
return "NETWORK_UNKNOWN"; |
|
case NetworkFormat_NetworkStructure_NETWORK_CLASSICAL: |
|
return "NETWORK_CLASSICAL"; |
|
case NetworkFormat_NetworkStructure_NETWORK_SE: |
|
return "NETWORK_SE"; |
|
case NetworkFormat_NetworkStructure_NETWORK_CLASSICAL_WITH_HEADFORMAT: |
|
return "NETWORK_CLASSICAL_WITH_HEADFORMAT"; |
|
case NetworkFormat_NetworkStructure_NETWORK_SE_WITH_HEADFORMAT: |
|
return "NETWORK_SE_WITH_HEADFORMAT"; |
|
case NetworkFormat_NetworkStructure_NETWORK_ONNX: |
|
return "NETWORK_ONNX"; |
|
case NetworkFormat_NetworkStructure_NETWORK_ATTENTIONBODY_WITH_HEADFORMAT: |
|
return "NETWORK_ATTENTIONBODY_WITH_HEADFORMAT"; |
|
}; |
|
return "NetworkStructure(" + std::to_string(val) + ")"; |
|
} |
|
enum NetworkFormat_PolicyFormat : int { |
|
NetworkFormat_PolicyFormat_POLICY_UNKNOWN = 0, |
|
NetworkFormat_PolicyFormat_POLICY_CLASSICAL = 1, |
|
NetworkFormat_PolicyFormat_POLICY_CONVOLUTION = 2, |
|
NetworkFormat_PolicyFormat_POLICY_ATTENTION = 3, |
|
}; |
|
inline std::string NetworkFormat_PolicyFormat_Name(NetworkFormat_PolicyFormat val) { |
|
switch (val) { |
|
case NetworkFormat_PolicyFormat_POLICY_UNKNOWN: |
|
return "POLICY_UNKNOWN"; |
|
case NetworkFormat_PolicyFormat_POLICY_CLASSICAL: |
|
return "POLICY_CLASSICAL"; |
|
case NetworkFormat_PolicyFormat_POLICY_CONVOLUTION: |
|
return "POLICY_CONVOLUTION"; |
|
case NetworkFormat_PolicyFormat_POLICY_ATTENTION: |
|
return "POLICY_ATTENTION"; |
|
}; |
|
return "PolicyFormat(" + std::to_string(val) + ")"; |
|
} |
|
enum NetworkFormat_ValueFormat : int { |
|
NetworkFormat_ValueFormat_VALUE_UNKNOWN = 0, |
|
NetworkFormat_ValueFormat_VALUE_CLASSICAL = 1, |
|
NetworkFormat_ValueFormat_VALUE_WDL = 2, |
|
NetworkFormat_ValueFormat_VALUE_PARAM = 3, |
|
}; |
|
inline std::string NetworkFormat_ValueFormat_Name(NetworkFormat_ValueFormat val) { |
|
switch (val) { |
|
case NetworkFormat_ValueFormat_VALUE_UNKNOWN: |
|
return "VALUE_UNKNOWN"; |
|
case NetworkFormat_ValueFormat_VALUE_CLASSICAL: |
|
return "VALUE_CLASSICAL"; |
|
case NetworkFormat_ValueFormat_VALUE_WDL: |
|
return "VALUE_WDL"; |
|
case NetworkFormat_ValueFormat_VALUE_PARAM: |
|
return "VALUE_PARAM"; |
|
}; |
|
return "ValueFormat(" + std::to_string(val) + ")"; |
|
} |
|
enum NetworkFormat_MovesLeftFormat : int { |
|
NetworkFormat_MovesLeftFormat_MOVES_LEFT_NONE = 0, |
|
NetworkFormat_MovesLeftFormat_MOVES_LEFT_V1 = 1, |
|
}; |
|
inline std::string NetworkFormat_MovesLeftFormat_Name(NetworkFormat_MovesLeftFormat val) { |
|
switch (val) { |
|
case NetworkFormat_MovesLeftFormat_MOVES_LEFT_NONE: |
|
return "MOVES_LEFT_NONE"; |
|
case NetworkFormat_MovesLeftFormat_MOVES_LEFT_V1: |
|
return "MOVES_LEFT_V1"; |
|
}; |
|
return "MovesLeftFormat(" + std::to_string(val) + ")"; |
|
} |
|
enum NetworkFormat_ActivationFunction : int { |
|
NetworkFormat_ActivationFunction_ACTIVATION_DEFAULT = 0, |
|
NetworkFormat_ActivationFunction_ACTIVATION_MISH = 1, |
|
NetworkFormat_ActivationFunction_ACTIVATION_RELU = 2, |
|
NetworkFormat_ActivationFunction_ACTIVATION_NONE = 3, |
|
NetworkFormat_ActivationFunction_ACTIVATION_TANH = 4, |
|
NetworkFormat_ActivationFunction_ACTIVATION_SIGMOID = 5, |
|
NetworkFormat_ActivationFunction_ACTIVATION_SELU = 6, |
|
NetworkFormat_ActivationFunction_ACTIVATION_SWISH = 7, |
|
NetworkFormat_ActivationFunction_ACTIVATION_RELU_2 = 8, |
|
NetworkFormat_ActivationFunction_ACTIVATION_SOFTMAX = 9, |
|
}; |
|
inline std::string NetworkFormat_ActivationFunction_Name(NetworkFormat_ActivationFunction val) { |
|
switch (val) { |
|
case NetworkFormat_ActivationFunction_ACTIVATION_DEFAULT: |
|
return "ACTIVATION_DEFAULT"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_MISH: |
|
return "ACTIVATION_MISH"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_RELU: |
|
return "ACTIVATION_RELU"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_NONE: |
|
return "ACTIVATION_NONE"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_TANH: |
|
return "ACTIVATION_TANH"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_SIGMOID: |
|
return "ACTIVATION_SIGMOID"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_SELU: |
|
return "ACTIVATION_SELU"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_SWISH: |
|
return "ACTIVATION_SWISH"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_RELU_2: |
|
return "ACTIVATION_RELU_2"; |
|
case NetworkFormat_ActivationFunction_ACTIVATION_SOFTMAX: |
|
return "ACTIVATION_SOFTMAX"; |
|
}; |
|
return "ActivationFunction(" + std::to_string(val) + ")"; |
|
} |
|
enum NetworkFormat_DefaultActivation : int { |
|
NetworkFormat_DefaultActivation_DEFAULT_ACTIVATION_RELU = 0, |
|
NetworkFormat_DefaultActivation_DEFAULT_ACTIVATION_MISH = 1, |
|
}; |
|
inline std::string NetworkFormat_DefaultActivation_Name(NetworkFormat_DefaultActivation val) { |
|
switch (val) { |
|
case NetworkFormat_DefaultActivation_DEFAULT_ACTIVATION_RELU: |
|
return "DEFAULT_ACTIVATION_RELU"; |
|
case NetworkFormat_DefaultActivation_DEFAULT_ACTIVATION_MISH: |
|
return "DEFAULT_ACTIVATION_MISH"; |
|
}; |
|
return "DefaultActivation(" + std::to_string(val) + ")"; |
|
} |
|
enum Format_Encoding : int { |
|
Format_Encoding_UNKNOWN = 0, |
|
Format_Encoding_LINEAR16 = 1, |
|
}; |
|
inline std::string Format_Encoding_Name(Format_Encoding val) { |
|
switch (val) { |
|
case Format_Encoding_UNKNOWN: |
|
return "UNKNOWN"; |
|
case Format_Encoding_LINEAR16: |
|
return "LINEAR16"; |
|
}; |
|
return "Encoding(" + std::to_string(val) + ")"; |
|
} |
|
enum OnnxModel_DataType : int { |
|
OnnxModel_DataType_UNKNOWN_DATATYPE = 0, |
|
OnnxModel_DataType_FLOAT = 1, |
|
OnnxModel_DataType_FLOAT16 = 10, |
|
OnnxModel_DataType_BFLOAT16 = 16, |
|
}; |
|
inline std::string OnnxModel_DataType_Name(OnnxModel_DataType val) { |
|
switch (val) { |
|
case OnnxModel_DataType_UNKNOWN_DATATYPE: |
|
return "UNKNOWN_DATATYPE"; |
|
case OnnxModel_DataType_FLOAT: |
|
return "FLOAT"; |
|
case OnnxModel_DataType_FLOAT16: |
|
return "FLOAT16"; |
|
case OnnxModel_DataType_BFLOAT16: |
|
return "BFLOAT16"; |
|
}; |
|
return "DataType(" + std::to_string(val) + ")"; |
|
} |
|
|
|
|
|
class EngineVersion final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_major() const; |
|
std::uint32_t major() const; |
|
void set_major(std::uint32_t val); |
|
|
|
bool has_minor() const; |
|
std::uint32_t minor() const; |
|
void set_minor(std::uint32_t val); |
|
|
|
bool has_patch() const; |
|
std::uint32_t patch() const; |
|
void set_patch(std::uint32_t val); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetVarInt(int field_id, std::uint64_t val) final; |
|
|
|
bool has_major_{}; |
|
std::uint32_t major_{}; |
|
bool has_minor_{}; |
|
std::uint32_t minor_{}; |
|
bool has_patch_{}; |
|
std::uint32_t patch_{}; |
|
}; |
|
|
|
class Weights_Layer final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_min_val() const; |
|
float min_val() const; |
|
void set_min_val(float val); |
|
|
|
bool has_max_val() const; |
|
float max_val() const; |
|
void set_max_val(float val); |
|
|
|
bool has_params() const; |
|
std::string_view params() const; |
|
void set_params(std::string_view val); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetInt32(int field_id, std::uint32_t val) final; |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_min_val_{}; |
|
float min_val_{}; |
|
bool has_max_val_{}; |
|
float max_val_{}; |
|
bool has_params_{}; |
|
std::string params_{}; |
|
}; |
|
|
|
class Weights_ConvBlock final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_weights() const; |
|
const Weights_Layer& weights() const; |
|
Weights_Layer* mutable_weights(); |
|
|
|
bool has_biases() const; |
|
const Weights_Layer& biases() const; |
|
Weights_Layer* mutable_biases(); |
|
|
|
bool has_bn_means() const; |
|
const Weights_Layer& bn_means() const; |
|
Weights_Layer* mutable_bn_means(); |
|
|
|
bool has_bn_stddivs() const; |
|
const Weights_Layer& bn_stddivs() const; |
|
Weights_Layer* mutable_bn_stddivs(); |
|
|
|
bool has_bn_gammas() const; |
|
const Weights_Layer& bn_gammas() const; |
|
Weights_Layer* mutable_bn_gammas(); |
|
|
|
bool has_bn_betas() const; |
|
const Weights_Layer& bn_betas() const; |
|
Weights_Layer* mutable_bn_betas(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_weights_{}; |
|
Weights_Layer weights_{}; |
|
bool has_biases_{}; |
|
Weights_Layer biases_{}; |
|
bool has_bn_means_{}; |
|
Weights_Layer bn_means_{}; |
|
bool has_bn_stddivs_{}; |
|
Weights_Layer bn_stddivs_{}; |
|
bool has_bn_gammas_{}; |
|
Weights_Layer bn_gammas_{}; |
|
bool has_bn_betas_{}; |
|
Weights_Layer bn_betas_{}; |
|
}; |
|
|
|
class Weights_SEunit final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_w1() const; |
|
const Weights_Layer& w1() const; |
|
Weights_Layer* mutable_w1(); |
|
|
|
bool has_b1() const; |
|
const Weights_Layer& b1() const; |
|
Weights_Layer* mutable_b1(); |
|
|
|
bool has_w2() const; |
|
const Weights_Layer& w2() const; |
|
Weights_Layer* mutable_w2(); |
|
|
|
bool has_b2() const; |
|
const Weights_Layer& b2() const; |
|
Weights_Layer* mutable_b2(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_w1_{}; |
|
Weights_Layer w1_{}; |
|
bool has_b1_{}; |
|
Weights_Layer b1_{}; |
|
bool has_w2_{}; |
|
Weights_Layer w2_{}; |
|
bool has_b2_{}; |
|
Weights_Layer b2_{}; |
|
}; |
|
|
|
class Weights_Residual final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_conv1() const; |
|
const Weights_ConvBlock& conv1() const; |
|
Weights_ConvBlock* mutable_conv1(); |
|
|
|
bool has_conv2() const; |
|
const Weights_ConvBlock& conv2() const; |
|
Weights_ConvBlock* mutable_conv2(); |
|
|
|
bool has_se() const; |
|
const Weights_SEunit& se() const; |
|
Weights_SEunit* mutable_se(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_conv1_{}; |
|
Weights_ConvBlock conv1_{}; |
|
bool has_conv2_{}; |
|
Weights_ConvBlock conv2_{}; |
|
bool has_se_{}; |
|
Weights_SEunit se_{}; |
|
}; |
|
|
|
class Weights_Smolgen final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_compress() const; |
|
const Weights_Layer& compress() const; |
|
Weights_Layer* mutable_compress(); |
|
|
|
bool has_dense1_w() const; |
|
const Weights_Layer& dense1_w() const; |
|
Weights_Layer* mutable_dense1_w(); |
|
|
|
bool has_dense1_b() const; |
|
const Weights_Layer& dense1_b() const; |
|
Weights_Layer* mutable_dense1_b(); |
|
|
|
bool has_ln1_gammas() const; |
|
const Weights_Layer& ln1_gammas() const; |
|
Weights_Layer* mutable_ln1_gammas(); |
|
|
|
bool has_ln1_betas() const; |
|
const Weights_Layer& ln1_betas() const; |
|
Weights_Layer* mutable_ln1_betas(); |
|
|
|
bool has_dense2_w() const; |
|
const Weights_Layer& dense2_w() const; |
|
Weights_Layer* mutable_dense2_w(); |
|
|
|
bool has_dense2_b() const; |
|
const Weights_Layer& dense2_b() const; |
|
Weights_Layer* mutable_dense2_b(); |
|
|
|
bool has_ln2_gammas() const; |
|
const Weights_Layer& ln2_gammas() const; |
|
Weights_Layer* mutable_ln2_gammas(); |
|
|
|
bool has_ln2_betas() const; |
|
const Weights_Layer& ln2_betas() const; |
|
Weights_Layer* mutable_ln2_betas(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_compress_{}; |
|
Weights_Layer compress_{}; |
|
bool has_dense1_w_{}; |
|
Weights_Layer dense1_w_{}; |
|
bool has_dense1_b_{}; |
|
Weights_Layer dense1_b_{}; |
|
bool has_ln1_gammas_{}; |
|
Weights_Layer ln1_gammas_{}; |
|
bool has_ln1_betas_{}; |
|
Weights_Layer ln1_betas_{}; |
|
bool has_dense2_w_{}; |
|
Weights_Layer dense2_w_{}; |
|
bool has_dense2_b_{}; |
|
Weights_Layer dense2_b_{}; |
|
bool has_ln2_gammas_{}; |
|
Weights_Layer ln2_gammas_{}; |
|
bool has_ln2_betas_{}; |
|
Weights_Layer ln2_betas_{}; |
|
}; |
|
|
|
class Weights_MHA final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_q_w() const; |
|
const Weights_Layer& q_w() const; |
|
Weights_Layer* mutable_q_w(); |
|
|
|
bool has_q_b() const; |
|
const Weights_Layer& q_b() const; |
|
Weights_Layer* mutable_q_b(); |
|
|
|
bool has_k_w() const; |
|
const Weights_Layer& k_w() const; |
|
Weights_Layer* mutable_k_w(); |
|
|
|
bool has_k_b() const; |
|
const Weights_Layer& k_b() const; |
|
Weights_Layer* mutable_k_b(); |
|
|
|
bool has_v_w() const; |
|
const Weights_Layer& v_w() const; |
|
Weights_Layer* mutable_v_w(); |
|
|
|
bool has_v_b() const; |
|
const Weights_Layer& v_b() const; |
|
Weights_Layer* mutable_v_b(); |
|
|
|
bool has_dense_w() const; |
|
const Weights_Layer& dense_w() const; |
|
Weights_Layer* mutable_dense_w(); |
|
|
|
bool has_dense_b() const; |
|
const Weights_Layer& dense_b() const; |
|
Weights_Layer* mutable_dense_b(); |
|
|
|
bool has_smolgen() const; |
|
const Weights_Smolgen& smolgen() const; |
|
Weights_Smolgen* mutable_smolgen(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_q_w_{}; |
|
Weights_Layer q_w_{}; |
|
bool has_q_b_{}; |
|
Weights_Layer q_b_{}; |
|
bool has_k_w_{}; |
|
Weights_Layer k_w_{}; |
|
bool has_k_b_{}; |
|
Weights_Layer k_b_{}; |
|
bool has_v_w_{}; |
|
Weights_Layer v_w_{}; |
|
bool has_v_b_{}; |
|
Weights_Layer v_b_{}; |
|
bool has_dense_w_{}; |
|
Weights_Layer dense_w_{}; |
|
bool has_dense_b_{}; |
|
Weights_Layer dense_b_{}; |
|
bool has_smolgen_{}; |
|
Weights_Smolgen smolgen_{}; |
|
}; |
|
|
|
class Weights_FFN final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_dense1_w() const; |
|
const Weights_Layer& dense1_w() const; |
|
Weights_Layer* mutable_dense1_w(); |
|
|
|
bool has_dense1_b() const; |
|
const Weights_Layer& dense1_b() const; |
|
Weights_Layer* mutable_dense1_b(); |
|
|
|
bool has_dense2_w() const; |
|
const Weights_Layer& dense2_w() const; |
|
Weights_Layer* mutable_dense2_w(); |
|
|
|
bool has_dense2_b() const; |
|
const Weights_Layer& dense2_b() const; |
|
Weights_Layer* mutable_dense2_b(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_dense1_w_{}; |
|
Weights_Layer dense1_w_{}; |
|
bool has_dense1_b_{}; |
|
Weights_Layer dense1_b_{}; |
|
bool has_dense2_w_{}; |
|
Weights_Layer dense2_w_{}; |
|
bool has_dense2_b_{}; |
|
Weights_Layer dense2_b_{}; |
|
}; |
|
|
|
class Weights_EncoderLayer final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_mha() const; |
|
const Weights_MHA& mha() const; |
|
Weights_MHA* mutable_mha(); |
|
|
|
bool has_ln1_gammas() const; |
|
const Weights_Layer& ln1_gammas() const; |
|
Weights_Layer* mutable_ln1_gammas(); |
|
|
|
bool has_ln1_betas() const; |
|
const Weights_Layer& ln1_betas() const; |
|
Weights_Layer* mutable_ln1_betas(); |
|
|
|
bool has_ffn() const; |
|
const Weights_FFN& ffn() const; |
|
Weights_FFN* mutable_ffn(); |
|
|
|
bool has_ln2_gammas() const; |
|
const Weights_Layer& ln2_gammas() const; |
|
Weights_Layer* mutable_ln2_gammas(); |
|
|
|
bool has_ln2_betas() const; |
|
const Weights_Layer& ln2_betas() const; |
|
Weights_Layer* mutable_ln2_betas(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_mha_{}; |
|
Weights_MHA mha_{}; |
|
bool has_ln1_gammas_{}; |
|
Weights_Layer ln1_gammas_{}; |
|
bool has_ln1_betas_{}; |
|
Weights_Layer ln1_betas_{}; |
|
bool has_ffn_{}; |
|
Weights_FFN ffn_{}; |
|
bool has_ln2_gammas_{}; |
|
Weights_Layer ln2_gammas_{}; |
|
bool has_ln2_betas_{}; |
|
Weights_Layer ln2_betas_{}; |
|
}; |
|
|
|
class Weights final : public lczero::ProtoMessage { |
|
public: |
|
using Layer = Weights_Layer; |
|
using ConvBlock = Weights_ConvBlock; |
|
using SEunit = Weights_SEunit; |
|
using Residual = Weights_Residual; |
|
using Smolgen = Weights_Smolgen; |
|
using MHA = Weights_MHA; |
|
using FFN = Weights_FFN; |
|
using EncoderLayer = Weights_EncoderLayer; |
|
|
|
bool has_input() const; |
|
const Weights_ConvBlock& input() const; |
|
Weights_ConvBlock* mutable_input(); |
|
|
|
Weights_Residual* add_residual(); |
|
const std::vector<Weights_Residual>& residual() const; |
|
std::vector<Weights_Residual>* mutable_residual(); |
|
const Weights_Residual& residual(size_t idx) const; |
|
Weights_Residual* mutable_residual(size_t idx); |
|
size_t residual_size() const; |
|
|
|
bool has_ip_emb_w() const; |
|
const Weights_Layer& ip_emb_w() const; |
|
Weights_Layer* mutable_ip_emb_w(); |
|
|
|
bool has_ip_emb_b() const; |
|
const Weights_Layer& ip_emb_b() const; |
|
Weights_Layer* mutable_ip_emb_b(); |
|
|
|
bool has_ip_mult_gate() const; |
|
const Weights_Layer& ip_mult_gate() const; |
|
Weights_Layer* mutable_ip_mult_gate(); |
|
|
|
bool has_ip_add_gate() const; |
|
const Weights_Layer& ip_add_gate() const; |
|
Weights_Layer* mutable_ip_add_gate(); |
|
|
|
Weights_EncoderLayer* add_encoder(); |
|
const std::vector<Weights_EncoderLayer>& encoder() const; |
|
std::vector<Weights_EncoderLayer>* mutable_encoder(); |
|
const Weights_EncoderLayer& encoder(size_t idx) const; |
|
Weights_EncoderLayer* mutable_encoder(size_t idx); |
|
size_t encoder_size() const; |
|
|
|
bool has_headcount() const; |
|
std::uint32_t headcount() const; |
|
void set_headcount(std::uint32_t val); |
|
|
|
Weights_EncoderLayer* add_pol_encoder(); |
|
const std::vector<Weights_EncoderLayer>& pol_encoder() const; |
|
std::vector<Weights_EncoderLayer>* mutable_pol_encoder(); |
|
const Weights_EncoderLayer& pol_encoder(size_t idx) const; |
|
Weights_EncoderLayer* mutable_pol_encoder(size_t idx); |
|
size_t pol_encoder_size() const; |
|
|
|
bool has_pol_headcount() const; |
|
std::uint32_t pol_headcount() const; |
|
void set_pol_headcount(std::uint32_t val); |
|
|
|
bool has_policy1() const; |
|
const Weights_ConvBlock& policy1() const; |
|
Weights_ConvBlock* mutable_policy1(); |
|
|
|
bool has_policy() const; |
|
const Weights_ConvBlock& policy() const; |
|
Weights_ConvBlock* mutable_policy(); |
|
|
|
bool has_ip_pol_w() const; |
|
const Weights_Layer& ip_pol_w() const; |
|
Weights_Layer* mutable_ip_pol_w(); |
|
|
|
bool has_ip_pol_b() const; |
|
const Weights_Layer& ip_pol_b() const; |
|
Weights_Layer* mutable_ip_pol_b(); |
|
|
|
bool has_ip2_pol_w() const; |
|
const Weights_Layer& ip2_pol_w() const; |
|
Weights_Layer* mutable_ip2_pol_w(); |
|
|
|
bool has_ip2_pol_b() const; |
|
const Weights_Layer& ip2_pol_b() const; |
|
Weights_Layer* mutable_ip2_pol_b(); |
|
|
|
bool has_ip3_pol_w() const; |
|
const Weights_Layer& ip3_pol_w() const; |
|
Weights_Layer* mutable_ip3_pol_w(); |
|
|
|
bool has_ip3_pol_b() const; |
|
const Weights_Layer& ip3_pol_b() const; |
|
Weights_Layer* mutable_ip3_pol_b(); |
|
|
|
bool has_ip4_pol_w() const; |
|
const Weights_Layer& ip4_pol_w() const; |
|
Weights_Layer* mutable_ip4_pol_w(); |
|
|
|
bool has_value() const; |
|
const Weights_ConvBlock& value() const; |
|
Weights_ConvBlock* mutable_value(); |
|
|
|
bool has_ip_val_w() const; |
|
const Weights_Layer& ip_val_w() const; |
|
Weights_Layer* mutable_ip_val_w(); |
|
|
|
bool has_ip_val_b() const; |
|
const Weights_Layer& ip_val_b() const; |
|
Weights_Layer* mutable_ip_val_b(); |
|
|
|
bool has_ip1_val_w() const; |
|
const Weights_Layer& ip1_val_w() const; |
|
Weights_Layer* mutable_ip1_val_w(); |
|
|
|
bool has_ip1_val_b() const; |
|
const Weights_Layer& ip1_val_b() const; |
|
Weights_Layer* mutable_ip1_val_b(); |
|
|
|
bool has_ip2_val_w() const; |
|
const Weights_Layer& ip2_val_w() const; |
|
Weights_Layer* mutable_ip2_val_w(); |
|
|
|
bool has_ip2_val_b() const; |
|
const Weights_Layer& ip2_val_b() const; |
|
Weights_Layer* mutable_ip2_val_b(); |
|
|
|
bool has_moves_left() const; |
|
const Weights_ConvBlock& moves_left() const; |
|
Weights_ConvBlock* mutable_moves_left(); |
|
|
|
bool has_ip_mov_w() const; |
|
const Weights_Layer& ip_mov_w() const; |
|
Weights_Layer* mutable_ip_mov_w(); |
|
|
|
bool has_ip_mov_b() const; |
|
const Weights_Layer& ip_mov_b() const; |
|
Weights_Layer* mutable_ip_mov_b(); |
|
|
|
bool has_ip1_mov_w() const; |
|
const Weights_Layer& ip1_mov_w() const; |
|
Weights_Layer* mutable_ip1_mov_w(); |
|
|
|
bool has_ip1_mov_b() const; |
|
const Weights_Layer& ip1_mov_b() const; |
|
Weights_Layer* mutable_ip1_mov_b(); |
|
|
|
bool has_ip2_mov_w() const; |
|
const Weights_Layer& ip2_mov_w() const; |
|
Weights_Layer* mutable_ip2_mov_w(); |
|
|
|
bool has_ip2_mov_b() const; |
|
const Weights_Layer& ip2_mov_b() const; |
|
Weights_Layer* mutable_ip2_mov_b(); |
|
|
|
bool has_smolgen_w() const; |
|
const Weights_Layer& smolgen_w() const; |
|
Weights_Layer* mutable_smolgen_w(); |
|
|
|
bool has_smolgen_b() const; |
|
const Weights_Layer& smolgen_b() const; |
|
Weights_Layer* mutable_smolgen_b(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
void SetVarInt(int field_id, std::uint64_t val) final; |
|
|
|
bool has_input_{}; |
|
Weights_ConvBlock input_{}; |
|
std::vector<Weights_Residual> residual_; |
|
bool has_ip_emb_w_{}; |
|
Weights_Layer ip_emb_w_{}; |
|
bool has_ip_emb_b_{}; |
|
Weights_Layer ip_emb_b_{}; |
|
bool has_ip_mult_gate_{}; |
|
Weights_Layer ip_mult_gate_{}; |
|
bool has_ip_add_gate_{}; |
|
Weights_Layer ip_add_gate_{}; |
|
std::vector<Weights_EncoderLayer> encoder_; |
|
bool has_headcount_{}; |
|
std::uint32_t headcount_{}; |
|
std::vector<Weights_EncoderLayer> pol_encoder_; |
|
bool has_pol_headcount_{}; |
|
std::uint32_t pol_headcount_{}; |
|
bool has_policy1_{}; |
|
Weights_ConvBlock policy1_{}; |
|
bool has_policy_{}; |
|
Weights_ConvBlock policy_{}; |
|
bool has_ip_pol_w_{}; |
|
Weights_Layer ip_pol_w_{}; |
|
bool has_ip_pol_b_{}; |
|
Weights_Layer ip_pol_b_{}; |
|
bool has_ip2_pol_w_{}; |
|
Weights_Layer ip2_pol_w_{}; |
|
bool has_ip2_pol_b_{}; |
|
Weights_Layer ip2_pol_b_{}; |
|
bool has_ip3_pol_w_{}; |
|
Weights_Layer ip3_pol_w_{}; |
|
bool has_ip3_pol_b_{}; |
|
Weights_Layer ip3_pol_b_{}; |
|
bool has_ip4_pol_w_{}; |
|
Weights_Layer ip4_pol_w_{}; |
|
bool has_value_{}; |
|
Weights_ConvBlock value_{}; |
|
bool has_ip_val_w_{}; |
|
Weights_Layer ip_val_w_{}; |
|
bool has_ip_val_b_{}; |
|
Weights_Layer ip_val_b_{}; |
|
bool has_ip1_val_w_{}; |
|
Weights_Layer ip1_val_w_{}; |
|
bool has_ip1_val_b_{}; |
|
Weights_Layer ip1_val_b_{}; |
|
bool has_ip2_val_w_{}; |
|
Weights_Layer ip2_val_w_{}; |
|
bool has_ip2_val_b_{}; |
|
Weights_Layer ip2_val_b_{}; |
|
bool has_moves_left_{}; |
|
Weights_ConvBlock moves_left_{}; |
|
bool has_ip_mov_w_{}; |
|
Weights_Layer ip_mov_w_{}; |
|
bool has_ip_mov_b_{}; |
|
Weights_Layer ip_mov_b_{}; |
|
bool has_ip1_mov_w_{}; |
|
Weights_Layer ip1_mov_w_{}; |
|
bool has_ip1_mov_b_{}; |
|
Weights_Layer ip1_mov_b_{}; |
|
bool has_ip2_mov_w_{}; |
|
Weights_Layer ip2_mov_w_{}; |
|
bool has_ip2_mov_b_{}; |
|
Weights_Layer ip2_mov_b_{}; |
|
bool has_smolgen_w_{}; |
|
Weights_Layer smolgen_w_{}; |
|
bool has_smolgen_b_{}; |
|
Weights_Layer smolgen_b_{}; |
|
}; |
|
|
|
class TrainingParams final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_training_steps() const; |
|
std::uint32_t training_steps() const; |
|
void set_training_steps(std::uint32_t val); |
|
|
|
bool has_learning_rate() const; |
|
float learning_rate() const; |
|
void set_learning_rate(float val); |
|
|
|
bool has_mse_loss() const; |
|
float mse_loss() const; |
|
void set_mse_loss(float val); |
|
|
|
bool has_policy_loss() const; |
|
float policy_loss() const; |
|
void set_policy_loss(float val); |
|
|
|
bool has_accuracy() const; |
|
float accuracy() const; |
|
void set_accuracy(float val); |
|
|
|
bool has_lc0_params() const; |
|
std::string_view lc0_params() const; |
|
void set_lc0_params(std::string_view val); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetVarInt(int field_id, std::uint64_t val) final; |
|
void SetInt32(int field_id, std::uint32_t val) final; |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_training_steps_{}; |
|
std::uint32_t training_steps_{}; |
|
bool has_learning_rate_{}; |
|
float learning_rate_{}; |
|
bool has_mse_loss_{}; |
|
float mse_loss_{}; |
|
bool has_policy_loss_{}; |
|
float policy_loss_{}; |
|
bool has_accuracy_{}; |
|
float accuracy_{}; |
|
bool has_lc0_params_{}; |
|
std::string lc0_params_{}; |
|
}; |
|
|
|
class NetworkFormat final : public lczero::ProtoMessage { |
|
public: |
|
using InputFormat = NetworkFormat_InputFormat; |
|
static constexpr InputFormat INPUT_UNKNOWN = |
|
NetworkFormat_InputFormat_INPUT_UNKNOWN; |
|
static constexpr InputFormat INPUT_CLASSICAL_112_PLANE = |
|
NetworkFormat_InputFormat_INPUT_CLASSICAL_112_PLANE; |
|
static constexpr InputFormat INPUT_112_WITH_CASTLING_PLANE = |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CASTLING_PLANE; |
|
static constexpr InputFormat INPUT_112_WITH_CANONICALIZATION = |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION; |
|
static constexpr InputFormat INPUT_112_WITH_CANONICALIZATION_HECTOPLIES = |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_HECTOPLIES; |
|
static constexpr InputFormat INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON = |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON; |
|
static constexpr InputFormat INPUT_112_WITH_CANONICALIZATION_V2 = |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_V2; |
|
static constexpr InputFormat INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON = |
|
NetworkFormat_InputFormat_INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON; |
|
static constexpr std::array<InputFormat,8> InputFormat_AllValues = { |
|
INPUT_UNKNOWN, |
|
INPUT_CLASSICAL_112_PLANE, |
|
INPUT_112_WITH_CASTLING_PLANE, |
|
INPUT_112_WITH_CANONICALIZATION, |
|
INPUT_112_WITH_CANONICALIZATION_HECTOPLIES, |
|
INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON, |
|
INPUT_112_WITH_CANONICALIZATION_V2, |
|
INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON, |
|
}; |
|
static std::string InputFormat_Name(InputFormat val) { |
|
return NetworkFormat_InputFormat_Name(val); |
|
} |
|
using OutputFormat = NetworkFormat_OutputFormat; |
|
static constexpr OutputFormat OUTPUT_UNKNOWN = |
|
NetworkFormat_OutputFormat_OUTPUT_UNKNOWN; |
|
static constexpr OutputFormat OUTPUT_CLASSICAL = |
|
NetworkFormat_OutputFormat_OUTPUT_CLASSICAL; |
|
static constexpr OutputFormat OUTPUT_WDL = |
|
NetworkFormat_OutputFormat_OUTPUT_WDL; |
|
static constexpr std::array<OutputFormat,3> OutputFormat_AllValues = { |
|
OUTPUT_UNKNOWN, |
|
OUTPUT_CLASSICAL, |
|
OUTPUT_WDL, |
|
}; |
|
static std::string OutputFormat_Name(OutputFormat val) { |
|
return NetworkFormat_OutputFormat_Name(val); |
|
} |
|
using NetworkStructure = NetworkFormat_NetworkStructure; |
|
static constexpr NetworkStructure NETWORK_UNKNOWN = |
|
NetworkFormat_NetworkStructure_NETWORK_UNKNOWN; |
|
static constexpr NetworkStructure NETWORK_CLASSICAL = |
|
NetworkFormat_NetworkStructure_NETWORK_CLASSICAL; |
|
static constexpr NetworkStructure NETWORK_SE = |
|
NetworkFormat_NetworkStructure_NETWORK_SE; |
|
static constexpr NetworkStructure NETWORK_CLASSICAL_WITH_HEADFORMAT = |
|
NetworkFormat_NetworkStructure_NETWORK_CLASSICAL_WITH_HEADFORMAT; |
|
static constexpr NetworkStructure NETWORK_SE_WITH_HEADFORMAT = |
|
NetworkFormat_NetworkStructure_NETWORK_SE_WITH_HEADFORMAT; |
|
static constexpr NetworkStructure NETWORK_ONNX = |
|
NetworkFormat_NetworkStructure_NETWORK_ONNX; |
|
static constexpr NetworkStructure NETWORK_ATTENTIONBODY_WITH_HEADFORMAT = |
|
NetworkFormat_NetworkStructure_NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; |
|
static constexpr std::array<NetworkStructure,7> NetworkStructure_AllValues = { |
|
NETWORK_UNKNOWN, |
|
NETWORK_CLASSICAL, |
|
NETWORK_SE, |
|
NETWORK_CLASSICAL_WITH_HEADFORMAT, |
|
NETWORK_SE_WITH_HEADFORMAT, |
|
NETWORK_ONNX, |
|
NETWORK_ATTENTIONBODY_WITH_HEADFORMAT, |
|
}; |
|
static std::string NetworkStructure_Name(NetworkStructure val) { |
|
return NetworkFormat_NetworkStructure_Name(val); |
|
} |
|
using PolicyFormat = NetworkFormat_PolicyFormat; |
|
static constexpr PolicyFormat POLICY_UNKNOWN = |
|
NetworkFormat_PolicyFormat_POLICY_UNKNOWN; |
|
static constexpr PolicyFormat POLICY_CLASSICAL = |
|
NetworkFormat_PolicyFormat_POLICY_CLASSICAL; |
|
static constexpr PolicyFormat POLICY_CONVOLUTION = |
|
NetworkFormat_PolicyFormat_POLICY_CONVOLUTION; |
|
static constexpr PolicyFormat POLICY_ATTENTION = |
|
NetworkFormat_PolicyFormat_POLICY_ATTENTION; |
|
static constexpr std::array<PolicyFormat,4> PolicyFormat_AllValues = { |
|
POLICY_UNKNOWN, |
|
POLICY_CLASSICAL, |
|
POLICY_CONVOLUTION, |
|
POLICY_ATTENTION, |
|
}; |
|
static std::string PolicyFormat_Name(PolicyFormat val) { |
|
return NetworkFormat_PolicyFormat_Name(val); |
|
} |
|
using ValueFormat = NetworkFormat_ValueFormat; |
|
static constexpr ValueFormat VALUE_UNKNOWN = |
|
NetworkFormat_ValueFormat_VALUE_UNKNOWN; |
|
static constexpr ValueFormat VALUE_CLASSICAL = |
|
NetworkFormat_ValueFormat_VALUE_CLASSICAL; |
|
static constexpr ValueFormat VALUE_WDL = |
|
NetworkFormat_ValueFormat_VALUE_WDL; |
|
static constexpr ValueFormat VALUE_PARAM = |
|
NetworkFormat_ValueFormat_VALUE_PARAM; |
|
static constexpr std::array<ValueFormat,4> ValueFormat_AllValues = { |
|
VALUE_UNKNOWN, |
|
VALUE_CLASSICAL, |
|
VALUE_WDL, |
|
VALUE_PARAM, |
|
}; |
|
static std::string ValueFormat_Name(ValueFormat val) { |
|
return NetworkFormat_ValueFormat_Name(val); |
|
} |
|
using MovesLeftFormat = NetworkFormat_MovesLeftFormat; |
|
static constexpr MovesLeftFormat MOVES_LEFT_NONE = |
|
NetworkFormat_MovesLeftFormat_MOVES_LEFT_NONE; |
|
static constexpr MovesLeftFormat MOVES_LEFT_V1 = |
|
NetworkFormat_MovesLeftFormat_MOVES_LEFT_V1; |
|
static constexpr std::array<MovesLeftFormat,2> MovesLeftFormat_AllValues = { |
|
MOVES_LEFT_NONE, |
|
MOVES_LEFT_V1, |
|
}; |
|
static std::string MovesLeftFormat_Name(MovesLeftFormat val) { |
|
return NetworkFormat_MovesLeftFormat_Name(val); |
|
} |
|
using ActivationFunction = NetworkFormat_ActivationFunction; |
|
static constexpr ActivationFunction ACTIVATION_DEFAULT = |
|
NetworkFormat_ActivationFunction_ACTIVATION_DEFAULT; |
|
static constexpr ActivationFunction ACTIVATION_MISH = |
|
NetworkFormat_ActivationFunction_ACTIVATION_MISH; |
|
static constexpr ActivationFunction ACTIVATION_RELU = |
|
NetworkFormat_ActivationFunction_ACTIVATION_RELU; |
|
static constexpr ActivationFunction ACTIVATION_NONE = |
|
NetworkFormat_ActivationFunction_ACTIVATION_NONE; |
|
static constexpr ActivationFunction ACTIVATION_TANH = |
|
NetworkFormat_ActivationFunction_ACTIVATION_TANH; |
|
static constexpr ActivationFunction ACTIVATION_SIGMOID = |
|
NetworkFormat_ActivationFunction_ACTIVATION_SIGMOID; |
|
static constexpr ActivationFunction ACTIVATION_SELU = |
|
NetworkFormat_ActivationFunction_ACTIVATION_SELU; |
|
static constexpr ActivationFunction ACTIVATION_SWISH = |
|
NetworkFormat_ActivationFunction_ACTIVATION_SWISH; |
|
static constexpr ActivationFunction ACTIVATION_RELU_2 = |
|
NetworkFormat_ActivationFunction_ACTIVATION_RELU_2; |
|
static constexpr ActivationFunction ACTIVATION_SOFTMAX = |
|
NetworkFormat_ActivationFunction_ACTIVATION_SOFTMAX; |
|
static constexpr std::array<ActivationFunction,10> ActivationFunction_AllValues = { |
|
ACTIVATION_DEFAULT, |
|
ACTIVATION_MISH, |
|
ACTIVATION_RELU, |
|
ACTIVATION_NONE, |
|
ACTIVATION_TANH, |
|
ACTIVATION_SIGMOID, |
|
ACTIVATION_SELU, |
|
ACTIVATION_SWISH, |
|
ACTIVATION_RELU_2, |
|
ACTIVATION_SOFTMAX, |
|
}; |
|
static std::string ActivationFunction_Name(ActivationFunction val) { |
|
return NetworkFormat_ActivationFunction_Name(val); |
|
} |
|
using DefaultActivation = NetworkFormat_DefaultActivation; |
|
static constexpr DefaultActivation DEFAULT_ACTIVATION_RELU = |
|
NetworkFormat_DefaultActivation_DEFAULT_ACTIVATION_RELU; |
|
static constexpr DefaultActivation DEFAULT_ACTIVATION_MISH = |
|
NetworkFormat_DefaultActivation_DEFAULT_ACTIVATION_MISH; |
|
static constexpr std::array<DefaultActivation,2> DefaultActivation_AllValues = { |
|
DEFAULT_ACTIVATION_RELU, |
|
DEFAULT_ACTIVATION_MISH, |
|
}; |
|
static std::string DefaultActivation_Name(DefaultActivation val) { |
|
return NetworkFormat_DefaultActivation_Name(val); |
|
} |
|
|
|
bool has_input() const; |
|
NetworkFormat_InputFormat input() const; |
|
void set_input(NetworkFormat_InputFormat val); |
|
|
|
bool has_output() const; |
|
NetworkFormat_OutputFormat output() const; |
|
void set_output(NetworkFormat_OutputFormat val); |
|
|
|
bool has_network() const; |
|
NetworkFormat_NetworkStructure network() const; |
|
void set_network(NetworkFormat_NetworkStructure val); |
|
|
|
bool has_policy() const; |
|
NetworkFormat_PolicyFormat policy() const; |
|
void set_policy(NetworkFormat_PolicyFormat val); |
|
|
|
bool has_value() const; |
|
NetworkFormat_ValueFormat value() const; |
|
void set_value(NetworkFormat_ValueFormat val); |
|
|
|
bool has_moves_left() const; |
|
NetworkFormat_MovesLeftFormat moves_left() const; |
|
void set_moves_left(NetworkFormat_MovesLeftFormat val); |
|
|
|
bool has_default_activation() const; |
|
NetworkFormat_DefaultActivation default_activation() const; |
|
void set_default_activation(NetworkFormat_DefaultActivation val); |
|
|
|
bool has_smolgen_activation() const; |
|
NetworkFormat_ActivationFunction smolgen_activation() const; |
|
void set_smolgen_activation(NetworkFormat_ActivationFunction val); |
|
|
|
bool has_ffn_activation() const; |
|
NetworkFormat_ActivationFunction ffn_activation() const; |
|
void set_ffn_activation(NetworkFormat_ActivationFunction val); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetVarInt(int field_id, std::uint64_t val) final; |
|
|
|
bool has_input_{}; |
|
NetworkFormat_InputFormat input_{}; |
|
bool has_output_{}; |
|
NetworkFormat_OutputFormat output_{}; |
|
bool has_network_{}; |
|
NetworkFormat_NetworkStructure network_{}; |
|
bool has_policy_{}; |
|
NetworkFormat_PolicyFormat policy_{}; |
|
bool has_value_{}; |
|
NetworkFormat_ValueFormat value_{}; |
|
bool has_moves_left_{}; |
|
NetworkFormat_MovesLeftFormat moves_left_{}; |
|
bool has_default_activation_{}; |
|
NetworkFormat_DefaultActivation default_activation_{}; |
|
bool has_smolgen_activation_{}; |
|
NetworkFormat_ActivationFunction smolgen_activation_{}; |
|
bool has_ffn_activation_{}; |
|
NetworkFormat_ActivationFunction ffn_activation_{}; |
|
}; |
|
|
|
class Format final : public lczero::ProtoMessage { |
|
public: |
|
using Encoding = Format_Encoding; |
|
static constexpr Encoding UNKNOWN = |
|
Format_Encoding_UNKNOWN; |
|
static constexpr Encoding LINEAR16 = |
|
Format_Encoding_LINEAR16; |
|
static constexpr std::array<Encoding,2> Encoding_AllValues = { |
|
UNKNOWN, |
|
LINEAR16, |
|
}; |
|
static std::string Encoding_Name(Encoding val) { |
|
return Format_Encoding_Name(val); |
|
} |
|
|
|
bool has_weights_encoding() const; |
|
Format_Encoding weights_encoding() const; |
|
void set_weights_encoding(Format_Encoding val); |
|
|
|
bool has_network_format() const; |
|
const NetworkFormat& network_format() const; |
|
NetworkFormat* mutable_network_format(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetVarInt(int field_id, std::uint64_t val) final; |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_weights_encoding_{}; |
|
Format_Encoding weights_encoding_{}; |
|
bool has_network_format_{}; |
|
NetworkFormat network_format_{}; |
|
}; |
|
|
|
class OnnxModel final : public lczero::ProtoMessage { |
|
public: |
|
using DataType = OnnxModel_DataType; |
|
static constexpr DataType UNKNOWN_DATATYPE = |
|
OnnxModel_DataType_UNKNOWN_DATATYPE; |
|
static constexpr DataType FLOAT = |
|
OnnxModel_DataType_FLOAT; |
|
static constexpr DataType FLOAT16 = |
|
OnnxModel_DataType_FLOAT16; |
|
static constexpr DataType BFLOAT16 = |
|
OnnxModel_DataType_BFLOAT16; |
|
static constexpr std::array<DataType,4> DataType_AllValues = { |
|
UNKNOWN_DATATYPE, |
|
FLOAT, |
|
FLOAT16, |
|
BFLOAT16, |
|
}; |
|
static std::string DataType_Name(DataType val) { |
|
return OnnxModel_DataType_Name(val); |
|
} |
|
|
|
bool has_model() const; |
|
std::string_view model() const; |
|
void set_model(std::string_view val); |
|
|
|
bool has_data_type() const; |
|
OnnxModel_DataType data_type() const; |
|
void set_data_type(OnnxModel_DataType val); |
|
|
|
bool has_input_planes() const; |
|
std::string_view input_planes() const; |
|
void set_input_planes(std::string_view val); |
|
|
|
bool has_output_value() const; |
|
std::string_view output_value() const; |
|
void set_output_value(std::string_view val); |
|
|
|
bool has_output_wdl() const; |
|
std::string_view output_wdl() const; |
|
void set_output_wdl(std::string_view val); |
|
|
|
bool has_output_policy() const; |
|
std::string_view output_policy() const; |
|
void set_output_policy(std::string_view val); |
|
|
|
bool has_output_mlh() const; |
|
std::string_view output_mlh() const; |
|
void set_output_mlh(std::string_view val); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetString(int field_id, std::string_view val) final; |
|
void SetVarInt(int field_id, std::uint64_t val) final; |
|
|
|
bool has_model_{}; |
|
std::string model_{}; |
|
bool has_data_type_{}; |
|
OnnxModel_DataType data_type_{}; |
|
bool has_input_planes_{}; |
|
std::string input_planes_{}; |
|
bool has_output_value_{}; |
|
std::string output_value_{}; |
|
bool has_output_wdl_{}; |
|
std::string output_wdl_{}; |
|
bool has_output_policy_{}; |
|
std::string output_policy_{}; |
|
bool has_output_mlh_{}; |
|
std::string output_mlh_{}; |
|
}; |
|
|
|
class Net final : public lczero::ProtoMessage { |
|
public: |
|
|
|
bool has_magic() const; |
|
std::uint32_t magic() const; |
|
void set_magic(std::uint32_t val); |
|
|
|
bool has_license() const; |
|
std::string_view license() const; |
|
void set_license(std::string_view val); |
|
|
|
bool has_min_version() const; |
|
const EngineVersion& min_version() const; |
|
EngineVersion* mutable_min_version(); |
|
|
|
bool has_format() const; |
|
const Format& format() const; |
|
Format* mutable_format(); |
|
|
|
bool has_training_params() const; |
|
const TrainingParams& training_params() const; |
|
TrainingParams* mutable_training_params(); |
|
|
|
bool has_weights() const; |
|
const Weights& weights() const; |
|
Weights* mutable_weights(); |
|
|
|
bool has_onnx_model() const; |
|
const OnnxModel& onnx_model() const; |
|
OnnxModel* mutable_onnx_model(); |
|
|
|
std::string OutputAsString() const final; |
|
std::string OutputAsJson() const final; |
|
void Clear() final; |
|
|
|
private: |
|
void SetInt32(int field_id, std::uint32_t val) final; |
|
void SetString(int field_id, std::string_view val) final; |
|
|
|
bool has_magic_{}; |
|
std::uint32_t magic_{}; |
|
bool has_license_{}; |
|
std::string license_{}; |
|
bool has_min_version_{}; |
|
EngineVersion min_version_{}; |
|
bool has_format_{}; |
|
Format format_{}; |
|
bool has_training_params_{}; |
|
TrainingParams training_params_{}; |
|
bool has_weights_{}; |
|
Weights weights_{}; |
|
bool has_onnx_model_{}; |
|
OnnxModel onnx_model_{}; |
|
}; |
|
|
|
|
|
|
|
inline std::string EngineVersion::OutputAsString() const { |
|
std::string out; |
|
if (has_major_) AppendVarInt(1, major_, &out); |
|
if (has_minor_) AppendVarInt(2, minor_, &out); |
|
if (has_patch_) AppendVarInt(3, patch_, &out); |
|
return out; |
|
} |
|
inline std::string EngineVersion::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_major_) AppendJsonField("major", major_, &first, &out); |
|
if (has_minor_) AppendJsonField("minor", minor_, &first, &out); |
|
if (has_patch_) AppendJsonField("patch", patch_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void EngineVersion::Clear() { |
|
has_major_ = false; |
|
major_ = {}; |
|
has_minor_ = false; |
|
minor_ = {}; |
|
has_patch_ = false; |
|
patch_ = {}; |
|
} |
|
inline void EngineVersion::SetVarInt(int field_id, std::uint64_t val) { |
|
switch (field_id) { |
|
case 1: set_major(static_cast<std::uint32_t>(val)); break; |
|
case 2: set_minor(static_cast<std::uint32_t>(val)); break; |
|
case 3: set_patch(static_cast<std::uint32_t>(val)); break; |
|
} |
|
} |
|
inline bool EngineVersion::has_major() const { return has_major_; } |
|
inline std::uint32_t EngineVersion::major() const { return major_; } |
|
inline void EngineVersion::set_major(std::uint32_t val) { |
|
has_major_ = true; |
|
major_ = val; |
|
} |
|
inline bool EngineVersion::has_minor() const { return has_minor_; } |
|
inline std::uint32_t EngineVersion::minor() const { return minor_; } |
|
inline void EngineVersion::set_minor(std::uint32_t val) { |
|
has_minor_ = true; |
|
minor_ = val; |
|
} |
|
inline bool EngineVersion::has_patch() const { return has_patch_; } |
|
inline std::uint32_t EngineVersion::patch() const { return patch_; } |
|
inline void EngineVersion::set_patch(std::uint32_t val) { |
|
has_patch_ = true; |
|
patch_ = val; |
|
} |
|
inline std::string Weights_Layer::OutputAsString() const { |
|
std::string out; |
|
if (has_min_val_) AppendInt32(1, bit_cast<std::uint32_t>(min_val_), &out); |
|
if (has_max_val_) AppendInt32(2, bit_cast<std::uint32_t>(max_val_), &out); |
|
if (has_params_) AppendString(3, params_, &out); |
|
return out; |
|
} |
|
inline std::string Weights_Layer::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_min_val_) AppendJsonField("min_val", min_val_, &first, &out); |
|
if (has_max_val_) AppendJsonField("max_val", max_val_, &first, &out); |
|
if (has_params_) AppendJsonField("params", params_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights_Layer::Clear() { |
|
has_min_val_ = false; |
|
min_val_ = {}; |
|
has_max_val_ = false; |
|
max_val_ = {}; |
|
has_params_ = false; |
|
params_ = {}; |
|
} |
|
inline void Weights_Layer::SetInt32(int field_id, std::uint32_t val) { |
|
switch (field_id) { |
|
case 1: set_min_val(bit_cast<float>(val)); break; |
|
case 2: set_max_val(bit_cast<float>(val)); break; |
|
} |
|
} |
|
inline void Weights_Layer::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 3: set_params(val); break; |
|
} |
|
} |
|
inline bool Weights_Layer::has_min_val() const { return has_min_val_; } |
|
inline float Weights_Layer::min_val() const { return min_val_; } |
|
inline void Weights_Layer::set_min_val(float val) { |
|
has_min_val_ = true; |
|
min_val_ = val; |
|
} |
|
inline bool Weights_Layer::has_max_val() const { return has_max_val_; } |
|
inline float Weights_Layer::max_val() const { return max_val_; } |
|
inline void Weights_Layer::set_max_val(float val) { |
|
has_max_val_ = true; |
|
max_val_ = val; |
|
} |
|
inline bool Weights_Layer::has_params() const { return has_params_; } |
|
inline std::string_view Weights_Layer::params() const { return params_; } |
|
inline void Weights_Layer::set_params(std::string_view val) { |
|
has_params_ = true; |
|
params_ = val; |
|
} |
|
inline std::string Weights_ConvBlock::OutputAsString() const { |
|
std::string out; |
|
if (has_weights_) AppendString(1, weights_.OutputAsString(), &out); |
|
if (has_biases_) AppendString(2, biases_.OutputAsString(), &out); |
|
if (has_bn_means_) AppendString(3, bn_means_.OutputAsString(), &out); |
|
if (has_bn_stddivs_) AppendString(4, bn_stddivs_.OutputAsString(), &out); |
|
if (has_bn_gammas_) AppendString(5, bn_gammas_.OutputAsString(), &out); |
|
if (has_bn_betas_) AppendString(6, bn_betas_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Weights_ConvBlock::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_weights_) AppendJsonField("weights", weights_, &first, &out); |
|
if (has_biases_) AppendJsonField("biases", biases_, &first, &out); |
|
if (has_bn_means_) AppendJsonField("bn_means", bn_means_, &first, &out); |
|
if (has_bn_stddivs_) AppendJsonField("bn_stddivs", bn_stddivs_, &first, &out); |
|
if (has_bn_gammas_) AppendJsonField("bn_gammas", bn_gammas_, &first, &out); |
|
if (has_bn_betas_) AppendJsonField("bn_betas", bn_betas_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights_ConvBlock::Clear() { |
|
has_weights_ = false; |
|
weights_ = {}; |
|
has_biases_ = false; |
|
biases_ = {}; |
|
has_bn_means_ = false; |
|
bn_means_ = {}; |
|
has_bn_stddivs_ = false; |
|
bn_stddivs_ = {}; |
|
has_bn_gammas_ = false; |
|
bn_gammas_ = {}; |
|
has_bn_betas_ = false; |
|
bn_betas_ = {}; |
|
} |
|
inline void Weights_ConvBlock::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: mutable_weights()->MergeFromString(val); break; |
|
case 2: mutable_biases()->MergeFromString(val); break; |
|
case 3: mutable_bn_means()->MergeFromString(val); break; |
|
case 4: mutable_bn_stddivs()->MergeFromString(val); break; |
|
case 5: mutable_bn_gammas()->MergeFromString(val); break; |
|
case 6: mutable_bn_betas()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Weights_ConvBlock::has_weights() const { return has_weights_; } |
|
inline const Weights_Layer& Weights_ConvBlock::weights() const { return weights_; } |
|
inline Weights_Layer* Weights_ConvBlock::mutable_weights() { |
|
has_weights_ = true; |
|
return &weights_; |
|
} |
|
inline bool Weights_ConvBlock::has_biases() const { return has_biases_; } |
|
inline const Weights_Layer& Weights_ConvBlock::biases() const { return biases_; } |
|
inline Weights_Layer* Weights_ConvBlock::mutable_biases() { |
|
has_biases_ = true; |
|
return &biases_; |
|
} |
|
inline bool Weights_ConvBlock::has_bn_means() const { return has_bn_means_; } |
|
inline const Weights_Layer& Weights_ConvBlock::bn_means() const { return bn_means_; } |
|
inline Weights_Layer* Weights_ConvBlock::mutable_bn_means() { |
|
has_bn_means_ = true; |
|
return &bn_means_; |
|
} |
|
inline bool Weights_ConvBlock::has_bn_stddivs() const { return has_bn_stddivs_; } |
|
inline const Weights_Layer& Weights_ConvBlock::bn_stddivs() const { return bn_stddivs_; } |
|
inline Weights_Layer* Weights_ConvBlock::mutable_bn_stddivs() { |
|
has_bn_stddivs_ = true; |
|
return &bn_stddivs_; |
|
} |
|
inline bool Weights_ConvBlock::has_bn_gammas() const { return has_bn_gammas_; } |
|
inline const Weights_Layer& Weights_ConvBlock::bn_gammas() const { return bn_gammas_; } |
|
inline Weights_Layer* Weights_ConvBlock::mutable_bn_gammas() { |
|
has_bn_gammas_ = true; |
|
return &bn_gammas_; |
|
} |
|
inline bool Weights_ConvBlock::has_bn_betas() const { return has_bn_betas_; } |
|
inline const Weights_Layer& Weights_ConvBlock::bn_betas() const { return bn_betas_; } |
|
inline Weights_Layer* Weights_ConvBlock::mutable_bn_betas() { |
|
has_bn_betas_ = true; |
|
return &bn_betas_; |
|
} |
|
inline std::string Weights_SEunit::OutputAsString() const { |
|
std::string out; |
|
if (has_w1_) AppendString(1, w1_.OutputAsString(), &out); |
|
if (has_b1_) AppendString(2, b1_.OutputAsString(), &out); |
|
if (has_w2_) AppendString(3, w2_.OutputAsString(), &out); |
|
if (has_b2_) AppendString(4, b2_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Weights_SEunit::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_w1_) AppendJsonField("w1", w1_, &first, &out); |
|
if (has_b1_) AppendJsonField("b1", b1_, &first, &out); |
|
if (has_w2_) AppendJsonField("w2", w2_, &first, &out); |
|
if (has_b2_) AppendJsonField("b2", b2_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights_SEunit::Clear() { |
|
has_w1_ = false; |
|
w1_ = {}; |
|
has_b1_ = false; |
|
b1_ = {}; |
|
has_w2_ = false; |
|
w2_ = {}; |
|
has_b2_ = false; |
|
b2_ = {}; |
|
} |
|
inline void Weights_SEunit::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: mutable_w1()->MergeFromString(val); break; |
|
case 2: mutable_b1()->MergeFromString(val); break; |
|
case 3: mutable_w2()->MergeFromString(val); break; |
|
case 4: mutable_b2()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Weights_SEunit::has_w1() const { return has_w1_; } |
|
inline const Weights_Layer& Weights_SEunit::w1() const { return w1_; } |
|
inline Weights_Layer* Weights_SEunit::mutable_w1() { |
|
has_w1_ = true; |
|
return &w1_; |
|
} |
|
inline bool Weights_SEunit::has_b1() const { return has_b1_; } |
|
inline const Weights_Layer& Weights_SEunit::b1() const { return b1_; } |
|
inline Weights_Layer* Weights_SEunit::mutable_b1() { |
|
has_b1_ = true; |
|
return &b1_; |
|
} |
|
inline bool Weights_SEunit::has_w2() const { return has_w2_; } |
|
inline const Weights_Layer& Weights_SEunit::w2() const { return w2_; } |
|
inline Weights_Layer* Weights_SEunit::mutable_w2() { |
|
has_w2_ = true; |
|
return &w2_; |
|
} |
|
inline bool Weights_SEunit::has_b2() const { return has_b2_; } |
|
inline const Weights_Layer& Weights_SEunit::b2() const { return b2_; } |
|
inline Weights_Layer* Weights_SEunit::mutable_b2() { |
|
has_b2_ = true; |
|
return &b2_; |
|
} |
|
inline std::string Weights_Residual::OutputAsString() const { |
|
std::string out; |
|
if (has_conv1_) AppendString(1, conv1_.OutputAsString(), &out); |
|
if (has_conv2_) AppendString(2, conv2_.OutputAsString(), &out); |
|
if (has_se_) AppendString(3, se_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Weights_Residual::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_conv1_) AppendJsonField("conv1", conv1_, &first, &out); |
|
if (has_conv2_) AppendJsonField("conv2", conv2_, &first, &out); |
|
if (has_se_) AppendJsonField("se", se_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights_Residual::Clear() { |
|
has_conv1_ = false; |
|
conv1_ = {}; |
|
has_conv2_ = false; |
|
conv2_ = {}; |
|
has_se_ = false; |
|
se_ = {}; |
|
} |
|
inline void Weights_Residual::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: mutable_conv1()->MergeFromString(val); break; |
|
case 2: mutable_conv2()->MergeFromString(val); break; |
|
case 3: mutable_se()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Weights_Residual::has_conv1() const { return has_conv1_; } |
|
inline const Weights_ConvBlock& Weights_Residual::conv1() const { return conv1_; } |
|
inline Weights_ConvBlock* Weights_Residual::mutable_conv1() { |
|
has_conv1_ = true; |
|
return &conv1_; |
|
} |
|
inline bool Weights_Residual::has_conv2() const { return has_conv2_; } |
|
inline const Weights_ConvBlock& Weights_Residual::conv2() const { return conv2_; } |
|
inline Weights_ConvBlock* Weights_Residual::mutable_conv2() { |
|
has_conv2_ = true; |
|
return &conv2_; |
|
} |
|
inline bool Weights_Residual::has_se() const { return has_se_; } |
|
inline const Weights_SEunit& Weights_Residual::se() const { return se_; } |
|
inline Weights_SEunit* Weights_Residual::mutable_se() { |
|
has_se_ = true; |
|
return &se_; |
|
} |
|
inline std::string Weights_Smolgen::OutputAsString() const { |
|
std::string out; |
|
if (has_compress_) AppendString(1, compress_.OutputAsString(), &out); |
|
if (has_dense1_w_) AppendString(2, dense1_w_.OutputAsString(), &out); |
|
if (has_dense1_b_) AppendString(3, dense1_b_.OutputAsString(), &out); |
|
if (has_ln1_gammas_) AppendString(4, ln1_gammas_.OutputAsString(), &out); |
|
if (has_ln1_betas_) AppendString(5, ln1_betas_.OutputAsString(), &out); |
|
if (has_dense2_w_) AppendString(6, dense2_w_.OutputAsString(), &out); |
|
if (has_dense2_b_) AppendString(7, dense2_b_.OutputAsString(), &out); |
|
if (has_ln2_gammas_) AppendString(8, ln2_gammas_.OutputAsString(), &out); |
|
if (has_ln2_betas_) AppendString(9, ln2_betas_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Weights_Smolgen::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_compress_) AppendJsonField("compress", compress_, &first, &out); |
|
if (has_dense1_w_) AppendJsonField("dense1_w", dense1_w_, &first, &out); |
|
if (has_dense1_b_) AppendJsonField("dense1_b", dense1_b_, &first, &out); |
|
if (has_ln1_gammas_) AppendJsonField("ln1_gammas", ln1_gammas_, &first, &out); |
|
if (has_ln1_betas_) AppendJsonField("ln1_betas", ln1_betas_, &first, &out); |
|
if (has_dense2_w_) AppendJsonField("dense2_w", dense2_w_, &first, &out); |
|
if (has_dense2_b_) AppendJsonField("dense2_b", dense2_b_, &first, &out); |
|
if (has_ln2_gammas_) AppendJsonField("ln2_gammas", ln2_gammas_, &first, &out); |
|
if (has_ln2_betas_) AppendJsonField("ln2_betas", ln2_betas_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights_Smolgen::Clear() { |
|
has_compress_ = false; |
|
compress_ = {}; |
|
has_dense1_w_ = false; |
|
dense1_w_ = {}; |
|
has_dense1_b_ = false; |
|
dense1_b_ = {}; |
|
has_ln1_gammas_ = false; |
|
ln1_gammas_ = {}; |
|
has_ln1_betas_ = false; |
|
ln1_betas_ = {}; |
|
has_dense2_w_ = false; |
|
dense2_w_ = {}; |
|
has_dense2_b_ = false; |
|
dense2_b_ = {}; |
|
has_ln2_gammas_ = false; |
|
ln2_gammas_ = {}; |
|
has_ln2_betas_ = false; |
|
ln2_betas_ = {}; |
|
} |
|
inline void Weights_Smolgen::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: mutable_compress()->MergeFromString(val); break; |
|
case 2: mutable_dense1_w()->MergeFromString(val); break; |
|
case 3: mutable_dense1_b()->MergeFromString(val); break; |
|
case 4: mutable_ln1_gammas()->MergeFromString(val); break; |
|
case 5: mutable_ln1_betas()->MergeFromString(val); break; |
|
case 6: mutable_dense2_w()->MergeFromString(val); break; |
|
case 7: mutable_dense2_b()->MergeFromString(val); break; |
|
case 8: mutable_ln2_gammas()->MergeFromString(val); break; |
|
case 9: mutable_ln2_betas()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Weights_Smolgen::has_compress() const { return has_compress_; } |
|
inline const Weights_Layer& Weights_Smolgen::compress() const { return compress_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_compress() { |
|
has_compress_ = true; |
|
return &compress_; |
|
} |
|
inline bool Weights_Smolgen::has_dense1_w() const { return has_dense1_w_; } |
|
inline const Weights_Layer& Weights_Smolgen::dense1_w() const { return dense1_w_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_dense1_w() { |
|
has_dense1_w_ = true; |
|
return &dense1_w_; |
|
} |
|
inline bool Weights_Smolgen::has_dense1_b() const { return has_dense1_b_; } |
|
inline const Weights_Layer& Weights_Smolgen::dense1_b() const { return dense1_b_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_dense1_b() { |
|
has_dense1_b_ = true; |
|
return &dense1_b_; |
|
} |
|
inline bool Weights_Smolgen::has_ln1_gammas() const { return has_ln1_gammas_; } |
|
inline const Weights_Layer& Weights_Smolgen::ln1_gammas() const { return ln1_gammas_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_ln1_gammas() { |
|
has_ln1_gammas_ = true; |
|
return &ln1_gammas_; |
|
} |
|
inline bool Weights_Smolgen::has_ln1_betas() const { return has_ln1_betas_; } |
|
inline const Weights_Layer& Weights_Smolgen::ln1_betas() const { return ln1_betas_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_ln1_betas() { |
|
has_ln1_betas_ = true; |
|
return &ln1_betas_; |
|
} |
|
inline bool Weights_Smolgen::has_dense2_w() const { return has_dense2_w_; } |
|
inline const Weights_Layer& Weights_Smolgen::dense2_w() const { return dense2_w_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_dense2_w() { |
|
has_dense2_w_ = true; |
|
return &dense2_w_; |
|
} |
|
inline bool Weights_Smolgen::has_dense2_b() const { return has_dense2_b_; } |
|
inline const Weights_Layer& Weights_Smolgen::dense2_b() const { return dense2_b_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_dense2_b() { |
|
has_dense2_b_ = true; |
|
return &dense2_b_; |
|
} |
|
inline bool Weights_Smolgen::has_ln2_gammas() const { return has_ln2_gammas_; } |
|
inline const Weights_Layer& Weights_Smolgen::ln2_gammas() const { return ln2_gammas_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_ln2_gammas() { |
|
has_ln2_gammas_ = true; |
|
return &ln2_gammas_; |
|
} |
|
inline bool Weights_Smolgen::has_ln2_betas() const { return has_ln2_betas_; } |
|
inline const Weights_Layer& Weights_Smolgen::ln2_betas() const { return ln2_betas_; } |
|
inline Weights_Layer* Weights_Smolgen::mutable_ln2_betas() { |
|
has_ln2_betas_ = true; |
|
return &ln2_betas_; |
|
} |
|
inline std::string Weights_MHA::OutputAsString() const { |
|
std::string out; |
|
if (has_q_w_) AppendString(1, q_w_.OutputAsString(), &out); |
|
if (has_q_b_) AppendString(2, q_b_.OutputAsString(), &out); |
|
if (has_k_w_) AppendString(3, k_w_.OutputAsString(), &out); |
|
if (has_k_b_) AppendString(4, k_b_.OutputAsString(), &out); |
|
if (has_v_w_) AppendString(5, v_w_.OutputAsString(), &out); |
|
if (has_v_b_) AppendString(6, v_b_.OutputAsString(), &out); |
|
if (has_dense_w_) AppendString(7, dense_w_.OutputAsString(), &out); |
|
if (has_dense_b_) AppendString(8, dense_b_.OutputAsString(), &out); |
|
if (has_smolgen_) AppendString(9, smolgen_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Weights_MHA::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_q_w_) AppendJsonField("q_w", q_w_, &first, &out); |
|
if (has_q_b_) AppendJsonField("q_b", q_b_, &first, &out); |
|
if (has_k_w_) AppendJsonField("k_w", k_w_, &first, &out); |
|
if (has_k_b_) AppendJsonField("k_b", k_b_, &first, &out); |
|
if (has_v_w_) AppendJsonField("v_w", v_w_, &first, &out); |
|
if (has_v_b_) AppendJsonField("v_b", v_b_, &first, &out); |
|
if (has_dense_w_) AppendJsonField("dense_w", dense_w_, &first, &out); |
|
if (has_dense_b_) AppendJsonField("dense_b", dense_b_, &first, &out); |
|
if (has_smolgen_) AppendJsonField("smolgen", smolgen_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights_MHA::Clear() { |
|
has_q_w_ = false; |
|
q_w_ = {}; |
|
has_q_b_ = false; |
|
q_b_ = {}; |
|
has_k_w_ = false; |
|
k_w_ = {}; |
|
has_k_b_ = false; |
|
k_b_ = {}; |
|
has_v_w_ = false; |
|
v_w_ = {}; |
|
has_v_b_ = false; |
|
v_b_ = {}; |
|
has_dense_w_ = false; |
|
dense_w_ = {}; |
|
has_dense_b_ = false; |
|
dense_b_ = {}; |
|
has_smolgen_ = false; |
|
smolgen_ = {}; |
|
} |
|
inline void Weights_MHA::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: mutable_q_w()->MergeFromString(val); break; |
|
case 2: mutable_q_b()->MergeFromString(val); break; |
|
case 3: mutable_k_w()->MergeFromString(val); break; |
|
case 4: mutable_k_b()->MergeFromString(val); break; |
|
case 5: mutable_v_w()->MergeFromString(val); break; |
|
case 6: mutable_v_b()->MergeFromString(val); break; |
|
case 7: mutable_dense_w()->MergeFromString(val); break; |
|
case 8: mutable_dense_b()->MergeFromString(val); break; |
|
case 9: mutable_smolgen()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Weights_MHA::has_q_w() const { return has_q_w_; } |
|
inline const Weights_Layer& Weights_MHA::q_w() const { return q_w_; } |
|
inline Weights_Layer* Weights_MHA::mutable_q_w() { |
|
has_q_w_ = true; |
|
return &q_w_; |
|
} |
|
inline bool Weights_MHA::has_q_b() const { return has_q_b_; } |
|
inline const Weights_Layer& Weights_MHA::q_b() const { return q_b_; } |
|
inline Weights_Layer* Weights_MHA::mutable_q_b() { |
|
has_q_b_ = true; |
|
return &q_b_; |
|
} |
|
inline bool Weights_MHA::has_k_w() const { return has_k_w_; } |
|
inline const Weights_Layer& Weights_MHA::k_w() const { return k_w_; } |
|
inline Weights_Layer* Weights_MHA::mutable_k_w() { |
|
has_k_w_ = true; |
|
return &k_w_; |
|
} |
|
inline bool Weights_MHA::has_k_b() const { return has_k_b_; } |
|
inline const Weights_Layer& Weights_MHA::k_b() const { return k_b_; } |
|
inline Weights_Layer* Weights_MHA::mutable_k_b() { |
|
has_k_b_ = true; |
|
return &k_b_; |
|
} |
|
inline bool Weights_MHA::has_v_w() const { return has_v_w_; } |
|
inline const Weights_Layer& Weights_MHA::v_w() const { return v_w_; } |
|
inline Weights_Layer* Weights_MHA::mutable_v_w() { |
|
has_v_w_ = true; |
|
return &v_w_; |
|
} |
|
inline bool Weights_MHA::has_v_b() const { return has_v_b_; } |
|
inline const Weights_Layer& Weights_MHA::v_b() const { return v_b_; } |
|
inline Weights_Layer* Weights_MHA::mutable_v_b() { |
|
has_v_b_ = true; |
|
return &v_b_; |
|
} |
|
inline bool Weights_MHA::has_dense_w() const { return has_dense_w_; } |
|
inline const Weights_Layer& Weights_MHA::dense_w() const { return dense_w_; } |
|
inline Weights_Layer* Weights_MHA::mutable_dense_w() { |
|
has_dense_w_ = true; |
|
return &dense_w_; |
|
} |
|
inline bool Weights_MHA::has_dense_b() const { return has_dense_b_; } |
|
inline const Weights_Layer& Weights_MHA::dense_b() const { return dense_b_; } |
|
inline Weights_Layer* Weights_MHA::mutable_dense_b() { |
|
has_dense_b_ = true; |
|
return &dense_b_; |
|
} |
|
inline bool Weights_MHA::has_smolgen() const { return has_smolgen_; } |
|
inline const Weights_Smolgen& Weights_MHA::smolgen() const { return smolgen_; } |
|
inline Weights_Smolgen* Weights_MHA::mutable_smolgen() { |
|
has_smolgen_ = true; |
|
return &smolgen_; |
|
} |
|
inline std::string Weights_FFN::OutputAsString() const { |
|
std::string out; |
|
if (has_dense1_w_) AppendString(1, dense1_w_.OutputAsString(), &out); |
|
if (has_dense1_b_) AppendString(2, dense1_b_.OutputAsString(), &out); |
|
if (has_dense2_w_) AppendString(3, dense2_w_.OutputAsString(), &out); |
|
if (has_dense2_b_) AppendString(4, dense2_b_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Weights_FFN::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_dense1_w_) AppendJsonField("dense1_w", dense1_w_, &first, &out); |
|
if (has_dense1_b_) AppendJsonField("dense1_b", dense1_b_, &first, &out); |
|
if (has_dense2_w_) AppendJsonField("dense2_w", dense2_w_, &first, &out); |
|
if (has_dense2_b_) AppendJsonField("dense2_b", dense2_b_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights_FFN::Clear() { |
|
has_dense1_w_ = false; |
|
dense1_w_ = {}; |
|
has_dense1_b_ = false; |
|
dense1_b_ = {}; |
|
has_dense2_w_ = false; |
|
dense2_w_ = {}; |
|
has_dense2_b_ = false; |
|
dense2_b_ = {}; |
|
} |
|
inline void Weights_FFN::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: mutable_dense1_w()->MergeFromString(val); break; |
|
case 2: mutable_dense1_b()->MergeFromString(val); break; |
|
case 3: mutable_dense2_w()->MergeFromString(val); break; |
|
case 4: mutable_dense2_b()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Weights_FFN::has_dense1_w() const { return has_dense1_w_; } |
|
inline const Weights_Layer& Weights_FFN::dense1_w() const { return dense1_w_; } |
|
inline Weights_Layer* Weights_FFN::mutable_dense1_w() { |
|
has_dense1_w_ = true; |
|
return &dense1_w_; |
|
} |
|
inline bool Weights_FFN::has_dense1_b() const { return has_dense1_b_; } |
|
inline const Weights_Layer& Weights_FFN::dense1_b() const { return dense1_b_; } |
|
inline Weights_Layer* Weights_FFN::mutable_dense1_b() { |
|
has_dense1_b_ = true; |
|
return &dense1_b_; |
|
} |
|
inline bool Weights_FFN::has_dense2_w() const { return has_dense2_w_; } |
|
inline const Weights_Layer& Weights_FFN::dense2_w() const { return dense2_w_; } |
|
inline Weights_Layer* Weights_FFN::mutable_dense2_w() { |
|
has_dense2_w_ = true; |
|
return &dense2_w_; |
|
} |
|
inline bool Weights_FFN::has_dense2_b() const { return has_dense2_b_; } |
|
inline const Weights_Layer& Weights_FFN::dense2_b() const { return dense2_b_; } |
|
inline Weights_Layer* Weights_FFN::mutable_dense2_b() { |
|
has_dense2_b_ = true; |
|
return &dense2_b_; |
|
} |
|
inline std::string Weights_EncoderLayer::OutputAsString() const { |
|
std::string out; |
|
if (has_mha_) AppendString(1, mha_.OutputAsString(), &out); |
|
if (has_ln1_gammas_) AppendString(2, ln1_gammas_.OutputAsString(), &out); |
|
if (has_ln1_betas_) AppendString(3, ln1_betas_.OutputAsString(), &out); |
|
if (has_ffn_) AppendString(4, ffn_.OutputAsString(), &out); |
|
if (has_ln2_gammas_) AppendString(5, ln2_gammas_.OutputAsString(), &out); |
|
if (has_ln2_betas_) AppendString(6, ln2_betas_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Weights_EncoderLayer::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_mha_) AppendJsonField("mha", mha_, &first, &out); |
|
if (has_ln1_gammas_) AppendJsonField("ln1_gammas", ln1_gammas_, &first, &out); |
|
if (has_ln1_betas_) AppendJsonField("ln1_betas", ln1_betas_, &first, &out); |
|
if (has_ffn_) AppendJsonField("ffn", ffn_, &first, &out); |
|
if (has_ln2_gammas_) AppendJsonField("ln2_gammas", ln2_gammas_, &first, &out); |
|
if (has_ln2_betas_) AppendJsonField("ln2_betas", ln2_betas_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights_EncoderLayer::Clear() { |
|
has_mha_ = false; |
|
mha_ = {}; |
|
has_ln1_gammas_ = false; |
|
ln1_gammas_ = {}; |
|
has_ln1_betas_ = false; |
|
ln1_betas_ = {}; |
|
has_ffn_ = false; |
|
ffn_ = {}; |
|
has_ln2_gammas_ = false; |
|
ln2_gammas_ = {}; |
|
has_ln2_betas_ = false; |
|
ln2_betas_ = {}; |
|
} |
|
inline void Weights_EncoderLayer::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: mutable_mha()->MergeFromString(val); break; |
|
case 2: mutable_ln1_gammas()->MergeFromString(val); break; |
|
case 3: mutable_ln1_betas()->MergeFromString(val); break; |
|
case 4: mutable_ffn()->MergeFromString(val); break; |
|
case 5: mutable_ln2_gammas()->MergeFromString(val); break; |
|
case 6: mutable_ln2_betas()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Weights_EncoderLayer::has_mha() const { return has_mha_; } |
|
inline const Weights_MHA& Weights_EncoderLayer::mha() const { return mha_; } |
|
inline Weights_MHA* Weights_EncoderLayer::mutable_mha() { |
|
has_mha_ = true; |
|
return &mha_; |
|
} |
|
inline bool Weights_EncoderLayer::has_ln1_gammas() const { return has_ln1_gammas_; } |
|
inline const Weights_Layer& Weights_EncoderLayer::ln1_gammas() const { return ln1_gammas_; } |
|
inline Weights_Layer* Weights_EncoderLayer::mutable_ln1_gammas() { |
|
has_ln1_gammas_ = true; |
|
return &ln1_gammas_; |
|
} |
|
inline bool Weights_EncoderLayer::has_ln1_betas() const { return has_ln1_betas_; } |
|
inline const Weights_Layer& Weights_EncoderLayer::ln1_betas() const { return ln1_betas_; } |
|
inline Weights_Layer* Weights_EncoderLayer::mutable_ln1_betas() { |
|
has_ln1_betas_ = true; |
|
return &ln1_betas_; |
|
} |
|
inline bool Weights_EncoderLayer::has_ffn() const { return has_ffn_; } |
|
inline const Weights_FFN& Weights_EncoderLayer::ffn() const { return ffn_; } |
|
inline Weights_FFN* Weights_EncoderLayer::mutable_ffn() { |
|
has_ffn_ = true; |
|
return &ffn_; |
|
} |
|
inline bool Weights_EncoderLayer::has_ln2_gammas() const { return has_ln2_gammas_; } |
|
inline const Weights_Layer& Weights_EncoderLayer::ln2_gammas() const { return ln2_gammas_; } |
|
inline Weights_Layer* Weights_EncoderLayer::mutable_ln2_gammas() { |
|
has_ln2_gammas_ = true; |
|
return &ln2_gammas_; |
|
} |
|
inline bool Weights_EncoderLayer::has_ln2_betas() const { return has_ln2_betas_; } |
|
inline const Weights_Layer& Weights_EncoderLayer::ln2_betas() const { return ln2_betas_; } |
|
inline Weights_Layer* Weights_EncoderLayer::mutable_ln2_betas() { |
|
has_ln2_betas_ = true; |
|
return &ln2_betas_; |
|
} |
|
inline std::string Weights::OutputAsString() const { |
|
std::string out; |
|
if (has_input_) AppendString(1, input_.OutputAsString(), &out); |
|
for (const auto& x : residual_) AppendString(2, x.OutputAsString(), &out); |
|
if (has_policy_) AppendString(3, policy_.OutputAsString(), &out); |
|
if (has_ip_pol_w_) AppendString(4, ip_pol_w_.OutputAsString(), &out); |
|
if (has_ip_pol_b_) AppendString(5, ip_pol_b_.OutputAsString(), &out); |
|
if (has_value_) AppendString(6, value_.OutputAsString(), &out); |
|
if (has_ip1_val_w_) AppendString(7, ip1_val_w_.OutputAsString(), &out); |
|
if (has_ip1_val_b_) AppendString(8, ip1_val_b_.OutputAsString(), &out); |
|
if (has_ip2_val_w_) AppendString(9, ip2_val_w_.OutputAsString(), &out); |
|
if (has_ip2_val_b_) AppendString(10, ip2_val_b_.OutputAsString(), &out); |
|
if (has_policy1_) AppendString(11, policy1_.OutputAsString(), &out); |
|
if (has_moves_left_) AppendString(12, moves_left_.OutputAsString(), &out); |
|
if (has_ip1_mov_w_) AppendString(13, ip1_mov_w_.OutputAsString(), &out); |
|
if (has_ip1_mov_b_) AppendString(14, ip1_mov_b_.OutputAsString(), &out); |
|
if (has_ip2_mov_w_) AppendString(15, ip2_mov_w_.OutputAsString(), &out); |
|
if (has_ip2_mov_b_) AppendString(16, ip2_mov_b_.OutputAsString(), &out); |
|
if (has_ip2_pol_w_) AppendString(17, ip2_pol_w_.OutputAsString(), &out); |
|
if (has_ip2_pol_b_) AppendString(18, ip2_pol_b_.OutputAsString(), &out); |
|
if (has_ip3_pol_w_) AppendString(19, ip3_pol_w_.OutputAsString(), &out); |
|
if (has_ip3_pol_b_) AppendString(20, ip3_pol_b_.OutputAsString(), &out); |
|
for (const auto& x : pol_encoder_) AppendString(21, x.OutputAsString(), &out); |
|
if (has_ip4_pol_w_) AppendString(22, ip4_pol_w_.OutputAsString(), &out); |
|
if (has_pol_headcount_) AppendVarInt(24, pol_headcount_, &out); |
|
if (has_ip_emb_w_) AppendString(25, ip_emb_w_.OutputAsString(), &out); |
|
if (has_ip_emb_b_) AppendString(26, ip_emb_b_.OutputAsString(), &out); |
|
for (const auto& x : encoder_) AppendString(27, x.OutputAsString(), &out); |
|
if (has_headcount_) AppendVarInt(28, headcount_, &out); |
|
if (has_ip_val_w_) AppendString(29, ip_val_w_.OutputAsString(), &out); |
|
if (has_ip_val_b_) AppendString(30, ip_val_b_.OutputAsString(), &out); |
|
if (has_ip_mov_w_) AppendString(31, ip_mov_w_.OutputAsString(), &out); |
|
if (has_ip_mov_b_) AppendString(32, ip_mov_b_.OutputAsString(), &out); |
|
if (has_ip_mult_gate_) AppendString(33, ip_mult_gate_.OutputAsString(), &out); |
|
if (has_ip_add_gate_) AppendString(34, ip_add_gate_.OutputAsString(), &out); |
|
if (has_smolgen_w_) AppendString(35, smolgen_w_.OutputAsString(), &out); |
|
if (has_smolgen_b_) AppendString(36, smolgen_b_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Weights::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_input_) AppendJsonField("input", input_, &first, &out); |
|
if (!residual_.empty()) AppendJsonRepeatedField("residual", residual_, &first, &out); |
|
if (has_ip_emb_w_) AppendJsonField("ip_emb_w", ip_emb_w_, &first, &out); |
|
if (has_ip_emb_b_) AppendJsonField("ip_emb_b", ip_emb_b_, &first, &out); |
|
if (has_ip_mult_gate_) AppendJsonField("ip_mult_gate", ip_mult_gate_, &first, &out); |
|
if (has_ip_add_gate_) AppendJsonField("ip_add_gate", ip_add_gate_, &first, &out); |
|
if (!encoder_.empty()) AppendJsonRepeatedField("encoder", encoder_, &first, &out); |
|
if (has_headcount_) AppendJsonField("headcount", headcount_, &first, &out); |
|
if (!pol_encoder_.empty()) AppendJsonRepeatedField("pol_encoder", pol_encoder_, &first, &out); |
|
if (has_pol_headcount_) AppendJsonField("pol_headcount", pol_headcount_, &first, &out); |
|
if (has_policy1_) AppendJsonField("policy1", policy1_, &first, &out); |
|
if (has_policy_) AppendJsonField("policy", policy_, &first, &out); |
|
if (has_ip_pol_w_) AppendJsonField("ip_pol_w", ip_pol_w_, &first, &out); |
|
if (has_ip_pol_b_) AppendJsonField("ip_pol_b", ip_pol_b_, &first, &out); |
|
if (has_ip2_pol_w_) AppendJsonField("ip2_pol_w", ip2_pol_w_, &first, &out); |
|
if (has_ip2_pol_b_) AppendJsonField("ip2_pol_b", ip2_pol_b_, &first, &out); |
|
if (has_ip3_pol_w_) AppendJsonField("ip3_pol_w", ip3_pol_w_, &first, &out); |
|
if (has_ip3_pol_b_) AppendJsonField("ip3_pol_b", ip3_pol_b_, &first, &out); |
|
if (has_ip4_pol_w_) AppendJsonField("ip4_pol_w", ip4_pol_w_, &first, &out); |
|
if (has_value_) AppendJsonField("value", value_, &first, &out); |
|
if (has_ip_val_w_) AppendJsonField("ip_val_w", ip_val_w_, &first, &out); |
|
if (has_ip_val_b_) AppendJsonField("ip_val_b", ip_val_b_, &first, &out); |
|
if (has_ip1_val_w_) AppendJsonField("ip1_val_w", ip1_val_w_, &first, &out); |
|
if (has_ip1_val_b_) AppendJsonField("ip1_val_b", ip1_val_b_, &first, &out); |
|
if (has_ip2_val_w_) AppendJsonField("ip2_val_w", ip2_val_w_, &first, &out); |
|
if (has_ip2_val_b_) AppendJsonField("ip2_val_b", ip2_val_b_, &first, &out); |
|
if (has_moves_left_) AppendJsonField("moves_left", moves_left_, &first, &out); |
|
if (has_ip_mov_w_) AppendJsonField("ip_mov_w", ip_mov_w_, &first, &out); |
|
if (has_ip_mov_b_) AppendJsonField("ip_mov_b", ip_mov_b_, &first, &out); |
|
if (has_ip1_mov_w_) AppendJsonField("ip1_mov_w", ip1_mov_w_, &first, &out); |
|
if (has_ip1_mov_b_) AppendJsonField("ip1_mov_b", ip1_mov_b_, &first, &out); |
|
if (has_ip2_mov_w_) AppendJsonField("ip2_mov_w", ip2_mov_w_, &first, &out); |
|
if (has_ip2_mov_b_) AppendJsonField("ip2_mov_b", ip2_mov_b_, &first, &out); |
|
if (has_smolgen_w_) AppendJsonField("smolgen_w", smolgen_w_, &first, &out); |
|
if (has_smolgen_b_) AppendJsonField("smolgen_b", smolgen_b_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Weights::Clear() { |
|
has_input_ = false; |
|
input_ = {}; |
|
residual_.clear(); |
|
has_ip_emb_w_ = false; |
|
ip_emb_w_ = {}; |
|
has_ip_emb_b_ = false; |
|
ip_emb_b_ = {}; |
|
has_ip_mult_gate_ = false; |
|
ip_mult_gate_ = {}; |
|
has_ip_add_gate_ = false; |
|
ip_add_gate_ = {}; |
|
encoder_.clear(); |
|
has_headcount_ = false; |
|
headcount_ = {}; |
|
pol_encoder_.clear(); |
|
has_pol_headcount_ = false; |
|
pol_headcount_ = {}; |
|
has_policy1_ = false; |
|
policy1_ = {}; |
|
has_policy_ = false; |
|
policy_ = {}; |
|
has_ip_pol_w_ = false; |
|
ip_pol_w_ = {}; |
|
has_ip_pol_b_ = false; |
|
ip_pol_b_ = {}; |
|
has_ip2_pol_w_ = false; |
|
ip2_pol_w_ = {}; |
|
has_ip2_pol_b_ = false; |
|
ip2_pol_b_ = {}; |
|
has_ip3_pol_w_ = false; |
|
ip3_pol_w_ = {}; |
|
has_ip3_pol_b_ = false; |
|
ip3_pol_b_ = {}; |
|
has_ip4_pol_w_ = false; |
|
ip4_pol_w_ = {}; |
|
has_value_ = false; |
|
value_ = {}; |
|
has_ip_val_w_ = false; |
|
ip_val_w_ = {}; |
|
has_ip_val_b_ = false; |
|
ip_val_b_ = {}; |
|
has_ip1_val_w_ = false; |
|
ip1_val_w_ = {}; |
|
has_ip1_val_b_ = false; |
|
ip1_val_b_ = {}; |
|
has_ip2_val_w_ = false; |
|
ip2_val_w_ = {}; |
|
has_ip2_val_b_ = false; |
|
ip2_val_b_ = {}; |
|
has_moves_left_ = false; |
|
moves_left_ = {}; |
|
has_ip_mov_w_ = false; |
|
ip_mov_w_ = {}; |
|
has_ip_mov_b_ = false; |
|
ip_mov_b_ = {}; |
|
has_ip1_mov_w_ = false; |
|
ip1_mov_w_ = {}; |
|
has_ip1_mov_b_ = false; |
|
ip1_mov_b_ = {}; |
|
has_ip2_mov_w_ = false; |
|
ip2_mov_w_ = {}; |
|
has_ip2_mov_b_ = false; |
|
ip2_mov_b_ = {}; |
|
has_smolgen_w_ = false; |
|
smolgen_w_ = {}; |
|
has_smolgen_b_ = false; |
|
smolgen_b_ = {}; |
|
} |
|
inline void Weights::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: mutable_input()->MergeFromString(val); break; |
|
case 2: add_residual()->MergeFromString(val); break; |
|
case 25: mutable_ip_emb_w()->MergeFromString(val); break; |
|
case 26: mutable_ip_emb_b()->MergeFromString(val); break; |
|
case 33: mutable_ip_mult_gate()->MergeFromString(val); break; |
|
case 34: mutable_ip_add_gate()->MergeFromString(val); break; |
|
case 27: add_encoder()->MergeFromString(val); break; |
|
case 21: add_pol_encoder()->MergeFromString(val); break; |
|
case 11: mutable_policy1()->MergeFromString(val); break; |
|
case 3: mutable_policy()->MergeFromString(val); break; |
|
case 4: mutable_ip_pol_w()->MergeFromString(val); break; |
|
case 5: mutable_ip_pol_b()->MergeFromString(val); break; |
|
case 17: mutable_ip2_pol_w()->MergeFromString(val); break; |
|
case 18: mutable_ip2_pol_b()->MergeFromString(val); break; |
|
case 19: mutable_ip3_pol_w()->MergeFromString(val); break; |
|
case 20: mutable_ip3_pol_b()->MergeFromString(val); break; |
|
case 22: mutable_ip4_pol_w()->MergeFromString(val); break; |
|
case 6: mutable_value()->MergeFromString(val); break; |
|
case 29: mutable_ip_val_w()->MergeFromString(val); break; |
|
case 30: mutable_ip_val_b()->MergeFromString(val); break; |
|
case 7: mutable_ip1_val_w()->MergeFromString(val); break; |
|
case 8: mutable_ip1_val_b()->MergeFromString(val); break; |
|
case 9: mutable_ip2_val_w()->MergeFromString(val); break; |
|
case 10: mutable_ip2_val_b()->MergeFromString(val); break; |
|
case 12: mutable_moves_left()->MergeFromString(val); break; |
|
case 31: mutable_ip_mov_w()->MergeFromString(val); break; |
|
case 32: mutable_ip_mov_b()->MergeFromString(val); break; |
|
case 13: mutable_ip1_mov_w()->MergeFromString(val); break; |
|
case 14: mutable_ip1_mov_b()->MergeFromString(val); break; |
|
case 15: mutable_ip2_mov_w()->MergeFromString(val); break; |
|
case 16: mutable_ip2_mov_b()->MergeFromString(val); break; |
|
case 35: mutable_smolgen_w()->MergeFromString(val); break; |
|
case 36: mutable_smolgen_b()->MergeFromString(val); break; |
|
} |
|
} |
|
inline void Weights::SetVarInt(int field_id, std::uint64_t val) { |
|
switch (field_id) { |
|
case 28: set_headcount(static_cast<std::uint32_t>(val)); break; |
|
case 24: set_pol_headcount(static_cast<std::uint32_t>(val)); break; |
|
} |
|
} |
|
inline bool Weights::has_input() const { return has_input_; } |
|
inline const Weights_ConvBlock& Weights::input() const { return input_; } |
|
inline Weights_ConvBlock* Weights::mutable_input() { |
|
has_input_ = true; |
|
return &input_; |
|
} |
|
inline Weights_Residual* Weights::add_residual() { return &residual_.emplace_back(); } |
|
inline const std::vector<Weights_Residual>& Weights::residual() const { return residual_; } |
|
inline std::vector<Weights_Residual>* Weights::mutable_residual() { return &residual_; } |
|
inline const Weights_Residual& Weights::residual(size_t idx) const { return residual_[idx]; } |
|
inline Weights_Residual* Weights::mutable_residual(size_t idx) { return &residual_[idx]; } |
|
inline size_t Weights::residual_size() const { return residual_.size(); } |
|
inline bool Weights::has_ip_emb_w() const { return has_ip_emb_w_; } |
|
inline const Weights_Layer& Weights::ip_emb_w() const { return ip_emb_w_; } |
|
inline Weights_Layer* Weights::mutable_ip_emb_w() { |
|
has_ip_emb_w_ = true; |
|
return &ip_emb_w_; |
|
} |
|
inline bool Weights::has_ip_emb_b() const { return has_ip_emb_b_; } |
|
inline const Weights_Layer& Weights::ip_emb_b() const { return ip_emb_b_; } |
|
inline Weights_Layer* Weights::mutable_ip_emb_b() { |
|
has_ip_emb_b_ = true; |
|
return &ip_emb_b_; |
|
} |
|
inline bool Weights::has_ip_mult_gate() const { return has_ip_mult_gate_; } |
|
inline const Weights_Layer& Weights::ip_mult_gate() const { return ip_mult_gate_; } |
|
inline Weights_Layer* Weights::mutable_ip_mult_gate() { |
|
has_ip_mult_gate_ = true; |
|
return &ip_mult_gate_; |
|
} |
|
inline bool Weights::has_ip_add_gate() const { return has_ip_add_gate_; } |
|
inline const Weights_Layer& Weights::ip_add_gate() const { return ip_add_gate_; } |
|
inline Weights_Layer* Weights::mutable_ip_add_gate() { |
|
has_ip_add_gate_ = true; |
|
return &ip_add_gate_; |
|
} |
|
inline Weights_EncoderLayer* Weights::add_encoder() { return &encoder_.emplace_back(); } |
|
inline const std::vector<Weights_EncoderLayer>& Weights::encoder() const { return encoder_; } |
|
inline std::vector<Weights_EncoderLayer>* Weights::mutable_encoder() { return &encoder_; } |
|
inline const Weights_EncoderLayer& Weights::encoder(size_t idx) const { return encoder_[idx]; } |
|
inline Weights_EncoderLayer* Weights::mutable_encoder(size_t idx) { return &encoder_[idx]; } |
|
inline size_t Weights::encoder_size() const { return encoder_.size(); } |
|
inline bool Weights::has_headcount() const { return has_headcount_; } |
|
inline std::uint32_t Weights::headcount() const { return headcount_; } |
|
inline void Weights::set_headcount(std::uint32_t val) { |
|
has_headcount_ = true; |
|
headcount_ = val; |
|
} |
|
inline Weights_EncoderLayer* Weights::add_pol_encoder() { return &pol_encoder_.emplace_back(); } |
|
inline const std::vector<Weights_EncoderLayer>& Weights::pol_encoder() const { return pol_encoder_; } |
|
inline std::vector<Weights_EncoderLayer>* Weights::mutable_pol_encoder() { return &pol_encoder_; } |
|
inline const Weights_EncoderLayer& Weights::pol_encoder(size_t idx) const { return pol_encoder_[idx]; } |
|
inline Weights_EncoderLayer* Weights::mutable_pol_encoder(size_t idx) { return &pol_encoder_[idx]; } |
|
inline size_t Weights::pol_encoder_size() const { return pol_encoder_.size(); } |
|
inline bool Weights::has_pol_headcount() const { return has_pol_headcount_; } |
|
inline std::uint32_t Weights::pol_headcount() const { return pol_headcount_; } |
|
inline void Weights::set_pol_headcount(std::uint32_t val) { |
|
has_pol_headcount_ = true; |
|
pol_headcount_ = val; |
|
} |
|
inline bool Weights::has_policy1() const { return has_policy1_; } |
|
inline const Weights_ConvBlock& Weights::policy1() const { return policy1_; } |
|
inline Weights_ConvBlock* Weights::mutable_policy1() { |
|
has_policy1_ = true; |
|
return &policy1_; |
|
} |
|
inline bool Weights::has_policy() const { return has_policy_; } |
|
inline const Weights_ConvBlock& Weights::policy() const { return policy_; } |
|
inline Weights_ConvBlock* Weights::mutable_policy() { |
|
has_policy_ = true; |
|
return &policy_; |
|
} |
|
inline bool Weights::has_ip_pol_w() const { return has_ip_pol_w_; } |
|
inline const Weights_Layer& Weights::ip_pol_w() const { return ip_pol_w_; } |
|
inline Weights_Layer* Weights::mutable_ip_pol_w() { |
|
has_ip_pol_w_ = true; |
|
return &ip_pol_w_; |
|
} |
|
inline bool Weights::has_ip_pol_b() const { return has_ip_pol_b_; } |
|
inline const Weights_Layer& Weights::ip_pol_b() const { return ip_pol_b_; } |
|
inline Weights_Layer* Weights::mutable_ip_pol_b() { |
|
has_ip_pol_b_ = true; |
|
return &ip_pol_b_; |
|
} |
|
inline bool Weights::has_ip2_pol_w() const { return has_ip2_pol_w_; } |
|
inline const Weights_Layer& Weights::ip2_pol_w() const { return ip2_pol_w_; } |
|
inline Weights_Layer* Weights::mutable_ip2_pol_w() { |
|
has_ip2_pol_w_ = true; |
|
return &ip2_pol_w_; |
|
} |
|
inline bool Weights::has_ip2_pol_b() const { return has_ip2_pol_b_; } |
|
inline const Weights_Layer& Weights::ip2_pol_b() const { return ip2_pol_b_; } |
|
inline Weights_Layer* Weights::mutable_ip2_pol_b() { |
|
has_ip2_pol_b_ = true; |
|
return &ip2_pol_b_; |
|
} |
|
inline bool Weights::has_ip3_pol_w() const { return has_ip3_pol_w_; } |
|
inline const Weights_Layer& Weights::ip3_pol_w() const { return ip3_pol_w_; } |
|
inline Weights_Layer* Weights::mutable_ip3_pol_w() { |
|
has_ip3_pol_w_ = true; |
|
return &ip3_pol_w_; |
|
} |
|
inline bool Weights::has_ip3_pol_b() const { return has_ip3_pol_b_; } |
|
inline const Weights_Layer& Weights::ip3_pol_b() const { return ip3_pol_b_; } |
|
inline Weights_Layer* Weights::mutable_ip3_pol_b() { |
|
has_ip3_pol_b_ = true; |
|
return &ip3_pol_b_; |
|
} |
|
inline bool Weights::has_ip4_pol_w() const { return has_ip4_pol_w_; } |
|
inline const Weights_Layer& Weights::ip4_pol_w() const { return ip4_pol_w_; } |
|
inline Weights_Layer* Weights::mutable_ip4_pol_w() { |
|
has_ip4_pol_w_ = true; |
|
return &ip4_pol_w_; |
|
} |
|
inline bool Weights::has_value() const { return has_value_; } |
|
inline const Weights_ConvBlock& Weights::value() const { return value_; } |
|
inline Weights_ConvBlock* Weights::mutable_value() { |
|
has_value_ = true; |
|
return &value_; |
|
} |
|
inline bool Weights::has_ip_val_w() const { return has_ip_val_w_; } |
|
inline const Weights_Layer& Weights::ip_val_w() const { return ip_val_w_; } |
|
inline Weights_Layer* Weights::mutable_ip_val_w() { |
|
has_ip_val_w_ = true; |
|
return &ip_val_w_; |
|
} |
|
inline bool Weights::has_ip_val_b() const { return has_ip_val_b_; } |
|
inline const Weights_Layer& Weights::ip_val_b() const { return ip_val_b_; } |
|
inline Weights_Layer* Weights::mutable_ip_val_b() { |
|
has_ip_val_b_ = true; |
|
return &ip_val_b_; |
|
} |
|
inline bool Weights::has_ip1_val_w() const { return has_ip1_val_w_; } |
|
inline const Weights_Layer& Weights::ip1_val_w() const { return ip1_val_w_; } |
|
inline Weights_Layer* Weights::mutable_ip1_val_w() { |
|
has_ip1_val_w_ = true; |
|
return &ip1_val_w_; |
|
} |
|
inline bool Weights::has_ip1_val_b() const { return has_ip1_val_b_; } |
|
inline const Weights_Layer& Weights::ip1_val_b() const { return ip1_val_b_; } |
|
inline Weights_Layer* Weights::mutable_ip1_val_b() { |
|
has_ip1_val_b_ = true; |
|
return &ip1_val_b_; |
|
} |
|
inline bool Weights::has_ip2_val_w() const { return has_ip2_val_w_; } |
|
inline const Weights_Layer& Weights::ip2_val_w() const { return ip2_val_w_; } |
|
inline Weights_Layer* Weights::mutable_ip2_val_w() { |
|
has_ip2_val_w_ = true; |
|
return &ip2_val_w_; |
|
} |
|
inline bool Weights::has_ip2_val_b() const { return has_ip2_val_b_; } |
|
inline const Weights_Layer& Weights::ip2_val_b() const { return ip2_val_b_; } |
|
inline Weights_Layer* Weights::mutable_ip2_val_b() { |
|
has_ip2_val_b_ = true; |
|
return &ip2_val_b_; |
|
} |
|
inline bool Weights::has_moves_left() const { return has_moves_left_; } |
|
inline const Weights_ConvBlock& Weights::moves_left() const { return moves_left_; } |
|
inline Weights_ConvBlock* Weights::mutable_moves_left() { |
|
has_moves_left_ = true; |
|
return &moves_left_; |
|
} |
|
inline bool Weights::has_ip_mov_w() const { return has_ip_mov_w_; } |
|
inline const Weights_Layer& Weights::ip_mov_w() const { return ip_mov_w_; } |
|
inline Weights_Layer* Weights::mutable_ip_mov_w() { |
|
has_ip_mov_w_ = true; |
|
return &ip_mov_w_; |
|
} |
|
inline bool Weights::has_ip_mov_b() const { return has_ip_mov_b_; } |
|
inline const Weights_Layer& Weights::ip_mov_b() const { return ip_mov_b_; } |
|
inline Weights_Layer* Weights::mutable_ip_mov_b() { |
|
has_ip_mov_b_ = true; |
|
return &ip_mov_b_; |
|
} |
|
inline bool Weights::has_ip1_mov_w() const { return has_ip1_mov_w_; } |
|
inline const Weights_Layer& Weights::ip1_mov_w() const { return ip1_mov_w_; } |
|
inline Weights_Layer* Weights::mutable_ip1_mov_w() { |
|
has_ip1_mov_w_ = true; |
|
return &ip1_mov_w_; |
|
} |
|
inline bool Weights::has_ip1_mov_b() const { return has_ip1_mov_b_; } |
|
inline const Weights_Layer& Weights::ip1_mov_b() const { return ip1_mov_b_; } |
|
inline Weights_Layer* Weights::mutable_ip1_mov_b() { |
|
has_ip1_mov_b_ = true; |
|
return &ip1_mov_b_; |
|
} |
|
inline bool Weights::has_ip2_mov_w() const { return has_ip2_mov_w_; } |
|
inline const Weights_Layer& Weights::ip2_mov_w() const { return ip2_mov_w_; } |
|
inline Weights_Layer* Weights::mutable_ip2_mov_w() { |
|
has_ip2_mov_w_ = true; |
|
return &ip2_mov_w_; |
|
} |
|
inline bool Weights::has_ip2_mov_b() const { return has_ip2_mov_b_; } |
|
inline const Weights_Layer& Weights::ip2_mov_b() const { return ip2_mov_b_; } |
|
inline Weights_Layer* Weights::mutable_ip2_mov_b() { |
|
has_ip2_mov_b_ = true; |
|
return &ip2_mov_b_; |
|
} |
|
inline bool Weights::has_smolgen_w() const { return has_smolgen_w_; } |
|
inline const Weights_Layer& Weights::smolgen_w() const { return smolgen_w_; } |
|
inline Weights_Layer* Weights::mutable_smolgen_w() { |
|
has_smolgen_w_ = true; |
|
return &smolgen_w_; |
|
} |
|
inline bool Weights::has_smolgen_b() const { return has_smolgen_b_; } |
|
inline const Weights_Layer& Weights::smolgen_b() const { return smolgen_b_; } |
|
inline Weights_Layer* Weights::mutable_smolgen_b() { |
|
has_smolgen_b_ = true; |
|
return &smolgen_b_; |
|
} |
|
inline std::string TrainingParams::OutputAsString() const { |
|
std::string out; |
|
if (has_training_steps_) AppendVarInt(1, training_steps_, &out); |
|
if (has_learning_rate_) AppendInt32(2, bit_cast<std::uint32_t>(learning_rate_), &out); |
|
if (has_mse_loss_) AppendInt32(3, bit_cast<std::uint32_t>(mse_loss_), &out); |
|
if (has_policy_loss_) AppendInt32(4, bit_cast<std::uint32_t>(policy_loss_), &out); |
|
if (has_accuracy_) AppendInt32(5, bit_cast<std::uint32_t>(accuracy_), &out); |
|
if (has_lc0_params_) AppendString(6, lc0_params_, &out); |
|
return out; |
|
} |
|
inline std::string TrainingParams::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_training_steps_) AppendJsonField("training_steps", training_steps_, &first, &out); |
|
if (has_learning_rate_) AppendJsonField("learning_rate", learning_rate_, &first, &out); |
|
if (has_mse_loss_) AppendJsonField("mse_loss", mse_loss_, &first, &out); |
|
if (has_policy_loss_) AppendJsonField("policy_loss", policy_loss_, &first, &out); |
|
if (has_accuracy_) AppendJsonField("accuracy", accuracy_, &first, &out); |
|
if (has_lc0_params_) AppendJsonField("lc0_params", lc0_params_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void TrainingParams::Clear() { |
|
has_training_steps_ = false; |
|
training_steps_ = {}; |
|
has_learning_rate_ = false; |
|
learning_rate_ = {}; |
|
has_mse_loss_ = false; |
|
mse_loss_ = {}; |
|
has_policy_loss_ = false; |
|
policy_loss_ = {}; |
|
has_accuracy_ = false; |
|
accuracy_ = {}; |
|
has_lc0_params_ = false; |
|
lc0_params_ = {}; |
|
} |
|
inline void TrainingParams::SetVarInt(int field_id, std::uint64_t val) { |
|
switch (field_id) { |
|
case 1: set_training_steps(static_cast<std::uint32_t>(val)); break; |
|
} |
|
} |
|
inline void TrainingParams::SetInt32(int field_id, std::uint32_t val) { |
|
switch (field_id) { |
|
case 2: set_learning_rate(bit_cast<float>(val)); break; |
|
case 3: set_mse_loss(bit_cast<float>(val)); break; |
|
case 4: set_policy_loss(bit_cast<float>(val)); break; |
|
case 5: set_accuracy(bit_cast<float>(val)); break; |
|
} |
|
} |
|
inline void TrainingParams::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 6: set_lc0_params(val); break; |
|
} |
|
} |
|
inline bool TrainingParams::has_training_steps() const { return has_training_steps_; } |
|
inline std::uint32_t TrainingParams::training_steps() const { return training_steps_; } |
|
inline void TrainingParams::set_training_steps(std::uint32_t val) { |
|
has_training_steps_ = true; |
|
training_steps_ = val; |
|
} |
|
inline bool TrainingParams::has_learning_rate() const { return has_learning_rate_; } |
|
inline float TrainingParams::learning_rate() const { return learning_rate_; } |
|
inline void TrainingParams::set_learning_rate(float val) { |
|
has_learning_rate_ = true; |
|
learning_rate_ = val; |
|
} |
|
inline bool TrainingParams::has_mse_loss() const { return has_mse_loss_; } |
|
inline float TrainingParams::mse_loss() const { return mse_loss_; } |
|
inline void TrainingParams::set_mse_loss(float val) { |
|
has_mse_loss_ = true; |
|
mse_loss_ = val; |
|
} |
|
inline bool TrainingParams::has_policy_loss() const { return has_policy_loss_; } |
|
inline float TrainingParams::policy_loss() const { return policy_loss_; } |
|
inline void TrainingParams::set_policy_loss(float val) { |
|
has_policy_loss_ = true; |
|
policy_loss_ = val; |
|
} |
|
inline bool TrainingParams::has_accuracy() const { return has_accuracy_; } |
|
inline float TrainingParams::accuracy() const { return accuracy_; } |
|
inline void TrainingParams::set_accuracy(float val) { |
|
has_accuracy_ = true; |
|
accuracy_ = val; |
|
} |
|
inline bool TrainingParams::has_lc0_params() const { return has_lc0_params_; } |
|
inline std::string_view TrainingParams::lc0_params() const { return lc0_params_; } |
|
inline void TrainingParams::set_lc0_params(std::string_view val) { |
|
has_lc0_params_ = true; |
|
lc0_params_ = val; |
|
} |
|
inline std::string NetworkFormat::OutputAsString() const { |
|
std::string out; |
|
if (has_input_) AppendVarInt(1, input_, &out); |
|
if (has_output_) AppendVarInt(2, output_, &out); |
|
if (has_network_) AppendVarInt(3, network_, &out); |
|
if (has_policy_) AppendVarInt(4, policy_, &out); |
|
if (has_value_) AppendVarInt(5, value_, &out); |
|
if (has_moves_left_) AppendVarInt(6, moves_left_, &out); |
|
if (has_default_activation_) AppendVarInt(7, default_activation_, &out); |
|
if (has_smolgen_activation_) AppendVarInt(8, smolgen_activation_, &out); |
|
if (has_ffn_activation_) AppendVarInt(9, ffn_activation_, &out); |
|
return out; |
|
} |
|
inline std::string NetworkFormat::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_input_) AppendJsonField("input", NetworkFormat_InputFormat_Name(input_), &first, &out); |
|
if (has_output_) AppendJsonField("output", NetworkFormat_OutputFormat_Name(output_), &first, &out); |
|
if (has_network_) AppendJsonField("network", NetworkFormat_NetworkStructure_Name(network_), &first, &out); |
|
if (has_policy_) AppendJsonField("policy", NetworkFormat_PolicyFormat_Name(policy_), &first, &out); |
|
if (has_value_) AppendJsonField("value", NetworkFormat_ValueFormat_Name(value_), &first, &out); |
|
if (has_moves_left_) AppendJsonField("moves_left", NetworkFormat_MovesLeftFormat_Name(moves_left_), &first, &out); |
|
if (has_default_activation_) AppendJsonField("default_activation", NetworkFormat_DefaultActivation_Name(default_activation_), &first, &out); |
|
if (has_smolgen_activation_) AppendJsonField("smolgen_activation", NetworkFormat_ActivationFunction_Name(smolgen_activation_), &first, &out); |
|
if (has_ffn_activation_) AppendJsonField("ffn_activation", NetworkFormat_ActivationFunction_Name(ffn_activation_), &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void NetworkFormat::Clear() { |
|
has_input_ = false; |
|
input_ = {}; |
|
has_output_ = false; |
|
output_ = {}; |
|
has_network_ = false; |
|
network_ = {}; |
|
has_policy_ = false; |
|
policy_ = {}; |
|
has_value_ = false; |
|
value_ = {}; |
|
has_moves_left_ = false; |
|
moves_left_ = {}; |
|
has_default_activation_ = false; |
|
default_activation_ = {}; |
|
has_smolgen_activation_ = false; |
|
smolgen_activation_ = {}; |
|
has_ffn_activation_ = false; |
|
ffn_activation_ = {}; |
|
} |
|
inline void NetworkFormat::SetVarInt(int field_id, std::uint64_t val) { |
|
switch (field_id) { |
|
case 1: set_input(static_cast<NetworkFormat_InputFormat>(val)); break; |
|
case 2: set_output(static_cast<NetworkFormat_OutputFormat>(val)); break; |
|
case 3: set_network(static_cast<NetworkFormat_NetworkStructure>(val)); break; |
|
case 4: set_policy(static_cast<NetworkFormat_PolicyFormat>(val)); break; |
|
case 5: set_value(static_cast<NetworkFormat_ValueFormat>(val)); break; |
|
case 6: set_moves_left(static_cast<NetworkFormat_MovesLeftFormat>(val)); break; |
|
case 7: set_default_activation(static_cast<NetworkFormat_DefaultActivation>(val)); break; |
|
case 8: set_smolgen_activation(static_cast<NetworkFormat_ActivationFunction>(val)); break; |
|
case 9: set_ffn_activation(static_cast<NetworkFormat_ActivationFunction>(val)); break; |
|
} |
|
} |
|
inline bool NetworkFormat::has_input() const { return has_input_; } |
|
inline NetworkFormat_InputFormat NetworkFormat::input() const { return input_; } |
|
inline void NetworkFormat::set_input(NetworkFormat_InputFormat val) { |
|
has_input_ = true; |
|
input_ = val; |
|
} |
|
inline bool NetworkFormat::has_output() const { return has_output_; } |
|
inline NetworkFormat_OutputFormat NetworkFormat::output() const { return output_; } |
|
inline void NetworkFormat::set_output(NetworkFormat_OutputFormat val) { |
|
has_output_ = true; |
|
output_ = val; |
|
} |
|
inline bool NetworkFormat::has_network() const { return has_network_; } |
|
inline NetworkFormat_NetworkStructure NetworkFormat::network() const { return network_; } |
|
inline void NetworkFormat::set_network(NetworkFormat_NetworkStructure val) { |
|
has_network_ = true; |
|
network_ = val; |
|
} |
|
inline bool NetworkFormat::has_policy() const { return has_policy_; } |
|
inline NetworkFormat_PolicyFormat NetworkFormat::policy() const { return policy_; } |
|
inline void NetworkFormat::set_policy(NetworkFormat_PolicyFormat val) { |
|
has_policy_ = true; |
|
policy_ = val; |
|
} |
|
inline bool NetworkFormat::has_value() const { return has_value_; } |
|
inline NetworkFormat_ValueFormat NetworkFormat::value() const { return value_; } |
|
inline void NetworkFormat::set_value(NetworkFormat_ValueFormat val) { |
|
has_value_ = true; |
|
value_ = val; |
|
} |
|
inline bool NetworkFormat::has_moves_left() const { return has_moves_left_; } |
|
inline NetworkFormat_MovesLeftFormat NetworkFormat::moves_left() const { return moves_left_; } |
|
inline void NetworkFormat::set_moves_left(NetworkFormat_MovesLeftFormat val) { |
|
has_moves_left_ = true; |
|
moves_left_ = val; |
|
} |
|
inline bool NetworkFormat::has_default_activation() const { return has_default_activation_; } |
|
inline NetworkFormat_DefaultActivation NetworkFormat::default_activation() const { return default_activation_; } |
|
inline void NetworkFormat::set_default_activation(NetworkFormat_DefaultActivation val) { |
|
has_default_activation_ = true; |
|
default_activation_ = val; |
|
} |
|
inline bool NetworkFormat::has_smolgen_activation() const { return has_smolgen_activation_; } |
|
inline NetworkFormat_ActivationFunction NetworkFormat::smolgen_activation() const { return smolgen_activation_; } |
|
inline void NetworkFormat::set_smolgen_activation(NetworkFormat_ActivationFunction val) { |
|
has_smolgen_activation_ = true; |
|
smolgen_activation_ = val; |
|
} |
|
inline bool NetworkFormat::has_ffn_activation() const { return has_ffn_activation_; } |
|
inline NetworkFormat_ActivationFunction NetworkFormat::ffn_activation() const { return ffn_activation_; } |
|
inline void NetworkFormat::set_ffn_activation(NetworkFormat_ActivationFunction val) { |
|
has_ffn_activation_ = true; |
|
ffn_activation_ = val; |
|
} |
|
inline std::string Format::OutputAsString() const { |
|
std::string out; |
|
if (has_weights_encoding_) AppendVarInt(1, weights_encoding_, &out); |
|
if (has_network_format_) AppendString(2, network_format_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Format::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_weights_encoding_) AppendJsonField("weights_encoding", Format_Encoding_Name(weights_encoding_), &first, &out); |
|
if (has_network_format_) AppendJsonField("network_format", network_format_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Format::Clear() { |
|
has_weights_encoding_ = false; |
|
weights_encoding_ = {}; |
|
has_network_format_ = false; |
|
network_format_ = {}; |
|
} |
|
inline void Format::SetVarInt(int field_id, std::uint64_t val) { |
|
switch (field_id) { |
|
case 1: set_weights_encoding(static_cast<Format_Encoding>(val)); break; |
|
} |
|
} |
|
inline void Format::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 2: mutable_network_format()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Format::has_weights_encoding() const { return has_weights_encoding_; } |
|
inline Format_Encoding Format::weights_encoding() const { return weights_encoding_; } |
|
inline void Format::set_weights_encoding(Format_Encoding val) { |
|
has_weights_encoding_ = true; |
|
weights_encoding_ = val; |
|
} |
|
inline bool Format::has_network_format() const { return has_network_format_; } |
|
inline const NetworkFormat& Format::network_format() const { return network_format_; } |
|
inline NetworkFormat* Format::mutable_network_format() { |
|
has_network_format_ = true; |
|
return &network_format_; |
|
} |
|
inline std::string OnnxModel::OutputAsString() const { |
|
std::string out; |
|
if (has_model_) AppendString(1, model_, &out); |
|
if (has_data_type_) AppendVarInt(2, data_type_, &out); |
|
if (has_input_planes_) AppendString(3, input_planes_, &out); |
|
if (has_output_value_) AppendString(4, output_value_, &out); |
|
if (has_output_wdl_) AppendString(5, output_wdl_, &out); |
|
if (has_output_policy_) AppendString(6, output_policy_, &out); |
|
if (has_output_mlh_) AppendString(7, output_mlh_, &out); |
|
return out; |
|
} |
|
inline std::string OnnxModel::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_model_) AppendJsonField("model", model_, &first, &out); |
|
if (has_data_type_) AppendJsonField("data_type", OnnxModel_DataType_Name(data_type_), &first, &out); |
|
if (has_input_planes_) AppendJsonField("input_planes", input_planes_, &first, &out); |
|
if (has_output_value_) AppendJsonField("output_value", output_value_, &first, &out); |
|
if (has_output_wdl_) AppendJsonField("output_wdl", output_wdl_, &first, &out); |
|
if (has_output_policy_) AppendJsonField("output_policy", output_policy_, &first, &out); |
|
if (has_output_mlh_) AppendJsonField("output_mlh", output_mlh_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void OnnxModel::Clear() { |
|
has_model_ = false; |
|
model_ = {}; |
|
has_data_type_ = false; |
|
data_type_ = {}; |
|
has_input_planes_ = false; |
|
input_planes_ = {}; |
|
has_output_value_ = false; |
|
output_value_ = {}; |
|
has_output_wdl_ = false; |
|
output_wdl_ = {}; |
|
has_output_policy_ = false; |
|
output_policy_ = {}; |
|
has_output_mlh_ = false; |
|
output_mlh_ = {}; |
|
} |
|
inline void OnnxModel::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 1: set_model(val); break; |
|
case 3: set_input_planes(val); break; |
|
case 4: set_output_value(val); break; |
|
case 5: set_output_wdl(val); break; |
|
case 6: set_output_policy(val); break; |
|
case 7: set_output_mlh(val); break; |
|
} |
|
} |
|
inline void OnnxModel::SetVarInt(int field_id, std::uint64_t val) { |
|
switch (field_id) { |
|
case 2: set_data_type(static_cast<OnnxModel_DataType>(val)); break; |
|
} |
|
} |
|
inline bool OnnxModel::has_model() const { return has_model_; } |
|
inline std::string_view OnnxModel::model() const { return model_; } |
|
inline void OnnxModel::set_model(std::string_view val) { |
|
has_model_ = true; |
|
model_ = val; |
|
} |
|
inline bool OnnxModel::has_data_type() const { return has_data_type_; } |
|
inline OnnxModel_DataType OnnxModel::data_type() const { return data_type_; } |
|
inline void OnnxModel::set_data_type(OnnxModel_DataType val) { |
|
has_data_type_ = true; |
|
data_type_ = val; |
|
} |
|
inline bool OnnxModel::has_input_planes() const { return has_input_planes_; } |
|
inline std::string_view OnnxModel::input_planes() const { return input_planes_; } |
|
inline void OnnxModel::set_input_planes(std::string_view val) { |
|
has_input_planes_ = true; |
|
input_planes_ = val; |
|
} |
|
inline bool OnnxModel::has_output_value() const { return has_output_value_; } |
|
inline std::string_view OnnxModel::output_value() const { return output_value_; } |
|
inline void OnnxModel::set_output_value(std::string_view val) { |
|
has_output_value_ = true; |
|
output_value_ = val; |
|
} |
|
inline bool OnnxModel::has_output_wdl() const { return has_output_wdl_; } |
|
inline std::string_view OnnxModel::output_wdl() const { return output_wdl_; } |
|
inline void OnnxModel::set_output_wdl(std::string_view val) { |
|
has_output_wdl_ = true; |
|
output_wdl_ = val; |
|
} |
|
inline bool OnnxModel::has_output_policy() const { return has_output_policy_; } |
|
inline std::string_view OnnxModel::output_policy() const { return output_policy_; } |
|
inline void OnnxModel::set_output_policy(std::string_view val) { |
|
has_output_policy_ = true; |
|
output_policy_ = val; |
|
} |
|
inline bool OnnxModel::has_output_mlh() const { return has_output_mlh_; } |
|
inline std::string_view OnnxModel::output_mlh() const { return output_mlh_; } |
|
inline void OnnxModel::set_output_mlh(std::string_view val) { |
|
has_output_mlh_ = true; |
|
output_mlh_ = val; |
|
} |
|
inline std::string Net::OutputAsString() const { |
|
std::string out; |
|
if (has_magic_) AppendInt32(1, magic_, &out); |
|
if (has_license_) AppendString(2, license_, &out); |
|
if (has_min_version_) AppendString(3, min_version_.OutputAsString(), &out); |
|
if (has_format_) AppendString(4, format_.OutputAsString(), &out); |
|
if (has_training_params_) AppendString(5, training_params_.OutputAsString(), &out); |
|
if (has_weights_) AppendString(10, weights_.OutputAsString(), &out); |
|
if (has_onnx_model_) AppendString(11, onnx_model_.OutputAsString(), &out); |
|
return out; |
|
} |
|
inline std::string Net::OutputAsJson() const { |
|
bool first = true; |
|
std::string out = "{"; |
|
if (has_magic_) AppendJsonField("magic", magic_, &first, &out); |
|
if (has_license_) AppendJsonField("license", license_, &first, &out); |
|
if (has_min_version_) AppendJsonField("min_version", min_version_, &first, &out); |
|
if (has_format_) AppendJsonField("format", format_, &first, &out); |
|
if (has_training_params_) AppendJsonField("training_params", training_params_, &first, &out); |
|
if (has_weights_) AppendJsonField("weights", weights_, &first, &out); |
|
if (has_onnx_model_) AppendJsonField("onnx_model", onnx_model_, &first, &out); |
|
out += "}"; |
|
return out; |
|
} |
|
inline void Net::Clear() { |
|
has_magic_ = false; |
|
magic_ = {}; |
|
has_license_ = false; |
|
license_ = {}; |
|
has_min_version_ = false; |
|
min_version_ = {}; |
|
has_format_ = false; |
|
format_ = {}; |
|
has_training_params_ = false; |
|
training_params_ = {}; |
|
has_weights_ = false; |
|
weights_ = {}; |
|
has_onnx_model_ = false; |
|
onnx_model_ = {}; |
|
} |
|
inline void Net::SetInt32(int field_id, std::uint32_t val) { |
|
switch (field_id) { |
|
case 1: set_magic(static_cast<std::uint32_t>(val)); break; |
|
} |
|
} |
|
inline void Net::SetString(int field_id, std::string_view val) { |
|
switch (field_id) { |
|
case 2: set_license(val); break; |
|
case 3: mutable_min_version()->MergeFromString(val); break; |
|
case 4: mutable_format()->MergeFromString(val); break; |
|
case 5: mutable_training_params()->MergeFromString(val); break; |
|
case 10: mutable_weights()->MergeFromString(val); break; |
|
case 11: mutable_onnx_model()->MergeFromString(val); break; |
|
} |
|
} |
|
inline bool Net::has_magic() const { return has_magic_; } |
|
inline std::uint32_t Net::magic() const { return magic_; } |
|
inline void Net::set_magic(std::uint32_t val) { |
|
has_magic_ = true; |
|
magic_ = val; |
|
} |
|
inline bool Net::has_license() const { return has_license_; } |
|
inline std::string_view Net::license() const { return license_; } |
|
inline void Net::set_license(std::string_view val) { |
|
has_license_ = true; |
|
license_ = val; |
|
} |
|
inline bool Net::has_min_version() const { return has_min_version_; } |
|
inline const EngineVersion& Net::min_version() const { return min_version_; } |
|
inline EngineVersion* Net::mutable_min_version() { |
|
has_min_version_ = true; |
|
return &min_version_; |
|
} |
|
inline bool Net::has_format() const { return has_format_; } |
|
inline const Format& Net::format() const { return format_; } |
|
inline Format* Net::mutable_format() { |
|
has_format_ = true; |
|
return &format_; |
|
} |
|
inline bool Net::has_training_params() const { return has_training_params_; } |
|
inline const TrainingParams& Net::training_params() const { return training_params_; } |
|
inline TrainingParams* Net::mutable_training_params() { |
|
has_training_params_ = true; |
|
return &training_params_; |
|
} |
|
inline bool Net::has_weights() const { return has_weights_; } |
|
inline const Weights& Net::weights() const { return weights_; } |
|
inline Weights* Net::mutable_weights() { |
|
has_weights_ = true; |
|
return &weights_; |
|
} |
|
inline bool Net::has_onnx_model() const { return has_onnx_model_; } |
|
inline const OnnxModel& Net::onnx_model() const { return onnx_model_; } |
|
inline OnnxModel* Net::mutable_onnx_model() { |
|
has_onnx_model_ = true; |
|
return &onnx_model_; |
|
} |
|
} |
|
|