@@ -158,3 +158,52 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
158158 tt.return %res : tensor <8 x16 xf16 >
159159 }
160160}
161+
162+
163+ // -----
164+
165+ #dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [2 , 2 ], repCluster = [1 , 1 ]}>
166+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
167+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
168+ #dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#dpas , kWidth =1 }>
169+ #dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#dpas , kWidth =2 }>
170+ #smem = #ttg.shared_memory
171+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
172+ // CHECK-LABEL: matmul_tf32dot
173+ tt.func @matmul_tf32dot (%ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 },
174+ %a: !ttg.memdesc <32 x16 xf32 , #shared , #smem >, %b: !ttg.memdesc <16 x32 xf32 , #shared , #smem >) {
175+ %cst = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #dpas >
176+ %a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
177+ %b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
178+
179+ // expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
180+ %28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
181+ %38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
182+
183+ tt.return
184+ }
185+ }
186+
187+ // -----
188+
189+ #dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [2 , 2 ], repCluster = [1 , 1 ]}>
190+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
191+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
192+ #dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#dpas , kWidth =1 }>
193+ // expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}}
194+ #dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#dpas , kWidth =2 }>
195+ #smem = #ttg.shared_memory
196+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
197+ // CHECK-LABEL: matmul_tf32dot
198+ tt.func @matmul_tf32dot (%ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 },
199+ %a: !ttg.memdesc <32 x16 xf32 , #shared , #smem >, %b: !ttg.memdesc <16 x32 xf32 , #shared , #smem >) {
200+ %cst = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #dpas >
201+ %a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
202+ %b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
203+
204+ %28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
205+ %38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
206+
207+ tt.return
208+ }
209+ }
0 commit comments