#ifndef __MMDIT_HPP__ #define __MMDIT_HPP__ #include "ggml_extend.hpp" #include "model.h" #define MMDIT_GRAPH_SIZE 10240 struct Mlp : public GGMLBlock { public: Mlp(int64_t in_features, int64_t hidden_features = -1, int64_t out_features = -1, bool bias = true) { // act_layer is always lambda: nn.GELU(approximate="tanh") // norm_layer is always None // use_conv is always False if (hidden_features == -1) { hidden_features = in_features; } if (out_features == -1) { out_features = in_features; } blocks["fc1"] = std::shared_ptr(new Linear(in_features, hidden_features, bias)); blocks["fc2"] = std::shared_ptr(new Linear(hidden_features, out_features, bias)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, n_token, in_features] auto fc1 = std::dynamic_pointer_cast(blocks["fc1"]); auto fc2 = std::dynamic_pointer_cast(blocks["fc2"]); x = fc1->forward(ctx, x); x = ggml_gelu_inplace(ctx, x); x = fc2->forward(ctx, x); return x; } }; struct PatchEmbed : public GGMLBlock { // 2D Image to Patch Embedding protected: bool flatten; bool dynamic_img_pad; int patch_size; public: PatchEmbed(int64_t img_size = 224, int patch_size = 16, int64_t in_chans = 3, int64_t embed_dim = 1536, bool bias = true, bool flatten = true, bool dynamic_img_pad = true) : patch_size(patch_size), flatten(flatten), dynamic_img_pad(dynamic_img_pad) { // img_size is always None // patch_size is always 2 // in_chans is always 16 // norm_layer is always False // strict_img_size is always true, but not used blocks["proj"] = std::shared_ptr(new Conv2d(in_chans, embed_dim, {patch_size, patch_size}, {patch_size, patch_size}, {0, 0}, {1, 1}, bias)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, C, H, W] // return: [N, H*W, embed_dim] auto proj = std::dynamic_pointer_cast(blocks["proj"]); if (dynamic_img_pad) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode } x = proj->forward(ctx, x); if (flatten) { x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]); x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); } return x; } }; struct TimestepEmbedder : public GGMLBlock { // Embeds scalar timesteps into vector representations. protected: int64_t frequency_embedding_size; public: TimestepEmbedder(int64_t hidden_size, int64_t frequency_embedding_size = 256) : frequency_embedding_size(frequency_embedding_size) { blocks["mlp.0"] = std::shared_ptr(new Linear(frequency_embedding_size, hidden_size, true, true)); blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) { // t: [N, ] // return: [N, hidden_size] auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); auto t_freq = ggml_nn_timestep_embedding(ctx, t, frequency_embedding_size); // [N, frequency_embedding_size] auto t_emb = mlp_0->forward(ctx, t_freq); t_emb = ggml_silu_inplace(ctx, t_emb); t_emb = mlp_2->forward(ctx, t_emb); return t_emb; } }; struct VectorEmbedder : public GGMLBlock { // Embeds a flat vector of dimension input_dim public: VectorEmbedder(int64_t input_dim, int64_t hidden_size) { blocks["mlp.0"] = std::shared_ptr(new Linear(input_dim, hidden_size, true, true)); blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, input_dim] // return: [N, hidden_size] auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); x = mlp_0->forward(ctx, x); x = ggml_silu_inplace(ctx, x); x = mlp_2->forward(ctx, x); return x; } }; class RMSNorm : public UnaryBlock { protected: int64_t hidden_size; float eps; void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } public: RMSNorm(int64_t hidden_size, float eps = 1e-06f) : hidden_size(hidden_size), eps(eps) {} struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; x = ggml_rms_norm(ctx, x, eps); x = ggml_mul(ctx, x, w); return x; } }; class SelfAttention : public GGMLBlock { public: int64_t num_heads; bool pre_only; std::string qk_norm; public: SelfAttention(int64_t dim, int64_t num_heads = 8, std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false) : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) { int64_t d_head = dim / num_heads; blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); if (!pre_only) { blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); } if (qk_norm == "rms") { blocks["ln_q"] = std::shared_ptr(new RMSNorm(d_head, 1.0e-6)); blocks["ln_k"] = std::shared_ptr(new RMSNorm(d_head, 1.0e-6)); } else if (qk_norm == "ln") { blocks["ln_q"] = std::shared_ptr(new LayerNorm(d_head, 1.0e-6)); blocks["ln_k"] = std::shared_ptr(new LayerNorm(d_head, 1.0e-6)); } } std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); auto qkv = qkv_proj->forward(ctx, x); auto qkv_vec = split_qkv(ctx, qkv); int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] auto v = qkv_vec[2]; // [N, n_token, n_head*d_head] if (qk_norm == "rms" || qk_norm == "ln") { auto ln_q = std::dynamic_pointer_cast(blocks["ln_q"]); auto ln_k = std::dynamic_pointer_cast(blocks["ln_k"]); q = ln_q->forward(ctx, q); k = ln_k->forward(ctx, k); } q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head] k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head] return {q, k, v}; } struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { GGML_ASSERT(!pre_only); auto proj = std::dynamic_pointer_cast(blocks["proj"]); x = proj->forward(ctx, x); // [N, n_token, dim] return x; } // x: [N, n_token, dim] struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { auto qkv = pre_attention(ctx, x); x = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* shift, struct ggml_tensor* scale) { // x: [N, L, C] // scale: [N, C] // shift: [N, C] scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); x = ggml_add(ctx, x, shift); return x; } struct DismantledBlock : public GGMLBlock { // A DiT block with gated adaptive layer norm (adaLN) conditioning. public: int64_t num_heads; bool pre_only; bool self_attn; public: DismantledBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false, bool self_attn = false) : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); if (self_attn) { blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); } if (!pre_only) { blocks["norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio); blocks["mlp"] = std::shared_ptr(new Mlp(hidden_size, mlp_hidden_dim)); } int64_t n_mods = 6; if (pre_only) { n_mods = 2; } if (self_attn) { n_mods = 9; } blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, n_mods * hidden_size)); } std::tuple, std::vector, std::vector> pre_attention_x(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { GGML_ASSERT(self_attn); // x: [N, n_token, hidden_size] // c: [N, hidden_size] auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); int64_t n_mods = 9; auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size] m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] int64_t offset = m->nb[1] * m->ne[1]; auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size] auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size] auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size] auto x_norm = norm1->forward(ctx, x); auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2); auto qkv2 = attn2->pre_attention(ctx, attn2_in); return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}}; } std::pair, std::vector> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); int64_t n_mods = 6; if (pre_only) { n_mods = 2; } auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size] m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] int64_t offset = m->nb[1] * m->ne[1]; auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] if (!pre_only) { auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}}; } else { auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); return {qkv, {NULL, NULL, NULL, NULL, NULL}}; } } struct ggml_tensor* post_attention_x(struct ggml_context* ctx, struct ggml_tensor* attn_out, struct ggml_tensor* attn2_out, struct ggml_tensor* x, struct ggml_tensor* gate_msa, struct ggml_tensor* shift_mlp, struct ggml_tensor* scale_mlp, struct ggml_tensor* gate_mlp, struct ggml_tensor* gate_msa2) { // attn_out: [N, n_token, hidden_size] // x: [N, n_token, hidden_size] // gate_msa: [N, hidden_size] // shift_mlp: [N, hidden_size] // scale_mlp: [N, hidden_size] // gate_mlp: [N, hidden_size] // return: [N, n_token, hidden_size] GGML_ASSERT(!pre_only); auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size] attn_out = attn->post_attention(ctx, attn_out); attn2_out = attn2->post_attention(ctx, attn2_out); x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2)); auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); return x; } struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* attn_out, struct ggml_tensor* x, struct ggml_tensor* gate_msa, struct ggml_tensor* shift_mlp, struct ggml_tensor* scale_mlp, struct ggml_tensor* gate_mlp) { // attn_out: [N, n_token, hidden_size] // x: [N, n_token, hidden_size] // gate_msa: [N, hidden_size] // shift_mlp: [N, hidden_size] // scale_mlp: [N, hidden_size] // gate_mlp: [N, hidden_size] // return: [N, n_token, hidden_size] GGML_ASSERT(!pre_only); auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] attn_out = attn->post_attention(ctx, attn_out); x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); return x; } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, hidden_size] auto attn = std::dynamic_pointer_cast(blocks["attn"]); if (self_attn) { auto qkv_intermediates = pre_attention_x(ctx, x, c); // auto qkv = qkv_intermediates.first; // auto intermediates = qkv_intermediates.second; // no longer a pair, but a tuple auto qkv = std::get<0>(qkv_intermediates); auto qkv2 = std::get<1>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates); auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] x = post_attention_x(ctx, attn_out, attn2_out, intermediates[0], intermediates[1], intermediates[2], intermediates[3], intermediates[4], intermediates[5]); return x; // [N, n_token, dim] } else { auto qkv_intermediates = pre_attention(ctx, x, c); auto qkv = qkv_intermediates.first; auto intermediates = qkv_intermediates.second; auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] x = post_attention(ctx, attn_out, intermediates[0], intermediates[1], intermediates[2], intermediates[3], intermediates[4]); return x; // [N, n_token, dim] } } }; __STATIC_INLINE__ std::pair block_mixing(struct ggml_context* ctx, struct ggml_tensor* context, struct ggml_tensor* x, struct ggml_tensor* c, std::shared_ptr context_block, std::shared_ptr x_block) { // context: [N, n_context, hidden_size] // x: [N, n_token, hidden_size] // c: [N, hidden_size] auto context_qkv_intermediates = context_block->pre_attention(ctx, context, c); auto context_qkv = context_qkv_intermediates.first; auto context_intermediates = context_qkv_intermediates.second; std::vector x_qkv, x_qkv2, x_intermediates; if (x_block->self_attn) { auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c); x_qkv = std::get<0>(x_qkv_intermediates); x_qkv2 = std::get<1>(x_qkv_intermediates); x_intermediates = std::get<2>(x_qkv_intermediates); } else { auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); x_qkv = x_qkv_intermediates.first; x_intermediates = x_qkv_intermediates.second; } std::vector qkv; for (int i = 0; i < 3; i++) { qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); } auto attn = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] auto context_attn = ggml_view_3d(ctx, attn, attn->ne[0], attn->ne[1], context->ne[1], attn->nb[1], attn->nb[2], 0); // [n_context, N, hidden_size] context_attn = ggml_cont(ctx, ggml_permute(ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size] auto x_attn = ggml_view_3d(ctx, attn, attn->ne[0], attn->ne[1], x->ne[1], attn->nb[1], attn->nb[2], attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size] x_attn = ggml_cont(ctx, ggml_permute(ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size] if (!context_block->pre_only) { context = context_block->post_attention(ctx, context_attn, context_intermediates[0], context_intermediates[1], context_intermediates[2], context_intermediates[3], context_intermediates[4]); } else { context = NULL; } if (x_block->self_attn) { auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] x = x_block->post_attention_x(ctx, x_attn, attn2, x_intermediates[0], x_intermediates[1], x_intermediates[2], x_intermediates[3], x_intermediates[4], x_intermediates[5]); } else { x = x_block->post_attention(ctx, x_attn, x_intermediates[0], x_intermediates[1], x_intermediates[2], x_intermediates[3], x_intermediates[4]); } return {context, x}; } struct JointBlock : public GGMLBlock { public: JointBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false, bool self_attn_x = false) { blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); } std::pair forward(struct ggml_context* ctx, struct ggml_tensor* context, struct ggml_tensor* x, struct ggml_tensor* c) { auto context_block = std::dynamic_pointer_cast(blocks["context_block"]); auto x_block = std::dynamic_pointer_cast(blocks["x_block"]); return block_mixing(ctx, context, x, c, context_block, x_block); } }; struct FinalLayer : public GGMLBlock { // The final layer of DiT. public: FinalLayer(int64_t hidden_size, int64_t patch_size, int64_t out_channels) { // total_out_channels is always None blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true)); blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); auto linear = std::dynamic_pointer_cast(blocks["linear"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] int64_t offset = m->nb[1] * m->ne[1]; auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] x = modulate(ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); return x; } }; struct MMDiT : public GGMLBlock { // Diffusion model with a Transformer backbone. protected: int64_t input_size = -1; int64_t patch_size = 2; int64_t in_channels = 16; int64_t d_self = -1; // >=0 for MMdiT-X int64_t depth = 24; float mlp_ratio = 4.0f; int64_t adm_in_channels = 2048; int64_t out_channels = 16; int64_t pos_embed_max_size = 192; int64_t num_patchs = 36864; // 192 * 192 int64_t context_size = 4096; int64_t context_embedder_out_dim = 1536; int64_t hidden_size; std::string qk_norm; void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "pos_embed") != tensor_types.end()) ? tensor_types[prefix + "pos_embed"] : GGML_TYPE_F32; params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1); } public: MMDiT(std::map& tensor_types) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 // rmsnorm is alwalys False // scale_mod_only is alwalys False // swiglu is alwalys False // qkv_bias is always True // context_processor_layers is always None // pos_embed_scaling_factor is not used // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} // read tensors from tensor_types for (auto pair : tensor_types) { std::string tensor_name = pair.first; if (tensor_name.find("model.diffusion_model.") == std::string::npos) continue; size_t jb = tensor_name.find("joint_blocks."); if (jb != std::string::npos) { tensor_name = tensor_name.substr(jb); // remove prefix int block_depth = atoi(tensor_name.substr(13, tensor_name.find(".", 13)).c_str()); if (block_depth + 1 > depth) { depth = block_depth + 1; } if (tensor_name.find("attn.ln") != std::string::npos) { if (tensor_name.find(".bias") != std::string::npos) { qk_norm = "ln"; } else { qk_norm = "rms"; } } if (tensor_name.find("attn2") != std::string::npos) { if (block_depth > d_self) { d_self = block_depth; } } } } if (d_self >= 0) { pos_embed_max_size *= 2; num_patchs *= 4; } LOG_INFO("MMDiT layers: %d (including %d MMDiT-x layers)", depth, d_self + 1); int64_t default_out_channels = in_channels; hidden_size = 64 * depth; context_embedder_out_dim = 64 * depth; int64_t num_heads = depth; blocks["x_embedder"] = std::shared_ptr(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true)); blocks["t_embedder"] = std::shared_ptr(new TimestepEmbedder(hidden_size)); if (adm_in_channels != -1) { blocks["y_embedder"] = std::shared_ptr(new VectorEmbedder(adm_in_channels, hidden_size)); } blocks["context_embedder"] = std::shared_ptr(new Linear(4096, context_embedder_out_dim, true, true)); for (int i = 0; i < depth; i++) { blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr(new JointBlock(hidden_size, num_heads, mlp_ratio, qk_norm, true, i == depth - 1, i <= d_self)); } blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels)); } struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx, int64_t h, int64_t w) { auto pos_embed = params["pos_embed"]; h = (h + 1) / patch_size; w = (w + 1) / patch_size; GGML_ASSERT(h <= pos_embed_max_size && h > 0); GGML_ASSERT(w <= pos_embed_max_size && w > 0); int64_t top = (pos_embed_max_size - h) / 2; int64_t left = (pos_embed_max_size - w) / 2; auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, hidden_size, pos_embed_max_size, pos_embed_max_size); // spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] spatial_pos_embed = ggml_view_3d(ctx, spatial_pos_embed, hidden_size, pos_embed_max_size, h, spatial_pos_embed->nb[1], spatial_pos_embed->nb[2], spatial_pos_embed->nb[2] * top); // [h, pos_embed_max_size, hidden_size] spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [pos_embed_max_size, h, hidden_size] spatial_pos_embed = ggml_view_3d(ctx, spatial_pos_embed, hidden_size, h, w, spatial_pos_embed->nb[1], spatial_pos_embed->nb[2], spatial_pos_embed->nb[2] * left); // [w, h, hidden_size] spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [h, w, hidden_size] spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, hidden_size, h * w, 1); // [1, h*w, hidden_size] return spatial_pos_embed; } struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* x, int64_t h, int64_t w) { // x: [N, H*W, patch_size * patch_size * C] // return: [N, C, H, W] int64_t n = x->ne[2]; int64_t c = out_channels; int64_t p = patch_size; h = (h + 1) / p; w = (w + 1) / p; GGML_ASSERT(h * w == x->ne[1]); x = ggml_reshape_4d(ctx, x, c, p * p, w * h, n); // [N, H*W, P*P, C] x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, H*W, P*P] x = ggml_reshape_4d(ctx, x, p, p, w, h * c * n); // [N*C*H, W, P, P] x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*H, P, W, P] x = ggml_reshape_4d(ctx, x, p * w, p * h, c, n); // [N, C, H*P, W*P] return x; } struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c_mod, struct ggml_tensor* context, std::vector skip_layers = std::vector()) { // x: [N, H*W, hidden_size] // context: [N, n_context, d_context] // c: [N, hidden_size] // return: [N, N*W, patch_size * patch_size * out_channels] auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); for (int i = 0; i < depth; i++) { // skip iteration if i is in skip_layers if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { continue; } auto block = std::dynamic_pointer_cast(blocks["joint_blocks." + std::to_string(i)]); auto context_x = block->forward(ctx, context, x, c_mod); context = context_x.first; x = context_x.second; } x = final_layer->forward(ctx, x, c_mod); // (N, T, patch_size ** 2 * out_channels) return x; } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* t, struct ggml_tensor* y = NULL, struct ggml_tensor* context = NULL, std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // t: (N,) tensor of diffusion timesteps // y: (N, adm_in_channels) tensor of class labels // context: (N, L, D) // return: (N, C, H, W) auto x_embedder = std::dynamic_pointer_cast(blocks["x_embedder"]); auto t_embedder = std::dynamic_pointer_cast(blocks["t_embedder"]); int64_t w = x->ne[0]; int64_t h = x->ne[1]; auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size] auto pos_embed = cropped_pos_embed(ctx, h, w); // [1, H*W, hidden_size] x = ggml_add(ctx, patch_embed, pos_embed); // [N, H*W, hidden_size] auto c = t_embedder->forward(ctx, t); // [N, hidden_size] if (y != NULL && adm_in_channels != -1) { auto y_embedder = std::dynamic_pointer_cast(blocks["y_embedder"]); y = y_embedder->forward(ctx, y); // [N, hidden_size] c = ggml_add(ctx, c, y); } if (context != NULL) { auto context_embedder = std::dynamic_pointer_cast(blocks["context_embedder"]); context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] } x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) x = unpatchify(ctx, x, h, w); // [N, C, H, W] return x; } }; struct MMDiTRunner : public GGMLRunner { MMDiT mmdit; static std::map empty_tensor_types; MMDiTRunner(ggml_backend_t backend, std::map& tensor_types = empty_tensor_types, const std::string prefix = "") : GGMLRunner(backend), mmdit(tensor_types) { mmdit.init(params_ctx, tensor_types, prefix); } std::string get_desc() { return "mmdit"; } void get_param_tensors(std::map& tensors, const std::string prefix) { mmdit.get_param_tensors(tensors, prefix); } struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* y, std::vector skip_layers = std::vector()) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false); x = to_backend(x); context = to_backend(context); y = to_backend(y); timesteps = to_backend(timesteps); struct ggml_tensor* out = mmdit.forward(compute_ctx, x, timesteps, y, context, skip_layers); ggml_build_forward_expand(gf, out); return gf; } void compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* y, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(x, timesteps, context, y, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB params.mem_buffer = NULL; params.no_alloc = false; struct ggml_context* work_ctx = ggml_init(params); GGML_ASSERT(work_ctx != NULL); { // cpu f16: pass // cpu f32: pass // cuda f16: pass // cuda f32: pass auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 128, 128, 16, 1); std::vector timesteps_vec(1, 999.f); auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); ggml_set_f32(x, 0.01f); // print_ggml_tensor(x); auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 154, 1); ggml_set_f32(context, 0.01f); // print_ggml_tensor(context); auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 2048, 1); ggml_set_f32(y, 0.01f); // print_ggml_tensor(y); struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); compute(8, x, timesteps, context, y, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); LOG_DEBUG("mmdit test done in %dms", t1 - t0); } } static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; std::shared_ptr mmdit = std::shared_ptr(new MMDiTRunner(backend)); { LOG_INFO("loading from '%s'", file_path.c_str()); mmdit->alloc_params_buffer(); std::map tensors; mmdit->get_param_tensors(tensors, "model.diffusion_model"); ModelLoader model_loader; if (!model_loader.init_from_file(file_path)) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } bool success = model_loader.load_tensors(tensors, backend); if (!success) { LOG_ERROR("load tensors from model loader failed"); return; } LOG_INFO("mmdit model loaded"); } mmdit->test(); } }; #endif