@@ -58,7 +58,7 @@ def __init__(
5858 bounds : Bounds = None ,
5959 node_idx : list = None ,
6060 label : list = None ,
61- manually_compute_derivative : bool = False ,
61+ compute_derivative : bool = False ,
6262 ** parameters : Any ,
6363 ):
6464 """
@@ -86,7 +86,7 @@ def __init__(
8686 The node time to be plotted on the graphs
8787 label: list
8888 Label of the curve to plot (to be added to the legend)
89- manually_compute_derivative : bool
89+ compute_derivative : bool
9090 If the function should send the next node with x and u. Prevents from computing all at once (therefore a bit slower)
9191 """
9292
@@ -108,7 +108,7 @@ def __init__(
108108 self .bounds = bounds
109109 self .node_idx = node_idx
110110 self .label = label
111- self .manually_compute_derivative = manually_compute_derivative
111+ self .compute_derivative = compute_derivative
112112 self .parameters = parameters
113113
114114
@@ -614,21 +614,14 @@ def update_data(self, v: dict):
614614 else :
615615 control = np .concatenate ((control , data_controls [s ]))
616616
617- if nlp .control_type == ControlType .CONSTANT :
618- u_mod = 1
619- elif nlp .control_type == ControlType .LINEAR_CONTINUOUS :
620- u_mod = 2
621- else :
622- raise NotImplementedError (f"Plotting { nlp .control_type } is not implemented yet" )
623-
624617 for key in self .variable_sizes [i ]:
625618 if not self .plot_func [key ][i ]:
626619 continue
627- # Automatically find u_modifier if the function is a casadi function otherwise fallback to default
628- u_mod2 = (
629- self . plot_func [ key ][ i ]. function . size2_in ( 1 )
630- if hasattr ( self .plot_func [key ][i ].function , "size2_in" )
631- else u_mod
620+ x_mod = 1 if self . plot_func [ key ][ i ]. compute_derivative else 0
621+ u_mod = (
622+ 1
623+ if nlp . control_type == ControlType . LINEAR_CONTINUOUS or self .plot_func [key ][i ].compute_derivative
624+ else 0
632625 )
633626
634627 if self .plot_func [key ][i ].type == PlotType .INTEGRATED :
@@ -637,17 +630,21 @@ def update_data(self, v: dict):
637630 y_tp = np .empty ((self .variable_sizes [i ][key ], len (t )))
638631 y_tp .fill (np .nan )
639632
640- mod = 1 if self .plot_func [key ][i ].manually_compute_derivative else 0
641633 val = self .plot_func [key ][i ].function (
642634 idx ,
643- state [:, step_size * idx : step_size * (idx + 1 ) + mod ],
644- control [:, idx : idx + u_mod2 + 1 ],
635+ state [:, step_size * idx : step_size * (idx + 1 ) + x_mod ],
636+ control [:, idx : idx + u_mod + 1 ],
645637 data_params_in_dyn ,
646638 ** self .plot_func [key ][i ].parameters ,
647639 )
640+
641+ if self .plot_func [key ][i ].compute_derivative :
642+ # This is a special case since derivative is not properly integrated
643+ val = np .repeat (val , y_tp .shape [1 ])[np .newaxis , :]
644+
648645 if val .shape != y_tp .shape :
649646 raise RuntimeError (
650- f"Wrong dimensions for plot { key } . Got { val .shape } , but expected { y .shape } "
647+ f"Wrong dimensions for plot { key } . Got { val .shape } , but expected { y_tp .shape } "
651648 )
652649 y_tp [:, :] = val
653650 all_y .append (y_tp )
@@ -659,24 +656,25 @@ def update_data(self, v: dict):
659656 self .__append_to_ydata ([y_tp ])
660657
661658 elif self .plot_func [key ][i ].type == PlotType .POINT :
662- y = np .empty ((len (self .plot_func [key ][i ].node_idx ),))
663- y .fill (np .nan )
664- mod = 1 if self .plot_func [key ][i ].manually_compute_derivative else 0
665- for i_node , node_idx in enumerate (self .plot_func [key ][i ].node_idx ):
666- val = self .plot_func [key ][i ].function (
667- node_idx ,
668- state [:, node_idx * step_size : (node_idx + 1 ) * step_size + mod : step_size ],
669- control [:, node_idx : node_idx + 1 + mod ],
670- data_params_in_dyn ,
671- ** self .plot_func [key ][i ].parameters ,
672- )
673- y [i_node ] = val
674- self .ydata .append (y )
659+ for i_var in range (self .variable_sizes [i ][key ]):
660+ y = np .empty ((len (self .plot_func [key ][i ].node_idx ),))
661+ y .fill (np .nan )
662+ mod = 1 if self .plot_func [key ][i ].compute_derivative else 0
663+ for i_node , node_idx in enumerate (self .plot_func [key ][i ].node_idx ):
664+ val = self .plot_func [key ][i ].function (
665+ node_idx ,
666+ state [:, node_idx * step_size : (node_idx + 1 ) * step_size + mod : step_size ],
667+ control [:, node_idx : node_idx + 1 + mod ],
668+ data_params_in_dyn ,
669+ ** self .plot_func [key ][i ].parameters ,
670+ )
671+ y [i_node ] = val [i_var ]
672+ self .ydata .append (y )
675673
676674 else :
677675 y = np .empty ((self .variable_sizes [i ][key ], len (self .t [i ])))
678676 y .fill (np .nan )
679- if self .plot_func [key ][i ].manually_compute_derivative :
677+ if self .plot_func [key ][i ].compute_derivative :
680678 for i_node , node_idx in enumerate (self .plot_func [key ][i ].node_idx ):
681679 val = self .plot_func [key ][i ].function (
682680 node_idx ,
@@ -687,8 +685,18 @@ def update_data(self, v: dict):
687685 )
688686 y [:, i_node ] = val
689687 else :
688+ nodes = self .plot_func [key ][i ].node_idx
689+ if nodes and len (nodes ) > 1 and len (nodes ) == round (state .shape [1 ] / step_size ):
690+ # Assume we are integrating but did not specify plot as such.
691+ # Therefore the arrival point is missing
692+ nodes += [nodes [- 1 ] + 1 ]
693+
690694 val = self .plot_func [key ][i ].function (
691- i , state [:, ::step_size ], control , data_params_in_dyn , ** self .plot_func [key ][i ].parameters
695+ nodes ,
696+ state [:, ::step_size ],
697+ control ,
698+ data_params_in_dyn ,
699+ ** self .plot_func [key ][i ].parameters ,
692700 )
693701 if val .shape != y .shape :
694702 raise RuntimeError (
0 commit comments