Skip to content

Commit cb2ce52

Browse files
committed
add support of old style "experiment_name" argument
1 parent a804240 commit cb2ce52

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

src/lightning/pytorch/loggers/comet.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class CometLogger(Logger):
6767
workspace="COMET_WORKSPACE", # Optional
6868
project="default_project", # Optional
6969
experiment_key="COMET_EXPERIMENT_KEY", # Optional
70-
experiment_name="lightning_logs", # Optional
70+
name="lightning_logs", # Optional
7171
)
7272
trainer = Trainer(logger=comet_logger)
7373
@@ -81,7 +81,7 @@ class CometLogger(Logger):
8181
comet_logger = CometLogger(
8282
workspace="COMET_WORKSPACE", # Optional
8383
project="default_project", # Optional
84-
experiment_name="lightning_logs", # Optional
84+
name="lightning_logs", # Optional
8585
online=False
8686
)
8787
trainer = Trainer(logger=comet_logger)
@@ -186,7 +186,7 @@ def __init__(self, *args, **kwarg):
186186
locally in an offline experiment. Default is ``True``.
187187
prefix (str, optional): The prefix to add to names of the logged metrics.
188188
example: prefix=`exp1`, then metric name will be `exp1_metric_name`
189-
**kwargs: Additional arguments like `experiment_name`, `log_code`, `offline_directory` etc. used by
189+
**kwargs: Additional arguments like `name`, `log_code`, `offline_directory` etc. used by
190190
:class:`CometExperiment` can be passed as keyword arguments in this logger.
191191
192192
Raises:
@@ -215,7 +215,20 @@ def __init__(
215215
##################################################
216216
# HANDLE PASSED OLD TYPE PARAMS
217217

218-
# handle old "project name" param
218+
# handle old "experiment_name" param
219+
if "experiment_name" in kwargs:
220+
log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.")
221+
experiment_name = kwargs.pop("experiment_name")
222+
223+
if "name" not in kwargs:
224+
kwargs["name"] = experiment_name
225+
else:
226+
log.warning(
227+
"You specified both `experiment_name` and `name` parameters, "
228+
"please use `name` only"
229+
)
230+
231+
# handle old "project_name" param
219232
if "project_name" in kwargs:
220233
log.warning("The parameter `project_name` is deprecated, please use `project` instead.")
221234
if project is None:

tests/tests_pytorch/loggers/test_comet.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def test_comet_logger_experiment_name(comet_mock):
100100

101101
comet_start = comet_mock.start
102102

103+
# here we use old style arg "experiment_name" (new one is "name")
103104
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)
104105
comet_start.assert_called_once_with(
105106
api_key=api_key,
@@ -110,11 +111,13 @@ def test_comet_logger_experiment_name(comet_mock):
110111
online=None,
111112
experiment_config=comet_mock.ExperimentConfig(),
112113
)
113-
# check that we saved "experiment name" in kwargs
114-
assert logger._kwargs["experiment_name"] == experiment_name
114+
# check that we saved "experiment name" in kwargs as new "name" arg
115+
assert logger._kwargs["name"] == experiment_name
116+
assert "experiment_name" not in logger._kwargs
115117

116-
# check that "experiment name" was passed to experiment config
117-
assert call(experiment_name=experiment_name) in comet_mock.ExperimentConfig.call_args_list
118+
# check that "experiment name" was passed to experiment config correctly
119+
assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list
120+
assert call(name=experiment_name) in comet_mock.ExperimentConfig.call_args_list
118121

119122

120123
@mock.patch.dict(os.environ, {})
@@ -123,7 +126,7 @@ def test_comet_version(comet_mock):
123126
api_key = "key"
124127
experiment_name = "My Name"
125128

126-
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)
129+
logger = CometLogger(api_key=api_key, name=experiment_name)
127130
assert logger._experiment is not None
128131
_ = logger.version
129132

0 commit comments

Comments
 (0)