Skip to content

Commit 4d38050

Browse files
committed
fix mindspore.ops.norm kw param
1 parent 877c003 commit 4d38050

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pygmtools/mindspore_backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ def _ms_max(*args, keep_dims=False, **kwargs):
3131
'Mindspore version to GitHub issues.')
3232

3333
_norm_signature = inspect.signature(mindspore.ops.norm)
34-
if 'axis' in _norm_signature.parameters and 'p' in _norm_signature.parameters:
35-
def _ms_norm(*args, p=None, axis=None, **kwargs):
36-
return mindspore.ops.norm(*args, p=p, axis=axis, **kwargs)
37-
elif 'dim' in _norm_signature.parameters and 'ord' in _norm_signature.parameters:
38-
def _ms_norm(*args, p=None, axis=None, **kwargs):
39-
return mindspore.ops.norm(*args, ord=p, dim=axis, **kwargs)
34+
if 'axis' in _norm_signature.parameters and 'p' in _norm_signature.parameters and 'keep_dims' in _norm_signature.parameters:
35+
def _ms_norm(*args, p=None, axis=None, keep_dims=False, **kwargs):
36+
return mindspore.ops.norm(*args, p=p, axis=axis, keep_dims=keep_dims, **kwargs)
37+
elif 'dim' in _norm_signature.parameters and 'ord' in _norm_signature.parameters and 'keepdim' in _norm_signature.parameters:
38+
def _ms_norm(*args, p=None, axis=None, keepdim=False, **kwargs):
39+
return mindspore.ops.norm(*args, ord=p, dim=axis, keepdim=keep_dims, **kwargs)
4040
else:
4141
raise ValueError('Mindspore function mindspore.ops.norm has unsupported signature. It is likely you are '
4242
'working with a new Mindspore version which breaks backward compatibility. Please report your '

0 commit comments

Comments
 (0)