|
| 1 | +import math |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | +from comfy.ldm.modules.attention import optimized_attention |
| 7 | +import comfy.model_management |
| 8 | + |
| 9 | +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: |
| 10 | + assert dim % 2 == 0 |
| 11 | + if not comfy.model_management.supports_fp64(pos.device): |
| 12 | + device = torch.device("cpu") |
| 13 | + else: |
| 14 | + device = pos.device |
| 15 | + |
| 16 | + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim |
| 17 | + omega = 1.0 / (theta**scale) |
| 18 | + out = torch.einsum("...n,d->...nd", pos, omega) |
| 19 | + out = torch.stack([torch.cos(out), torch.sin(out)], dim=0) |
| 20 | + return out.to(dtype=torch.float32, device=pos.device) |
| 21 | + |
| 22 | +def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
| 23 | + rot_dim = freqs_cis.shape[-1] |
| 24 | + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] |
| 25 | + cos_ = freqs_cis[0] |
| 26 | + sin_ = freqs_cis[1] |
| 27 | + x1, x2 = x.chunk(2, dim=-1) |
| 28 | + x_rotated = torch.cat((-x2, x1), dim=-1) |
| 29 | + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) |
| 30 | + |
| 31 | +class ErnieImageEmbedND3(nn.Module): |
| 32 | + def __init__(self, dim: int, theta: int, axes_dim: tuple): |
| 33 | + super().__init__() |
| 34 | + self.dim = dim |
| 35 | + self.theta = theta |
| 36 | + self.axes_dim = list(axes_dim) |
| 37 | + |
| 38 | + def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| 39 | + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) |
| 40 | + emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2] |
| 41 | + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] |
| 42 | + |
| 43 | +class ErnieImagePatchEmbedDynamic(nn.Module): |
| 44 | + def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None): |
| 45 | + super().__init__() |
| 46 | + self.patch_size = patch_size |
| 47 | + self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype) |
| 48 | + |
| 49 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 50 | + x = self.proj(x) |
| 51 | + batch_size, dim, height, width = x.shape |
| 52 | + return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous() |
| 53 | + |
| 54 | +class Timesteps(nn.Module): |
| 55 | + def __init__(self, num_channels: int, flip_sin_to_cos: bool = False): |
| 56 | + super().__init__() |
| 57 | + self.num_channels = num_channels |
| 58 | + self.flip_sin_to_cos = flip_sin_to_cos |
| 59 | + |
| 60 | + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| 61 | + half_dim = self.num_channels // 2 |
| 62 | + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim |
| 63 | + emb = torch.exp(exponent) |
| 64 | + emb = timesteps[:, None].float() * emb[None, :] |
| 65 | + if self.flip_sin_to_cos: |
| 66 | + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) |
| 67 | + else: |
| 68 | + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
| 69 | + return emb |
| 70 | + |
| 71 | +class TimestepEmbedding(nn.Module): |
| 72 | + def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None): |
| 73 | + super().__init__() |
| 74 | + Linear = operations.Linear |
| 75 | + self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype) |
| 76 | + self.act = nn.SiLU() |
| 77 | + self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype) |
| 78 | + |
| 79 | + def forward(self, sample: torch.Tensor) -> torch.Tensor: |
| 80 | + sample = self.linear_1(sample) |
| 81 | + sample = self.act(sample) |
| 82 | + sample = self.linear_2(sample) |
| 83 | + return sample |
| 84 | + |
| 85 | +class ErnieImageAttention(nn.Module): |
| 86 | + def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None): |
| 87 | + super().__init__() |
| 88 | + self.heads = heads |
| 89 | + self.head_dim = dim_head |
| 90 | + self.inner_dim = heads * dim_head |
| 91 | + |
| 92 | + Linear = operations.Linear |
| 93 | + RMSNorm = operations.RMSNorm |
| 94 | + |
| 95 | + self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) |
| 96 | + self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) |
| 97 | + self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) |
| 98 | + |
| 99 | + self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype) |
| 100 | + self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype) |
| 101 | + |
| 102 | + self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)]) |
| 103 | + |
| 104 | + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor: |
| 105 | + B, S, _ = x.shape |
| 106 | + |
| 107 | + q_flat = self.to_q(x) |
| 108 | + k_flat = self.to_k(x) |
| 109 | + v_flat = self.to_v(x) |
| 110 | + |
| 111 | + query = q_flat.view(B, S, self.heads, self.head_dim) |
| 112 | + key = k_flat.view(B, S, self.heads, self.head_dim) |
| 113 | + |
| 114 | + query = self.norm_q(query) |
| 115 | + key = self.norm_k(key) |
| 116 | + |
| 117 | + if image_rotary_emb is not None: |
| 118 | + query = apply_rotary_emb(query, image_rotary_emb) |
| 119 | + key = apply_rotary_emb(key, image_rotary_emb) |
| 120 | + |
| 121 | + query, key = query.to(x.dtype), key.to(x.dtype) |
| 122 | + |
| 123 | + q_flat = query.reshape(B, S, -1) |
| 124 | + k_flat = key.reshape(B, S, -1) |
| 125 | + |
| 126 | + hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask) |
| 127 | + |
| 128 | + return self.to_out[0](hidden_states) |
| 129 | + |
| 130 | +class ErnieImageFeedForward(nn.Module): |
| 131 | + def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None): |
| 132 | + super().__init__() |
| 133 | + Linear = operations.Linear |
| 134 | + self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype) |
| 135 | + self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype) |
| 136 | + self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype) |
| 137 | + |
| 138 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 139 | + return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) |
| 140 | + |
| 141 | +class ErnieImageSharedAdaLNBlock(nn.Module): |
| 142 | + def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): |
| 143 | + super().__init__() |
| 144 | + RMSNorm = operations.RMSNorm |
| 145 | + |
| 146 | + self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype) |
| 147 | + self.self_attention = ErnieImageAttention( |
| 148 | + query_dim=hidden_size, |
| 149 | + dim_head=hidden_size // num_heads, |
| 150 | + heads=num_heads, |
| 151 | + eps=eps, |
| 152 | + operations=operations, |
| 153 | + device=device, |
| 154 | + dtype=dtype |
| 155 | + ) |
| 156 | + self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype) |
| 157 | + self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype) |
| 158 | + |
| 159 | + def forward(self, x, rotary_pos_emb, temb, attention_mask=None): |
| 160 | + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb |
| 161 | + |
| 162 | + residual = x |
| 163 | + x_norm = self.adaLN_sa_ln(x) |
| 164 | + x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) |
| 165 | + |
| 166 | + attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) |
| 167 | + x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) |
| 168 | + |
| 169 | + residual = x |
| 170 | + x_norm = self.adaLN_mlp_ln(x) |
| 171 | + x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) |
| 172 | + |
| 173 | + return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype) |
| 174 | + |
| 175 | +class ErnieImageAdaLNContinuous(nn.Module): |
| 176 | + def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): |
| 177 | + super().__init__() |
| 178 | + LayerNorm = operations.LayerNorm |
| 179 | + Linear = operations.Linear |
| 180 | + self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype) |
| 181 | + self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype) |
| 182 | + |
| 183 | + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: |
| 184 | + scale, shift = self.linear(conditioning).chunk(2, dim=-1) |
| 185 | + x = self.norm(x) |
| 186 | + x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| 187 | + return x |
| 188 | + |
| 189 | +class ErnieImageModel(nn.Module): |
| 190 | + def __init__( |
| 191 | + self, |
| 192 | + hidden_size: int = 4096, |
| 193 | + num_attention_heads: int = 32, |
| 194 | + num_layers: int = 36, |
| 195 | + ffn_hidden_size: int = 12288, |
| 196 | + in_channels: int = 128, |
| 197 | + out_channels: int = 128, |
| 198 | + patch_size: int = 1, |
| 199 | + text_in_dim: int = 3072, |
| 200 | + rope_theta: int = 256, |
| 201 | + rope_axes_dim: tuple = (32, 48, 48), |
| 202 | + eps: float = 1e-6, |
| 203 | + qk_layernorm: bool = True, |
| 204 | + device=None, |
| 205 | + dtype=None, |
| 206 | + operations=None, |
| 207 | + **kwargs |
| 208 | + ): |
| 209 | + super().__init__() |
| 210 | + self.dtype = dtype |
| 211 | + self.hidden_size = hidden_size |
| 212 | + self.num_heads = num_attention_heads |
| 213 | + self.head_dim = hidden_size // num_attention_heads |
| 214 | + self.patch_size = patch_size |
| 215 | + self.out_channels = out_channels |
| 216 | + |
| 217 | + Linear = operations.Linear |
| 218 | + |
| 219 | + self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype) |
| 220 | + self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None |
| 221 | + |
| 222 | + self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False) |
| 223 | + self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype) |
| 224 | + |
| 225 | + self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) |
| 226 | + |
| 227 | + self.adaLN_modulation = nn.Sequential( |
| 228 | + nn.SiLU(), |
| 229 | + Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype) |
| 230 | + ) |
| 231 | + |
| 232 | + self.layers = nn.ModuleList([ |
| 233 | + ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype) |
| 234 | + for _ in range(num_layers) |
| 235 | + ]) |
| 236 | + |
| 237 | + self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype) |
| 238 | + self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype) |
| 239 | + |
| 240 | + def forward(self, x, timesteps, context, **kwargs): |
| 241 | + device, dtype = x.device, x.dtype |
| 242 | + B, C, H, W = x.shape |
| 243 | + p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size |
| 244 | + N_img = Hp * Wp |
| 245 | + |
| 246 | + img_bsh = self.x_embedder(x) |
| 247 | + |
| 248 | + text_bth = context |
| 249 | + if self.text_proj is not None and text_bth.numel() > 0: |
| 250 | + text_bth = self.text_proj(text_bth) |
| 251 | + Tmax = text_bth.shape[1] |
| 252 | + |
| 253 | + hidden_states = torch.cat([img_bsh, text_bth], dim=1) |
| 254 | + |
| 255 | + text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32) |
| 256 | + text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32) |
| 257 | + index = float(Tmax) |
| 258 | + |
| 259 | + transformer_options = kwargs.get("transformer_options", {}) |
| 260 | + rope_options = transformer_options.get("rope_options", None) |
| 261 | + |
| 262 | + h_len, w_len = float(Hp), float(Wp) |
| 263 | + h_offset, w_offset = 0.0, 0.0 |
| 264 | + |
| 265 | + if rope_options is not None: |
| 266 | + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 |
| 267 | + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 |
| 268 | + index += rope_options.get("shift_t", 0.0) |
| 269 | + h_offset += rope_options.get("shift_y", 0.0) |
| 270 | + w_offset += rope_options.get("shift_x", 0.0) |
| 271 | + |
| 272 | + image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32) |
| 273 | + image_ids[:, :, 0] = image_ids[:, :, 1] + index |
| 274 | + image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1) |
| 275 | + image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0) |
| 276 | + |
| 277 | + image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1) |
| 278 | + |
| 279 | + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype) |
| 280 | + del image_ids, text_ids |
| 281 | + |
| 282 | + sample = self.time_proj(timesteps.to(dtype)).to(self.time_embedding.linear_1.weight.dtype) |
| 283 | + c = self.time_embedding(sample) |
| 284 | + |
| 285 | + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ |
| 286 | + t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1) |
| 287 | + ] |
| 288 | + |
| 289 | + temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] |
| 290 | + for layer in self.layers: |
| 291 | + hidden_states = layer(hidden_states, rotary_pos_emb, temb) |
| 292 | + |
| 293 | + hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states) |
| 294 | + |
| 295 | + patches = self.final_linear(hidden_states)[:, :N_img, :] |
| 296 | + output = ( |
| 297 | + patches.view(B, Hp, Wp, p, p, self.out_channels) |
| 298 | + .permute(0, 5, 1, 3, 2, 4) |
| 299 | + .contiguous() |
| 300 | + .view(B, self.out_channels, H, W) |
| 301 | + ) |
| 302 | + |
| 303 | + return output |
0 commit comments