Skip to content

Commit 5142a42

Browse files
committed
Fix NaN bug in normalization for fp16
1 parent 9602126 commit 5142a42

File tree

7 files changed

+84
-77
lines changed

7 files changed

+84
-77
lines changed

tripy/examples/diffusion/clip_model.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class CLIPConfig:
2929
num_heads: int = 12
3030
max_seq_len: int = 77
3131
num_hidden_layers: int = 12
32-
dtype: tp.dtype = tp.float16
32+
dtype: tp.dtype = tp.float32
3333

3434
class CLIPMLP(tp.Module):
3535
def __init__(self, config: CLIPConfig):
@@ -52,6 +52,7 @@ def __init__(self, config: CLIPConfig):
5252
self.v_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
5353
self.q_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
5454
self.out_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
55+
self.dtype = config.dtype
5556

5657
def __call__(self, hidden_states, causal_attention_mask):
5758
bsz, tgt_len, embed_dim = hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2]
@@ -65,7 +66,7 @@ def __call__(self, hidden_states, causal_attention_mask):
6566
for x in (q, k, v)
6667
]
6768
attn_output = scaled_dot_product_attention(
68-
q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask
69+
q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask, dtype=self.dtype,
6970
)
7071
out = self.out_proj(tp.reshape(tp.transpose(attn_output, 1, 2), (bsz, tgt_len, embed_dim)))
7172
return out
@@ -74,18 +75,18 @@ def __call__(self, hidden_states, causal_attention_mask):
7475
class CLIPEncoderLayer(tp.Module):
7576
def __init__(self, config: CLIPConfig):
7677
self.self_attn = CLIPAttention(config)
77-
self.layer_norm1 = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
78+
self.layer_norm1 = tp.LayerNorm(config.embedding_size, dtype=tp.float32)
7879
self.mlp = CLIPMLP(config)
79-
self.layer_norm2 = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
80+
self.layer_norm2 = tp.LayerNorm(config.embedding_size, dtype=tp.float32)
8081

8182
def __call__(self, hidden_states, causal_attention_mask):
8283
residual = hidden_states
83-
hidden_states = self.layer_norm1(hidden_states)
84+
hidden_states = tp.cast(self.layer_norm1(tp.cast(hidden_states, self.layer_norm1.dtype)), hidden_states.dtype)
8485
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
8586
hidden_states = residual + hidden_states
8687

8788
residual = hidden_states
88-
hidden_states = self.layer_norm2(hidden_states)
89+
hidden_states = tp.cast(self.layer_norm2(tp.cast(hidden_states, self.layer_norm2.dtype)), hidden_states.dtype)
8990
hidden_states = self.mlp(hidden_states)
9091
hidden_states = residual + hidden_states
9192

@@ -115,10 +116,10 @@ class CLIPTextTransformer(tp.Module):
115116
def __init__(self, config: CLIPConfig):
116117
self.embeddings = CLIPTextEmbeddings(config)
117118
self.encoder = CLIPEncoder(config)
118-
self.final_layer_norm = tp.LayerNorm(config.embedding_size, dtype=config.dtype)
119+
self.final_layer_norm = tp.LayerNorm(config.embedding_size, dtype=tp.float32)
119120
self.max_seq_len = config.max_seq_len
120121

121122
def __call__(self, input_ids):
122123
x = self.embeddings(input_ids, tp.reshape(tp.iota((input_ids.shape[1],), dtype=tp.int32), (1, -1)))
123124
x = self.encoder(x, tp.triu(tp.full((1, 1, self.max_seq_len, self.max_seq_len), float("-inf")), 1))
124-
return self.final_layer_norm(x)
125+
return tp.cast(self.final_layer_norm(tp.cast(x, self.final_layer_norm.dtype)), x.dtype)

tripy/examples/diffusion/example.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def compile_clip(model, dtype=tp.int32, verbose=False):
5252
return compile_model(model, inputs, verbose=verbose)
5353

5454

55-
def compile_unet(model, dtype=tp.float16, verbose=False):
55+
def compile_unet(model, dtype, verbose=False):
5656
unconditional_context_shape = (1, 77, 768)
5757
conditional_context_shape = (1, 77, 768)
5858
latent_shape = (1, 4, 64, 64)
@@ -68,16 +68,16 @@ def compile_unet(model, dtype=tp.float16, verbose=False):
6868
return compile_model(model, inputs, verbose=verbose)
6969

