Skip to content

Commit d0c0eb7

Browse files
committed
fixing mindspore.norm and paddle-openssl compatibility
1 parent 8557cc9 commit d0c0eb7

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

.github/workflows/python-package.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ jobs:
3030
python -m pip install flake8 pytest-cov
3131
if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi
3232
if [ "${{ matrix.python-version }}" != "3.10" ]; then pip install mindspore==2.0.0; fi
33+
wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb
34+
dpkg -i libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb
3335
- name: Generate astar.so
3436
run: |
3537
python pygmtools/c_astar_src/build_c_astar.py
@@ -46,7 +48,7 @@ jobs:
4648
pytest --cov=pygmtools --cov-report=xml --cov-append
4749
- name: Upload to codecov
4850
uses: codecov/codecov-action@v3
49-
if: matrix.python-version == 3.8
51+
if: matrix.python-version == 3.9
5052

5153
macos:
5254

pygmtools/mindspore_backend.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ def _ms_max(*args, keep_dims=False, **kwargs):
3030
'working with a new Mindspore version which breaks backward compatibility. Please report your '
3131
'Mindspore version to GitHub issues.')
3232

33+
_norm_signature = inspect.signature(mindspore.ops.norm)
34+
if 'axis' in _norm_signature.parameters:
35+
def _ms_norm(*args, axis=None, **kwargs):
36+
return mindspore.ops.norm(*args, axis=axis, **kwargs)
37+
elif 'dim' in _norm_signature.parameters:
38+
def _ms_norm(*args, axis=None, **kwargs):
39+
return mindspore.ops.norm(*args, dim=axis, **kwargs)
40+
else:
41+
raise ValueError('Mindspore function mindspore.ops.norm has unsupported signature. It is likely you are '
42+
'working with a new Mindspore version which breaks backward compatibility. Please report your '
43+
'Mindspore version to GitHub issues.')
44+
45+
3346
#############################################
3447
# Linear Assignment Problem Solvers #
3548
#############################################
@@ -261,15 +274,15 @@ def rrwm(K: mindspore.Tensor, n1: mindspore.Tensor, n2: mindspore.Tensor, n1max,
261274
# random walk
262275
v = mindspore.ops.BatchMatMul()(K, v)
263276
last_v = v
264-
n = mindspore.ops.norm(v, axis=1, p=1, keep_dims=True)
277+
n = _ms_norm(v, axis=1, p=1, keep_dims=True)
265278
v = v / n
266279

267280
# reweighted jump
268281
s = v.view(batch_num, int(n2max), int(n1max)).swapaxes(1, 2)
269282
s = beta * s / s.max(axis=1, keepdims=True).max(axis=2, keepdims=True)
270283
v = alpha * sinkhorn(s, n1, n2, max_iter=sk_iter, batched_operation=True).swapaxes(1, 2).reshape(batch_num, n1n2, 1) + \
271284
(1 - alpha) * v
272-
n = mindspore.ops.norm(v, axis=1, p=1, keep_dims=True)
285+
n = _ms_norm(v, axis=1, p=1, keep_dims=True)
273286
v = mindspore.ops.matmul(v, 1 / n)
274287

275288
if (v - last_v).sum().sqrt() < 1e-5:
@@ -287,7 +300,7 @@ def sm(K: mindspore.Tensor, n1: mindspore.Tensor, n2: mindspore.Tensor, n1max, n
287300
v = vlast = v0
288301
for i in range(max_iter):
289302
v = mindspore.ops.BatchMatMul()(K, v)
290-
n = mindspore.ops.norm(v, axis=1, p=2)
303+
n = _ms_norm(v, axis=1, p=2)
291304
v = mindspore.ops.matmul(v, (1 / n).view(batch_num, 1, 1))
292305
if (v - vlast).sum().sqrt() < 1e-5:
293306
break

0 commit comments

Comments
 (0)