Skip to content

Commit 7c3b27f

Browse files
committed
adding type annotation the ruff check for ann
1 parent d40ed5c commit 7c3b27f

File tree

3 files changed

+61
-54
lines changed

3 files changed

+61
-54
lines changed

dbldatagen/spark_singleton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class SparkSingleton:
1919
"""A singleton class which returns one Spark session instance"""
2020

2121
@classmethod
22-
def getInstance(cls) -> SparkSession:
22+
def getInstance(cls: type["SparkSingleton"]) -> SparkSession:
2323
"""Creates a `SparkSession` instance for Datalib.
2424
2525
:returns: A Spark instance
@@ -28,7 +28,7 @@ def getInstance(cls) -> SparkSession:
2828
return SparkSession.builder.getOrCreate()
2929

3030
@classmethod
31-
def getLocalInstance(cls, appName: str = "new Spark session", useAllCores: bool = True) -> SparkSession:
31+
def getLocalInstance(cls: type["SparkSingleton"], appName: str = "new Spark session", useAllCores: bool = True) -> SparkSession:
3232
"""Creates a machine local `SparkSession` instance for Datalib.
3333
By default, it uses `n-1` cores of the available cores for the spark session,
3434
where `n` is total cores available.

dbldatagen/utils.py

Lines changed: 58 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
import re
1414
import time
1515
import warnings
16+
from collections.abc import Callable
1617
from datetime import timedelta
18+
from typing import Any, Optional, Union
1719

1820
import jmespath
1921

2022

21-
def deprecated(message=""):
23+
def deprecated(message: str = "") -> Callable[[Callable[..., Any]], Callable[..., Any]]:
2224
"""
2325
Define a deprecated decorator without dependencies on 3rd party libraries
2426
@@ -27,9 +29,9 @@ def deprecated(message=""):
2729
"""
2830

2931
# create closure around function that follows use of the decorator
30-
def deprecated_decorator(func):
32+
def deprecated_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
3133
@functools.wraps(func)
32-
def deprecated_func(*args, **kwargs):
34+
def deprecated_func(*args: object, **kwargs: object) -> object:
3335
warnings.warn(f"`{func.__name__}` is a deprecated function or method. \n{message}",
3436
category=DeprecationWarning, stacklevel=1)
3537
warnings.simplefilter("default", DeprecationWarning)
@@ -47,21 +49,21 @@ class DataGenError(Exception):
4749
:param baseException: underlying exception, if any that caused the issue
4850
"""
4951

50-
def __init__(self, msg, baseException=None):
52+
def __init__(self: "DataGenError", msg: str, baseException: Optional[Exception] = None) -> None:
5153
""" constructor
5254
"""
5355
super().__init__(msg)
54-
self._underlyingException = baseException
55-
self._msg = msg
56+
self._underlyingException: Optional[Exception] = baseException
57+
self._msg: str = msg
5658

57-
def __repr__(self):
59+
def __repr__(self: "DataGenError") -> str:
5860
return f"DataGenError(msg='{self._msg}', baseException={self._underlyingException})"
5961

60-
def __str__(self):
62+
def __str__(self: "DataGenError") -> str:
6163
return f"DataGenError(msg='{self._msg}', baseException={self._underlyingException})"
6264

6365

64-
def coalesce_values(*args):
66+
def coalesce_values(*args: object) -> Optional[object]:
6567
"""For a supplied list of arguments, returns the first argument that does not have the value `None`
6668
6769
:param args: variable list of arguments which are evaluated
@@ -73,7 +75,7 @@ def coalesce_values(*args):
7375
return None
7476

7577

76-
def ensure(cond, msg="condition does not hold true"):
78+
def ensure(cond: bool, msg: str = "condition does not hold true") -> None:
7779
"""ensure(cond, s) => throws Exception(s) if c is not true
7880
7981
:param cond: condition to test
@@ -82,14 +84,14 @@ def ensure(cond, msg="condition does not hold true"):
8284
:returns: Does not return anything but raises exception if condition does not hold
8385
"""
8486

85-
def strip_margin(text):
87+
def strip_margin(text: str) -> str:
8688
return re.sub(r"\n[ \t]*\|", "\n", text)
8789

8890
if not cond:
8991
raise DataGenError(strip_margin(msg))
9092

9193

