Skip to content

Commit 79f6da5

Browse files
committed
small cleanup
1 parent d0fde97 commit 79f6da5

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

tripy/examples/diffusion/example.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@
3030
from examples.diffusion.model import StableDiffusion, StableDiffusionConfig
3131
from examples.diffusion.weight_loader import load_from_diffusers
3232

33-
import nvtx
34-
35-
tp.logger.verbosity = "ir"
36-
3733
batch = tp.NamedDimension("batch", 1, 1, 1)
3834
max_seq_len = tp.NamedDimension("max_seq_len", 77, 77, 77)
3935
embed_dim = tp.NamedDimension("embed_dim", 768, 768, 768)
@@ -159,9 +155,6 @@ def tripy_diffusion(args):
159155
unet_compiled.save(os.path.join(args.engine_dir, "unet_executable.tpymodel"))
160156
vae_compiled.save(os.path.join(args.engine_dir, "vae_executable.tpymodel"))
161157

162-
pr = nvtx.Profile()
163-
pr.enable()
164-
165158
# Run through CLIP to get context from prompt
166159
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
167160
torch_prompt = tokenizer(

tripy/examples/diffusion/helper.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");Add commentMore actions
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
import math
217
from typing import Optional
318

@@ -13,11 +28,7 @@ def scaled_dot_product_attention(
1328
) -> tp.Tensor:
1429
dtype = query.dtype
1530
if attn_mask is not None and attn_mask.dtype == tp.bool:
16-
attn_mask = tp.where(
17-
(attn_mask == 0),
18-
tp.ones_like(attn_mask, dtype=dtype) * -float("inf"),
19-
tp.zeros_like(attn_mask, dtype=dtype),
20-
)
31+
attn_mask = tp.where((attn_mask == 0), tp.cast(tp.Tensor(-float("inf")), dtype=dtype), 0.0)
2132
if attn_mask is not None:
2233
attn_mask = tp.cast(attn_mask, dtype)
2334
k_t = tp.transpose(key, -2, -1)
@@ -26,4 +37,4 @@ def scaled_dot_product_attention(
2637

2738

2839
def clamp(tensor: tp.Tensor, min: int, max: int):
29-
return tp.minimum(tp.maximum(tensor, tp.ones_like(tensor) * min), tp.ones_like(tensor) * max)
40+
return tp.minimum(tp.maximum(tensor, min), max)

0 commit comments

Comments
 (0)