7
7
import os
8
8
import re
9
9
import sys
10
+
10
11
sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." )))
11
12
import subprocess
12
13
import textwrap
13
- import yaml
14
14
from collections import OrderedDict
15
- from torchgen .code_template import CodeTemplate
16
15
from dataclasses import dataclass
17
- from typing import Any , Dict , List , Tuple , Optional
16
+ from typing import Any , Dict , List , Optional , Tuple
17
+
18
+ import yaml
19
+ from torchgen .code_template import CodeTemplate
18
20
from yaml .constructor import ConstructorError
19
21
from yaml .nodes import MappingNode
20
22
@@ -128,51 +130,63 @@ class ShaderInfo:
128
130
bias_storage_type : str = ""
129
131
register_for : Optional [Tuple [str , List [str ]]] = None
130
132
133
+
131
134
def getName (filePath : str ) -> str :
132
135
return os .path .basename (filePath ).replace ("/" , "_" ).replace ("." , "_" )
133
136
137
+
134
138
def isDescriptorLine (lineStr : str ) -> bool :
135
139
descriptorLineId = r"^layout\(set"
136
140
return re .search (descriptorLineId , lineStr ) is not None
137
141
142
+
138
143
def isTileSizeLine (lineStr : str ) -> bool :
139
144
tile_size_id = r"^ \* TILE_SIZE = \("
140
145
return re .search (tile_size_id , lineStr ) is not None
141
146
147
+
142
148
def findTileSizes (lineStr : str ) -> List [int ]:
143
149
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
144
150
matches = re .search (tile_size_id , lineStr )
145
151
if matches is None :
146
152
raise AssertionError ("matches is None in findTileSizes" )
147
153
return [int (matches .group (1 )), int (matches .group (2 )), int (matches .group (3 ))]
148
154
155
+
149
156
def isWeightStorageTypeLine (lineStr : str ) -> bool :
150
157
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
151
158
return re .search (weight_storage_id , lineStr ) is not None
152
159
160
+
153
161
def getWeightStorageType (lineStr : str ) -> str :
154
162
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
155
163
matches = re .search (weight_storage_id , lineStr )
156
164
if matches is None :
157
165
raise AssertionError ("matches is None in getWeightStorageType" )
158
166
return matches .group (1 )
159
167
168
+
160
169
def isBiasStorageTypeLine (lineStr : str ) -> bool :
161
170
weight_storage_id = r"^ \* BIAS_STORAGE = "
162
171
return re .search (weight_storage_id , lineStr ) is not None
163
172
173
+
164
174
def getBiasStorageType (lineStr : str ) -> str :
165
175
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
166
176
matches = re .search (weight_storage_id , lineStr )
167
177
if matches is None :
168
178
raise AssertionError ("matches is None in getBiasStorageType" )
169
179
return matches .group (1 )
170
180
181
+
171
182
def isRegisterForLine (lineStr : str ) -> bool :
172
183
# Check for Shader Name and a list of at least one Registry Key
173
- register_for_id = r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
184
+ register_for_id = (
185
+ r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
186
+ )
174
187
return re .search (register_for_id , lineStr ) is not None
175
188
189
+
176
190
def findRegisterFor (lineStr : str ) -> Tuple [str , List [str ]]:
177
191
register_for_pattern = r"'([A-Za-z0-9_]+)'"
178
192
matches = re .findall (register_for_pattern , lineStr )
@@ -181,6 +195,7 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
181
195
matches_list = list (matches )
182
196
return (matches_list [0 ], matches_list [1 :])
183
197
198
+
184
199
typeIdMapping = {
185
200
r"image[123]D\b" : "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" ,
186
201
r"sampler[123]D\b" : "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" ,
@@ -189,12 +204,13 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
189
204
}
190
205
191
206
storageTypeToEnum = {
192
- "TEXTURE_2D" : "api::StorageType::TEXTURE_2D" ,
193
- "TEXTURE_3D" : "api::StorageType::TEXTURE_3D" ,
194
- "BUFFER" : "api::StorageType::BUFFER" ,
207
+ "TEXTURE_2D" : "api::StorageType::TEXTURE_2D" ,
208
+ "TEXTURE_3D" : "api::StorageType::TEXTURE_3D" ,
209
+ "BUFFER" : "api::StorageType::BUFFER" ,
195
210
"" : "api::StorageType::UNKNOWN" ,
196
211
}
197
212
213
+
198
214
def determineDescriptorType (lineStr : str ) -> str :
199
215
for identifier , typeNum in typeIdMapping .items ():
200
216
if re .search (identifier , lineStr ):
@@ -203,6 +219,7 @@ def determineDescriptorType(lineStr: str) -> str:
203
219
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
204
220
)
205
221
222
+
206
223
def getShaderInfo (srcFilePath : str ) -> ShaderInfo :
207
224
shader_info = ShaderInfo ([], [], "" )
208
225
with open (srcFilePath ) as srcFile :
@@ -220,9 +237,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
220
237
221
238
return shader_info
222
239
240
+
223
241
def genGLSLFromGLSLT (src_dir_path : str , tmp_dir_path : str ) -> None :
224
242
template_dir_path = os .path .join (src_dir_path , "templates" )
225
- vexs = glob .glob (os .path .join (template_dir_path , '**' , ' *.yaml' ), recursive = True )
243
+ vexs = glob .glob (os .path .join (template_dir_path , "**" , " *.yaml" ), recursive = True )
226
244
parameter_yaml_files = []
227
245
for f in vexs :
228
246
if len (f ) > 1 :
@@ -231,7 +249,7 @@ def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
231
249
for params_yaml in parameter_yaml_files :
232
250
generator .add_params_yaml (params_yaml ) # type: ignore[no-untyped-call]
233
251
234
- vexs = glob .glob (os .path .join (src_dir_path , '**' , ' *.glslt' ), recursive = True )
252
+ vexs = glob .glob (os .path .join (src_dir_path , "**" , " *.glslt" ), recursive = True )
235
253
templateSrcPaths = []
236
254
for f in vexs :
237
255
if len (f ) > 1 :
@@ -258,7 +276,7 @@ def genCppH(
258
276
templateSrcPaths = []
259
277
260
278
for srcDirPath in srcDirPaths :
261
- vexs = glob .glob (os .path .join (srcDirPath , '**' , ' *.glsl' ), recursive = True )
279
+ vexs = glob .glob (os .path .join (srcDirPath , "**" , " *.glsl" ), recursive = True )
262
280
for f in vexs :
263
281
if len (f ) > 1 :
264
282
templateSrcPaths .append (f )
@@ -267,7 +285,7 @@ def genCppH(
267
285
# Now add glsl files that are generated from templates
268
286
genGLSLFromGLSLT (srcDirPath , tmpDirPath )
269
287
270
- vexs = glob .glob (os .path .join (tmpDirPath , '**' , ' *.glsl' ), recursive = True )
288
+ vexs = glob .glob (os .path .join (tmpDirPath , "**" , " *.glsl" ), recursive = True )
271
289
for f in vexs :
272
290
if len (f ) > 1 :
273
291
templateSrcPaths .append (f )
@@ -283,17 +301,20 @@ def genCppH(
283
301
codeTemplate = CodeTemplate .from_file (templateSrcPath )
284
302
srcPath = tmpDirPath + "/" + name + ".glsl"
285
303
content = codeTemplate .substitute (env )
286
- with open (srcPath , 'w' ) as fw :
304
+ with open (srcPath , "w" ) as fw :
287
305
fw .write (content )
288
306
289
307
spvPath = tmpDirPath + "/" + name + ".spv"
290
308
print (f"spvPath { spvPath } " )
291
309
292
310
cmd = [
293
- glslcPath , "-fshader-stage=compute" ,
294
- srcPath , "-o" , spvPath ,
311
+ glslcPath ,
312
+ "-fshader-stage=compute" ,
313
+ srcPath ,
314
+ "-o" ,
315
+ spvPath ,
295
316
"--target-env=vulkan1.0" ,
296
- "-Werror"
317
+ "-Werror" ,
297
318
] + [arg for srcDirPath in srcDirPaths for arg in ["-I" , srcDirPath ]]
298
319
299
320
print ("\n glslc cmd:" , cmd )
@@ -323,7 +344,9 @@ def genCppH(
323
344
h += "extern const ShaderListing shader_infos;\n "
324
345
h += "extern ShaderRegistry shader_registry;\n "
325
346
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n }\n "
326
- h += "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n }\n "
347
+ h += (
348
+ "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n }\n "
349
+ )
327
350
328
351
h += nsend
329
352
@@ -341,8 +364,8 @@ def genCppH(
341
364
name = getName (spvPath ).replace ("_spv" , "" )
342
365
343
366
print (f"spvPath:{ spvPath } " )
344
- with open (spvPath , 'rb' ) as fr :
345
- next_bin = array .array ('I' , fr .read ())
367
+ with open (spvPath , "rb" ) as fr :
368
+ next_bin = array .array ("I" , fr .read ())
346
369
sizeBytes = 4 * len (next_bin )
347
370
shader_info_bin_code .append (
348
371
"const uint32_t {}_bin[] = {{\n {}\n }};" .format (
@@ -362,7 +385,7 @@ def genCppH(
362
385
shader_info_layouts = "{{{}}}" .format (",\n " .join (shader_info .layouts ))
363
386
364
387
shader_info_args = [
365
- f" \" vulkan.{ name } \" " ,
388
+ f'" vulkan.{ name } "' ,
366
389
f"{ name } _bin" ,
367
390
str (sizeBytes ),
368
391
shader_info_layouts ,
@@ -373,7 +396,7 @@ def genCppH(
373
396
374
397
shader_info_cpp_code .append (
375
398
textwrap .indent (
376
- "{{ \ " {}\ " ,\n api::ShaderInfo(\n {})}}" .format (
399
+ '{{ "{}",\n api::ShaderInfo(\n {})}}' .format (
377
400
name ,
378
401
textwrap .indent (",\n " .join (shader_info_args ), " " ),
379
402
),
@@ -386,7 +409,7 @@ def genCppH(
386
409
for registry_key in registry_keys :
387
410
shader_info_registry_code .append (
388
411
textwrap .indent (
389
- f"{{ \ "{ op_name } \ " , {{{{\ "{ registry_key } \ " , \ "{ name } \ " }}}}}}" ,
412
+ f'{{ "{ op_name } ", {{{{"{ registry_key } ", "{ name } "}}}}}}' ,
390
413
" " ,
391
414
),
392
415
)
@@ -421,34 +444,20 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
421
444
422
445
423
446
def main (argv : List [str ]) -> int :
424
- parser = argparse .ArgumentParser (description = '' )
447
+ parser = argparse .ArgumentParser (description = "" )
425
448
parser .add_argument (
426
- '-i' ,
427
- ' --glsl-paths' ,
428
- nargs = '+' ,
449
+ "-i" ,
450
+ " --glsl-paths" ,
451
+ nargs = "+" ,
429
452
help = 'List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"' ,
430
- default = ['.' ],
453
+ default = ["." ],
431
454
)
455
+ parser .add_argument ("-c" , "--glslc-path" , required = True , help = "" )
456
+ parser .add_argument ("-t" , "--tmp-dir-path" , required = True , help = "/tmp" )
457
+ parser .add_argument ("-o" , "--output-path" , required = True , help = "" )
432
458
parser .add_argument (
433
- '-c' ,
434
- '--glslc-path' ,
435
- required = True ,
436
- help = '' )
437
- parser .add_argument (
438
- '-t' ,
439
- '--tmp-dir-path' ,
440
- required = True ,
441
- help = '/tmp' )
442
- parser .add_argument (
443
- '-o' ,
444
- '--output-path' ,
445
- required = True ,
446
- help = '' )
447
- parser .add_argument (
448
- "--env" ,
449
- metavar = "KEY=VALUE" ,
450
- nargs = '*' ,
451
- help = "Set a number of key-value pairs" )
459
+ "--env" , metavar = "KEY=VALUE" , nargs = "*" , help = "Set a number of key-value pairs"
460
+ )
452
461
options = parser .parse_args ()
453
462
env = DEFAULT_ENV
454
463
for key , value in parse_arg_env (options .env ).items ():
@@ -466,9 +475,15 @@ def main(argv: List[str]) -> int:
466
475
srcDirPaths = options .glsl_paths ,
467
476
glslcPath = options .glslc_path ,
468
477
tmpDirPath = options .tmp_dir_path ,
469
- env = env )
478
+ env = env ,
479
+ )
470
480
471
481
return 0
472
482
473
- if __name__ == '__main__' :
483
+
484
+ def invoke_main () -> None :
474
485
sys .exit (main (sys .argv ))
486
+
487
+
488
+ if __name__ == "__main__" :
489
+ invoke_main () # pragma: no cover
0 commit comments