|
| 1 | +from delira._debug_mode import get_current_debug_mode, switch_debug_mode, \ |
| 2 | + set_debug_mode |
| 3 | +from delira._backends import get_backends |
| 4 | + |
1 | 5 | from ._version import get_versions |
2 | 6 | import json |
3 | 7 | import os |
4 | 8 | import warnings |
5 | 9 | warnings.simplefilter('default', DeprecationWarning) |
6 | 10 | warnings.simplefilter('ignore', ImportWarning) |
7 | 11 |
|
8 | | -# to register new possible backends, they have to be added to this list. |
9 | | -# each backend should consist of a tuple of length 2 with the first entry |
10 | | -# being the package import name and the second being the backend abbreviation. |
11 | | -# E.g. TensorFlow's package is named 'tensorflow' but if the package is found, |
12 | | -# it will be considered as 'tf' later on |
13 | | -__POSSIBLE_BACKENDS = [("torch", "torch"), ("tensorflow", "tf")] |
14 | | -__BACKENDS = [] |
15 | | - |
16 | | -__DEBUG_MODE = False |
17 | | - |
18 | | - |
19 | | -def _determine_backends(): |
20 | | - |
21 | | - _config_file = __file__.replace("__init__.py", ".delira") |
22 | | - # look for config file to determine backend |
23 | | - # if file exists: load config into environment variables |
24 | | - |
25 | | - if not os.path.isfile(_config_file): |
26 | | - _backends = {} |
27 | | - # try to import all possible backends to determine valid backends |
28 | | - |
29 | | - import importlib |
30 | | - for curr_backend in __POSSIBLE_BACKENDS: |
31 | | - try: |
32 | | - assert len(curr_backend) == 2 |
33 | | - assert all([isinstance(_tmp, str) for _tmp in curr_backend]), \ |
34 | | - "All entries in current backend must be strings" |
35 | | - |
36 | | - # check if backend can be imported |
37 | | - bcknd = importlib.util.find_spec(curr_backend[0]) |
38 | | - |
39 | | - if bcknd is not None: |
40 | | - _backends[curr_backend[1]] = True |
41 | | - else: |
42 | | - _backends[curr_backend[1]] = False |
43 | | - del bcknd |
44 | | - |
45 | | - except ValueError: |
46 | | - _backends[curr_backend[1]] = False |
47 | | - |
48 | | - with open(_config_file, "w") as f: |
49 | | - json.dump({"version": __version__, "backend": _backends}, |
50 | | - f, sort_keys=True, indent=4) |
51 | | - |
52 | | - del _backends |
53 | | - |
54 | | - # set values from config file to variable |
55 | | - with open(_config_file) as f: |
56 | | - _config_dict = json.load(f) |
57 | | - for key, val in _config_dict.pop("backend").items(): |
58 | | - if val: |
59 | | - __BACKENDS.append(key.upper()) |
60 | | - del _config_dict |
61 | | - |
62 | | - del _config_file |
63 | | - |
64 | | - |
65 | | -def get_backends(): |
66 | | - """ |
67 | | - Return List of currently available backends |
68 | | -
|
69 | | - Returns |
70 | | - ------- |
71 | | - list |
72 | | - list of strings containing the currently installed backends |
73 | | -
|
74 | | - """ |
75 | | - |
76 | | - if not __BACKENDS: |
77 | | - _determine_backends() |
78 | | - return __BACKENDS |
79 | | - |
80 | | - |
81 | | -# Functions to get and set the internal __DEBUG_MODE variable. This variable |
82 | | -# currently only defines whether to use multiprocessing or not. At the moment |
83 | | -# this is only used inside the BaseDataManager, which either returns a |
84 | | -# MultiThreadedAugmenter or a SingleThreadedAugmenter depending on the current |
85 | | -# debug mode. |
86 | | -# All other functions using multiprocessing should be aware of this and |
87 | | -# implement a functionality without multiprocessing |
88 | | -# (even if this slows down things a lot!). |
89 | | - |
90 | | -def get_current_debug_mode(): |
91 | | - """ |
92 | | - Getter function for the current debug mode |
93 | | -
|
94 | | - Returns |
95 | | - ------- |
96 | | - bool |
97 | | - current debug mode |
98 | | -
|
99 | | - """ |
100 | | - return __DEBUG_MODE |
101 | | - |
102 | | - |
103 | | -def switch_debug_mode(): |
104 | | - """ |
105 | | - Alternates the current debug mode |
106 | | -
|
107 | | - """ |
108 | | - set_debug_mode(not get_current_debug_mode()) |
109 | | - |
110 | | - |
111 | | -def set_debug_mode(mode: bool): |
112 | | - """ |
113 | | - Sets a new debug mode |
114 | | -
|
115 | | - Parameters |
116 | | - ---------- |
117 | | - mode : bool |
118 | | - the new debug mode |
119 | | -
|
120 | | - """ |
121 | | - global __DEBUG_MODE |
122 | | - __DEBUG_MODE = mode |
123 | | - |
124 | | - |
125 | 12 | __version__ = get_versions()['version'] |
126 | 13 | del get_versions |
0 commit comments