92-
def mkBoundsList(x, default):
94+
def mkBoundsList(x: Optional[Union[int, list[int]]], default: Union[int, list[int]]) -> tuple[bool, list[int]]:
9395
""" make a bounds list from supplied parameter - otherwise use default
9496
9597
:param x: integer or list of 2 values that define bounds list
@@ -100,16 +102,20 @@ def mkBoundsList(x, default):
100102
retval = (True, [default, default]) if isinstance(default, int) else (True, list(default))
101103
return retval
102104
elif isinstance(x, int):
103-
bounds_list = [x, x]
105+
bounds_list: list[int] = [x, x]
104106
assert len(bounds_list) == 2, "bounds list must be of length 2"
105107
return False, bounds_list
106108
else:
107-
bounds_list = list(x)
109+
bounds_list: list[int] = list(x)
108110
assert len(bounds_list) == 2, "bounds list must be of length 2"
109111
return False, bounds_list
110112

111113

112-
def topologicalSort(sources, initial_columns=None, flatten=True):
114+
def topologicalSort(
115+
sources: list[tuple[str, set[str]]],
116+
initial_columns: Optional[list[str]] = None,
117+
flatten: bool = True
118+
) -> Union[list[str], list[list[str]]]:
113119
""" Perform a topological sort over sources
114120
115121
Used to compute the column test data generation order of the column generation dependencies.
@@ -129,16 +135,16 @@ def topologicalSort(sources, initial_columns=None, flatten=True):
129135
Overall the effect is that the input build order should be retained unless there are forward references
130136
"""
131137
# generate a copy so that we can modify in place
132-
pending = [(name, set(deps)) for name, deps in sources]
133-
provided = [] if initial_columns is None else initial_columns[:]
134-
build_orders = [] if initial_columns is None else [initial_columns]
138+
pending: list[tuple[str, set[str]]] = [(name, set(deps)) for name, deps in sources]
139+
provided: list[str] = [] if initial_columns is None else initial_columns[:]
140+
build_orders: list[list[str]] = [] if initial_columns is None else [initial_columns]
135141

136142
while pending:
137-
next_pending = []
138-
gen = []
139-
value_emitted = False
140-
defer_emitted = False
141-
gen_provided = []
143+
next_pending: list[tuple[str, set[str]]] = []
144+
gen: list[str] = []
145+
value_emitted: bool = False
146+
defer_emitted: bool = False
147+
gen_provided: list[str] = []
142148
for entry in pending:
143149
name, deps = entry
144150
deps.difference_update(provided)
@@ -165,7 +171,7 @@ def topologicalSort(sources, initial_columns=None, flatten=True):
165171
pending = next_pending
166172

167173
if flatten:
168-
flattened_list = [item for sublist in build_orders for item in sublist]
174+
flattened_list: list[str] = [item for sublist in build_orders for item in sublist]
169175
return flattened_list
170176
else:
171177
return build_orders
@@ -176,31 +182,31 @@ def topologicalSort(sources, initial_columns=None, flatten=True):
176182
_WEEKS_PER_YEAR = 52
177183

178184

179-
def parse_time_interval(spec):
185+
def parse_time_interval(spec: str) -> timedelta:
180186
"""parse time interval from string"""
181-
hours = 0
182-
minutes = 0
183-
weeks = 0
184-
microseconds = 0
185-
milliseconds = 0
186-
seconds = 0
187-
years = 0
188-
days = 0
187+
hours: int = 0
188+
minutes: int = 0
189+
weeks: int = 0
190+
microseconds: int = 0
191+
milliseconds: int = 0
192+
seconds: int = 0
193+
years: int = 0
194+
days: int = 0
189195

190196
assert spec is not None, "Must have valid time interval specification"
191197

192198
# get time specs such as 12 days, etc. Supported timespans are years, days, hours, minutes, seconds
193-
timespecs = [x.strip() for x in spec.strip().split(",")]
199+
timespecs: list[str] = [x.strip() for x in spec.strip().split(",")]
194200

195201
for ts in timespecs:
196202
# allow both 'days=1' and '1 day' syntax
197-
timespec_parts = re.findall(PATTERN_NAME_EQUALS_VALUE, ts)
203+
timespec_parts: list[tuple[str, str]] = re.findall(PATTERN_NAME_EQUALS_VALUE, ts)
198204
# findall returns list of tuples
199205
if timespec_parts is not None and len(timespec_parts) > 0:
200-
num_parts = len(timespec_parts[0])
206+
num_parts: int = len(timespec_parts[0])
201207
assert num_parts >= 1, "must have numeric specification and time element such as `12 hours` or `hours=12`"
202-
time_value = int(timespec_parts[0][num_parts - 1])
203-
time_type = timespec_parts[0][0].lower()
208+
time_value: int = int(timespec_parts[0][num_parts - 1])
209+
time_type: str = timespec_parts[0][0].lower()
204210
else:
205211
timespec_parts = re.findall(PATTERN_VALUE_SPACE_NAME, ts)
206212
num_parts = len(timespec_parts[0])
@@ -225,7 +231,7 @@ def parse_time_interval(spec):
225231
elif time_type in ["milliseconds", "millisecond"]:
226232
milliseconds = time_value
227233

228-
delta = timedelta(
234+
delta: timedelta = timedelta(
229235
days=days,
230236
seconds=seconds,
231237
microseconds=microseconds,
@@ -238,7 +244,7 @@ def parse_time_interval(spec):
238244
return delta
239245

240246

241-
def strip_margins(s, marginChar):
247+
def strip_margins(s: str, marginChar: str) -> str:
242248
"""
243249
Python equivalent of Scala stripMargins method
244250
Takes a string (potentially multiline) and strips all chars up and including the first occurrence of `marginChar`.
@@ -258,20 +264,20 @@ def strip_margins(s, marginChar):
258264
assert s is not None and isinstance(s, str)
259265
assert marginChar is not None and isinstance(marginChar, str)
260266

