Skip to content

Commit fdfef8e

Browse files
authored
Heaviside hotfix (#87)
* revert to torch heaviside, as the derivative is not implemented * force the value at x=0 in the numpy implementation
1 parent 6ff459d commit fdfef8e

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

modulus/sym/utils/sympy/numpy_printer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def _max_np(x, axis=None):
181181
return return_value
182182

183183

184-
def _heaviside_np(x, x2=0.5):
184+
def _heaviside_np(x, x2=0):
185+
# force x2 to 0
186+
x2 = 0
185187
return np.heaviside(x, x2)
186188

187189

modulus/sym/utils/sympy/torch_printer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,8 @@ def _where_torch(conditions, x, y):
7272
return torch.where(conditions, x, y)
7373

7474

75-
def _heaviside_torch(x, values=0.5):
76-
values = torch.tensor([values]).to(x.device)
77-
return torch.heaviside(x, values)
75+
def _heaviside_torch(x, values=0):
76+
return torch.maximum(torch.sign(x), torch.zeros(1, device=x.device))
7877

7978

8079
def _sqrt_torch(x):

0 commit comments

Comments
 (0)