@@ -72,7 +72,6 @@ def __init__(
7272 ):
7373 del kwargs
7474 super ().__init__ ()
75- self .register_buffer ("seq" , torch .arange (max (len_h , len_w , len_t ), dtype = torch .float , device = device ))
7675 self .base_fps = base_fps
7776 self .max_h = len_h
7877 self .max_w = len_w
@@ -134,21 +133,19 @@ def generate_embeddings(
134133 temporal_freqs = 1.0 / (t_theta ** self .dim_temporal_range .to (device = device ))
135134
136135 B , T , H , W , _ = B_T_H_W_C
136+ seq = torch .arange (max (H , W , T ), dtype = torch .float , device = device )
137137 uniform_fps = (fps is None ) or isinstance (fps , (int , float )) or (fps .min () == fps .max ())
138138 assert (
139139 uniform_fps or B == 1 or T == 1
140140 ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
141- assert (
142- H <= self .max_h and W <= self .max_w
143- ), f"Input dimensions (H={ H } , W={ W } ) exceed the maximum dimensions (max_h={ self .max_h } , max_w={ self .max_w } )"
144- half_emb_h = torch .outer (self .seq [:H ].to (device = device ), h_spatial_freqs )
145- half_emb_w = torch .outer (self .seq [:W ].to (device = device ), w_spatial_freqs )
141+ half_emb_h = torch .outer (seq [:H ].to (device = device ), h_spatial_freqs )
142+ half_emb_w = torch .outer (seq [:W ].to (device = device ), w_spatial_freqs )
146143
147144 # apply sequence scaling in temporal dimension
148145 if fps is None or self .enable_fps_modulation is False : # image case
149- half_emb_t = torch .outer (self . seq [:T ].to (device = device ), temporal_freqs )
146+ half_emb_t = torch .outer (seq [:T ].to (device = device ), temporal_freqs )
150147 else :
151- half_emb_t = torch .outer (self . seq [:T ].to (device = device ) / fps * self .base_fps , temporal_freqs )
148+ half_emb_t = torch .outer (seq [:T ].to (device = device ) / fps * self .base_fps , temporal_freqs )
152149
153150 half_emb_h = torch .stack ([torch .cos (half_emb_h ), - torch .sin (half_emb_h ), torch .sin (half_emb_h ), torch .cos (half_emb_h )], dim = - 1 )
154151 half_emb_w = torch .stack ([torch .cos (half_emb_w ), - torch .sin (half_emb_w ), torch .sin (half_emb_w ), torch .cos (half_emb_w )], dim = - 1 )
0 commit comments