-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdistributed.py
More file actions
127 lines (108 loc) · 5.35 KB
/
distributed.py
File metadata and controls
127 lines (108 loc) · 5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Distributed Execution Module for GNN
Provides Ray and Dask-based parallel dispatching for script execution and parameter sweeps.
Includes robust retry semantics for node failure in external cloud instances.
"""
import logging
from typing import Any, Callable, Dict, List, Literal, Optional
logger = logging.getLogger(__name__)
class Dispatcher:
"""
Dispatcher for distributed parameter sweeps and script execution.
Supports both Ray and Dask backends.
"""
def __init__(self, backend: Literal["ray", "dask"] = "ray", address: Optional[str] = None, num_cpus: Optional[int] = None, max_retries: int = 3):
"""Initialize connection to distributed cluster."""
self.backend = backend
self.address = address
self.num_cpus = num_cpus
self.max_retries = max_retries
self._initialized = False
self.client = None
def connect_to_cluster(self) -> bool:
"""Connect to distributed cluster."""
if self.backend == "ray":
try:
import ray
if not ray.is_initialized():
ray.init(address=self.address, num_cpus=self.num_cpus, ignore_reinit_error=True)
self._initialized = True
logger.info(f"Successfully connected to Ray cluster (Active Nodes: {len(ray.nodes())})")
return True
except ImportError:
logger.warning("Ray is not installed. Run: pip install ray")
return False
except Exception as e:
logger.error(f"Failed to initialize Ray: {e}")
return False
elif self.backend == "dask":
try:
from dask.distributed import Client, LocalCluster
if self.address:
self.client = Client(self.address)
else:
cluster = LocalCluster(n_workers=self.num_cpus if self.num_cpus else 4)
self.client = Client(cluster)
self._initialized = True
logger.info(f"Successfully connected to Dask cluster: {self.client}")
return True
except ImportError:
logger.warning("Dask is not installed. Run: pip install dask distributed")
return False
except Exception as e:
logger.error(f"Failed to initialize Dask: {e}")
return False
return False
def shutdown(self):
"""Shutdown connection."""
if self._initialized:
try:
if self.backend == "ray":
import ray
ray.shutdown()
elif self.backend == "dask" and self.client:
self.client.close()
self._initialized = False
except ImportError as e:
logger.debug("ray/dask not installed during shutdown: %s", e)
def run_scripts_parallel(self, script_infos: List[Dict[str, Any]], execute_fn: Callable, **kwargs) -> List[Dict[str, Any]]:
"""
Execute multiple scripts in parallel across workers with robust retries.
"""
if not self._initialized and not self.connect_to_cluster():
logger.warning("Falling back to sequential execution due to initialization failure.")
return [execute_fn(info, **kwargs) for info in script_infos]
logger.info(f"Dispatching {len(script_infos)} scripts to {self.backend.capitalize()} cluster...")
if self.backend == "ray":
import ray
# Context switch to a remote function with robust retries
@ray.remote(max_retries=self.max_retries, retry_exceptions=True)
def _remote_execute(script_info, kwargs_dict):
return execute_fn(script_info, **kwargs_dict)
futures = [_remote_execute.remote(info, kwargs) for info in script_infos]
return ray.get(futures)
elif self.backend == "dask":
# Use retries parameter if manually providing the tuple logic
futures = [self.client.submit(execute_fn, info, **kwargs) for info in script_infos]
return self.client.gather(futures)
return []
def parameter_sweep(self, model_fn: Callable, param_grid: List[Dict[str, Any]]) -> List[Any]:
"""
Execute a parameter sweep with built-in retry semantics.
"""
if not self._initialized and not self.connect_to_cluster():
logger.warning("Falling back to sequential parameter sweep.")
return [model_fn(**params) for params in param_grid]
logger.info(f"Dispatching {len(param_grid)} parameter combinations for sweep using {self.backend.capitalize()}...")
if self.backend == "ray":
import ray
@ray.remote(max_retries=self.max_retries, retry_exceptions=True)
def _remote_eval(params):
return model_fn(**params)
futures = [_remote_eval.remote(p) for p in param_grid]
return ray.get(futures)
elif self.backend == "dask":
futures = [self.client.submit(model_fn, **p) for p in param_grid]
return self.client.gather(futures)
# Backward compatibility alias
RayDispatcher = Dispatcher