Skip to content

Commit f10ba6c

Browse files
authored
Merge pull request #1287 from samplise/diagnosis-agent-observe
Refactor diagnosis agent
2 parents 2a1a3f5 + cf40a3c commit f10ba6c

File tree

9 files changed

+312
-96
lines changed

9 files changed

+312
-96
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2024 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from typing import List
15+
16+
17+
class DiagnoseAction:
18+
def __init__(self):
19+
self._actions: List[str] = []
20+
21+
def add_action(self, action: str):
22+
self._actions.append(action)

dlrover/python/diagnosis/common/inference_chain.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,20 @@ class InferenceName:
2020
END = "end"
2121
TRAINING = "training"
2222
NODE = "node"
23+
WORKER = "worker"
2324

2425

2526
class InferenceAttribute:
2627
ISORNOT = "is_or_not"
2728
IS = "is"
2829
NOT = "not"
30+
COLLECT = "collect"
2931

3032

3133
class InferenceDescription:
3234
HANG = "hang"
3335
FAILURE = "failure"
36+
METRICS = "metrics"
3437

3538

3639
@dataclass
@@ -92,12 +95,7 @@ def combine_inferences(
9295
) -> List[Inference]:
9396
inferences = []
9497
for inference2 in inferences2:
95-
is_duplicate = False
96-
for inference1 in inferences1:
97-
if is_same_inference(inference1, inference2):
98-
is_duplicate = True
99-
break
100-
if not is_duplicate:
98+
if not is_inference_included(inferences1, inference2):
10199
inferences.append(inference2)
102100

103101
for inference1 in inferences1:
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2024 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from typing import List
15+
16+
from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction
17+
from dlrover.python.diagnosis.common.inference_chain import Inference
18+
19+
20+
def coordinate_inferences(observations: List[Inference]) -> DiagnoseAction:
21+
return DiagnoseAction()
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2024 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from typing import List
15+
16+
from dlrover.python.common import env_utils
17+
from dlrover.python.diagnosis.common.constants import DiagnosisDataType
18+
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
19+
from dlrover.python.diagnosis.common.inference_chain import (
20+
Inference,
21+
InferenceAttribute,
22+
InferenceDescription,
23+
InferenceName,
24+
InferenceOperator,
25+
)
26+
from dlrover.python.diagnosis.datacollector.xpu_timer_metric_collector import (
27+
XpuTimerMetricsCollector,
28+
)
29+
from dlrover.python.elastic_agent.master_client import MasterClient
30+
31+
32+
class MetricsCollectionOperator(InferenceOperator):
33+
"""
34+
MetricsCollectionOperator is the operator to collect
35+
worker diagnosis metrics.
36+
"""
37+
38+
def __init__(self):
39+
super().__init__(None)
40+
self._xpu_timer_collector = XpuTimerMetricsCollector()
41+
self._client = MasterClient.singleton_instance()
42+
43+
def is_compatible(self, inference: Inference) -> bool:
44+
if (
45+
inference.name == InferenceName.WORKER
46+
and inference.attribution == InferenceAttribute.COLLECT
47+
and inference.description == InferenceDescription.METRICS
48+
):
49+
return True
50+
else:
51+
return False
52+
53+
def infer(self, inferences: List[Inference]) -> List[Inference]:
54+
xpu_timer_metric = self._xpu_timer_collector.collect_data()
55+
if xpu_timer_metric:
56+
agent_xpu_metric = WorkerTrainingMetric(
57+
data_type=DiagnosisDataType.XPU_TIMER_METRIC,
58+
data_content=xpu_timer_metric,
59+
node_id=env_utils.get_node_id(),
60+
node_type=env_utils.get_node_type(),
61+
node_rank=env_utils.get_node_rank(),
62+
)
63+
self._client.report_diagnosis_agent_metrics(agent_xpu_metric)
64+
65+
return []

dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,18 @@
1414
from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_failure_node_operator import ( # noqa: E501
1515
CheckFailureNodeOperator,
1616
)
17+
from dlrover.python.diagnosis.inferencechain.inferenceoperator.metrics_collection_operator import ( # noqa: E501
18+
MetricsCollectionOperator,
19+
)
1720

1821

1922
def get_training_failure_operators():
2023
return [CheckFailureNodeOperator()]
24+
25+
26+
def get_worker_observe_operators():
27+
return [MetricsCollectionOperator()]
28+
29+
30+
def get_worker_diagnosis_operators():
31+
return []

dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
import threading
1616
import time
1717
from datetime import datetime
18-
from typing import Dict
18+
from typing import Dict, List
1919

2020
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
2121

22-
from dlrover.python.common import env_utils
2322
from dlrover.python.common.constants import TrainingExceptionLevel
2423
from dlrover.python.common.error import ProcessError
2524
from dlrover.python.common.log import default_logger as logger
@@ -28,25 +27,28 @@
2827
from dlrover.python.diagnosis.common.constants import (
2928
DiagnosisAction,
3029
DiagnosisConstant,
31-
DiagnosisDataType,
3230
InferenceConfigKey,
3331
)
32+
from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction
3433
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
3534
from dlrover.python.diagnosis.common.inference_chain import (
3635
Inference,
3736
InferenceAttribute,
3837
InferenceDescription,
3938
InferenceName,
39+
combine_inferences,
4040
is_inference_included,
4141
)
42-
from dlrover.python.diagnosis.datacollector.xpu_timer_metric_collector import (
43-
XpuTimerMetricsCollector,
42+
from dlrover.python.diagnosis.inferencechain.coordinator import (
43+
coordinate_inferences,
4444
)
4545
from dlrover.python.diagnosis.inferencechain.inference_chain import (
4646
InferenceChain,
4747
)
4848
from dlrover.python.diagnosis.inferencechain.inferenceoperator.operator import ( # noqa: E501
4949
get_training_failure_operators,
50+
get_worker_diagnosis_operators,
51+
get_worker_observe_operators,
5052
)
5153
from dlrover.python.elastic_agent.master_client import MasterClient
5254

@@ -56,8 +58,16 @@ def __init__(self, training_log_file: str, errors: str):
5658
self._client = MasterClient.singleton_instance()
5759
self._training_log_file = training_log_file
5860
self._errors = errors
59-
self._xpu_timer_metric_collector = XpuTimerMetricsCollector()
6061
self._stopped = False
62+
self._observe_problems: List[Inference] = [
63+
Inference(
64+
name=InferenceName.WORKER,
65+
attribution=InferenceAttribute.COLLECT,
66+
description=InferenceDescription.METRICS,
67+
),
68+
]
69+
self._observe_operators = get_worker_observe_operators()
70+
self._diagnosis_operators = get_worker_diagnosis_operators()
6171

6272
self.start()
6373

@@ -81,23 +91,43 @@ def start(self):
8191
def stop(self):
8292
self._stopped = True
8393

94+
def _observe(self) -> List[Inference]:
95+
observations: List[Inference] = []
96+
for problem in self._observe_problems:
97+
ic = InferenceChain([problem], self._observe_operators)
98+
try:
99+
infs = ic.infer()
100+
if len(infs) > 0:
101+
observations = combine_inferences(observations, infs)
102+
except Exception as e:
103+
logger.error(f"fail to observe problem {problem}: {e}")
104+
return observations
105+
106+
def _diagnose_observations(
107+
self, observations: List[Inference]
108+
) -> DiagnoseAction:
109+
conclusions: List[Inference] = []
110+
for ob in observations:
111+
ic = InferenceChain([ob], self._diagnosis_operators)
112+
try:
113+
infs = ic.infer()
114+
if len(infs) > 0:
115+
conclusions = combine_inferences(conclusions, infs)
116+
except Exception as e:
117+
logger.error(f"fail to diagnose observation {ob}: {e}")
118+
return coordinate_inferences(conclusions)
119+
84120
def _periodically_diagnosis(self):
85121
logger.info("Start periodically diagnosis...")
86122
while True:
87123
if self._stopped:
88124
logger.info("Stop periodically diagnosis.")
89125
break
90126

91-
xpu_timer_metric = self._xpu_timer_metric_collector.collect_data()
92-
if xpu_timer_metric:
93-
agent_xpu_metric = WorkerTrainingMetric(
94-
data_type=DiagnosisDataType.XPU_TIMER_METRIC,
95-
data_content=xpu_timer_metric,
96-
node_id=env_utils.get_node_id(),
97-
node_type=env_utils.get_node_type(),
98-
node_rank=env_utils.get_node_rank(),
99-
)
100-
self._report_metric_to_master(agent_xpu_metric)
127+
observations = self._observe()
128+
if len(observations) > 0:
129+
logger.info(f"Observed problems: {observations}")
130+
self._diagnose_observations(observations)
101131

102132
time.sleep(
103133
DiagnosisConstant.AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS

dlrover/python/tests/test_diagnosis_agent.py

Lines changed: 3 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,15 @@
1313

1414
import os
1515
import unittest
16-
from unittest.mock import patch
1716

1817
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
1918
from torch.distributed.launcher.api import LaunchConfig
2019

2120
from dlrover.python.common import env_utils
22-
from dlrover.python.common.constants import NodeEnv, NodeType, RendezvousName
21+
from dlrover.python.common.constants import RendezvousName
2322
from dlrover.python.common.worker import WorkerContext
24-
from dlrover.python.diagnosis.common.constants import (
25-
DiagnosisAction,
26-
DiagnosisDataType,
27-
EnvConfigKey,
28-
)
23+
from dlrover.python.diagnosis.common.constants import DiagnosisAction
2924
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
30-
from dlrover.python.diagnosis.datacollector.training_log_collector import (
31-
TrainingLogCollector,
32-
)
33-
from dlrover.python.diagnosis.datacollector.xpu_timer_metric_collector import (
34-
XpuTimerMetricsCollector,
35-
)
3625
from dlrover.python.elastic_agent.diagnosis.diagnosis_agent import (
3726
DiagnosisAgent,
3827
)
@@ -49,7 +38,7 @@
4938

5039
class TestDiagnosisAgent(unittest.TestCase):
5140
def setUp(self):
52-
self.master_proc, self.addr = start_local_master()
41+
self._master, self.addr = start_local_master()
5342
MasterClient._instance = build_master_client(self.addr, 1)
5443
launch_config = LaunchConfig(
5544
min_nodes=1,
@@ -109,63 +98,6 @@ def test_diagnose_training(self):
10998
action = agent.diagnose_training_failure(wc)
11099
self.assertEqual(action, DiagnosisAction.RESTART_WORKER)
111100

112-
@patch(
113-
"dlrover.python.diagnosis.datacollector.training_log_collector"
114-
".read_last_n_lines"
115-
)
116-
def test_log_collect(self, mock_file_util):
117-
mock_file_util.return_value = [
118-
"test0",
119-
"DLRover agent started with:",
120-
"test1",
121-
]
122-
training_log_collector = TrainingLogCollector(
123-
log_file="test", n_line=3
124-
)
125-
self.assertTrue(training_log_collector.is_enabled())
126-
result = training_log_collector.collect_data()
127-
self.assertTrue("test0" not in result.logs)
128-
self.assertTrue("test1" in result.logs)
129-
130-
def test_xpu_timer_metric_collect(self):
131-
collector = XpuTimerMetricsCollector()
132-
self.assertFalse(collector.is_enabled())
133-
134-
env_utils.set_env(EnvConfigKey.XPU_TIMER_PORT, 18889)
135-
collector = XpuTimerMetricsCollector()
136-
self.assertTrue(collector.is_enabled())
137-
138-
self.assertEqual(collector.collect_data(), "")
139-
140-
file = "data/xpu_timer_metrics"
141-
file_path = os.path.join(os.path.dirname(__file__), file)
142-
with open(file_path, "r", encoding="utf-8") as file:
143-
test_metrics = file.read()
144-
result = collector._preprocess_metrics(test_metrics)
145-
self.assertTrue(result)
146-
if "#" in result or "exposer" in result:
147-
self.fail()
148-
149-
env_utils.set_env(NodeEnv.NODE_ID, 1)
150-
env_utils.set_env(NodeEnv.NODE_TYPE, NodeType.WORKER)
151-
env_utils.set_env(NodeEnv.NODE_RANK, 1)
152-
agent_xpu_metric = WorkerTrainingMetric(
153-
data_type=DiagnosisDataType.XPU_TIMER_METRIC,
154-
data_content=result,
155-
node_id=env_utils.get_node_id(),
156-
node_type=env_utils.get_node_type(),
157-
node_rank=env_utils.get_node_rank(),
158-
)
159-
self.assertEqual(
160-
agent_xpu_metric.data_type,
161-
DiagnosisDataType.XPU_TIMER_METRIC,
162-
)
163-
self.assertEqual(agent_xpu_metric.data_content, result)
164-
self.assertEqual(agent_xpu_metric.node_id, 1)
165-
self.assertEqual(agent_xpu_metric.node_type, NodeType.WORKER)
166-
self.assertEqual(agent_xpu_metric.node_rank, 1)
167-
self.assertTrue(agent_xpu_metric.timestamp > 0)
168-
169101
def test_worker_training_metric(self):
170102
test = WorkerTrainingMetric(
171103
data_content="test123",

0 commit comments

Comments
 (0)