261-
lines = s.split("\n")
262-
revised_lines = []
267+
lines: list[str] = s.split("\n")
268+
revised_lines: list[str] = []
263269

264270
for line in lines:
265271
if marginChar in line:
266-
revised_line = line[line.index(marginChar) + 1:]
272+
revised_line: str = line[line.index(marginChar) + 1:]
267273
revised_lines.append(revised_line)
268274
else:
269275
revised_lines.append(line)
270276

271277
return "\n".join(revised_lines)
272278

273279

274-
def split_list_matching_condition(lst, cond):
280+
def split_list_matching_condition(lst: list[Any], cond: Callable[[Any], bool]) -> list[list[Any]]:
275281
"""
276282
Split a list on elements that match a condition
277283
@@ -293,9 +299,9 @@ def split_list_matching_condition(lst, cond):
293299
:arg cond: lambda function or function taking single argument and returning True or False
294300
:returns: list of sublists
295301
"""
296-
retval = []
302+
retval: list[list[Any]] = []
297303

298-
def match_condition(matchList, matchFn):
304+
def match_condition(matchList: list[Any], matchFn: Callable[[Any], bool]) -> int:
299305
"""Return first index of element of list matching condition"""
300306
if matchList is None or len(matchList) == 0:
301307
return -1
@@ -311,7 +317,7 @@ def match_condition(matchList, matchFn):
311317
elif len(lst) == 1:
312318
retval = [lst]
313319
else:
314-
ix = match_condition(lst, cond)
320+
ix: int = match_condition(lst, cond)
315321
if ix != -1:
316322
retval.extend(split_list_matching_condition(lst[0:ix], cond))
317323
retval.append(lst[ix:ix + 1])
@@ -323,7 +329,7 @@ def match_condition(matchList, matchFn):
323329
return [el for el in retval if el != []]
324330

325331

326-
def json_value_from_path(searchPath, jsonData, defaultValue):
332+
def json_value_from_path(searchPath: str, jsonData: str, defaultValue: object) -> object:
327333
""" Get JSON value from JSON data referenced by searchPath
328334
329335
searchPath should be a JSON path as supported by the `jmespath` package
@@ -337,20 +343,20 @@ def json_value_from_path(searchPath, jsonData, defaultValue):
337343
assert searchPath is not None and len(searchPath) > 0, "search path cannot be empty"
338344
assert jsonData is not None and len(jsonData) > 0, "JSON data cannot be empty"
339345

340-
jsonDict = json.loads(jsonData)
346+
jsonDict: dict = json.loads(jsonData)
341347

342-
jsonValue = jmespath.search(searchPath, jsonDict)
348+
jsonValue: Any = jmespath.search(searchPath, jsonDict)
343349

344350
if jsonValue is not None:
345351
return jsonValue
346352

347353
return defaultValue
348354

349355

350-
def system_time_millis():
356+
def system_time_millis() -> int:
351357
""" return system time as milliseconds since start of epoch
352358
353359
:return: system time millis as long
354360
"""
355-
curr_time = round(time.time() / 1000)
361+
curr_time: int = round(time.time() / 1000)
356362
return curr_time

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ fmt = ["ruff check . --fix",
116116
"mypy .",
117117
"pylint --output-format=colorized -j 0 dbldatagen tests"]
118118
verify = ["ruff check .",
119+
"ruff check . --select ANN",
119120
"mypy .",
120121
"pylint --output-format=colorized -j 0 dbldatagen tests"]
121122

0 commit comments

Comments
 (0)