Skip to content

Commit 9703109

Browse files
jysohn23dlibenziJackCaoG
authored
Final cherry-pick for r1.6 release (#2479)
Co-authored-by: Davide Libenzi <[email protected]> Co-authored-by: JackCaoG <[email protected]> Co-authored-by: JackCaoG <[email protected]>
1 parent 06d564b commit 9703109

File tree

8 files changed

+40
-13
lines changed

8 files changed

+40
-13
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ xla_model
1616
.. autofunction:: xrt_world_size
1717
.. autofunction:: all_reduce
1818
.. autofunction:: all_gather
19+
.. autofunction:: all_to_all
1920
.. autofunction:: add_step_closure
2021
.. autofunction:: wait_device_ops
2122
.. autofunction:: optimizer_step

test/pytorch_test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@
169169
'test_masked_select_mem_overlap', # doesn't raise
170170
'test_scatter_mem_overlap', # doesn't raise
171171
'test_index_mem_overlap', # doesn't raise
172+
'test_topk_nonfinite_xla_float32', # TFXLA update HLO changed for 1.6
173+
'test_topk_nonfinite_xla_float64', # TFXLA update HLO changed for 1.6
172174
},
173175
'TestViewOpsXLA': {
174176
'test_contiguous_nonview',

test/test_operations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,20 @@ def test_get_xla_tensor(self):
663663
self.assertEqual(tx, sx.data.cpu())
664664

665665

666+
class TestBinaryCrossEntropyLimitValue(XlaTestCase):
667+
668+
def test_cross_entropy_loss(self):
669+
670+
def test_fn(pred, target):
671+
lossfn = nn.BCELoss()
672+
return lossfn(pred, target)
673+
674+
pred = torch.tensor(1.0)
675+
target = torch.tensor(1.0)
676+
for offset in [1, 0, 1e-8, 1e-7]:
677+
self.runAtenTest([pred - offset, target], test_fn)
678+
679+
666680
class TestDynamicShape(XlaTestCase):
667681

668682
def test_nonzero_shape(self):

third_party/tensorflow

Submodule tensorflow updated 5737 files

torch_xla/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import socket
55
import time
66

7-
from .version import __version__ as version
7+
from .version import __version__
88

99

1010
def _maybe_select_tpu_version():
@@ -40,17 +40,18 @@ def _wait_for_open(version, timeout=100, interval=10, log=True):
4040

4141
import cloud_tpu_client
4242
client = cloud_tpu_client.Client(tpu_name)
43-
client.configure_tpu_version(f'pytorch-{version}', restart_type='ifNeeded')
43+
client.configure_tpu_version(
44+
f'pytorch-{__version__}', restart_type='ifNeeded')
4445
# client.wait_for_healthy() API doesn't work as we dont have TPU API access
45-
_wait_for_open(version)
46+
_wait_for_open(__version__)
4647
except ImportError:
4748
logging.warning((
4849
'Not selecting corresponding TPU runtime since cloud_tpu_client is not '
4950
'installed. Ignore if not running on Colab/Kaggle TPU.'))
5051
except Exception:
5152
# This path is hit, when we get throttled by the verison changer
5253
# when we import torch_xla from xmp.spawn-ed processes.
53-
_wait_for_open(version, log=False)
54+
_wait_for_open(__version__, log=False)
5455

5556

5657
def _setup_grpc():

torch_xla/core/xla_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,6 @@ def all_to_all(value,
478478
groups=None):
479479
"""Performs an XLA `AllToAll()` operation on the input tensor.
480480
481-
WARNING: This function is not very reliable, may produce wrong results under
482-
certain inputs. Use it at your own risk.
483-
484481
See: https://www.tensorflow.org/xla/operation_semantics#alltoall
485482
486483
Args:

torch_xla/csrc/reduction.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ xla::XlaOp CreateProduct(xla::XlaOp input,
127127
xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
128128
const absl::optional<xla::XlaOp>& weight,
129129
ReductionMode reduction) {
130+
static const float kLogBound = -100;
130131
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
131132
xla::XlaOp xweight;
132133
if (weight) {
@@ -137,8 +138,11 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
137138
XlaHelpers::ScalarBroadcast<float>(1.0, input_shape, target.builder());
138139
}
139140
xla::XlaOp one = xla::One(input.builder(), input_shape.element_type());
140-
xla::XlaOp result = -xweight * (target * xla::Log(input) +
141-
(one - target) * xla::Log(one - input));
141+
xla::XlaOp log_bound = XlaHelpers::ScalarValue(
142+
kLogBound, input_shape.element_type(), input.builder());
143+
xla::XlaOp result =
144+
-xweight * (target * xla::Max(xla::Log(input), log_bound) +
145+
(one - target) * xla::Max(xla::Log(one - input), log_bound));
142146
if (reduction == ReductionMode::kNone) {
143147
return result;
144148
}
@@ -154,6 +158,7 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
154158
xla::XlaOp BuildBinaryCrossEntropyBackward(
155159
xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp target,
156160
const absl::optional<xla::XlaOp>& weight, ReductionMode reduction) {
161+
static const float kEpsilon = 1e-12;
157162
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
158163
xla::XlaOp xweight;
159164
if (weight) {
@@ -164,7 +169,10 @@ xla::XlaOp BuildBinaryCrossEntropyBackward(
164169
XlaHelpers::ScalarBroadcast<float>(1.0, input_shape, target.builder());
165170
}
166171
xla::XlaOp one = xla::One(input.builder(), input_shape.element_type());
167-
xla::XlaOp result = xweight * (input - target) / input / (one - input);
172+
xla::XlaOp epsilon = XlaHelpers::ScalarValue(
173+
kEpsilon, input_shape.element_type(), input.builder());
174+
xla::XlaOp result =
175+
xweight * (input - target) / xla::Max(input * (one - input), epsilon);
168176
if (reduction == ReductionMode::kNone) {
169177
return result * grad_output;
170178
}

torch_xla/distributed/xla_multiprocessing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,13 @@ def _start_fn(index, pf_cfg, fn, args):
226226
# Calling _setup_replication() will trigger XLA library initialization, so the
227227
# environment must be fully setup before doing so.
228228
_setup_replication()
229+
fn(gindex, *args)
230+
231+
232+
def _mp_start_fn(index, pf_cfg, fn, args):
229233
exit_code = 0
230234
try:
231-
fn(gindex, *args)
235+
_start_fn(index, pf_cfg, fn, args)
232236
except Exception as e:
233237
print(
234238
'Exception in device={}: {}'.format(_get_multiprocessing_device(),
@@ -288,7 +292,7 @@ def spawn(fn,
288292
_start_fn(0, pf_cfg, fn, args)
289293
else:
290294
return torch.multiprocessing.start_processes(
291-
_start_fn,
295+
_mp_start_fn,
292296
args=(pf_cfg, fn, args),
293297
nprocs=pf_cfg.num_devices,
294298
join=join,

0 commit comments

Comments
 (0)