@@ -467,3 +467,171 @@ def step_fn(verts, feats):
467467 outputs .append (y_t )
468468
469469 return torch .stack (outputs , dim = 0 ) # [T, N, 3]
470+
471+
472+ class FIGConvUNetOneStepRollout (FIGConvUNet ):
473+ """
474+ FIGConvUNet with one-step rollout for crash simulation.
475+
476+ - Training: teacher forcing (uses GT positions at each step)
477+ - Inference: autoregressive (uses predictions)
478+ """
479+
480+ def __init__ (self , * args , ** kwargs ):
481+ self .dt : float = kwargs .pop ("dt" , 5e-3 )
482+ self .initial_vel : torch .Tensor = kwargs .pop ("initial_vel" )
483+ self .rollout_steps : int = kwargs .pop ("num_time_steps" ) - 1
484+ super ().__init__ (* args , ** kwargs )
485+
486+ def forward (self , sample : SimSample , data_stats : dict ) -> torch .Tensor :
487+ """
488+ Args:
489+ Sample: SimSample containing node_features and node_target
490+ data_stats: dict containing normalization stats
491+ Returns:
492+ [T, N, 3] rollout of predicted positions
493+ """
494+ inputs = sample .node_features
495+ x0 = inputs ["coords" ] # initial pos [N, 3]
496+ features = inputs .get ("features" , x0 .new_zeros ((x0 .size (0 ), 0 ))) # [N, F]
497+
498+ # Ground truth sequence [T, N, 3]
499+ N = x0 .size (0 )
500+ gt_seq = torch .cat (
501+ [x0 .unsqueeze (0 ), sample .node_target .view (N , - 1 , 3 ).transpose (0 , 1 )],
502+ dim = 0 ,
503+ )
504+
505+ outputs : list [torch .Tensor ] = []
506+ # First step: backstep to create y_-1
507+ y_t0 = gt_seq [0 ] - self .initial_vel * self .dt
508+ y_t1 = gt_seq [0 ]
509+
510+ for t in range (self .rollout_steps ):
511+ # In training mode (except first step), use ground truth positions
512+ if self .training and t > 0 :
513+ y_t0 , y_t1 = gt_seq [t - 1 ], gt_seq [t ]
514+
515+ # Prepare vertices for FIGConvUNet: [1, N, 3]
516+ vertices = y_t1 .unsqueeze (0 ) # [1, N, 3]
517+
518+ vel = (y_t1 - y_t0 ) / self .dt
519+ vel_norm = (vel - data_stats ["node" ]["norm_vel_mean" ]) / (
520+ data_stats ["node" ]["norm_vel_std" ] + EPS
521+ )
522+
523+ # [1, N, 3 + F]
524+ fx_t = torch .cat ([vel_norm , features ], dim = - 1 ).unsqueeze (0 )
525+
526+ def step_fn (verts , feats ):
527+ out , _ = super (FIGConvUNetOneStepRollout , self ).forward (
528+ vertices = verts , features = feats
529+ )
530+ return out
531+
532+ if self .training :
533+ outf = ckpt (
534+ step_fn ,
535+ vertices ,
536+ fx_t ,
537+ use_reentrant = False ,
538+ ).squeeze (0 ) # [N, 3]
539+ else :
540+ outf = step_fn (vertices , fx_t ).squeeze (0 ) # [N, 3]
541+
542+ acc = (
543+ outf * data_stats ["node" ]["norm_acc_std" ]
544+ + data_stats ["node" ]["norm_acc_mean" ]
545+ )
546+ vel_pred = self .dt * acc + vel
547+ y_t2_pred = self .dt * vel_pred + y_t1
548+
549+ outputs .append (y_t2_pred )
550+
551+ if not self .training :
552+ # autoregressive update for inference
553+ y_t0 , y_t1 = y_t1 , y_t2_pred
554+
555+ return torch .stack (outputs , dim = 0 ) # [T, N, 3]
556+
557+
558+ class FIGConvUNetAutoregressiveRolloutTraining (FIGConvUNet ):
559+ """
560+ FIGConvUNet with autoregressive rollout training for crash simulation.
561+
562+ Predicts sequence by autoregressively updating velocity and position
563+ using predicted accelerations. Supports gradient checkpointing during training.
564+ """
565+
566+ def __init__ (self , * args , ** kwargs ):
567+ self .dt : float = kwargs .pop ("dt" )
568+ self .initial_vel : torch .Tensor = kwargs .pop ("initial_vel" )
569+ self .rollout_steps : int = kwargs .pop ("num_time_steps" ) - 1
570+ super ().__init__ (* args , ** kwargs )
571+
572+ def forward (self , sample : SimSample , data_stats : dict ) -> torch .Tensor :
573+ """
574+ Args:
575+ sample: SimSample containing node_features and node_target
576+ data_stats: dict containing normalization stats
577+ Returns:
578+ [T, N, 3] rollout of predicted positions
579+ """
580+ inputs = sample .node_features
581+ coords = inputs ["coords" ] # [N, 3]
582+ features = inputs .get ("features" , coords .new_zeros ((coords .size (0 ), 0 )))
583+ N = coords .size (0 )
584+ device = coords .device
585+
586+ # Initial states
587+ y_t1 = coords # [N, 3]
588+ y_t0 = y_t1 - self .initial_vel * self .dt # backstep using initial velocity
589+
590+ outputs : list [torch .Tensor ] = []
591+ for t in range (self .rollout_steps ):
592+ time_t = 0.0 if self .rollout_steps <= 1 else t / (self .rollout_steps - 1 )
593+ time_t = torch .tensor ([time_t ], device = device , dtype = torch .float32 )
594+
595+ # Velocity normalization
596+ vel = (y_t1 - y_t0 ) / self .dt
597+ vel_norm = (vel - data_stats ["node" ]["norm_vel_mean" ]) / (
598+ data_stats ["node" ]["norm_vel_std" ] + EPS
599+ )
600+
601+ # Prepare vertices for FIGConvUNet: [1, N, 3]
602+ vertices = y_t1 .unsqueeze (0 ) # [1, N, 3]
603+
604+ # Prepare features: vel_norm + features + time [N, 3+F+1]
605+ fx_t = torch .cat (
606+ [vel_norm , features , time_t .expand (N , 1 )], dim = - 1
607+ ) # [N, 3+F+1]
608+ fx_t = fx_t .unsqueeze (0 ) # [1, N, 3+F+1]
609+
610+ def step_fn (verts , feats ):
611+ out , _ = super (FIGConvUNetAutoregressiveRolloutTraining , self ).forward (
612+ vertices = verts , features = feats
613+ )
614+ return out
615+
616+ if self .training :
617+ outf = ckpt (
618+ step_fn ,
619+ vertices ,
620+ fx_t ,
621+ use_reentrant = False ,
622+ ).squeeze (0 ) # [N, 3]
623+ else :
624+ outf = step_fn (vertices , fx_t ).squeeze (0 ) # [N, 3]
625+
626+ # De-normalize acceleration
627+ acc = (
628+ outf * data_stats ["node" ]["norm_acc_std" ]
629+ + data_stats ["node" ]["norm_acc_mean" ]
630+ )
631+ vel = self .dt * acc + vel
632+ y_t2 = self .dt * vel + y_t1
633+
634+ outputs .append (y_t2 )
635+ y_t1 , y_t0 = y_t2 , y_t1
636+
637+ return torch .stack (outputs , dim = 0 ) # [T, N, 3]
0 commit comments