@@ -127,6 +127,7 @@ xla::XlaOp CreateProduct(xla::XlaOp input,
127127xla::XlaOp BuildBinaryCrossEntropy (xla::XlaOp input, xla::XlaOp target,
128128 const absl::optional<xla::XlaOp>& weight,
129129 ReductionMode reduction) {
130+ static const float kLogBound = -100 ;
130131 const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp (input);
131132 xla::XlaOp xweight;
132133 if (weight) {
@@ -137,8 +138,11 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
137138 XlaHelpers::ScalarBroadcast<float >(1.0 , input_shape, target.builder ());
138139 }
139140 xla::XlaOp one = xla::One (input.builder (), input_shape.element_type ());
140- xla::XlaOp result = -xweight * (target * xla::Log (input) +
141- (one - target) * xla::Log (one - input));
141+ xla::XlaOp log_bound = XlaHelpers::ScalarValue (
142+ kLogBound , input_shape.element_type (), input.builder ());
143+ xla::XlaOp result =
144+ -xweight * (target * xla::Max (xla::Log (input), log_bound) +
145+ (one - target) * xla::Max (xla::Log (one - input), log_bound));
142146 if (reduction == ReductionMode::kNone ) {
143147 return result;
144148 }
@@ -154,6 +158,7 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
154158xla::XlaOp BuildBinaryCrossEntropyBackward (
155159 xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp target,
156160 const absl::optional<xla::XlaOp>& weight, ReductionMode reduction) {
161+ static const float kEpsilon = 1e-12 ;
157162 const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp (input);
158163 xla::XlaOp xweight;
159164 if (weight) {
@@ -164,7 +169,10 @@ xla::XlaOp BuildBinaryCrossEntropyBackward(
164169 XlaHelpers::ScalarBroadcast<float >(1.0 , input_shape, target.builder ());
165170 }
166171 xla::XlaOp one = xla::One (input.builder (), input_shape.element_type ());
167- xla::XlaOp result = xweight * (input - target) / input / (one - input);
172+ xla::XlaOp epsilon = XlaHelpers::ScalarValue (
173+ kEpsilon , input_shape.element_type (), input.builder ());
174+ xla::XlaOp result =
175+ xweight * (input - target) / xla::Max (input * (one - input), epsilon);
168176 if (reduction == ReductionMode::kNone ) {
169177 return result * grad_output;
170178 }
0 commit comments