diff --git a/nc_time_axis/__init__.py b/nc_time_axis/__init__.py index a5276ad..f3ef15f 100644 --- a/nc_time_axis/__init__.py +++ b/nc_time_axis/__init__.py @@ -453,6 +453,11 @@ def has_year_zero(year): ticks = [t for t in ticks if t.year != 0] return cftime.date2num(ticks, self.date_unit, calendar=self.calendar) + def set_params(self, **kwargs): + self._max_n_locator_days.set_params(**kwargs) + self._max_n_locator.set_params(**kwargs) + return + class NetCDFTimeConverter(mdates.DateConverter): """ diff --git a/nc_time_axis/tests/integration/test_plot.py b/nc_time_axis/tests/integration/test_plot.py index 7698e91..911f9b1 100644 --- a/nc_time_axis/tests/integration/test_plot.py +++ b/nc_time_axis/tests/integration/test_plot.py @@ -78,6 +78,16 @@ def test_fill_between(self): plt.fill_between(cdt, temperatures, 0) + def test_locator_params(self): + times = cftime.num2date( + np.arange(30), nc_time_axis._TIME_UNITS, calendar="360_day" + ) + plt.plot(times, np.arange(30)) + plt.locator_params(axis="x", min_n_ticks=15) + plt.draw() + ticks = plt.xticks()[0] + self.assertFalse(ticks.size < 15) + def setup_function(function): plt.close() diff --git a/nc_time_axis/tests/unit/test_NetCDFTimeDateLocator.py b/nc_time_axis/tests/unit/test_NetCDFTimeDateLocator.py index b29b7de..af632fc 100644 --- a/nc_time_axis/tests/unit/test_NetCDFTimeDateLocator.py +++ b/nc_time_axis/tests/unit/test_NetCDFTimeDateLocator.py @@ -185,5 +185,50 @@ def test_NetCDFTimeDateLocator_date_unit_warning(): NetCDFTimeDateLocator(5, "360_day", "days since 2000-01-01") +class Test_set_params(unittest.TestCase): + def setUp(self): + # list of maxs to trigger different resolutions + self.vmax_list = [0.0003, 0.02, 1, 30, 365, 5000] + self.params = {"nbins": 10, "min_n_ticks": 4} + + def check(self, max_n_ticks, **kwargs): + # Create an instance of your class + locator = NetCDFTimeDateLocator( + max_n_ticks=max_n_ticks, calendar="gregorian" + ) + # Call the set_params method + locator.set_params(**kwargs) + return locator + + def test_set_params(self): + for key, value in self.params.items(): + for vmax in self.vmax_list: + locator = self.check(3, **{key: value}) + ticks = locator.tick_values(0, vmax) + # Assert that the expected values are set and returned + if key == "nbins": + # not more than max + 1 + self.assertFalse(ticks.size > value + 1) + # make sure number of ticks increased from initial + self.assertTrue(ticks.size > 4) + elif key == "min_n_ticks": + # not less than min + self.assertFalse(ticks.size < value) + else: + raise ValueError( + "tests on parameters other than nbins and min_n_ticks are not" + " yet implemented" + ) + + # Add more assertions to test the behavior of the method + # For example, you can assert that the internal state of the locator is updated correctly + self.assertEqual( + getattr(locator._max_n_locator_days, f"_{key}"), value + ) + self.assertEqual( + getattr(locator._max_n_locator, f"_{key}"), value + ) + + if __name__ == "__main__": unittest.main()