7070

71-
def compile_vae(model, dtype=tp.float16, verbose=False):
71+
def compile_vae(model, dtype, verbose=False):
7272
inputs = (tp.InputInfo((1, 4, 64, 64), dtype=dtype),)
7373
return compile_model(model, inputs, verbose=verbose)
7474

7575

76-
def run_diffusion_loop(model, unconditional_context, context, latent, steps, guidance):
76+
def run_diffusion_loop(model, unconditional_context, context, latent, steps, guidance, dtype):
7777
timesteps = list(range(1, 1000, 1000 // steps))
78-
print(f"[I] Running diffusion for {timesteps} timesteps...")
79-
alphas = get_alphas_cumprod()[tp.Tensor(timesteps)]
80-
alphas_prev = tp.concatenate([tp.Tensor([1.0]), alphas[:-1]], dim=0)
78+
print(f"[I] Running diffusion for {steps} timesteps...")
79+
alphas = get_alphas_cumprod(dtype=dtype)[tp.Tensor(timesteps)]
80+
alphas_prev = tp.concatenate([tp.Tensor([1.0], dtype=dtype), alphas[:-1]], dim=0)
8181

8282
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
8383
t.set_description("idx: %1d, timestep: %3d" % (index, timestep))
@@ -86,32 +86,34 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
8686
unconditional_context,
8787
context,
8888
latent,
89-
tp.cast(tp.Tensor([timestep]), tp.float32),
89+
tp.Tensor([timestep], dtype=dtype),
9090
alphas[tid],
9191
alphas_prev[tid],
92-
tp.Tensor([guidance]),
92+
tp.Tensor([guidance], dtype=dtype),
9393
)
9494
return latent
9595

9696

9797
def tripy_diffusion(args):
9898
run_start_time = time.perf_counter()
9999

100-
if os.path.isdir("engines"):
100+
dtype, torch_dtype = (tp.float16, torch.float16) if args.fp16 else (tp.float32, torch.float32)
101+
102+
if os.path.isdir(args.engine_dir):
101103
print("[I] Loading cached engines from disk...")
102104
clip_compiled = tp.Executable.load(os.path.join("engines", "clip_executable.json"))
103105
unet_compiled = tp.Executable.load(os.path.join("engines", "unet_executable.json"))
104106
vae_compiled = tp.Executable.load(os.path.join("engines", "vae_executable.json"))
105107
else:
106-
model = StableDiffusion(StableDiffusionConfig(dtype=tp.float16))
108+
model = StableDiffusion(StableDiffusionConfig(dtype=dtype))
107109
print("[I] Loading model weights...", flush=True)
108-
load_from_diffusers(model, tp.float16, debug=True)
110+
load_from_diffusers(model, dtype, args.hf_token, debug=True)
109111
clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True)
110-
unet_compiled = compile_unet(model, verbose=True)
111-
vae_compiled = compile_vae(model.decode, verbose=True)
112+
unet_compiled = compile_unet(model, dtype, verbose=True)
113+
vae_compiled = compile_vae(model.decode, dtype, verbose=True)
112114

113-
os.mkdir("engines")
114-
print("[I] Saving engines to disk...")
115+
os.mkdir(args.engine_dir)
116+
print(f"[I] Saving engines to {args.engine_dir}...")
115117
clip_compiled.save(os.path.join("engines", "clip_executable.json"))
116118
unet_compiled.save(os.path.join("engines", "unet_executable.json"))
117119
vae_compiled.save(os.path.join("engines", "vae_executable.json"))
@@ -135,11 +137,11 @@ def tripy_diffusion(args):
135137
# Backbone of diffusion - the UNet
136138
if args.seed is not None:
137139
torch.manual_seed(args.seed)
138-
torch_latent = torch.randn((1, 4, 64, 64)).to("cuda")
140+
torch_latent = torch.randn((1, 4, 64, 64), dtype=torch_dtype).to("cuda")
139141
latent = tp.Tensor(torch_latent)
140142

