#ifndef __COMMON_HPP__ #define __COMMON_HPP__ #include "ggml_extend.hpp" class DownSampleBlock : public GGMLBlock { protected: int channels; int out_channels; bool vae_downsample; public: DownSampleBlock(int channels, int out_channels, bool vae_downsample = false) : channels(channels), out_channels(out_channels), vae_downsample(vae_downsample) { if (vae_downsample) { blocks["conv"] = std::shared_ptr(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0})); } else { blocks["op"] = std::shared_ptr(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1})); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, channels, h, w] if (vae_downsample) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); x = ggml_pad(ctx, x, 1, 1, 0, 0); x = conv->forward(ctx, x); } else { auto conv = std::dynamic_pointer_cast(blocks["op"]); x = conv->forward(ctx, x); } return x; // [N, out_channels, h/2, w/2] } }; class UpSampleBlock : public GGMLBlock { protected: int channels; int out_channels; public: UpSampleBlock(int channels, int out_channels) : channels(channels), out_channels(out_channels) { blocks["conv"] = std::shared_ptr(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1})); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, channels, h, w] auto conv = std::dynamic_pointer_cast(blocks["conv"]); x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2] x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2] return x; } }; class ResBlock : public GGMLBlock { protected: // network hparams int64_t channels; // model_channels * (1, 1, 1, 2, 2, 4, 4, 4) int64_t emb_channels; // time_embed_dim int64_t out_channels; // mult * model_channels std::pair kernel_size; int dims; bool skip_t_emb; bool exchange_temb_dims; std::shared_ptr conv_nd(int dims, int64_t in_channels, int64_t out_channels, std::pair kernel_size, std::pair padding) { GGML_ASSERT(dims == 2 || dims == 3); if (dims == 3) { return std::shared_ptr(new Conv3dnx1x1(in_channels, out_channels, kernel_size.first, 1, padding.first)); } else { return std::shared_ptr(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding)); } } public: ResBlock(int64_t channels, int64_t emb_channels, int64_t out_channels, std::pair kernel_size = {3, 3}, int dims = 2, bool exchange_temb_dims = false, bool skip_t_emb = false) : channels(channels), emb_channels(emb_channels), out_channels(out_channels), kernel_size(kernel_size), dims(dims), skip_t_emb(skip_t_emb), exchange_temb_dims(exchange_temb_dims) { std::pair padding = {kernel_size.first / 2, kernel_size.second / 2}; blocks["in_layers.0"] = std::shared_ptr(new GroupNorm32(channels)); // in_layer_1 is nn.SILU() blocks["in_layers.2"] = conv_nd(dims, channels, out_channels, kernel_size, padding); if (!skip_t_emb) { // emb_layer_0 is nn.SILU() blocks["emb_layers.1"] = std::shared_ptr(new Linear(emb_channels, out_channels)); } blocks["out_layers.0"] = std::shared_ptr(new GroupNorm32(out_channels)); // out_layer_1 is nn.SILU() // out_layer_2 is nn.Dropout(), skip for inference blocks["out_layers.3"] = conv_nd(dims, out_channels, out_channels, kernel_size, padding); if (out_channels != channels) { blocks["skip_connection"] = conv_nd(dims, channels, out_channels, {1, 1}, {0, 0}); } } virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = NULL) { // For dims==3, we reduce dimension from 5d to 4d by merging h and w, in order not to change ggml // [N, c, t, h, w] => [N, c, t, h * w] // x: [N, channels, h, w] if dims == 2 else [N, channels, t, h, w] // emb: [N, emb_channels] if dims == 2 else [N, t, emb_channels] auto in_layers_0 = std::dynamic_pointer_cast(blocks["in_layers.0"]); auto in_layers_2 = std::dynamic_pointer_cast(blocks["in_layers.2"]); auto out_layers_0 = std::dynamic_pointer_cast(blocks["out_layers.0"]); auto out_layers_3 = std::dynamic_pointer_cast(blocks["out_layers.3"]); if (emb == NULL) { GGML_ASSERT(skip_t_emb); } // in_layers auto h = in_layers_0->forward(ctx, x); h = ggml_silu_inplace(ctx, h); h = in_layers_2->forward(ctx, h); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] // emb_layers if (!skip_t_emb) { auto emb_layer_1 = std::dynamic_pointer_cast(blocks["emb_layers.1"]); auto emb_out = ggml_silu(ctx, emb); emb_out = emb_layer_1->forward(ctx, emb_out); // [N, out_channels] if dims == 2 else [N, t, out_channels] if (dims == 2) { emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1] } else { emb_out = ggml_reshape_4d(ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1] if (exchange_temb_dims) { // emb_out = rearrange(emb_out, "b t c ... -> b c t ...") emb_out = ggml_cont(ctx, ggml_permute(ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1] } } h = ggml_add(ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] } // out_layers h = out_layers_0->forward(ctx, h); h = ggml_silu_inplace(ctx, h); // dropout, skip for inference h = out_layers_3->forward(ctx, h); // skip connection if (out_channels != channels) { auto skip_connection = std::dynamic_pointer_cast(blocks["skip_connection"]); x = skip_connection->forward(ctx, x); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] } h = ggml_add(ctx, h, x); return h; // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] } }; class GEGLU : public GGMLBlock { protected: int64_t dim_in; int64_t dim_out; void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32; enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32; params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2); } public: GEGLU(int64_t dim_in, int64_t dim_out) : dim_in(dim_in), dim_out(dim_out) {} struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [ne3, ne2, ne1, dim_in] // return: [ne3, ne2, ne1, dim_out] struct ggml_tensor* w = params["proj.weight"]; struct ggml_tensor* b = params["proj.bias"]; auto x_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in] auto x_b = ggml_view_1d(ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in] auto gate_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ] auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ] auto x_in = x; x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out] auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out] gate = ggml_gelu_inplace(ctx, gate); x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, dim_out] return x; } }; class FeedForward : public GGMLBlock { public: FeedForward(int64_t dim, int64_t dim_out, int64_t mult = 4) { int64_t inner_dim = dim * mult; blocks["net.0"] = std::shared_ptr(new GEGLU(dim, inner_dim)); // net_1 is nn.Dropout(), skip for inference blocks["net.2"] = std::shared_ptr(new Linear(inner_dim, dim_out)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [ne3, ne2, ne1, dim] // return: [ne3, ne2, ne1, dim_out] auto net_0 = std::dynamic_pointer_cast(blocks["net.0"]); auto net_2 = std::dynamic_pointer_cast(blocks["net.2"]); x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim] x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out] return x; } }; class CrossAttention : public GGMLBlock { protected: int64_t query_dim; int64_t context_dim; int64_t n_head; int64_t d_head; bool flash_attn; public: CrossAttention(int64_t query_dim, int64_t context_dim, int64_t n_head, int64_t d_head, bool flash_attn = false) : n_head(n_head), d_head(d_head), query_dim(query_dim), context_dim(context_dim), flash_attn(flash_attn) { int64_t inner_dim = d_head * n_head; blocks["to_q"] = std::shared_ptr(new Linear(query_dim, inner_dim, false)); blocks["to_k"] = std::shared_ptr(new Linear(context_dim, inner_dim, false)); blocks["to_v"] = std::shared_ptr(new Linear(context_dim, inner_dim, false)); blocks["to_out.0"] = std::shared_ptr(new Linear(inner_dim, query_dim)); // to_out_1 is nn.Dropout(), skip for inference } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { // x: [N, n_token, query_dim] // context: [N, n_context, context_dim] // return: [N, n_token, query_dim] auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); auto to_k = std::dynamic_pointer_cast(blocks["to_k"]); auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); auto to_out_0 = std::dynamic_pointer_cast(blocks["to_out.0"]); int64_t n = x->ne[2]; int64_t n_token = x->ne[1]; int64_t n_context = context->ne[1]; int64_t inner_dim = d_head * n_head; auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim] auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] return x; } }; class BasicTransformerBlock : public GGMLBlock { protected: int64_t n_head; int64_t d_head; bool ff_in; public: BasicTransformerBlock(int64_t dim, int64_t n_head, int64_t d_head, int64_t context_dim, bool ff_in = false, bool flash_attn = false) : n_head(n_head), d_head(d_head), ff_in(ff_in) { // disable_self_attn is always False // disable_temporal_crossattention is always False // switch_temporal_ca_to_sa is always False // inner_dim is always None or equal to dim // gated_ff is always True blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head, flash_attn)); blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn)); blocks["ff"] = std::shared_ptr(new FeedForward(dim, dim)); blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); blocks["norm3"] = std::shared_ptr(new LayerNorm(dim)); if (ff_in) { blocks["norm_in"] = std::shared_ptr(new LayerNorm(dim)); blocks["ff_in"] = std::shared_ptr(new FeedForward(dim, dim)); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { // x: [N, n_token, query_dim] // context: [N, n_context, context_dim] // return: [N, n_token, query_dim] auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); auto ff = std::dynamic_pointer_cast(blocks["ff"]); auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); auto norm3 = std::dynamic_pointer_cast(blocks["norm3"]); if (ff_in) { auto norm_in = std::dynamic_pointer_cast(blocks["norm_in"]); auto ff_in = std::dynamic_pointer_cast(blocks["ff_in"]); auto x_skip = x; x = norm_in->forward(ctx, x); x = ff_in->forward(ctx, x); // self.is_res is always True x = ggml_add(ctx, x, x_skip); } auto r = x; x = norm1->forward(ctx, x); x = attn1->forward(ctx, x, x); // self-attention x = ggml_add(ctx, x, r); r = x; x = norm2->forward(ctx, x); x = attn2->forward(ctx, x, context); // cross-attention x = ggml_add(ctx, x, r); r = x; x = norm3->forward(ctx, x); x = ff->forward(ctx, x); x = ggml_add(ctx, x, r); return x; } }; class SpatialTransformer : public GGMLBlock { protected: int64_t in_channels; // mult * model_channels int64_t n_head; int64_t d_head; int64_t depth = 1; // 1 int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2 public: SpatialTransformer(int64_t in_channels, int64_t n_head, int64_t d_head, int64_t depth, int64_t context_dim, bool flash_attn = false) : in_channels(in_channels), n_head(n_head), d_head(d_head), depth(depth), context_dim(context_dim) { // We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False // disable_self_attn is always False int64_t inner_dim = n_head * d_head; // in_channels blocks["norm"] = std::shared_ptr(new GroupNorm32(in_channels)); blocks["proj_in"] = std::shared_ptr(new Conv2d(in_channels, inner_dim, {1, 1})); for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn)); } blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); } virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { // x: [N, in_channels, h, w] // context: [N, max_position(aka n_token), hidden_size(aka context_dim)] auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto proj_in = std::dynamic_pointer_cast(blocks["proj_in"]); auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); auto x_in = x; int64_t n = x->ne[3]; int64_t h = x->ne[1]; int64_t w = x->ne[0]; int64_t inner_dim = n_head * d_head; x = norm->forward(ctx, x); x = proj_in->forward(ctx, x); // [N, inner_dim, h, w] x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); auto transformer_block = std::dynamic_pointer_cast(blocks[name]); x = transformer_block->forward(ctx, x, context); } x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w] // proj_out x = proj_out->forward(ctx, x); // [N, in_channels, h, w] x = ggml_add(ctx, x, x_in); return x; } }; class AlphaBlender : public GGMLBlock { protected: void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { // Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32; params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } float get_alpha() { // image_only_indicator is always tensor([0.]) and since mix_factor.shape is [1,] // so learned_with_images is same as learned float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]); return sigmoid(alpha); } public: AlphaBlender() { // merge_strategy is always learned_with_images // for inference, we don't need to set alpha // since mix_factor.shape is [1,], we don't need rearrange using rearrange_pattern } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x_spatial, struct ggml_tensor* x_temporal) { // image_only_indicator is always tensor([0.]) float alpha = get_alpha(); auto x = ggml_add(ctx, ggml_scale(ctx, x_spatial, alpha), ggml_scale(ctx, x_temporal, 1.0f - alpha)); return x; } }; class VideoResBlock : public ResBlock { public: VideoResBlock(int channels, int emb_channels, int out_channels, std::pair kernel_size = {3, 3}, int64_t video_kernel_size = 3, int dims = 2) // always 2 : ResBlock(channels, emb_channels, out_channels, kernel_size, dims) { blocks["time_stack"] = std::shared_ptr(new ResBlock(out_channels, emb_channels, out_channels, kernel_size, 3, true)); blocks["time_mixer"] = std::shared_ptr(new AlphaBlender()); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb, int num_video_frames) { // x: [N, channels, h, w] aka [b*t, channels, h, w] // emb: [N, emb_channels] aka [b*t, emb_channels] // image_only_indicator is always tensor([0.]) auto time_stack = std::dynamic_pointer_cast(blocks["time_stack"]); auto time_mixer = std::dynamic_pointer_cast(blocks["time_mixer"]); x = ResBlock::forward(ctx, x, emb); int64_t T = num_video_frames; int64_t B = x->ne[3] / T; int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) auto x_mix = x; emb = ggml_reshape_4d(ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ... x = time_stack->forward(ctx, x, emb); // b t c (h w) x = time_mixer->forward(ctx, x_mix, x); // b t c (h w) x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w return x; } }; #endif // __COMMON_HPP__