Skip to content

Commit aa44f9c

Browse files
committed
refactored
1 parent 831d619 commit aa44f9c

File tree

5 files changed

+126
-80
lines changed

5 files changed

+126
-80
lines changed

dlrover/python/common/grpc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,9 @@ class NodeCheckpointState(Message):
489489

490490

491491
@dataclass
492-
class WorkerDiagnosisData(Message):
493-
type: str = ""
494-
timestamp: int = 0
495-
content: str = ""
492+
class DiagnosisReportData(Message):
493+
data_cls: str = ""
494+
data_content: str = ""
496495
node_rank: int = -1
497496

498497

dlrover/python/diagnosis/common/diagnosis_data.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import json
1415
from abc import ABCMeta
1516
from datetime import datetime
1617
from typing import List
@@ -20,18 +21,36 @@
2021

2122

2223
class DiagnosisData(metaclass=ABCMeta):
24+
"""
25+
Basic definition of diagnosis data.
26+
27+
Args:
28+
timestamp (datetime): Timestamp of diagnosis data.
29+
data_type (str): Type of metric. Defaults to "GENERIC".
30+
data_content (str): Content of the metric. Defaults to "".
31+
node_id (int): Node ID. Defaults to -1.
32+
node_type (str): Node type. Defaults to "".
33+
node_rank (int): Node rank. Defaults to -1.
34+
"""
35+
2336
def __init__(
2437
self,
2538
timestamp: int = 0,
2639
data_type: str = DiagnosisDataType.GENERIC,
2740
data_content: str = "",
41+
node_id: int = -1,
42+
node_type: str = "",
43+
node_rank: int = -1,
2844
):
2945
if timestamp == 0:
3046
self._timestamp = int(round(datetime.now().timestamp()))
3147
else:
3248
self._timestamp = timestamp
3349
self._data_type = data_type
3450
self._data_content = data_content
51+
self._node_id = node_id
52+
self._node_type = node_type
53+
self._node_rank = node_rank
3554

3655
@property
3756
def data_type(self) -> str:
@@ -45,33 +64,6 @@ def timestamp(self) -> int:
4564
def data_content(self) -> str:
4665
return self._data_content
4766

48-
49-
class WorkerDiagnosisData(DiagnosisData):
50-
def __init__(
51-
self,
52-
timestamp: int = 0,
53-
data_type: str = DiagnosisDataType.GENERIC,
54-
data_content: str = "",
55-
node_id: int = -1,
56-
node_type: str = "",
57-
node_rank: int = -1,
58-
):
59-
"""
60-
General metric
61-
62-
Args:
63-
data_type (str): Type of metric. Defaults to "GENERIC".
64-
data_content (str): Content of the metric. Defaults to "".
65-
node_id (int): Node ID. Defaults to -1.
66-
node_type (str): Node type. Defaults to "".
67-
node_rank (int): Node rank. Defaults to -1.
68-
"""
69-
70-
super().__init__(timestamp, data_type, data_content)
71-
self._node_id = node_id
72-
self._node_type = node_type
73-
self._node_rank = node_rank
74-
7567
@property
7668
def node_id(self):
7769
return self._node_id
@@ -84,35 +76,47 @@ def node_type(self):
8476
def node_rank(self):
8577
return self._node_rank
8678

79+
def to_json(self):
80+
data = {k.lstrip("_"): v for k, v in self.__dict__.items()}
81+
return json.dumps(data)
82+
83+
@classmethod
84+
def from_json(cls, json_data):
85+
return cls(**json.loads(json_data))
86+
87+
def is_from_worker(self):
88+
return self._node_id != -1
89+
90+
91+
class WorkerTrainingMetric(DiagnosisData):
92+
"""
93+
Diagnosis data for worker training metric.
94+
95+
Args:
96+
timestamp (datetime): Timestamp of diagnosis data.
97+
data_type (str): Type of metric. Defaults to "GENERIC".
98+
data_content (str): Content of the metric. Defaults to "".
99+
node_id (int): Node ID. Defaults to -1.
100+
node_type (str): Node type. Defaults to "".
101+
node_rank (int): Node rank. Defaults to -1.
102+
is_final_result (bool, optional): Whether the metric is final result.
103+
Defaults to False.
104+
need_report (bool, optional): Whether the metric needs report.
105+
Defaults to False.
106+
"""
87107

