Skip to content

Commit 31283d2

Browse files
Implement Ernie Image model. (Comfy-Org#13369)
1 parent 55ebd28 commit 31283d2

8 files changed

Lines changed: 433 additions & 3 deletions

File tree

comfy/ldm/ernie/model.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
from comfy.ldm.modules.attention import optimized_attention
7+
import comfy.model_management
8+
9+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
10+
assert dim % 2 == 0
11+
if not comfy.model_management.supports_fp64(pos.device):
12+
device = torch.device("cpu")
13+
else:
14+
device = pos.device
15+
16+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
17+
omega = 1.0 / (theta**scale)
18+
out = torch.einsum("...n,d->...nd", pos, omega)
19+
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
20+
return out.to(dtype=torch.float32, device=pos.device)
21+
22+
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
23+
rot_dim = freqs_cis.shape[-1]
24+
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
25+
cos_ = freqs_cis[0]
26+
sin_ = freqs_cis[1]
27+
x1, x2 = x.chunk(2, dim=-1)
28+
x_rotated = torch.cat((-x2, x1), dim=-1)
29+
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
30+
31+
class ErnieImageEmbedND3(nn.Module):
32+
def __init__(self, dim: int, theta: int, axes_dim: tuple):
33+
super().__init__()
34+
self.dim = dim
35+
self.theta = theta
36+
self.axes_dim = list(axes_dim)
37+
38+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
39+
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
40+
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
41+
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
42+
43+
class ErnieImagePatchEmbedDynamic(nn.Module):
44+
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
45+
super().__init__()
46+
self.patch_size = patch_size
47+
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype)
48+
49+
def forward(self, x: torch.Tensor) -> torch.Tensor:
50+
x = self.proj(x)
51+
batch_size, dim, height, width = x.shape
52+
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
53+
54+
class Timesteps(nn.Module):
55+
def __init__(self, num_channels: int, flip_sin_to_cos: bool = False):
56+
super().__init__()
57+
self.num_channels = num_channels
58+
self.flip_sin_to_cos = flip_sin_to_cos
59+
60+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
61+
half_dim = self.num_channels // 2
62+
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
63+
emb = torch.exp(exponent)
64+
emb = timesteps[:, None].float() * emb[None, :]
65+
if self.flip_sin_to_cos:
66+
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
67+
else:
68+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
69+
return emb
70+
71+
class TimestepEmbedding(nn.Module):
72+
def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None):
73+
super().__init__()
74+
Linear = operations.Linear
75+
self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype)
76+
self.act = nn.SiLU()
77+
self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype)
78+
79+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
80+
sample = self.linear_1(sample)
81+
sample = self.act(sample)
82+
sample = self.linear_2(sample)
83+
return sample
84+
85+
class ErnieImageAttention(nn.Module):
86+
def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
87+
super().__init__()
88+
self.heads = heads
89+
self.head_dim = dim_head
90+
self.inner_dim = heads * dim_head
91+
92+
Linear = operations.Linear
93+
RMSNorm = operations.RMSNorm
94+
95+
self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
96+
self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
97+
self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
98+
99+
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
100+
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
101+
102+
self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)])
103+
104+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor:
105+
B, S, _ = x.shape
106+
107+
q_flat = self.to_q(x)
108+
k_flat = self.to_k(x)
109+
v_flat = self.to_v(x)
110+
111+
query = q_flat.view(B, S, self.heads, self.head_dim)
112+
key = k_flat.view(B, S, self.heads, self.head_dim)
113+
114+
query = self.norm_q(query)
115+
key = self.norm_k(key)
116+
117+
if image_rotary_emb is not None:
118+
query = apply_rotary_emb(query, image_rotary_emb)
119+
key = apply_rotary_emb(key, image_rotary_emb)
120+
121+
query, key = query.to(x.dtype), key.to(x.dtype)
122+
123+
q_flat = query.reshape(B, S, -1)
124+
k_flat = key.reshape(B, S, -1)
125+
126+
hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask)
127+
128+
return self.to_out[0](hidden_states)
129+
130+
class ErnieImageFeedForward(nn.Module):
131+
def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None):
132+
super().__init__()
133+
Linear = operations.Linear
134+
self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
135+
self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
136+
self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype)
137+
138+
def forward(self, x: torch.Tensor) -> torch.Tensor:
139+
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
140+
141+
class ErnieImageSharedAdaLNBlock(nn.Module):
142+
def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
143+
super().__init__()
144+
RMSNorm = operations.RMSNorm
145+
146+
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
147+
self.self_attention = ErnieImageAttention(
148+
query_dim=hidden_size,
149+
dim_head=hidden_size // num_heads,
150+
heads=num_heads,
151+
eps=eps,
152+
operations=operations,
153+
device=device,
154+
dtype=dtype
155+
)
156+
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
157+
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype)
158+
159+
def forward(self, x, rotary_pos_emb, temb, attention_mask=None):
160+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
161+
162+
residual = x
163+
x_norm = self.adaLN_sa_ln(x)
164+
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
165+
166+
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
167+
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
168+
169+
residual = x
170+
x_norm = self.adaLN_mlp_ln(x)
171+
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
172+
173+
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
174+
175+
class ErnieImageAdaLNContinuous(nn.Module):
176+
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
177+
super().__init__()
178+
LayerNorm = operations.LayerNorm
179+
Linear = operations.Linear
180+
self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
181+
self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype)
182+
183+
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
184+
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
185+
x = self.norm(x)
186+
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
187+
return x
188+
189+
class ErnieImageModel(nn.Module):
190+
def __init__(
191+
self,
192+
hidden_size: int = 4096,
193+
num_attention_heads: int = 32,
194+
num_layers: int = 36,
195+
ffn_hidden_size: int = 12288,
196+
in_channels: int = 128,
197+
out_channels: int = 128,
198+
patch_size: int = 1,
199+
text_in_dim: int = 3072,
200+
rope_theta: int = 256,
201+
rope_axes_dim: tuple = (32, 48, 48),
202+
eps: float = 1e-6,
203+
qk_layernorm: bool = True,
204+
device=None,
205+
dtype=None,
206+
operations=None,
207+
**kwargs
208+
):
209+
super().__init__()
210+
self.dtype = dtype
211+
self.hidden_size = hidden_size
212+
self.num_heads = num_attention_heads
213+
self.head_dim = hidden_size // num_attention_heads
214+
self.patch_size = patch_size
215+
self.out_channels = out_channels
216+
217+
Linear = operations.Linear
218+
219+
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype)
220+
self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None
221+
222+
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False)
223+
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype)
224+
225+
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
226+
227+
self.adaLN_modulation = nn.Sequential(
228+
nn.SiLU(),
229+
Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype)
230+
)
231+
232+
self.layers = nn.ModuleList([
233+
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype)
234+
for _ in range(num_layers)
235+
])
236+
237+
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype)
238+
self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype)
239+
240+
def forward(self, x, timesteps, context, **kwargs):
241+
device, dtype = x.device, x.dtype
242+
B, C, H, W = x.shape
243+
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
244+
N_img = Hp * Wp
245+
246+
img_bsh = self.x_embedder(x)
247+
248+
text_bth = context
249+
if self.text_proj is not None and text_bth.numel() > 0:
250+
text_bth = self.text_proj(text_bth)
251+
Tmax = text_bth.shape[1]
252+
253+
hidden_states = torch.cat([img_bsh, text_bth], dim=1)
254+
255+
text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32)
256+
text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32)
257+
index = float(Tmax)
258+
259+
transformer_options = kwargs.get("transformer_options", {})
260+
rope_options = transformer_options.get("rope_options", None)
261+
262+
h_len, w_len = float(Hp), float(Wp)
263+
h_offset, w_offset = 0.0, 0.0
264+
265+
if rope_options is not None:
266+
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
267+
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
268+
index += rope_options.get("shift_t", 0.0)
269+
h_offset += rope_options.get("shift_y", 0.0)
270+
w_offset += rope_options.get("shift_x", 0.0)
271+
272+
image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32)
273+
image_ids[:, :, 0] = image_ids[:, :, 1] + index
274+
image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1)
275+
image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0)
276+
277+
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
278+
279+
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
280+
del image_ids, text_ids
281+
282+
sample = self.time_proj(timesteps.to(dtype)).to(self.time_embedding.linear_1.weight.dtype)
283+
c = self.time_embedding(sample)
284+
285+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
286+
t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
287+
]
288+
289+
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
290+
for layer in self.layers:
291+
hidden_states = layer(hidden_states, rotary_pos_emb, temb)
292+
293+
hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states)
294+
295+
patches = self.final_linear(hidden_states)[:, :N_img, :]
296+
output = (
297+
patches.view(B, Hp, Wp, p, p, self.out_channels)
298+
.permute(0, 5, 1, 3, 2, 4)
299+
.contiguous()
300+
.view(B, self.out_channels, H, W)
301+
)
302+
303+
return output

