@@ -54,8 +54,10 @@ def get_component_configs(model, cfg):
5454 """
5555 Get configurations for different components, including both compilation and weight loading info.
5656 """
57- batchsize = (1 , 2 , 4 )
58- num_obj = (1 , 2 , 4 )
57+ batch = tp .NamedDimension ("batch" , 1 , 2 , 4 )
58+ num_obj = tp .NamedDimension ("num_obj" , 1 , 2 , 4 )
59+ seq_len = tp .NamedDimension ("seq_len" , 4100 , 16400 , 28736 )
60+ mem_attention_batch = tp .NamedDimension ("mem_attention_batch" , 1 , 2 , 8 )
5961 model_precision = getattr (cfg ["model" ], "model_precision" , "float32" )
6062 return {
6163 "memory_attention" : {
@@ -64,19 +66,19 @@ def get_component_configs(model, cfg):
6466 "dtype" : model_precision ,
6567 "compile_args" : [
6668 tp .InputInfo (
67- (4096 , ( 1 , 2 , 8 ) , 256 ),
69+ (4096 , mem_attention_batch , 256 ),
6870 getattr (tp , model_precision ),
6971 ),
7072 tp .InputInfo (
71- (( 4100 , 16400 , 28736 ), ( 1 , 2 , 8 ) , 64 ),
73+ (seq_len , mem_attention_batch , 64 ),
7274 getattr (tp , model_precision ),
7375 ),
7476 tp .InputInfo (
75- (4096 , ( 1 , 2 , 8 ) , 256 ),
77+ (4096 , mem_attention_batch , 256 ),
7678 getattr (tp , model_precision ),
7779 ),
7880 tp .InputInfo (
79- (( 4100 , 16400 , 28736 ), ( 1 , 2 , 8 ) , 64 ),
81+ (seq_len , mem_attention_batch , 64 ),
8082 getattr (tp , model_precision ),
8183 ),
8284 # TODO (#594): Remove this hack once we are able to pass in DimensionSizes directly:
@@ -124,29 +126,29 @@ def get_component_configs(model, cfg):
124126 "dtype" : model_precision ,
125127 "compile_args" : [
126128 tp .InputInfo (
127- (batchsize , 256 , 64 , 64 ),
129+ (batch , 256 , 64 , 64 ),
128130 dtype = getattr (tp , model_precision ),
129131 ), # image_embeddings
130132 tp .InputInfo (
131133 (1 , 256 , 64 , 64 ),
132134 dtype = getattr (tp , model_precision ),
133135 ), # image_pe
134136 tp .InputInfo (
135- (batchsize , (2 , 4 , 6 ), 256 ),
137+ (batch , (2 , 4 , 6 ), 256 ),
136138 dtype = getattr (tp , model_precision ),
137139 ), # sparse_prompt_embeddings
138140 tp .InputInfo (
139- (batchsize , 256 , 64 , 64 ),
141+ (batch , 256 , 64 , 64 ),
140142 dtype = getattr (tp , model_precision ),
141143 ), # dense_prompt_embeddings
142144 True , # multimask_output
143145 False , # repeat_image
144146 tp .InputInfo (
145- (batchsize , 32 , 256 , 256 ),
147+ (batch , 32 , 256 , 256 ),
146148 dtype = getattr (tp , model_precision ),
147149 ), # high_res_features_1
148150 tp .InputInfo (
149- (batchsize , 64 , 128 , 128 ),
151+ (batch , 64 , 128 , 128 ),
150152 dtype = getattr (tp , model_precision ),
151153 ), # high_res_features_2
152154 ],
@@ -159,7 +161,7 @@ def get_component_configs(model, cfg):
159161 "dtype" : model_precision ,
160162 "compile_args" : [
161163 tp .InputInfo (
162- (batchsize , 256 , 256 , 256 ),
164+ (batch , 256 , 256 , 256 ),
163165 dtype = getattr (tp , model_precision ),
164166 )
165167 ],
@@ -172,7 +174,7 @@ def get_component_configs(model, cfg):
172174 "dtype" : model_precision ,
173175 "compile_args" : [
174176 tp .InputInfo (
175- (batchsize , 256 , 128 , 128 ),
177+ (batch , 256 , 128 , 128 ),
176178 dtype = getattr (tp , model_precision ),
177179 )
178180 ],
@@ -184,8 +186,8 @@ def get_component_configs(model, cfg):
184186 "model" : model .memory_encoder ,
185187 "dtype" : model_precision ,
186188 "compile_args" : [
187- tp .InputInfo ((batchsize , 256 , 64 , 64 ), getattr (tp , model_precision )),
188- tp .InputInfo ((batchsize , num_obj , 1024 , 1024 ), getattr (tp , model_precision )),
189+ tp .InputInfo ((batch , 256 , 64 , 64 ), getattr (tp , model_precision )),
190+ tp .InputInfo ((batch , num_obj , 1024 , 1024 ), getattr (tp , model_precision )),
189191 True ,
190192 ],
191193 "skip_dtype_convert" : ["ln" , "norm" ]
@@ -196,8 +198,8 @@ def get_component_configs(model, cfg):
196198 "model" : model .sam_prompt_encoder ,
197199 "dtype" : "float32" ,
198200 "compile_args" : [
199- tp .InputInfo ((batchsize , num_obj , 2 ), dtype = tp .float32 ),
200- tp .InputInfo ((batchsize , num_obj ), dtype = tp .int32 ),
201+ tp .InputInfo ((batch , num_obj , 2 ), dtype = tp .float32 ),
202+ tp .InputInfo ((batch , num_obj ), dtype = tp .int32 ),
201203 None ,
202204 None ,
203205 ],
@@ -224,7 +226,7 @@ def get_component_configs(model, cfg):
224226 "dtype" : model_precision ,
225227 "compile_args" : [
226228 tp .InputInfo (
227- (batchsize , 3 , 1024 , 1024 ),
229+ (batch , 3 , 1024 , 1024 ),
228230 dtype = getattr (tp , model_precision ),
229231 ),
230232 ],
0 commit comments