Skip to content

Commit 892b8e4

Browse files
authored
Merge pull request #373 from Ipuch/solParameters
Displaying optimized parameters in sol.print()
2 parents 2738197 + 59f5dd2 commit 892b8e4

File tree

7 files changed

+70
-50
lines changed

7 files changed

+70
-50
lines changed

bioptim/dynamics/configure_problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def configure_tau(nlp, as_states: bool, as_controls: bool, fatigue: FatigueList
502502
)
503503

504504
for p, params in enumerate(fatigue_suffix):
505-
name = f"tau_{params}_{tau_suffix}"
505+
name = f"tau_{tau_suffix}_{params}"
506506
ConfigureProblem._adjust_mapping(name, ["q"], nlp)
507507
ConfigureProblem.configure_new_variable(name, name_tau, nlp, True, False, skip_plot=True)
508508
nlp.plot[f"{name}_controls"] = CustomPlot(

bioptim/dynamics/dynamics_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __get_fatigable_tau(nlp: NonLinearProgram, states: MX, controls: MX, fatigue
159159
for suffix in tau_suffix:
160160
model = getattr(t.model, suffix)
161161
tau_tp += (
162-
DynamicsFunctions.get(nlp.states[f"tau_{model.dynamics_suffix()}_{suffix}"], states)[i]
162+
DynamicsFunctions.get(nlp.states[f"tau_{suffix}_{model.dynamics_suffix()}"], states)[i]
163163
* model.scale
164164
)
165165
tau = vertcat(tau, tau_tp)

bioptim/dynamics/xia_fatigue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,13 @@ def _dynamics_per_suffix(self, dxdt, suffix, nlp, index, states, controls):
121121
var = getattr(self, suffix)
122122
target_load = self._get_target_load(var, suffix, nlp, controls, index)
123123
fatigue = [
124-
DynamicsFunctions.get(nlp.states[f"tau_{dyn_suffix}_{suffix}"], states)[index, :]
124+
DynamicsFunctions.get(nlp.states[f"tau_{suffix}_{dyn_suffix}"], states)[index, :]
125125
for dyn_suffix in var.suffix()
126126
]
127127
current_dxdt = var.apply_dynamics(target_load, *fatigue)
128128

129129
for i, dyn_suffix in enumerate(var.suffix()):
130-
dxdt[nlp.states[f"tau_{dyn_suffix}_{suffix}"].index[index], :] = current_dxdt[i]
130+
dxdt[nlp.states[f"tau_{suffix}_{dyn_suffix}"].index[index], :] = current_dxdt[i]
131131

132132
return dxdt
133133

bioptim/gui/plot.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
bounds: Bounds = None,
5959
node_idx: list = None,
6060
label: list = None,
61-
manually_compute_derivative: bool = False,
61+
compute_derivative: bool = False,
6262
**parameters: Any,
6363
):
6464
"""
@@ -86,7 +86,7 @@ def __init__(
8686
The node time to be plotted on the graphs
8787
label: list
8888
Label of the curve to plot (to be added to the legend)
89-
manually_compute_derivative: bool
89+
compute_derivative: bool
9090
If the function should send the next node with x and u. Prevents from computing all at once (therefore a bit slower)
9191
"""
9292

@@ -108,7 +108,7 @@ def __init__(
108108
self.bounds = bounds
109109
self.node_idx = node_idx
110110
self.label = label
111-
self.manually_compute_derivative = manually_compute_derivative
111+
self.compute_derivative = compute_derivative
112112
self.parameters = parameters
113113

114114

@@ -614,21 +614,14 @@ def update_data(self, v: dict):
614614
else:
615615
control = np.concatenate((control, data_controls[s]))
616616

617-
if nlp.control_type == ControlType.CONSTANT:
618-
u_mod = 1
619-
elif nlp.control_type == ControlType.LINEAR_CONTINUOUS:
620-
u_mod = 2
621-
else:
622-
raise NotImplementedError(f"Plotting {nlp.control_type} is not implemented yet")
623-
624617
for key in self.variable_sizes[i]:
625618
if not self.plot_func[key][i]:
626619
continue
627-
# Automatically find u_modifier if the function is a casadi function otherwise fallback to default
628-
u_mod2 = (
629-
self.plot_func[key][i].function.size2_in(1)
630-
if hasattr(self.plot_func[key][i].function, "size2_in")
631-
else u_mod
620+
x_mod = 1 if self.plot_func[key][i].compute_derivative else 0
621+
u_mod = (
622+
1
623+
if nlp.control_type == ControlType.LINEAR_CONTINUOUS or self.plot_func[key][i].compute_derivative
624+
else 0
632625
)
633626

634627
if self.plot_func[key][i].type == PlotType.INTEGRATED:
@@ -637,17 +630,21 @@ def update_data(self, v: dict):
637630
y_tp = np.empty((self.variable_sizes[i][key], len(t)))
638631
y_tp.fill(np.nan)
639632

640-
mod = 1 if self.plot_func[key][i].manually_compute_derivative else 0
641633
val = self.plot_func[key][i].function(
642634
idx,
643-
state[:, step_size * idx : step_size * (idx + 1) + mod],
644-
control[:, idx : idx + u_mod2 + 1],
635+
state[:, step_size * idx : step_size * (idx + 1) + x_mod],
636+
control[:, idx : idx + u_mod + 1],
645637
data_params_in_dyn,
646638
**self.plot_func[key][i].parameters,
647639
)
640+
641+
if self.plot_func[key][i].compute_derivative:
642+
# This is a special case since derivative is not properly integrated
643+
val = np.repeat(val, y_tp.shape[1])[np.newaxis, :]
644+
648645
if val.shape != y_tp.shape:
649646
raise RuntimeError(
650-
f"Wrong dimensions for plot {key}. Got {val.shape}, but expected {y.shape}"
647+
f"Wrong dimensions for plot {key}. Got {val.shape}, but expected {y_tp.shape}"
651648
)
652649
y_tp[:, :] = val
653650
all_y.append(y_tp)
@@ -659,24 +656,25 @@ def update_data(self, v: dict):
659656
self.__append_to_ydata([y_tp])
660657

661658
elif self.plot_func[key][i].type == PlotType.POINT:
662-
y = np.empty((len(self.plot_func[key][i].node_idx),))
663-
y.fill(np.nan)
664-
mod = 1 if self.plot_func[key][i].manually_compute_derivative else 0
665-
for i_node, node_idx in enumerate(self.plot_func[key][i].node_idx):
666-
val = self.plot_func[key][i].function(
667-
node_idx,
668-
state[:, node_idx * step_size : (node_idx + 1) * step_size + mod : step_size],
669-
control[:, node_idx : node_idx + 1 + mod],
670-
data_params_in_dyn,
671-
**self.plot_func[key][i].parameters,
672-
)
673-
y[i_node] = val
674-
self.ydata.append(y)
659+
for i_var in range(self.variable_sizes[i][key]):
660+
y = np.empty((len(self.plot_func[key][i].node_idx),))
661+
y.fill(np.nan)
662+
mod = 1 if self.plot_func[key][i].compute_derivative else 0
663+
for i_node, node_idx in enumerate(self.plot_func[key][i].node_idx):
664+
val = self.plot_func[key][i].function(
665+
node_idx,
666+
state[:, node_idx * step_size : (node_idx + 1) * step_size + mod : step_size],
667+
control[:, node_idx : node_idx + 1 + mod],
668+
data_params_in_dyn,
669+
**self.plot_func[key][i].parameters,
670+
)
671+
y[i_node] = val[i_var]
672+
self.ydata.append(y)
675673

676674
else:
677675
y = np.empty((self.variable_sizes[i][key], len(self.t[i])))
678676
y.fill(np.nan)
679-
if self.plot_func[key][i].manually_compute_derivative:
677+
if self.plot_func[key][i].compute_derivative:
680678
for i_node, node_idx in enumerate(self.plot_func[key][i].node_idx):
681679
val = self.plot_func[key][i].function(
682680
node_idx,
@@ -687,8 +685,18 @@ def update_data(self, v: dict):
687685
)
688686
y[:, i_node] = val
689687
else:
688+
nodes = self.plot_func[key][i].node_idx
689+
if nodes and len(nodes) > 1 and len(nodes) == round(state.shape[1] / step_size):
690+
# Assume we are integrating but did not specify plot as such.
691+
# Therefore the arrival point is missing
692+
nodes += [nodes[-1] + 1]
693+
690694
val = self.plot_func[key][i].function(
691-
i, state[:, ::step_size], control, data_params_in_dyn, **self.plot_func[key][i].parameters
695+
nodes,
696+
state[:, ::step_size],
697+
control,
698+
data_params_in_dyn,
699+
**self.plot_func[key][i].parameters,
692700
)
693701
if val.shape != y.shape:
694702
raise RuntimeError(

bioptim/limits/penalty_option.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __init__(
161161
self.weight = weight
162162
self.function: Union[Function, None] = None
163163
self.weighted_function: Union[Function, None] = None
164+
self.weighted_function_non_threaded: Union[Function, None] = None
164165
self.derivative = derivative
165166
self.explicit_derivative = explicit_derivative
166167
self.integrate = integrate
@@ -337,6 +338,7 @@ def _set_penalty_function(self, all_pn: Union[PenaltyNodeList, list, tuple], fcn
337338
self.weighted_function = Function(
338339
name, [state_cx, control_cx, param_cx, weight_cx, target_cx, dt_cx], [modified_fcn]
339340
)
341+
self.weighted_function_non_threaded = self.weighted_function
340342

341343
if ocp.n_threads > 1 and self.multi_thread and len(self.node_idx) > 1:
342344
self.function = self.function.map(len(self.node_idx), "thread", ocp.n_threads)
@@ -380,14 +382,21 @@ def _finish_add_target_to_plot(self, all_pn: PenaltyNodeList):
380382
381383
"""
382384

385+
def plot_function(t, x, u, p):
386+
if isinstance(t, (list, tuple)):
387+
return self.target_to_plot[:, [self.node_idx.index(_t) for _t in t]]
388+
else:
389+
return self.target_to_plot[:, self.node_idx.index(t)]
390+
383391
if self.target_to_plot is not None:
384-
if self.target_to_plot.shape[0] > 1:
392+
if self.target_to_plot.shape[1] > 1:
385393
plot_type = PlotType.STEP
386394
else:
387395
plot_type = PlotType.POINT
396+
388397
all_pn.ocp.add_plot(
389398
self.target_plot_name,
390-
lambda t, x, u, p: self.target_to_plot,
399+
plot_function,
391400
color="tab:red",
392401
plot_type=plot_type,
393402
phase=all_pn.nlp.phase_idx,

bioptim/optimization/optimal_control_program.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -647,11 +647,11 @@ def penalty_color():
647647
penalties_internal = nlp.g_internal
648648

649649
for penalty in penalties:
650-
if penalty is None:
650+
if not penalty:
651651
continue
652652
name_unique_objective.append(penalty.name)
653653
for penalty_internal in penalties_internal:
654-
if penalty_internal is None:
654+
if not penalty_internal:
655655
continue
656656
name_unique_objective.append(penalty_internal.name)
657657
color = {}
@@ -671,21 +671,24 @@ def compute_penalty_values(t, x, u, p, penalty, dt):
671671
if dt.shape[0] > 1:
672672
dt = dt[penalty.phase]
673673

674-
_target = penalty.target[..., t] if penalty.target is not None and isinstance(t, int) else []
674+
_target = (
675+
penalty.target[..., penalty.node_idx.index(t)]
676+
if penalty.target is not None and isinstance(t, int)
677+
else []
678+
)
675679

676680
out = []
677681
if penalty.transition:
678682
raise NotImplementedError("add_plot_penalty with phase transition is not implemented yet")
679683
elif penalty.derivative or penalty.explicit_derivative:
680-
out.append(penalty.weighted_function(x, u, p, penalty.weight, _target, dt))
684+
out.append(penalty.weighted_function_non_threaded(x[:, [0, -1]], u, p, penalty.weight, _target, dt))
681685
else:
682-
_u = u if penalty.weighted_function.sparsity_in(1).shape[1] > 1 else u[:, :-1]
683-
out.append(penalty.weighted_function(x[:, :-1], _u, p, penalty.weight, _target, dt))
686+
out.append(penalty.weighted_function_non_threaded(x, u, p, penalty.weight, _target, dt))
684687
return sum1(horzcat(*out))
685688

686689
def add_penalty(_penalties):
687690
for penalty in _penalties:
688-
if penalty is None:
691+
if not penalty:
689692
continue
690693

691694
dt = penalty.dt
@@ -707,7 +710,7 @@ def add_penalty(_penalties):
707710
"dt": dt,
708711
"color": color[penalty.name],
709712
"label": penalty.name,
710-
"manually_compute_derivative": True,
713+
"compute_derivative": penalty.derivative or penalty.explicit_derivative or penalty.integrate,
711714
}
712715
if (
713716
isinstance(penalty.type, ObjectiveFcn.Mayer)

tests/test_global_fatigue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def test_fatigable_torque():
102102
# Check some of the results
103103
states, controls = sol.states, sol.controls
104104
q, qdot = states["q"], states["qdot"]
105-
ma_minus, mr_minus, mf_minus = states["tau_ma_minus"], states["tau_mr_minus"], states["tau_mf_minus"]
106-
ma_plus, mr_plus, mf_plus = states["tau_ma_plus"], states["tau_mr_plus"], states["tau_mf_plus"]
105+
ma_minus, mr_minus, mf_minus = states["tau_minus_ma"], states["tau_minus_mr"], states["tau_minus_mf"]
106+
ma_plus, mr_plus, mf_plus = states["tau_plus_ma"], states["tau_plus_mr"], states["tau_plus_mf"]
107107
tau_minus, tau_plus = controls["tau_minus"], controls["tau_plus"]
108108

109109
# initial and final position

0 commit comments

Comments
 (0)