From 9c189adc9b11948ab8e605f75da17187ccf35f3a Mon Sep 17 00:00:00 2001
From: David Tweedle <david.aig.tweedle@gmail.com>
Date: Thu, 5 Jun 2025 16:07:06 -0400
Subject: [PATCH 1/6] Update metrics.py - fix for ogbg pytorch

It seems that the problem affecting the pytorch ogbg workloads (but only if they run for some length of time) has to do with jax/xla cpu compilation of the metrics computation. By converting the jax arrays to numpy, hopefully this can be avoided. The next step is to test on schedule free and shampoo, which I hope to do very soon.
---
 algoperf/workloads/ogbg/metrics.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 55f83d905..c2db383b5 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -40,7 +40,7 @@ def compute(self):
 
     if USE_PYTORCH_DDP:
       # Sync labels, logits, and masks across devices.
-      all_values = [labels, logits, mask]
+      all_values = [np.array(labels), np.array(logits), np.array(mask)]
       for idx, array in enumerate(all_values):
         tensor = torch.as_tensor(array, device=DEVICE)
         # Assumes that the tensors on all devices have the same shape.
@@ -51,7 +51,7 @@ def compute(self):
 
     mask = mask.astype(bool)
 
-    probs = jax.nn.sigmoid(logits)
+    probs = 1 / (1 + np.exp(-logits))
     num_tasks = labels.shape[1]
     average_precisions = np.full(num_tasks, np.nan)
 

From fdc956bf62b80f84e634d6717ab9bc12aea33a9e Mon Sep 17 00:00:00 2001
From: David Tweedle <david.aig.tweedle@gmail.com>
Date: Mon, 9 Jun 2025 10:27:45 -0400
Subject: [PATCH 2/6] Update metrics.py

The problem with torchrun and jax seems to be caused by jax.nn.sigmoid.
---
 algoperf/workloads/ogbg/metrics.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index c2db383b5..982e1044e 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -37,10 +37,11 @@ def compute(self):
     labels = values['labels']
     logits = values['logits']
     mask = values['mask']
+    sigmoid = jax.nn.sigmoid
 
     if USE_PYTORCH_DDP:
       # Sync labels, logits, and masks across devices.
-      all_values = [np.array(labels), np.array(logits), np.array(mask)]
+      all_values = [labels, logits, mask]
       for idx, array in enumerate(all_values):
         tensor = torch.as_tensor(array, device=DEVICE)
         # Assumes that the tensors on all devices have the same shape.
@@ -48,10 +49,11 @@ def compute(self):
         dist.all_gather(all_tensors, tensor)
         all_values[idx] = torch.cat(all_tensors).cpu().numpy()
       labels, logits, mask = all_values
+      sigmoid = lambda x: 1 / (1 + np.exp(-x))
 
     mask = mask.astype(bool)
 
-    probs = 1 / (1 + np.exp(-logits))
+    probs = sigmoid(logits)
     num_tasks = labels.shape[1]
     average_precisions = np.full(num_tasks, np.nan)
 

From 6c888df9a365be98332575bae74295b23501a7ea Mon Sep 17 00:00:00 2001
From: David Tweedle <david.aig.tweedle@gmail.com>
Date: Mon, 9 Jun 2025 10:43:29 -0400
Subject: [PATCH 3/6] Update metrics.py

Changed from lambda expression which pylint doesn't like.
---
 algoperf/workloads/ogbg/metrics.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 982e1044e..5e1e7c4ef 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -31,6 +31,9 @@ class MeanAveragePrecision(
     metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))):
   """Computes the mean average precision (mAP) over different tasks."""
 
+  def sigmoid_np(x):
+    return 1 / (1 + np.exp(-x))
+
   def compute(self):
     # Matches the official OGB evaluation scheme for mean average precision.
     values = super().compute()
@@ -49,7 +52,7 @@ def compute(self):
         dist.all_gather(all_tensors, tensor)
         all_values[idx] = torch.cat(all_tensors).cpu().numpy()
       labels, logits, mask = all_values
-      sigmoid = lambda x: 1 / (1 + np.exp(-x))
+      sigmoid = sigmoid_np
 
     mask = mask.astype(bool)
 

From e4a55ab1db0a114a3a713c327ab334effcba9d53 Mon Sep 17 00:00:00 2001
From: David Tweedle <david.aig.tweedle@gmail.com>
Date: Mon, 9 Jun 2025 17:19:11 -0400
Subject: [PATCH 4/6] Update metrics.py

Defined np sigmoid inside use_pytorch_ddp
---
 algoperf/workloads/ogbg/metrics.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 5e1e7c4ef..8f342c25d 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -31,9 +31,6 @@ class MeanAveragePrecision(
     metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))):
   """Computes the mean average precision (mAP) over different tasks."""
 
-  def sigmoid_np(x):
-    return 1 / (1 + np.exp(-x))
-
   def compute(self):
     # Matches the official OGB evaluation scheme for mean average precision.
     values = super().compute()
@@ -52,6 +49,8 @@ def compute(self):
         dist.all_gather(all_tensors, tensor)
         all_values[idx] = torch.cat(all_tensors).cpu().numpy()
       labels, logits, mask = all_values
+      def sigmoid_np(x):
+        return 1 / (1 + np.exp(-x))
       sigmoid = sigmoid_np
 
     mask = mask.astype(bool)

From 07f89a2b69d6f7667c83961e0fea4ae228681355 Mon Sep 17 00:00:00 2001
From: David Tweedle <david.aig.tweedle@gmail.com>
Date: Mon, 9 Jun 2025 17:29:58 -0400
Subject: [PATCH 5/6] Update metrics.py

Added white space before and after sigmoid_np
---
 algoperf/workloads/ogbg/metrics.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 8f342c25d..19d43aae4 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -49,8 +49,10 @@ def compute(self):
         dist.all_gather(all_tensors, tensor)
         all_values[idx] = torch.cat(all_tensors).cpu().numpy()
       labels, logits, mask = all_values
+      
       def sigmoid_np(x):
         return 1 / (1 + np.exp(-x))
+        
       sigmoid = sigmoid_np
 
     mask = mask.astype(bool)

From 3e436c771f270be867a6ec973eb6c022a5c6ae69 Mon Sep 17 00:00:00 2001
From: David Tweedle <david.aig.tweedle@gmail.com>
Date: Mon, 9 Jun 2025 18:19:52 -0400
Subject: [PATCH 6/6] Update metrics.py

Fix white space
---
 algoperf/workloads/ogbg/metrics.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 19d43aae4..ea6041a6c 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -49,10 +49,10 @@ def compute(self):
         dist.all_gather(all_tensors, tensor)
         all_values[idx] = torch.cat(all_tensors).cpu().numpy()
       labels, logits, mask = all_values
-      
+
       def sigmoid_np(x):
         return 1 / (1 + np.exp(-x))
-        
+
       sigmoid = sigmoid_np
 
     mask = mask.astype(bool)