Skip to content

ENH: Added Crop and Clip Methods to Function Class #817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Attention: The newest changes should be on top -->

### Added

- ENH: Added Crop and Clip Methods to Function Class [#817](https://github.com/RocketPy-Team/RocketPy/pull/817)
- ENH: Parallel mode for monte-carlo simulations 2 [#768](https://github.com/RocketPy-Team/RocketPy/pull/768)
- DOC: ASTRA Flight Example [#770](https://github.com/RocketPy-Team/RocketPy/pull/770)
- ENH: Add Eccentricity to Stochastic Simulations [#792](https://github.com/RocketPy-Team/RocketPy/pull/792)
Expand Down
185 changes: 179 additions & 6 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
self.__extrapolation__ = extrapolation
self.title = title
self.__img_dim__ = 1 # always 1, here for backwards compatibility
self.__interval__ = None # x interval for function if cropped

# args must be passed from self.
self.set_source(self.source)
Expand Down Expand Up @@ -627,8 +628,8 @@ def __get_value_opt_nd(self, *args):

def set_discrete(
self,
lower=0,
upper=10,
lower=None,
upper=None,
Comment on lines -630 to +632
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We gotta change the default value in the docstring as well. "Default is None" should be used.

samples=200,
interpolation="spline",
extrapolation="constant",
Expand Down Expand Up @@ -689,13 +690,33 @@ def set_discrete(
func = deepcopy(self) if not mutate_self else self

if func.__dom_dim__ == 1:
# Determine boundaries
domain = [0, 10]
if self.__interval__ is not None:
if self.__interval__[0] > domain[0]:
domain[0] = self.__interval__[0]
if self.__interval__[1] > domain[1]:
domain[1] = self.__interval__[1]
lower = domain[0] if lower is None else lower
upper = domain[1] if upper is None else upper
xs = np.linspace(lower, upper, samples)
ys = func.get_value(xs.tolist()) if one_by_one else func.get_value(xs)
func.__interpolation__ = interpolation
func.__extrapolation__ = extrapolation
func.set_source(np.column_stack((xs, ys)))
elif func.__dom_dim__ == 2:
# Determine boundaries
domain = [[0, 10], [0, 10]]
if self.__interval__ is not None:
for i in range(0, 2):
if self.__interval__[i] is not None:
if self.__interval__[i][0] > domain[i][0]:
domain[i][0] = self.__interval__[i][0]
if self.__interval__[i][1] < domain[i][1]:
domain[i][1] = self.__interval__[i][1]
lower = [domain[0][0], domain[1][0]] if lower is None else lower
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what you did is not wrong, just be aware there's a different way of telling the same thing:

Suggested change
lower = [domain[0][0], domain[1][0]] if lower is None else lower
lower = lower or [domain[0][0], domain[1][0]]

lower = 2 * [lower] if isinstance(lower, NUMERICAL_TYPES) else lower
upper = [domain[0][1], domain[1][1]] if upper is None else upper
upper = 2 * [upper] if isinstance(upper, NUMERICAL_TYPES) else upper
sam = 2 * [samples] if isinstance(samples, NUMERICAL_TYPES) else samples
# Create nodes to evaluate function
Expand Down Expand Up @@ -897,6 +918,144 @@ def reset(

return self

def crop(self, x_lim):
"""This method allows the user to limit the input values of the Function
to a certain range and delete all set of input and output pairs outside
the specified range of values.

Parameters
----------
x_lim : list of values,
Range of values with lower and upper limits for input values to be
cropped within.

Returns
-------
self : Function

See also
--------
Function.clip

Examples
--------
>>> from rocketpy import Function
>>> f1 = Function(lambda x1, x2: np.sin(x1)*np.cos(x2), inputs=['x1', 'x2'], outputs='y')
>>> f1
Function from R2 to R1 : (x1, x2) → (y)
>>> f1.crop([(-1, 1), (-2, 2)])
>>> f1.plot()
>>> f2 = Function(lambda x1, x2: np.cos(x1)*np.sin(x2), inputs=['x1', 'x2'], outputs='y')
>>> f2
Function from R2 to R1 : (x1, x2) → (y)
>>> f2.crop([None, (-2, 2)])
Copy link
Preview

Copilot AI May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider updating the crop() docstring to clarify that passing None for a dimension means no cropping constraint on that axis, as illustrated by the example.

Copilot uses AI. Check for mistakes.

>>> f2.plot()
"""
if not isinstance(x_lim, list):
raise TypeError("x_lim must be a list of tuples.")
if len(x_lim) > self.__dom_dim__:
raise ValueError(
"x_lim must not exceed the length of the domain dimension."
)
if isinstance(self.source, np.ndarray):
if self.__dom_dim__ == 1:
self.source = self.source[
(self.source[:, 0] >= x_lim[0][0])
& (self.source[:, 0] <= x_lim[0][1])
]
elif self.__dom_dim__ == 2:
self.source = self.source[
(self.source[:, 0] >= x_lim[0][0])
& (self.source[:, 0] <= x_lim[0][1])
& (self.source[:, 1] >= x_lim[1][0])
& (self.source[:, 1] <= x_lim[1][1])
]
if self.__dom_dim__ == 1:
if x_lim[0][0] < x_lim[0][1]:
self.__interval__ = x_lim[0]
elif self.__dom_dim__ == 2:
if len(x_lim) != 0:
if x_lim[0] is not None and x_lim[0][0] < x_lim[0][1]:
self.__interval__ = [x_lim[0]]
else:
self.__interval__ = [None]
if (
len(x_lim) >= 2
and x_lim[1] is not None
and x_lim[1][0] < x_lim[1][1]
):
self.__interval__.append(x_lim[1])
else:
self.__interval__.append(None)
else:
raise IndexError("x_lim must be of index 2 for 2-D function")
self.set_source(self.source)

return self

def clip(self, y_lim):
"""This method allows the user to limit the output values of the Function
to a certain range and delete all set of input and output pairs outside
the specified range of values.

Parameters
----------
y_lim : list of values,
Range of values with lower and upper limits for output values to be
clipped within.

Returns
-------
self : Function

See also
--------
Function.crop

Examples
--------
>>> from rocketpy import Function
>>> f = Function(lambda x: x**2, inputs='x', outputs='y')
>>> f
Function from R1 to R1 : (x) → (y)
>>> f.clip([(-5, 5)])
"""
if not isinstance(y_lim, list):
raise TypeError("y_lim must be a list of tuples.")
if len(y_lim) != len(self.__outputs__):
raise ValueError(
"y_lim must have the same length as the output dimensions."
)

if isinstance(self.source, np.ndarray):
self.source = self.source[
(self.source[:, self.__dom_dim__] >= y_lim[0][0])
& (self.source[:, self.__dom_dim__] <= y_lim[0][1])
]
elif callable(self.source):
if isinstance(self.source, NUMERICAL_TYPES):
try:
if self.source < y_lim[0][0]:
raise ArithmeticError("Constant function outside range")
if self.source > y_lim[0][1]:
raise ArithmeticError("Constant function outside range")
except TypeError as e:
raise TypeError("y_lim must be same type as function source") from e
else:
f = self.source
self.source = lambda x: max(y_lim[0][0], min(y_lim[0][1], f(x)))
Comment on lines +1035 to +1046
Copy link
Preview

Copilot AI May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Review the constant-function logic in the clip() method; the branch checking for NUMERICAL_TYPES on a callable source may be ambiguous. Consider refactoring this logic to clearly distinguish constant functions from other callables.

Suggested change
elif callable(self.source):
if isinstance(self.source, NUMERICAL_TYPES):
try:
if self.source < y_lim[0][0]:
raise ArithmeticError("Constant function outside range")
if self.source > y_lim[0][1]:
raise ArithmeticError("Constant function outside range")
except TypeError as e:
raise TypeError("y_lim must be same type as function source") from e
else:
f = self.source
self.source = lambda x: max(y_lim[0][0], min(y_lim[0][1], f(x)))
elif isinstance(self.source, NUMERICAL_TYPES):
try:
if self.source < y_lim[0][0]:
raise ArithmeticError("Constant function outside range")
if self.source > y_lim[0][1]:
raise ArithmeticError("Constant function outside range")
except TypeError as e:
raise TypeError("y_lim must be the same type as the function source") from e
elif callable(self.source):
f = self.source
self.source = lambda x: max(y_lim[0][0], min(y_lim[0][1], f(x)))

Copilot uses AI. Check for mistakes.

try:
self.set_source(self.source)
except ValueError as e:
raise ValueError(
"Cannot clip function as function reduces to "
f"{len(self.source)} points (too few data points to define"
" a domain). Number of rows must be equal to number of "
"columns after applying clipping function."
) from e

return self

# Define all get methods
def get_inputs(self):
"Return tuple of inputs of the function."
Expand Down Expand Up @@ -1525,8 +1684,14 @@ def plot_1d( # pylint: disable=too-many-statements
ax = fig.axes
if self._source_type is SourceType.CALLABLE:
# Determine boundaries
lower = 0 if lower is None else lower
upper = 10 if upper is None else upper
domain = [0, 10]
if self.__interval__ is not None:
if self.__interval__[0] > domain[0]:
domain[0] = self.__interval__[0]
if self.__interval__[1] > domain[1]:
domain[1] = self.__interval__[1]
lower = domain[0] if lower is None else lower
upper = domain[1] if upper is None else upper
else:
# Determine boundaries
x_data = self.x_array
Expand Down Expand Up @@ -1637,9 +1802,17 @@ def plot_2d( # pylint: disable=too-many-statements
# Define a mesh and f values at mesh nodes for plotting
if self._source_type is SourceType.CALLABLE:
# Determine boundaries
lower = [0, 0] if lower is None else lower
domain = [[0, 10], [0, 10]]
if self.__interval__ is not None:
for i in range(0, 2):
if self.__interval__[i] is not None:
if self.__interval__[i][0] > domain[i][0]:
domain[i][0] = self.__interval__[i][0]
if self.__interval__[i][1] < domain[i][1]:
domain[i][1] = self.__interval__[i][1]
lower = [domain[0][0], domain[1][0]] if lower is None else lower
lower = 2 * [lower] if isinstance(lower, NUMERICAL_TYPES) else lower
upper = [10, 10] if upper is None else upper
upper = [domain[0][1], domain[1][1]] if upper is None else upper
upper = 2 * [upper] if isinstance(upper, NUMERICAL_TYPES) else upper
else:
# Determine boundaries
Expand Down
107 changes: 107 additions & 0 deletions tests/unit/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,113 @@ def test_set_discrete_based_on_model_non_mutator(linear_func):
assert callable(func.source)


source_array = np.array(
[
[-2, -4, -6],
[-0.75, -1.5, -2.25],
[0, 0, 0],
[0, 1, 1],
[0.5, 1, 1.5],
[1.5, 1, 2.5],
[2, 4, 6],
]
)
cropped_array = np.array([[-0.75, -1.5, -2.25], [0, 0, 0], [0, 1, 1], [0.5, 1, 1.5]])
clipped_array = np.array([[0, 0, 0], [0, 1, 1], [0.5, 1, 1.5]])


@pytest.mark.parametrize(
"array3dsource, array3dcropped",
[
(source_array, cropped_array),
],
)
def test_crop_ndarray(array3dsource, array3dcropped): # pylint: disable=unused-argument
"""Tests the functionality of crop method of the Function class.
The source is initialized as a ndarray before cropping.
"""
func = Function(array3dsource, inputs=["x1", "x2"], outputs="y")
cropped_func = func.crop([(-1, 1), (-2, 2)]) # pylint: disable=unused-argument

assert isinstance(func, Function)
assert isinstance(cropped_func, Function)
assert np.array_equal(cropped_func.source, array3dcropped)
assert isinstance(cropped_func.source, type(func.source))


def test_crop_function():
"""Tests the functionality of crop method of the Function class.
The source is initialized as a function before cropping.
"""
func = Function(
lambda x1, x2: np.sin(x1) * np.cos(x2), inputs=["x1", "x2"], outputs="y"
)
cropped_func = func.crop([(-1, 1), (-2, 2)])

assert isinstance(func, Function)
assert isinstance(cropped_func, Function)
assert callable(func.source)
assert callable(cropped_func.source)


def test_crop_constant():
"""Tests the functionality of crop method of the Function class.
The source is initialized as a single integer constant before cropping.
"""
func = Function(13)
cropped_func = func.crop([(-1, 1)])

assert isinstance(func, Function)
assert isinstance(cropped_func, Function)
assert callable(func.source)
assert callable(cropped_func.source)


@pytest.mark.parametrize(
"array3dsource, array3dclipped",
[
(source_array, clipped_array),
],
)
def test_clip_ndarray(array3dsource, array3dclipped): # pylint: disable=unused-argument
"""Tests the functionality of clip method of the Function class.
The source is initialized as a ndarray before cropping.
"""
func = Function(array3dsource, inputs=["x1", "x2"], outputs="y")
clipped_func = func.clip([(-2, 2)]) # pylint: disable=unused-argument

assert isinstance(func, Function)
assert isinstance(clipped_func, Function)
assert np.array_equal(clipped_func.source, array3dclipped)
assert isinstance(clipped_func.source, type(func.source))


def test_clip_function():
"""Tests the functionality of clip method of the Function class.
The source is initialized as a function before clipping.
"""
func = Function(lambda x: x**2, inputs="x", outputs="y")
clipped_func = func.clip([(-1, 1)])

assert isinstance(func, Function)
assert isinstance(clipped_func, Function)
assert callable(func.source)
assert callable(clipped_func.source)


def test_clip_constant():
"""Tests the functionality of clip method of the Function class.
The source is initialized as a single integer constant before clipping.
"""
func = Function(1)
clipped_func = func.clip([(-2, 2)])

assert isinstance(func, Function)
assert isinstance(clipped_func, Function)
assert callable(func.source)
assert callable(clipped_func.source)


@pytest.mark.parametrize(
"x, y, expected_x, expected_y",
[
Expand Down
Loading