88-
class WorkerTrainingMetric(WorkerDiagnosisData):
89108
def __init__(
90109
self,
91110
timestamp: int = 0,
92111
data_type: str = DiagnosisDataType.GENERIC,
93112
data_content: str = "",
94-
node_id: int = -1,
95-
node_type: str = "",
96-
node_rank: int = -1,
113+
node_id=env_utils.get_node_id(),
114+
node_type=env_utils.get_node_type(),
115+
node_rank=env_utils.get_node_rank(),
97116
is_final_result=False,
98117
need_report=False,
99118
):
100-
"""
101-
General metric
102-
103-
Args:
104-
data_type (str): Type of metric. Defaults to "GENERIC".
105-
data_content (str): Content of the metric. Defaults to "".
106-
is_final_result (bool, optional): Whether the metric is final
107-
result or not. Defaults to False.
108-
need_report (bool, optional): Whether the metric needs
109-
report(to Brain). Defaults to False.
110-
node_id (int): Node ID. Defaults to -1.
111-
node_type (str): Node type. Defaults to "".
112-
node_rank (int): Node rank. Defaults to -1.
113-
"""
114-
115-
super().__init__(
119+
super(WorkerTrainingMetric, self).__init__(
116120
timestamp, data_type, data_content, node_id, node_type, node_rank
117121
)
118122
self._is_final_result = is_final_result
@@ -133,8 +137,26 @@ def is_resolvable(self):
133137
return False
134138

135139

136-
class TrainingLog(WorkerDiagnosisData):
137-
def __init__(self, timestamp: int = 0, logs: List[str] = None):
140+
class TrainingLog(DiagnosisData):
141+
"""
142+
Worker's training log.
143+
144+
Args:
145+
timestamp (datetime): Timestamp of diagnosis data.
146+
logs (list): Log content in list format.
147+
node_id (int): Node ID. Defaults to -1.
148+
node_type (str): Node type. Defaults to "".
149+
node_rank (int): Node rank. Defaults to -1.
150+
"""
151+
152+
def __init__(
153+
self,
154+
timestamp: int = 0,
155+
logs: List[str] = None,
156+
node_id=env_utils.get_node_id(),
157+
node_type=env_utils.get_node_type(),
158+
node_rank=env_utils.get_node_rank(),
159+
):
138160
if logs is None:
139161
data_content = ""
140162
else:
@@ -144,9 +166,9 @@ def __init__(self, timestamp: int = 0, logs: List[str] = None):
144166
timestamp,
145167
DiagnosisDataType.TRAINING_LOG,
146168
data_content,
147-
env_utils.get_node_id(),
148-
env_utils.get_node_type(),
149-
env_utils.get_node_rank(),
169+
node_id,
170+
node_type,
171+
node_rank,
150172
)
151173

152174
@property

dlrover/python/elastic_agent/master_client.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from dlrover.python.common.constants import NetworkFailureReason, NodeEnv
2424
from dlrover.python.common.log import default_logger as logger
2525
from dlrover.python.common.singleton import Singleton
26-
from dlrover.python.diagnosis.common.diagnosis_data import WorkerDiagnosisData
26+
from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData
2727

2828

2929
def retry_grpc_request(func):
@@ -382,11 +382,10 @@ def report_failures(self, error_data, restart_count=-1, level=""):
382382
def report_paral_config(self, config: grpc.ParallelConfig):
383383
self._report(config)
384384

