Skip to content

Commit c5d5663

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Reformat code that doesn't follow the standard.
PiperOrigin-RevId: 800095985
1 parent 49ea494 commit c5d5663

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

tf_keras/metrics/confusion_metrics.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,15 +1624,15 @@ def result(self):
16241624
# ROC or PR values. In particular
16251625
# 1) Both measures diverge when there are no negative examples;
16261626
# 2) Both measures diverge when there are no true positives;
1627-
# 3) Recall gain becomes negative when the recall is lower than the label
1628-
# average (i.e. when more negative examples are classified positive
1629-
# than real positives).
1627+
# 3) Recall gain becomes negative when the recall is lower than the
1628+
# label average (i.e. when more negative examples are classified
1629+
# positive than real positives).
16301630
#
16311631
# We ignore case 1 as it is easily communicated. For case 2 we set
16321632
# recall_gain to 0 and precision_gain to 1. For case 3 we set the
16331633
# recall_gain to 0. These fixes will result in an overastimation of
1634-
# the AUC for estimateors that are anti-correlated with the label (at
1635-
# some thresholds).
1634+
# the AUC for estimateors that are anti-correlated with the label
1635+
# (at some thresholds).
16361636
#
16371637
# The scaling factor $\frac{P}{N}$ that is used to form both
16381638
# gain values.
@@ -1641,13 +1641,27 @@ def result(self):
16411641
tf.math.add(self.false_positives, self.true_negatives),
16421642
)
16431643

1644-
recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(self.false_negatives, self.true_positives)
1645-
precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(self.false_positives, self.true_positives)
1644+
recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
1645+
self.false_negatives, self.true_positives
1646+
)
1647+
precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
1648+
self.false_positives, self.true_positives
1649+
)
16461650
# Handle case 2.
1647-
recall_gain = tf.where(tf.equal(self.true_positives, 0.0), tf.zeros_like(recall_gain), recall_gain)
1648-
precision_gain = tf.where(tf.equal(self.true_positives, 0.0), tf.ones_like(precision_gain), precision_gain)
1651+
recall_gain = tf.where(
1652+
tf.equal(self.true_positives, 0.0),
1653+
tf.zeros_like(recall_gain),
1654+
recall_gain
1655+
)
1656+
precision_gain = tf.where(
1657+
tf.equal(self.true_positives, 0.0),
1658+
tf.ones_like(precision_gain),
1659+
precision_gain
1660+
)
16491661
# Handle case 3.
1650-
recall_gain = tf.math.maximum(recall_gain, tf.zeros_like(recall_gain))
1662+
recall_gain = tf.math.maximum(
1663+
recall_gain, tf.zeros_like(recall_gain)
1664+
)
16511665

16521666
x = recall_gain
16531667
y = precision_gain

tf_keras/metrics/confusion_metrics_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,7 +1623,7 @@ def test_weighted_pr_gain_majoring(self):
16231623
# precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
16241624
# heights = [max(0, 1), max(1, 1)] = [1, 1]
16251625
# widths = [(1 - 0), (0 - 0)] = [1, 0]
1626-
expected_result = 1 * 1 + 0 * 1
1626+
expected_result = 1 * 1 + 0 * 1
16271627
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
16281628

16291629
def test_weighted_pr_gain_minoring(self):
@@ -1649,7 +1649,9 @@ def test_weighted_pr_gain_minoring(self):
16491649

16501650
def test_weighted_pr_gain_interpolation(self):
16511651
self.setup()
1652-
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR_GAIN")
1652+
auc_obj = metrics.AUC(
1653+
num_thresholds=self.num_thresholds, curve="PR_GAIN"
1654+
)
16531655
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
16541656
result = auc_obj(
16551657
self.y_true, self.y_pred, sample_weight=self.sample_weight
@@ -1666,11 +1668,13 @@ def test_weighted_pr_gain_interpolation(self):
16661668

16671669
def test_pr_gain_interpolation(self):
16681670
self.setup()
1669-
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR_GAIN")
1671+
auc_obj = metrics.AUC(
1672+
num_thresholds=self.num_thresholds, curve="PR_GAIN"
1673+
)
16701674
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
16711675
y_true = tf.constant([0, 0, 0, 1, 0, 1, 0, 1, 1, 1])
16721676
y_pred = tf.constant([0.1, 0.2, 0.3, 0.3, 0.4, 0.4, 0.6, 0.6, 0.8, 0.9])
1673-
result = auc_obj( y_true, y_pred)
1677+
result = auc_obj(y_true, y_pred)
16741678

16751679
# tp = [5, 3, 0], fp = [5, 1, 0], fn = [0, 2, 5], tn = [0, 4, 4]
16761680
# scaling_factor (P/N) = 5/5 = 1

0 commit comments

Comments
 (0)