Skip to content

Commit 4d7bf72

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[BE][Easy] fix ruff rule needless-bool (SIM103) (pytorch#130206)
Pull Request resolved: pytorch#130206 Approved by: https://github.com/malfet
1 parent fa5f572 commit 4d7bf72

35 files changed

+64
-177
lines changed

benchmarks/dynamo/training_loss.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,7 @@ def check_loss(ref_loss, res_loss):
9999
assert len(ref_loss) == len(res_loss)
100100
length = len(ref_loss)
101101
x = min(length, 10)
102-
if sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 1e-1:
103-
return True
104-
else:
105-
return False
102+
return sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 0.1
106103

107104

108105
def parse_args():

benchmarks/operator_benchmark/benchmark_core.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -339,16 +339,12 @@ def _check_keep(self, test_flag, cmd_flag):
339339
return cmd_flag is None or test_flag == cmd_flag
340340

341341
def _check_operator_first_char(self, test_flag, cmd_flag):
342-
if cmd_flag is None or test_flag[:1].lower() in cmd_flag:
343-
return True
344-
return False
342+
return cmd_flag is None or test_flag[:1].lower() in cmd_flag
345343

346344
def _check_keep_list(self, test_flag, cmd_flag_list):
347-
if cmd_flag_list is None or any(
345+
return cmd_flag_list is None or any(
348346
test_flag == cmd_flag for cmd_flag in cmd_flag_list
349-
):
350-
return True
351-
return False
347+
)
352348

353349
def _keep_test(self, test_case):
354350
# TODO: consider regex matching for test filtering.
@@ -362,7 +358,7 @@ def _keep_test(self, test_case):
362358
)
363359

364360
# Filter framework, operator, test_name, tag, forward_only
365-
if (
361+
return (
366362
self._check_keep(op_test_config.test_name, self.args.test_name)
367363
and self._check_keep_list(test_case.op_bench.module_name(), operators)
368364
and self._check_operator_first_char(
@@ -381,10 +377,7 @@ def _keep_test(self, test_case):
381377
or "device" not in test_case.test_config.input_config
382378
or self.args.device in op_test_config.test_name
383379
)
384-
):
385-
return True
386-
387-
return False
380+
)
388381

389382
def _print_test_case_info(self, test_case):
390383
# Print out the test name and skip the real execution

scripts/compile_tests/common.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ def find(testcase, condition):
4343

4444
def skipped_test(testcase):
4545
def condition(children):
46-
tags = [child.tag for child in children]
47-
if "skipped" in tags:
48-
return True
49-
return False
46+
return "skipped" in {child.tag for child in children}
5047

5148
return find(testcase, condition)
5249

@@ -55,12 +52,8 @@ def passed_test(testcase):
5552
def condition(children):
5653
if len(children) == 0:
5754
return True
58-
tags = [child.tag for child in children]
59-
if "skipped" in tags:
60-
return False
61-
if "failed" in tags:
62-
return False
63-
return True
55+
tags = {child.tag for child in children}
56+
return "skipped" not in tags and "failed" not in tags
6457

6558
return find(testcase, condition)
6659

scripts/compile_tests/passrate.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,7 @@ def should_exclude(key):
4141
if test_file == "UNKNOWN":
4242
return True
4343
# Policy: "pass rate" does not include inductor, export, or dynamo tests.
44-
if test_file.startswith("inductor/"):
45-
return True
46-
if test_file.startswith("export/"):
47-
return True
48-
if test_file.startswith("dynamo/"):
49-
return True
50-
return False
44+
return test_file.startswith(("inductor/", "export/", "dynamo/"))
5145

5246

5347
def compute_pass_rate(eager_dir, dynamo_dir):

test/dynamo/test_higher_order_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@
3939

4040
def check_dynamic_shape_capture():
4141
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
42-
if not config.assume_static_by_default:
43-
return True
44-
return False
42+
return not config.assume_static_by_default
4543

4644

4745
def count_ops(gm, args, freq, op):

test/dynamo/test_structured_trace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ def inductor_schedule_fn(a):
5656

5757
class StructuredTraceTestingFilter(logging.Filter):
5858
def filter(self, record):
59-
if "str" in record.metadata:
60-
return False
61-
return True
59+
return "str" not in record.metadata
6260

6361

6462
class StructuredTraceTestingFormatter(logging.Formatter):

test/fx/test_subgraph_rewriter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,9 +844,7 @@ def second_input_is_scalar(match, original_graph, pattern_graph):
844844
if input_idx == 1:
845845
num_node = node
846846
input_idx += 1
847-
if not isinstance(match.nodes_map[num_node], (int, float)):
848-
return False
849-
return True
847+
return isinstance(match.nodes_map[num_node], (int, float))
850848

851849
def check_replacement_nodes(self, traced, matches):
852850
replacement_nodes_in_graph = [

test/jit/test_data_parallel.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,10 @@ def test_tensor_sharing(self):
114114

115115
def assert_share_data(t1, t2):
116116
# Only checks that they point to the same memory on the same device.
117-
if t1.device != t2.device:
118-
return False
119-
if t1.storage().data_ptr() != t2.storage().data_ptr():
120-
return False
121-
return True
117+
return (
118+
t1.device == t2.device
119+
and t1.storage().data_ptr() == t2.storage().data_ptr()
120+
)
122121

123122
for p1, p2 in zip(module.parameters(), replica[0].parameters()):
124123
self.assertTrue(assert_share_data(p1, p2))

test/run_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,7 @@ def maybe_set_hip_visible_devies():
104104

105105

106106
def strtobool(s):
107-
if s.lower() in ["", "0", "false", "off"]:
108-
return False
109-
return True
107+
return s.lower() not in {"", "0", "false", "off"}
110108

111109

112110
class TestChoices(list):

test/test_jit.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def doAutodiffCheck(testname):
167167
# these tests are disabled because BailOut nodes
168168
# inserted by ProfilingExecutor interfere with
169169
# subgraph slicing of Differentiable Graphs
170-
test_exceptions = [
170+
test_exceptions = (
171171
# functional
172172
'test_nn_dropout',
173173
'test_nn_log_softmax',
@@ -195,11 +195,9 @@ def doAutodiffCheck(testname):
195195
'test_split_with_sizes_dim_neg0',
196196
'test_split_with_sizes_size_0',
197197
'test_nn_max_pool2d_with_indices',
198-
]
198+
)
199199

200-
if testname in test_exceptions:
201-
return False
202-
return True
200+
return testname not in test_exceptions
203201

204202

205203
# TODO: enable TE in PE when all tests are fixed

0 commit comments

Comments
 (0)