#ifndef __TAE_HPP__ #define __TAE_HPP__ #include "ggml_extend.hpp" #include "model.h" /* =================================== TinyAutoEncoder =================================== References: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/vae.py https://github.com/madebyollin/taesd/blob/main/taesd.py */ class TAEBlock : public UnaryBlock { protected: int n_in; int n_out; public: TAEBlock(int n_in, int n_out) : n_in(n_in), n_out(n_out) { blocks["conv.0"] = std::shared_ptr(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.2"] = std::shared_ptr(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.4"] = std::shared_ptr(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); if (n_in != n_out) { blocks["skip"] = std::shared_ptr(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false)); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [n, n_in, h, w] // return: [n, n_out, h, w] auto conv_0 = std::dynamic_pointer_cast(blocks["conv.0"]); auto conv_2 = std::dynamic_pointer_cast(blocks["conv.2"]); auto conv_4 = std::dynamic_pointer_cast(blocks["conv.4"]); auto h = conv_0->forward(ctx, x); h = ggml_relu_inplace(ctx, h); h = conv_2->forward(ctx, h); h = ggml_relu_inplace(ctx, h); h = conv_4->forward(ctx, h); if (n_in != n_out) { auto skip = std::dynamic_pointer_cast(blocks["skip"]); LOG_DEBUG("skip"); x = skip->forward(ctx, x); } h = ggml_add(ctx, h, x); h = ggml_relu_inplace(ctx, h); return h; } }; class TinyEncoder : public UnaryBlock { int in_channels = 3; int channels = 64; int z_channels = 4; int num_blocks = 3; public: TinyEncoder(int z_channels = 4) : z_channels(z_channels) { int index = 0; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); for (int i = 0; i < num_blocks; i++) { blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); } blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); for (int i = 0; i < num_blocks; i++) { blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); } blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); for (int i = 0; i < num_blocks; i++) { blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); } blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1})); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [n, in_channels, h, w] // return: [n, z_channels, h/8, w/8] for (int i = 0; i < num_blocks * 3 + 6; i++) { auto block = std::dynamic_pointer_cast(blocks[std::to_string(i)]); x = block->forward(ctx, x); } return x; } }; class TinyDecoder : public UnaryBlock { int z_channels = 4; int channels = 64; int out_channels = 3; int num_blocks = 3; public: TinyDecoder(int z_channels = 4) : z_channels(z_channels) { int index = 0; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() for (int i = 0; i < num_blocks; i++) { blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); } index++; // nn.Upsample() blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); for (int i = 0; i < num_blocks; i++) { blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); } index++; // nn.Upsample() blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); for (int i = 0; i < num_blocks; i++) { blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); } index++; // nn.Upsample() blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1})); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) { // z: [n, z_channels, h, w] // return: [n, out_channels, h*8, w*8] auto h = ggml_scale(ctx, z, 1.0f / 3.0f); h = ggml_tanh_inplace(ctx, h); h = ggml_scale(ctx, h, 3.0f); for (int i = 0; i < num_blocks * 3 + 10; i++) { if (blocks.find(std::to_string(i)) == blocks.end()) { if (i == 1) { h = ggml_relu_inplace(ctx, h); } else { h = ggml_upscale(ctx, h, 2); } continue; } auto block = std::dynamic_pointer_cast(blocks[std::to_string(i)]); h = block->forward(ctx, h); } return h; } }; class TAESD : public GGMLBlock { protected: bool decode_only; public: TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) : decode_only(decode_only) { int z_channels = 4; if (sd_version_is_dit(version)) { z_channels = 16; } blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder(z_channels)); if (!decode_only) { blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels)); } } struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) { auto decoder = std::dynamic_pointer_cast(blocks["decoder.layers"]); return decoder->forward(ctx, z); } struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) { auto encoder = std::dynamic_pointer_cast(blocks["encoder.layers"]); return encoder->forward(ctx, x); } }; struct TinyAutoEncoder : public GGMLRunner { TAESD taesd; bool decode_only = false; TinyAutoEncoder(ggml_backend_t backend, std::map& tensor_types, const std::string prefix, bool decoder_only = true, SDVersion version = VERSION_SD1) : decode_only(decoder_only), taesd(decode_only, version), GGMLRunner(backend) { taesd.init(params_ctx, tensor_types, prefix); } std::string get_desc() { return "taesd"; } bool load_from_file(const std::string& file_path) { LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false"); alloc_params_buffer(); std::map taesd_tensors; taesd.get_param_tensors(taesd_tensors); std::set ignore_tensors; if (decode_only) { ignore_tensors.insert("encoder."); } ModelLoader model_loader; if (!model_loader.init_from_file(file_path)) { LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str()); return false; } bool success = model_loader.load_tensors(taesd_tensors, backend, ignore_tensors); if (!success) { LOG_ERROR("load tae tensors from model loader failed"); return false; } LOG_INFO("taesd model loaded"); return success; } struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); z = to_backend(z); struct ggml_tensor* out = decode_graph ? taesd.decode(compute_ctx, z) : taesd.encode(compute_ctx, z); ggml_build_forward_expand(gf, out); return gf; } void compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, struct ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } }; #endif // __TAE_HPP__