Skip to content

Commit 7231272

Browse files
authored
Add version selection snippet for r1.7 (#2577)
* Add version selection snippet for r1.7 * VSCode artifact ignore
1 parent 3f8c5dd commit 7231272

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ torch_xla/csrc/aten_xla_type_default.cpp
1919
# Below files are not deleted by "setup.py clean".
2020

2121
third_party/tensorflow/
22+
23+
# Visual Studio Code files
24+
.vscode
25+
.vs

torch_xla/__init__.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,57 @@
1+
import logging
12
import os
23
import re
4+
import socket
5+
import time
6+
7+
from .version import __version__
8+
9+
10+
def _maybe_select_tpu_version():
11+
# Setup correct TPU runtime version for Colab and Kaggle.
12+
13+
def _is_open(ip, port):
14+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
15+
if s.connect_ex((ip, int(port))) == 0:
16+
return True
17+
return False
18+
19+
def _wait_for_open(version, timeout=100, interval=10, log=True):
20+
tpu_addr = os.environ['TPU_NAME'].split('grpc://')[1]
21+
deadline = time.time() + timeout
22+
23+
while not _is_open(*tpu_addr.split(':')):
24+
if log:
25+
logging.warning(
26+
f'Waiting for TPU to be start up with version pytorch-{version}...')
27+
if time.time() > deadline:
28+
raise RuntimeError('Timed out waiting for TPU to start up')
29+
time.sleep(interval)
30+
31+
if log:
32+
logging.warning(
33+
f'TPU has started up successfully with version pytorch-{version}')
34+
35+
try:
36+
tpu_name = os.environ.get('TPU_NAME', '')
37+
if not tpu_name.startswith('grpc://'):
38+
# Not colab/kaggle
39+
return
40+
41+
import cloud_tpu_client
42+
client = cloud_tpu_client.Client(tpu_name)
43+
client.configure_tpu_version(
44+
f'pytorch-{__version__}', restart_type='ifNeeded')
45+
# client.wait_for_healthy() API doesn't work as we dont have TPU API access
46+
_wait_for_open(__version__)
47+
except ImportError:
48+
logging.warning((
49+
'Not selecting corresponding TPU runtime since cloud_tpu_client is not '
50+
'installed. Ignore if not running on Colab/Kaggle TPU.'))
51+
except Exception:
52+
# This path is hit, when we get throttled by the verison changer
53+
# when we import torch_xla from xmp.spawn-ed processes.
54+
_wait_for_open(__version__, log=False)
355

456

557
def _setup_grpc():
@@ -33,13 +85,13 @@ def _setup_xla_flags():
3385

3486

3587
# These needs to be called before the _XLAC module is loaded.
88+
_maybe_select_tpu_version()
3689
_setup_grpc()
3790
_setup_xla_flags()
3891

3992
import atexit
4093
import torch
4194
from ._patched_functions import _apply_patches
42-
from .version import __version__
4395
import _XLAC
4496

4597

0 commit comments

Comments
 (0)