Skip to content

Commit 87cbc68

Browse files
committed
not allowing naked erm be combined with fbopt
1 parent 3eed766 commit 87cbc68

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

domainlab/algos/trainers/fbopt_mu_controller.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def __init__(self, trainer, **kwargs):
4545
self.mu_min = trainer.aconf.mu_min
4646
self.mu_clip = trainer.aconf.mu_clip
4747

48+
if not kwargs:
49+
raise RuntimeError("feedback scheduler requires **kwargs, the set \
50+
of multipliers non-empty")
4851
self.mmu = kwargs
4952
# force initial value of mu
5053
self.mmu = {key: self.init_mu for key, val in self.mmu.items()}

tests/test_fbopt.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
unit and end-end test for deep all, mldg
33
"""
4+
import pytest
45
from tests.utils_test import utils_test_algo
56

67

@@ -27,13 +28,24 @@ def test_diva_fbopt():
2728
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3"
2829
utils_test_algo(args)
2930

31+
3032
def test_erm_fbopt():
3133
"""
3234
erm
3335
"""
3436
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3" # pylint: disable=line-too-long
37+
with pytest.raises(RuntimeError):
38+
utils_test_algo(args)
39+
40+
41+
def test_irm_fbopt():
42+
"""
43+
irm
44+
"""
45+
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt_irm --nname=alexnet --epos=3" # pylint: disable=line-too-long
3546
utils_test_algo(args)
3647

48+
3749
def test_forcesetpoint_fbopt():
3850
"""
3951
diva

0 commit comments

Comments
 (0)