141143
diffusion_run_start = time.perf_counter()
142-
latent = run_diffusion_loop(unet_compiled, unconditional_context, context, latent, args.steps, args.guidance)
144+
latent = run_diffusion_loop(unet_compiled, unconditional_context, context, latent, args.steps, args.guidance, dtype)
143145
diffusion_run_end = time.perf_counter()
144146
print(f"[I] Finished diffusion denoising. Inference took {diffusion_run_end - diffusion_run_start} seconds.")
145147

@@ -173,15 +175,17 @@ def hf_diffusion(args):
173175

174176
run_start_time = time.perf_counter()
175177

178+
dtype = torch.float16 if args.fp16 else torch.float32
179+
model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if args.fp16 else {}
180+
176181
# Initialize models
177-
model_id = "CompVis/stable-diffusion-v1-4" #"benjamin-paine/stable-diffusion-v1-5" #"runwayml/stable-diffusion-v1-5"
178-
clip_id = "openai/clip-vit-large-patch14"
182+
model_id = "KiwiXR/stable-diffusion-v1-5"
179183

180184
print("[I] Loading models...")
181-
hf_tokenizer = CLIPTokenizer.from_pretrained(clip_id)
182-
hf_encoder = CLIPTextModel.from_pretrained(clip_id).to("cuda")
183-
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda")
184-
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to("cuda")
185+
hf_tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
186+
hf_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to("cuda")
187+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", use_auth_token=args.hf_token, **model_opts).to("cuda")
188+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_auth_token=args.hf_token, **model_opts).to("cuda")
185189
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
186190

187191
# Run through CLIP to get context from prompt
@@ -192,19 +196,20 @@ def hf_diffusion(args):
192196
uncond_input = hf_tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt").to("cuda")
193197
text_embeddings = hf_encoder(text_input.input_ids, output_hidden_states=True)[0]
194198
uncond_embeddings = hf_encoder(uncond_input.input_ids)[0]
195-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
199+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
196200
clip_run_end = time.perf_counter()
197201
print(f"took {clip_run_end - clip_run_start} seconds.")
198202

199203
# Backbone of diffusion - the UNet
200204
if args.seed is not None:
201205
torch.manual_seed(args.seed)
202-
torch_latent = torch.randn((1, 4, 64, 64)).to("cuda")
206+
torch_latent = torch.randn((1, 4, 64, 64), dtype=dtype).to("cuda")
203207
torch_latent *= scheduler.init_noise_sigma
204208

205209
scheduler.set_timesteps(args.steps)
206210

207211
diffusion_run_start = time.perf_counter()
212+
print(f"[I] Running diffusion for {args.steps} timesteps...")
208213
for t in tqdm(scheduler.timesteps):
209214
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
210215
latent_model_input = torch.cat([torch_latent] * 2)
@@ -267,7 +272,6 @@ def print_summary(denoising_steps, times):
267272

268273

269274
# TODO: Add torch compilation modes
270-
# TODO: Add fp16 support
271275
# TODO: Add Timing context
272276
def main():
273277
default_prompt = "a horse sized cat eating a bagel"
@@ -282,6 +286,8 @@ def main():
282286
parser.add_argument("--seed", type=int, help="Set the random latent seed")
283287
parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength")
284288
parser.add_argument('--torch-inference', action='store_true', help="Run inference with PyTorch (eager mode) instead of TensorRT.")
289+
parser.add_argument('--hf-token', type=str, default='', help="HuggingFace API access token for downloading model checkpoints")
290+
parser.add_argument('--engine-dir', type=str, default='engines', help="Output directory for TensorRT engines")
285291
args = parser.parse_args()
286292

287293
if args.torch_inference:

