Skip to content

Commit 9b736c7

Browse files
zsolpytorchmergebot
authored andcommitted
[Codemod][python/main_function] caffe2: (pytorch#113357)
Differential Revision: D51149464 Pull Request resolved: pytorch#113357 Approved by: https://github.com/huydhn
1 parent 87aeb24 commit 9b736c7

File tree

4 files changed

+116
-66
lines changed

4 files changed

+116
-66
lines changed

tools/gen_vulkan_spv.py

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
import os
88
import re
99
import sys
10+
1011
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
1112
import subprocess
1213
import textwrap
13-
import yaml
1414
from collections import OrderedDict
15-
from torchgen.code_template import CodeTemplate
1615
from dataclasses import dataclass
17-
from typing import Any, Dict, List, Tuple, Optional
16+
from typing import Any, Dict, List, Optional, Tuple
17+
18+
import yaml
19+
from torchgen.code_template import CodeTemplate
1820
from yaml.constructor import ConstructorError
1921
from yaml.nodes import MappingNode
2022

@@ -128,51 +130,63 @@ class ShaderInfo:
128130
bias_storage_type: str = ""
129131
register_for: Optional[Tuple[str, List[str]]] = None
130132

133+
131134
def getName(filePath: str) -> str:
132135
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
133136

137+
134138
def isDescriptorLine(lineStr: str) -> bool:
135139
descriptorLineId = r"^layout\(set"
136140
return re.search(descriptorLineId, lineStr) is not None
137141

142+
138143
def isTileSizeLine(lineStr: str) -> bool:
139144
tile_size_id = r"^ \* TILE_SIZE = \("
140145
return re.search(tile_size_id, lineStr) is not None
141146

147+
142148
def findTileSizes(lineStr: str) -> List[int]:
143149
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
144150
matches = re.search(tile_size_id, lineStr)
145151
if matches is None:
146152
raise AssertionError("matches is None in findTileSizes")
147153
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
148154

155+
149156
def isWeightStorageTypeLine(lineStr: str) -> bool:
150157
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
151158
return re.search(weight_storage_id, lineStr) is not None
152159

160+
153161
def getWeightStorageType(lineStr: str) -> str:
154162
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
155163
matches = re.search(weight_storage_id, lineStr)
156164
if matches is None:
157165
raise AssertionError("matches is None in getWeightStorageType")
158166
return matches.group(1)
159167

168+
160169
def isBiasStorageTypeLine(lineStr: str) -> bool:
161170
weight_storage_id = r"^ \* BIAS_STORAGE = "
162171
return re.search(weight_storage_id, lineStr) is not None
163172

173+
164174
def getBiasStorageType(lineStr: str) -> str:
165175
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
166176
matches = re.search(weight_storage_id, lineStr)
167177
if matches is None:
168178
raise AssertionError("matches is None in getBiasStorageType")
169179
return matches.group(1)
170180

181+
171182
def isRegisterForLine(lineStr: str) -> bool:
172183
# Check for Shader Name and a list of at least one Registry Key
173-
register_for_id = r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
184+
register_for_id = (
185+
r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
186+
)
174187
return re.search(register_for_id, lineStr) is not None
175188

189+
176190
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
177191
register_for_pattern = r"'([A-Za-z0-9_]+)'"
178192
matches = re.findall(register_for_pattern, lineStr)
@@ -181,6 +195,7 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
181195
matches_list = list(matches)
182196
return (matches_list[0], matches_list[1:])
183197

198+
184199
typeIdMapping = {
185200
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
186201
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
@@ -189,12 +204,13 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
189204
}
190205

191206
storageTypeToEnum = {
192-
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
193-
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
194-
"BUFFER" : "api::StorageType::BUFFER",
207+
"TEXTURE_2D": "api::StorageType::TEXTURE_2D",
208+
"TEXTURE_3D": "api::StorageType::TEXTURE_3D",
209+
"BUFFER": "api::StorageType::BUFFER",
195210
"": "api::StorageType::UNKNOWN",
196211
}
197212

213+
198214
def determineDescriptorType(lineStr: str) -> str:
199215
for identifier, typeNum in typeIdMapping.items():
200216
if re.search(identifier, lineStr):
@@ -203,6 +219,7 @@ def determineDescriptorType(lineStr: str) -> str:
203219
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
204220
)
205221

222+
206223
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
207224
shader_info = ShaderInfo([], [], "")
208225
with open(srcFilePath) as srcFile:
@@ -220,9 +237,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
220237

221238
return shader_info
222239

240+
223241
def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
224242
template_dir_path = os.path.join(src_dir_path, "templates")
225-
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
243+
vexs = glob.glob(os.path.join(template_dir_path, "**", "*.yaml"), recursive=True)
226244
parameter_yaml_files = []
227245
for f in vexs:
228246
if len(f) > 1:
@@ -231,7 +249,7 @@ def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
231249
for params_yaml in parameter_yaml_files:
232250
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
233251

