1+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ # SPDX-License-Identifier: Apache-2.0
3+ #
4+ # Licensed under the Apache License, Version 2.0 (the "License");Add commentMore actions
5+ # you may not use this file except in compliance with the License.
6+ # You may obtain a copy of the License at
7+ #
8+ # http://www.apache.org/licenses/LICENSE-2.0
9+ #
10+ # Unless required by applicable law or agreed to in writing, software
11+ # distributed under the License is distributed on an "AS IS" BASIS,
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ # See the License for the specific language governing permissions and
14+ # limitations under the License.
15+
116import math
217from typing import Optional
318
@@ -13,11 +28,7 @@ def scaled_dot_product_attention(
1328) -> tp .Tensor :
1429 dtype = query .dtype
1530 if attn_mask is not None and attn_mask .dtype == tp .bool :
16- attn_mask = tp .where (
17- (attn_mask == 0 ),
18- tp .ones_like (attn_mask , dtype = dtype ) * - float ("inf" ),
19- tp .zeros_like (attn_mask , dtype = dtype ),
20- )
31+ attn_mask = tp .where ((attn_mask == 0 ), tp .cast (tp .Tensor (- float ("inf" )), dtype = dtype ), 0.0 )
2132 if attn_mask is not None :
2233 attn_mask = tp .cast (attn_mask , dtype )
2334 k_t = tp .transpose (key , - 2 , - 1 )
@@ -26,4 +37,4 @@ def scaled_dot_product_attention(
2637
2738
2839def clamp (tensor : tp .Tensor , min : int , max : int ):
29- return tp .minimum (tp .maximum (tensor , tp . ones_like ( tensor ) * min ), tp . ones_like ( tensor ) * max )
40+ return tp .minimum (tp .maximum (tensor , min ), max )
0 commit comments