385-
def report_diagnosis_agent_metrics(self, data: WorkerDiagnosisData):
386-
message = grpc.WorkerDiagnosisData(
387-
data.data_type,
388-
data.timestamp,
389-
data.data_content,
385+
def report_diagnosis_agent_metrics(self, data: DiagnosisData):
386+
message = grpc.DiagnosisReportData(
387+
data.__class__.__name__,
388+
data.to_json(),
390389
data.node_rank,
391390
)
392391
self._report(message)

dlrover/python/master/servicer.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import threading
1515
import time
1616
from concurrent import futures
17-
from typing import Dict, List
17+
from typing import Dict, List, Optional
1818

1919
import grpc as grpc_lib
2020

@@ -32,7 +32,7 @@
3232
)
3333
from dlrover.python.common.global_context import Context
3434
from dlrover.python.common.log import default_logger as logger
35-
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
35+
from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData
3636
from dlrover.python.master.diagnosis.diagnosis import DiagnosisManager
3737
from dlrover.python.master.elastic_training.kv_store_service import (
3838
KVStoreService,
@@ -354,10 +354,8 @@ def report(self, request, _):
354354
success = self._report_heartbeat(node_type, node_id, message)
355355
elif isinstance(message, grpc.NodeCheckpointState):
356356
success = self._sync_checkpoint(node_type, node_id, message)
357-
elif isinstance(message, grpc.WorkerDiagnosisData):
358-
success = self._report_worker_diagnosis_data(
359-
node_type, node_id, message
360-
)
357+
elif isinstance(message, grpc.DiagnosisReportData):
358+
success = self._report_worker_diagnosis_data(message)
361359

362360
response.success = success
363361
return response
@@ -613,19 +611,17 @@ def _sync_checkpoint(
613611
rdzv_manager = self._rdzv_managers[RendezvousName.ELASTIC_TRAINING]
614612
return rdzv_manager.sync_ckpt_nodes(node_id, message.step)
615613

616-
def _report_worker_diagnosis_data(
617-
self, node_type, node_id, message: grpc.WorkerDiagnosisData
618-
):
614+
def _report_worker_diagnosis_data(self, message: grpc.DiagnosisReportData):
619615
if self._diagnosis_manager:
620-
data = WorkerTrainingMetric(
621-
timestamp=message.timestamp,
622-
data_type=message.type,
623-
data_content=message.content,
624-
node_id=node_id,
625-
node_type=node_type,
626-
node_rank=message.node_rank,
627-
)
628-
self._diagnosis_manager.collect_diagnosis_data(data)
616+
data_cls: Optional[DiagnosisData] = globals().get(message.data_cls)
617+
if data_cls is None:
618+
logger.warning(
619+
"Invalid diagnosis report "
620+
f"data type: {message.data_cls}"
621+
)
622+
return False
623+
data_obj = data_cls.from_json(message.data_content)
624+
self._diagnosis_manager.collect_diagnosis_data(data_obj)
629625
return True
630626

631627
def _sync_training_ports(

dlrover/python/tests/test_diagnosis_agent.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,36 @@ def test_xpu_timer_metric_collect(self):
166166
self.assertEqual(agent_xpu_metric.node_rank, 1)
167167
self.assertTrue(agent_xpu_metric.timestamp > 0)
168168

169+
def test_worker_training_metric(self):
170+
test = WorkerTrainingMetric(
171+
data_content="test123",
172+
node_id=env_utils.get_node_id(),
173+
node_type=env_utils.get_node_type(),
174+
node_rank=env_utils.get_node_rank(),
175+
is_final_result=True,
176+
)
177+
178+
test_str = test.to_json()
179+
self.assertTrue('"data_content": "test123"' in test_str)
180+
181+
test_new = WorkerTrainingMetric.from_json(test_str)
182+
self.assertEqual(test_new.timestamp, test.timestamp)
183+
self.assertEqual(test_new.data_content, test.data_content)
184+
self.assertEqual(test_new.data_type, test.data_type)
185+
self.assertEqual(test_new.is_final_result, test.is_final_result)
186+
187+
test_new = globals().get("WorkerTrainingMetric").from_json(test_str)
188+
self.assertEqual(test_new.timestamp, test.timestamp)
189+
self.assertEqual(test_new.data_content, test.data_content)
190+
self.assertEqual(test_new.data_type, test.data_type)
191+
self.assertEqual(test_new.is_final_result, test.is_final_result)
192+
193+
test_new = globals().get(test.__class__.__name__).from_json(test_str)
194+
self.assertEqual(test_new.timestamp, test.timestamp)
195+
self.assertEqual(test_new.data_content, test.data_content)
196+
self.assertEqual(test_new.data_type, test.data_type)
197+
self.assertEqual(test_new.is_final_result, test.is_final_result)
198+
169199

170200
if __name__ == "__main__":
171201
unittest.main()

0 commit comments

Comments
 (0)