@@ -260,6 +260,44 @@ def test_multiple_features(self) -> None:
260
260
)
261
261
self ._test_ebc ([eb1_config , eb2_config ], features )
262
262
263
+ def test_multiple_kernels_per_ebc_table (self ) -> None :
264
+ class TestModule (torch .nn .Module ):
265
+ def __init__ (self , m : torch .nn .Module ) -> None :
266
+ super ().__init__ ()
267
+ self .m = m
268
+
269
+ eb1_config = EmbeddingBagConfig (
270
+ name = "t1" , embedding_dim = 16 , num_embeddings = 10 , feature_names = ["f1" ]
271
+ )
272
+ eb2_config = EmbeddingBagConfig (
273
+ name = "t2" ,
274
+ embedding_dim = 16 ,
275
+ num_embeddings = 10 ,
276
+ feature_names = ["f2" ],
277
+ use_virtual_table = True ,
278
+ )
279
+ eb3_config = EmbeddingBagConfig (
280
+ name = "t3" , embedding_dim = 16 , num_embeddings = 10 , feature_names = ["f3" ]
281
+ )
282
+ ebc = EmbeddingBagCollection (tables = [eb1_config , eb2_config , eb3_config ])
283
+ model = TestModule (ebc )
284
+ qebc = trec_infer .modules .quantize_embeddings (
285
+ model ,
286
+ dtype = torch .int8 ,
287
+ inplace = True ,
288
+ per_table_weight_dtype = {"t1" : torch .float16 },
289
+ )
290
+ self .assertTrue (isinstance (qebc .m , QuantEmbeddingBagCollection ))
291
+ # feature name should be consistent with the order of grouped embeddings
292
+ self .assertEqual (qebc .m ._feature_names , ["f1" , "f3" , "f2" ])
293
+
294
+ features = KeyedJaggedTensor (
295
+ keys = ["f1" , "f2" , "f3" ],
296
+ values = torch .as_tensor ([0 , 1 , 2 ]),
297
+ lengths = torch .as_tensor ([1 , 1 , 1 ]),
298
+ )
299
+ self ._test_ebc ([eb1_config , eb2_config , eb3_config ], features )
300
+
263
301
# pyre-ignore
264
302
@given (
265
303
data_type = st .sampled_from (
@@ -742,6 +780,93 @@ def __init__(self, m: torch.nn.Module) -> None:
742
780
self .assertEqual (config .name , "t2" )
743
781
self .assertEqual (config .data_type , DataType .INT8 )
744
782
783
+ def test_multiple_kernels_per_ec_table (self ) -> None :
784
+ class TestModule (torch .nn .Module ):
785
+ def __init__ (self , m : torch .nn .Module ) -> None :
786
+ super ().__init__ ()
787
+ self .m = m
788
+
789
+ eb1_config = EmbeddingConfig (
790
+ name = "t1" , embedding_dim = 16 , num_embeddings = 10 , feature_names = ["f1" ]
791
+ )
792
+ eb2_config = EmbeddingConfig (
793
+ name = "t2" ,
794
+ embedding_dim = 16 ,
795
+ num_embeddings = 10 ,
796
+ feature_names = ["f2" ],
797
+ use_virtual_table = True ,
798
+ )
799
+ eb3_config = EmbeddingConfig (
800
+ name = "t3" ,
801
+ embedding_dim = 16 ,
802
+ num_embeddings = 10 ,
803
+ feature_names = ["f3" ],
804
+ )
805
+ ec = EmbeddingCollection (tables = [eb1_config , eb2_config , eb3_config ])
806
+ model = TestModule (ec )
807
+ qconfig_spec_keys : List [Type [torch .nn .Module ]] = [EmbeddingCollection ]
808
+ quant_mapping : Dict [Type [torch .nn .Module ], Type [torch .nn .Module ]] = {
809
+ EmbeddingCollection : QuantEmbeddingCollection
810
+ }
811
+ qec = trec_infer .modules .quantize_embeddings (
812
+ model ,
813
+ dtype = torch .int8 ,
814
+ additional_qconfig_spec_keys = qconfig_spec_keys ,
815
+ additional_mapping = quant_mapping ,
816
+ inplace = True ,
817
+ per_table_weight_dtype = {
818
+ "t1" : torch .float16 ,
819
+ "t2" : torch .float16 ,
820
+ "t3" : torch .float16 ,
821
+ },
822
+ )
823
+ self .assertTrue (isinstance (qec .m , QuantEmbeddingCollection ))
824
+ # feature name should be consistent with the order of grouped embeddings
825
+ self .assertEqual (qec .m ._feature_names , ["f1" , "f3" , "f2" ])
826
+
827
+ # pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
828
+ configs = model .m .embedding_configs ()
829
+ self .assertEqual (len (configs ), 3 )
830
+ features = KeyedJaggedTensor (
831
+ keys = ["f1" , "f2" , "f3" ],
832
+ values = torch .as_tensor (
833
+ [
834
+ 5 ,
835
+ 1 ,
836
+ 0 ,
837
+ 0 ,
838
+ 4 ,
839
+ 3 ,
840
+ 4 ,
841
+ 9 ,
842
+ 2 ,
843
+ 2 ,
844
+ 3 ,
845
+ 3 ,
846
+ 1 ,
847
+ 5 ,
848
+ 0 ,
849
+ 7 ,
850
+ 5 ,
851
+ 0 ,
852
+ 9 ,
853
+ 9 ,
854
+ 3 ,
855
+ 5 ,
856
+ 6 ,
857
+ 6 ,
858
+ 9 ,
859
+ 3 ,
860
+ 7 ,
861
+ 8 ,
862
+ 7 ,
863
+ 7 ,
864
+ ]
865
+ ),
866
+ lengths = torch .as_tensor ([9 , 12 , 9 ]),
867
+ )
868
+ self ._test_ec (tables = [eb3_config , eb1_config , eb2_config ], features = features )
869
+
745
870
def test_different_quantization_dtype_per_ebc_table (self ) -> None :
746
871
class TestModule (torch .nn .Module ):
747
872
def __init__ (self , m : torch .nn .Module ) -> None :
0 commit comments