# Copyright 2025 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch from torch import nn class SinusPositionEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x, scale=1000): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class TimestepEmbedding(nn.Module): def __init__(self, dim, freq_embed_dim=256): super().__init__() self.time_embed = SinusPositionEmbedding(freq_embed_dim) self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) def forward(self, timestep): # noqa: F821 time_hidden = self.time_embed(timestep) time_hidden = time_hidden.to(timestep.dtype) time = self.time_mlp(time_hidden) # b d return time