234-
vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
252+
vexs = glob.glob(os.path.join(src_dir_path, "**", "*.glslt"), recursive=True)
235253
templateSrcPaths = []
236254
for f in vexs:
237255
if len(f) > 1:
@@ -258,7 +276,7 @@ def genCppH(
258276
templateSrcPaths = []
259277

260278
for srcDirPath in srcDirPaths:
261-
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
279+
vexs = glob.glob(os.path.join(srcDirPath, "**", "*.glsl"), recursive=True)
262280
for f in vexs:
263281
if len(f) > 1:
264282
templateSrcPaths.append(f)
@@ -267,7 +285,7 @@ def genCppH(
267285
# Now add glsl files that are generated from templates
268286
genGLSLFromGLSLT(srcDirPath, tmpDirPath)
269287

270-
vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
288+
vexs = glob.glob(os.path.join(tmpDirPath, "**", "*.glsl"), recursive=True)
271289
for f in vexs:
272290
if len(f) > 1:
273291
templateSrcPaths.append(f)
@@ -283,17 +301,20 @@ def genCppH(
283301
codeTemplate = CodeTemplate.from_file(templateSrcPath)
284302
srcPath = tmpDirPath + "/" + name + ".glsl"
285303
content = codeTemplate.substitute(env)
286-
with open(srcPath, 'w') as fw:
304+
with open(srcPath, "w") as fw:
287305
fw.write(content)
288306

289307
spvPath = tmpDirPath + "/" + name + ".spv"
290308
print(f"spvPath {spvPath}")
291309

292310
cmd = [
293-
glslcPath, "-fshader-stage=compute",
294-
srcPath, "-o", spvPath,
311+
glslcPath,
312+
"-fshader-stage=compute",
313+
srcPath,
314+
"-o",
315+
spvPath,
295316
"--target-env=vulkan1.0",
296-
"-Werror"
317+
"-Werror",
297318
] + [arg for srcDirPath in srcDirPaths for arg in ["-I", srcDirPath]]
298319

299320
print("\nglslc cmd:", cmd)
@@ -323,7 +344,9 @@ def genCppH(
323344
h += "extern const ShaderListing shader_infos;\n"
324345
h += "extern ShaderRegistry shader_registry;\n"
325346
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
326-
h += "inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
347+
h += (
348+
"inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
349+
)
327350

328351
h += nsend
329352

@@ -341,8 +364,8 @@ def genCppH(
341364
name = getName(spvPath).replace("_spv", "")
342365

343366
print(f"spvPath:{spvPath}")
344-
with open(spvPath, 'rb') as fr:
345-
next_bin = array.array('I', fr.read())
367+
with open(spvPath, "rb") as fr:
368+
next_bin = array.array("I", fr.read())
346369
sizeBytes = 4 * len(next_bin)
347370
shader_info_bin_code.append(
348371
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
@@ -362,7 +385,7 @@ def genCppH(
362385
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
363386

364387
shader_info_args = [
365-
f"\"vulkan.{name}\"",
388+
f'"vulkan.{name}"',
366389
f"{name}_bin",
367390
str(sizeBytes),
368391
shader_info_layouts,
@@ -373,7 +396,7 @@ def genCppH(
373396

374397
shader_info_cpp_code.append(
375398
textwrap.indent(
376-
"{{\"{}\",\n api::ShaderInfo(\n{})}}".format(
399+
'{{"{}",\n api::ShaderInfo(\n{})}}'.format(
377400
name,
378401
textwrap.indent(",\n".join(shader_info_args), " "),
379402
),
@@ -386,7 +409,7 @@ def genCppH(
386409
for registry_key in registry_keys:
387410
shader_info_registry_code.append(
388411
textwrap.indent(
389-
f"{{\"{op_name}\", {{{{\"{registry_key}\", \"{name}\"}}}}}}",
412+
f'{{"{op_name}", {{{{"{registry_key}", "{name}"}}}}}}',
390413
" ",
391414
),
392415
)
@@ -421,34 +444,20 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
421444

422445

423446
def main(argv: List[str]) -> int:
424-
parser = argparse.ArgumentParser(description='')
447+
parser = argparse.ArgumentParser(description="")
425448
parser.add_argument(
426-
'-i',
427-
'--glsl-paths',
428-
nargs='+',
449+
"-i",
450+
"--glsl-paths",
451+
nargs="+",
429452
help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
430-
default=['.'],
453+
default=["."],
431454
)
455+
parser.add_argument("-c", "--glslc-path", required=True, help="")
456+
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
457+
parser.add_argument("-o", "--output-path", required=True, help="")
432458
parser.add_argument(
433-
'-c',
434-
'--glslc-path',
435-
required=True,
436-
help='')
437-
parser.add_argument(
438-
'-t',
439-
'--tmp-dir-path',
440-
required=True,
441-
help='/tmp')
442-
parser.add_argument(
443-
'-o',
444-
'--output-path',
445-
required=True,
446-
help='')
447-
parser.add_argument(
448-
"--env",
449-
metavar="KEY=VALUE",
450-
nargs='*',
451-
help="Set a number of key-value pairs")
459+
"--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
460+
)
452461
options = parser.parse_args()
453462
env = DEFAULT_ENV
454463
for key, value in parse_arg_env(options.env).items():
@@ -466,9 +475,15 @@ def main(argv: List[str]) -> int:
466475
srcDirPaths=options.glsl_paths,
467476
glslcPath=options.glslc_path,
468477
tmpDirPath=options.tmp_dir_path,
469-
env=env)
478+
env=env,
479+
)
470480

471481
return 0
472482

473-
if __name__ == '__main__':
483+
484+
def invoke_main() -> None:
474485
sys.exit(main(sys.argv))
486+
487+
488+
if __name__ == "__main__":
489+
invoke_main() # pragma: no cover

tools/substitute.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os.path
44

55

6-
if __name__ == "__main__":
6+
def main() -> None:
77
parser = argparse.ArgumentParser()
88
parser.add_argument("--input-file")
99
parser.add_argument("--output-file")
@@ -22,3 +22,7 @@
2222

2323
with open(output_file, "w") as f:
2424
f.write(contents)
25+
26+
27+
if __name__ == "__main__":
28+
main() # pragma: no cover

torch/utils/_freeze.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
import itertools
2727
import marshal
2828
import os
29+
import types
2930
from dataclasses import dataclass
3031
from pathlib import Path
3132
from typing import List
32-
import types
3333

3434

3535
PATH_MARKER = "<Generated by torch::deploy>"
@@ -121,10 +121,10 @@ def write_bytecode(self, install_root):
121121
122122
Shared frozen modules evenly across the files.
123123
"""
124-
bytecode_file_names = [
125-
f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)
124+
bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)]
125+
bytecode_files = [
126+
open(os.path.join(install_root, name), "w") for name in bytecode_file_names
126127
]
127-
bytecode_files = [open(os.path.join(install_root, name), "w") for name in bytecode_file_names]
128128
it = itertools.cycle(bytecode_files)
129129
for m in self.frozen_modules:
130130
self.write_frozen(m, next(it))
@@ -202,7 +202,6 @@ def get_module_qualname(self, file_path: Path, top_package_path: Path) -> List[s
202202
module_parent = normalized_path.parent.parts
203203
return list(module_parent) + [module_basename]
204204

205-
206205
def compile_string(self, file_content: str) -> types.CodeType:
207206
# instead of passing in the real build time path to 'compile', we
208207
# pass in a marker instead. This prevents the build time path being
@@ -239,19 +238,26 @@ def compile_file(self, path: Path, top_package_path: Path):
239238

240239
bytecode = marshal.dumps(co)
241240
size = len(bytecode)
242-
if path.name == '__init__.py':
241+
if path.name == "__init__.py":
243242
# Python packages are signified by negative size.
244243
size = -size
245244
self.frozen_modules.append(
246245
FrozenModule(".".join(module_qualname), c_name, size, bytecode)
247246
)
248247

249-
if __name__ == "__main__":
248+
249+
def main() -> None:
250250
parser = argparse.ArgumentParser(description="Compile py source")
251251
parser.add_argument("paths", nargs="*", help="Paths to freeze.")
252252
parser.add_argument("--verbose", action="store_true", help="Print debug logs")
253-
parser.add_argument("--install-dir", "--install_dir", help="Root directory for all output files")
254-
parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules")
253+
parser.add_argument(
254+
"--install-dir", "--install_dir", help="Root directory for all output files"
255+
)
256+
parser.add_argument(
257+
"--oss",
258+
action="store_true",
259+
help="If it's OSS build, add a fake _PyImport_FrozenModules",
260+
)
255261
parser.add_argument(
256262
"--symbol-name",
257263
"--symbol_name",
@@ -265,7 +271,7 @@ def compile_file(self, path: Path, top_package_path: Path):
265271

266272
for p in args.paths:
267273
path = Path(p)
268-
if path.is_dir() and not Path.exists(path / '__init__.py'):
274+
if path.is_dir() and not Path.exists(path / "__init__.py"):
269275
# this 'top level path p' is a standard directory containing modules,
270276
# not a module itself
271277
# each 'mod' could be a dir containing __init__.py or .py file
@@ -277,3 +283,7 @@ def compile_file(self, path: Path, top_package_path: Path):
277283

278284
f.write_bytecode(args.install_dir)
279285
f.write_main(args.install_dir, args.oss, args.symbol_name)
286+
287+
288+
if __name__ == "__main__":
289+
main() # pragma: no cover

0 commit comments

Comments
 (0)