|
| 1 | +import logging |
1 | 2 | import os |
2 | 3 | import re |
| 4 | +import socket |
| 5 | +import time |
| 6 | + |
| 7 | +from .version import __version__ |
| 8 | + |
| 9 | + |
| 10 | +def _maybe_select_tpu_version(): |
| 11 | + # Setup correct TPU runtime version for Colab and Kaggle. |
| 12 | + |
| 13 | + def _is_open(ip, port): |
| 14 | + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 15 | + if s.connect_ex((ip, int(port))) == 0: |
| 16 | + return True |
| 17 | + return False |
| 18 | + |
| 19 | + def _wait_for_open(version, timeout=100, interval=10, log=True): |
| 20 | + tpu_addr = os.environ['TPU_NAME'].split('grpc://')[1] |
| 21 | + deadline = time.time() + timeout |
| 22 | + |
| 23 | + while not _is_open(*tpu_addr.split(':')): |
| 24 | + if log: |
| 25 | + logging.warning( |
| 26 | + f'Waiting for TPU to be start up with version pytorch-{version}...') |
| 27 | + if time.time() > deadline: |
| 28 | + raise RuntimeError('Timed out waiting for TPU to start up') |
| 29 | + time.sleep(interval) |
| 30 | + |
| 31 | + if log: |
| 32 | + logging.warning( |
| 33 | + f'TPU has started up successfully with version pytorch-{version}') |
| 34 | + |
| 35 | + try: |
| 36 | + tpu_name = os.environ.get('TPU_NAME', '') |
| 37 | + if not tpu_name.startswith('grpc://'): |
| 38 | + # Not colab/kaggle |
| 39 | + return |
| 40 | + |
| 41 | + import cloud_tpu_client |
| 42 | + client = cloud_tpu_client.Client(tpu_name) |
| 43 | + client.configure_tpu_version( |
| 44 | + f'pytorch-{__version__}', restart_type='ifNeeded') |
| 45 | + # client.wait_for_healthy() API doesn't work as we dont have TPU API access |
| 46 | + _wait_for_open(__version__) |
| 47 | + except ImportError: |
| 48 | + logging.warning(( |
| 49 | + 'Not selecting corresponding TPU runtime since cloud_tpu_client is not ' |
| 50 | + 'installed. Ignore if not running on Colab/Kaggle TPU.')) |
| 51 | + except Exception: |
| 52 | + # This path is hit, when we get throttled by the verison changer |
| 53 | + # when we import torch_xla from xmp.spawn-ed processes. |
| 54 | + _wait_for_open(__version__, log=False) |
3 | 55 |
|
4 | 56 |
|
5 | 57 | def _setup_grpc(): |
@@ -33,13 +85,13 @@ def _setup_xla_flags(): |
33 | 85 |
|
34 | 86 |
|
35 | 87 | # These needs to be called before the _XLAC module is loaded. |
| 88 | +_maybe_select_tpu_version() |
36 | 89 | _setup_grpc() |
37 | 90 | _setup_xla_flags() |
38 | 91 |
|
39 | 92 | import atexit |
40 | 93 | import torch |
41 | 94 | from ._patched_functions import _apply_patches |
42 | | -from .version import __version__ |
43 | 95 | import _XLAC |
44 | 96 |
|
45 | 97 |
|
|
0 commit comments