comfy/model_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import comfy.ldm.anima.model
5454
import comfy.ldm.ace.ace_step15
5555
import comfy.ldm.rt_detr.rtdetr_v4
56+
import comfy.ldm.ernie.model
5657

5758
import comfy.model_management
5859
import comfy.patcher_extension
@@ -1962,3 +1963,14 @@ def concat_cond(self, **kwargs):
19621963
class RT_DETR_v4(BaseModel):
19631964
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
19641965
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
1966+
1967+
class ErnieImage(BaseModel):
1968+
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
1969+
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
1970+
1971+
def extra_conds(self, **kwargs):
1972+
out = super().extra_conds(**kwargs)
1973+
cross_attn = kwargs.get("cross_attn", None)
1974+
if cross_attn is not None:
1975+
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
1976+
return out

comfy/model_detection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
713713
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
714714
return dit_config
715715

716+
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
717+
dit_config = {}
718+
dit_config["image_model"] = "ernie"
719+
return dit_config
720+
716721
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
717722
return None
718723

comfy/sd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import comfy.text_encoders.ace15
6363
import comfy.text_encoders.longcat_image
6464
import comfy.text_encoders.qwen35
65+
import comfy.text_encoders.ernie
6566

6667
import comfy.model_patcher
6768
import comfy.lora
@@ -1235,6 +1236,7 @@ class TEModel(Enum):
12351236
QWEN35_4B = 25
12361237
QWEN35_9B = 26
12371238
QWEN35_27B = 27
1239+
MINISTRAL_3_3B = 28
12381240

12391241

12401242
def detect_te_model(sd):
@@ -1301,6 +1303,8 @@ def detect_te_model(sd):
13011303
return TEModel.MISTRAL3_24B
13021304
else:
13031305
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
1306+
if weight.shape[0] == 3072:
1307+
return TEModel.MINISTRAL_3_3B
13041308

13051309
return TEModel.LLAMA3_8
13061310
return None
@@ -1458,6 +1462,10 @@ class EmptyClass:
14581462
elif te_model == TEModel.QWEN3_06B:
14591463
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
14601464
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
1465+
elif te_model == TEModel.MINISTRAL_3_3B:
1466+
clip_target.clip = comfy.text_encoders.ernie.te(**llama_detect(clip_data))
1467+
clip_target.tokenizer = comfy.text_encoders.ernie.ErnieTokenizer
1468+
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
14611469
else:
14621470
# clip_l
14631471
if clip_type == CLIPType.SD3:

0 commit comments

Comments
 (0)