tripy/examples/diffusion/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def scaled_dot_product_attention(
1212
embedding_dim: Optional[int] = None,
1313
attn_mask: Optional[tp.Tensor] = None,
1414
is_causal: bool = False,
15-
dtype: tp.dtype = tp.float16
15+
dtype: tp.dtype = tp.float32
1616
) -> tp.Tensor:
1717
"""
1818
Computes scaled dot-product attention.

tripy/examples/diffusion/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
@dataclass
3535
class StableDiffusionConfig:
36-
dtype: tp.dtype = tp.float16
36+
dtype: tp.dtype = tp.float32
3737
clip_config: Optional[CLIPConfig] = field(default=None, init=False)
3838
unet_config: Optional[UNetConfig] = field(default=None, init=False)
3939
vae_config: Optional[VAEConfig] = field(default=None, init=False)
@@ -44,11 +44,11 @@ def __post_init__(self):
4444
self.vae_config = VAEConfig(dtype=self.dtype)
4545

4646
# equivalent to LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
47-
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
47+
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000, dtype=tp.float32):
4848
betas = np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32) ** 2
4949
alphas = 1.0 - betas
5050
alphas_cumprod = np.cumprod(alphas, axis=0)
51-
return tp.Tensor(alphas_cumprod)
51+
return tp.cast(tp.Tensor(alphas_cumprod), dtype)
5252

5353

5454
class StableDiffusion(tp.Module):

tripy/examples/diffusion/unet_model.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
from typing import List, Tuple
2020

21+
import torch
2122
import tripy as tp
2223
from dataclasses import dataclass
2324

@@ -33,28 +34,30 @@ class UNetConfig:
3334
num_heads: int = 8
3435
context_dim: int = 768
3536
emb_channels: int = 1280
36-
dtype: tp.dtype = tp.float16
37+
dtype: tp.dtype = tp.float32
3738

3839

3940
# Used for UNet, not to be confused with ResnetBlock, called ResnetBlock2D in HF diffusers
4041
class ResBlock(tp.Module):
4142
def __init__(self, config: UNetConfig, channels, emb_channels, out_channels):
42-
self.norm1 = tp.GroupNorm(32, channels, dtype=config.dtype)
43+
self.norm1 = tp.GroupNorm(32, channels, dtype=tp.float32)
4344
self.conv1 = tp.Conv(channels, out_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype)
4445
self.time_emb_proj = tp.Linear(emb_channels, out_channels, dtype=config.dtype)
45-
self.norm2 = tp.GroupNorm(32, out_channels, dtype=config.dtype)
46+
self.norm2 = tp.GroupNorm(32, out_channels, dtype=tp.float32)
4647
self.conv2 = tp.Conv(out_channels, out_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype)
4748
self.nonlinearity = tp.silu
4849
self.conv_shortcut = tp.Conv(channels, out_channels, (1, 1), dtype=config.dtype) if channels != out_channels else lambda x: x
4950

5051
def __call__(self, x, emb):
51-
h = self.conv1(self.nonlinearity(self.norm1(x)))
52+
h = tp.cast(self.norm1(tp.cast(x, self.norm1.dtype)), x.dtype)
53+
h = self.conv1(self.nonlinearity(h))
5254
emb_out = self.time_emb_proj(self.nonlinearity(emb))
5355
target_shape = emb_out.shape + (1, 1)
5456
# TODO: #228: WAR to prevent computing output rank in infer_rank for reshape
5557
target_shape.trace_tensor.shape = (emb_out.rank + 2,)
5658
h = h + tp.reshape(emb_out, target_shape)
57-
h = self.conv2(self.nonlinearity(self.norm2(h)))
59+
h = tp.cast(self.norm2(tp.cast(h, self.norm2.dtype)), h.dtype)
60+
h = self.conv2(self.nonlinearity(h))
5861
ret = self.conv_shortcut(x) + h
5962
return ret
6063

@@ -67,14 +70,15 @@ def __init__(self, config: UNetConfig, query_dim, context_dim, n_heads, d_head):
6770
self.num_heads = n_heads
6871
self.head_size = d_head
6972
self.to_out = [tp.Linear(n_heads * d_head, query_dim, dtype=config.dtype)]
73+
self.dtype = config.dtype
7074

7175
def __call__(self, x, context=None):
7276
context = x if context is None else context
7377
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
7478
q, k, v = [
7579
tp.transpose(tp.reshape(y, (x.shape[0], -1, self.num_heads, self.head_size)), 1, 2) for y in (q, k, v)
7680
]
77-
attention = tp.transpose(scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size), 1, 2)
81+
attention = tp.transpose(scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size, dtype=self.dtype), 1, 2)
7882
h_ = tp.reshape(attention, (x.shape[0], -1, self.num_heads * self.head_size))
7983
out = sequential(h_, self.to_out)
8084
return out
@@ -116,20 +120,20 @@ def __init__(self, config, dim, context_dim, n_heads, d_head):
116120
self.attn1 = CrossAttention(config, dim, dim, n_heads, d_head)
117121
self.ff = FeedForward(config, dim)
118122
self.attn2 = CrossAttention(config, dim, context_dim, n_heads, d_head)
119-
self.norm1 = tp.LayerNorm(dim, dtype=config.dtype)
120-
self.norm2 = tp.LayerNorm(dim, dtype=config.dtype)
121-
self.norm3 = tp.LayerNorm(dim, dtype=config.dtype)
123+
self.norm1 = tp.LayerNorm(dim, dtype=tp.float32)
124+
self.norm2 = tp.LayerNorm(dim, dtype=tp.float32)
125+
self.norm3 = tp.LayerNorm(dim, dtype=tp.float32)
122126

123127
def __call__(self, x, context=None):
124-
x = self.attn1(self.norm1(x)) + x
125-
x = self.attn2(self.norm2(x), context=context) + x
126-
x = self.ff(self.norm3(x)) + x
128+
x = self.attn1(tp.cast(self.norm1(tp.cast(x, self.norm1.dtype)), x.dtype)) + x
129+
x = self.attn2(tp.cast(self.norm2(tp.cast(x, self.norm2.dtype)), x.dtype), context=context) + x
130+
x = self.ff(tp.cast(self.norm3(tp.cast(x, self.norm3.dtype)), x.dtype)) + x
127131
return x
128132

129133

130134
class SpatialTransformer(tp.Module): # Transformer2dModel in HF diffusers
131135
def __init__(self, config: UNetConfig, channels, context_dim, n_heads, d_head):
132-
self.norm = tp.GroupNorm(32, channels, dtype=config.dtype)
136+
self.norm = tp.GroupNorm(32, channels, dtype=tp.float32)
133137
assert channels == n_heads * d_head
134138
self.proj_in = tp.Conv(channels, n_heads * d_head, (1, 1), dtype=config.dtype)
135139
self.transformer_blocks = [BasicTransformerBlock(config, channels, context_dim, n_heads, d_head)]
@@ -138,7 +142,7 @@ def __init__(self, config: UNetConfig, channels, context_dim, n_heads, d_head):
138142
def __call__(self, x, context=None):
139143
b, c, h, w = x.shape
140144
x_in = x
141-
x = self.norm(x)
145+
x = tp.cast(self.norm(tp.cast(x, self.norm.dtype)), x.dtype)
142146
x = self.proj_in(x)
143147
x = tp.permute(tp.reshape(x, (b, c, h * w)), (0, 2, 1))
144148
for block in self.transformer_blocks:
@@ -272,15 +276,14 @@ def __init__(self, config: UNetConfig):
272276
CrossAttnUpBlock2D(config, up_channels[2:5], down_channels[2]),
273277
CrossAttnUpBlock2D(config, up_channels[4:7], down_channels[1], use_upsampler=False),
274278
]
275-
self.conv_norm_out = tp.GroupNorm(32, config.model_channels, dtype=config.dtype)
279+
self.conv_norm_out = tp.GroupNorm(32, config.model_channels, dtype=tp.float32)
276280
self.conv_act = tp.silu
277281
self.conv_out = tp.Conv(config.model_channels, config.io_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype)
278282

279283
def __call__(self, x, timesteps=None, context=None):
280284
# TODO: real time embedding
281285
t_emb = timestep_embedding(timesteps, self.config.model_channels, self.config.dtype)
282286
emb = self.time_embedding(t_emb)
283-
284287
x = self.conv_in(x)
285288
saved_inputs = [x]
286289

@@ -301,6 +304,7 @@ def __call__(self, x, timesteps=None, context=None):
301304
else:
302305
x = block(x, emb, context, partial_inputs)
303306

304-
act = self.conv_out(self.conv_act(self.conv_norm_out(x)))
307+
act = tp.cast(self.conv_norm_out(tp.cast(x, self.conv_norm_out.dtype)), x.dtype)
308+
act = self.conv_out(self.conv_act(act))
305309
return act
306310

0 commit comments

Comments
 (0)