Skip to content

Commit 5b677d4

Browse files
committed
Fix annotations
1 parent cd980df commit 5b677d4

File tree

8 files changed

+122
-88
lines changed

8 files changed

+122
-88
lines changed

examples/cloud/serverless_training/pytorch/data_ingestion.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
# Your S3 bucket
1616
S3_BUCKET = "your_s3_bucket"
1717

18-
IMAGES_URI = f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_images"
19-
LABELS_URI = f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_labels"
18+
IMAGES_URI = (
19+
f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_images"
20+
)
21+
LABELS_URI = (
22+
f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_labels"
23+
)
2024

2125

2226
# Let's define an ingestion function
@@ -63,7 +67,9 @@ def mnist_ingest(ingestion_func: Any) -> None:
6367
ingestion_func(data=labels, batch_size=64, uri=LABELS_URI)
6468

6569

66-
tiledb.client.configure(username=TILEDB_USER_NAME, password=TILEDB_PASSWD, workspace=TILEDB_WORKSPACE)
70+
tiledb.client.configure(
71+
username=TILEDB_USER_NAME, password=TILEDB_PASSWD, workspace=TILEDB_WORKSPACE
72+
)
6773
tiledb.client.login()
6874

6975
tiledb.client.udf.exec(mnist_ingest, ingestion_func=ingest_in_tiledb)

examples/cloud/serverless_training/pytorch/model_load_and_predict.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
# Your S3 bucket
1515
S3_BUCKET = "your_s3_bucket"
1616

17-
IMAGES_URI = f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_images"
18-
LABELS_URI = f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_labels"
19-
MODEL_URI = f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_model"
17+
IMAGES_URI = (
18+
f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_images"
19+
)
20+
LABELS_URI = (
21+
f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_labels"
22+
)
23+
MODEL_URI = (
24+
f"tiledb://{TILEDB_WORKSPACE}/{TILEDB_TEAMSPACE}/s3://{S3_BUCKET}/mnist_model"
25+
)
2026

2127
IO_BATCH_SIZE = 20000
2228

@@ -71,7 +77,9 @@ def forward(self, x: torch.Tensor) -> Any:
7177
return [np.argmax(pred) for pred in output.numpy()]
7278

7379

74-
tiledb.client.configure(username=TILEDB_USER_NAME, password=TILEDB_PASSWD, workspace=TILEDB_WORKSPACE)
80+
tiledb.client.configure(
81+
username=TILEDB_USER_NAME, password=TILEDB_PASSWD, workspace=TILEDB_WORKSPACE
82+
)
7583
tiledb.client.login()
7684

7785
predictions = tiledb.client.udf.exec(predict)

examples/cloud/serverless_training/pytorch/model_training.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
import tiledb.client
4+
45
from tiledb.ml.readers.types import ArrayParams
56

67
# Your TileDB username and password, exported as environmental variables
@@ -51,10 +52,10 @@ def forward(self, x: torch.Tensor) -> Any:
5152
logits = self.linear_relu_stack(x)
5253
return logits
5354

54-
def do_random_noise(img, mag=0.1):
55-
noise = np.random.uniform(-1, 1,img.shape)*mag
55+
def do_random_noise(img: np.ndarray, mag: float = 0.1) -> np.ndarray:
56+
noise = np.random.uniform(-1, 1, img.shape) * mag
5657
img = img + noise
57-
img = np.clip(img,0,1)
58+
img = np.clip(img, 0, 1)
5859
return img
5960

6061
with tiledb.open(IMAGES_URI) as x, tiledb.open(LABELS_URI) as y:
@@ -118,7 +119,9 @@ def do_random_noise(img, mag=0.1):
118119
model.save()
119120

120121

121-
tiledb.client.configure(username=TILEDB_USER_NAME, password=TILEDB_PASSWD, workspace=TILEDB_WORKSPACE)
122+
tiledb.client.configure(
123+
username=TILEDB_USER_NAME, password=TILEDB_PASSWD, workspace=TILEDB_WORKSPACE
124+
)
122125
tiledb.client.login()
123126

124127
tiledb.client.udf.exec(train)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ test=pytest
4141
[flake8]
4242
statistics = true
4343
exclude = .git
44-
ignore = E203, E501, W503, B950
44+
ignore = E203, E231, E501, E713, W503, B950
4545
select = B,C,E,F,W,T4,B9
4646
per-file-ignores =
4747
__init__.py: F401, F403

tests/models/test_cloud_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import pytest
1+
from tiledb.ml.models._cloud_utils import update_file_properties
22

3-
from tiledb.ml.models._cloud_utils import (
4-
update_file_properties,
5-
)
63

74
class TestCloudUtils:
8-
95
def test_update_file_properties(self, mocker):
106
mock_tiledb_cloud_update_file_properties = mocker.patch(
117
"tiledb.client.array.update_file_properties"

tests/models/test_tensorflow_keras_models.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
import platform
88
import shutil
99

10+
import keras
1011
import numpy as np
1112
import pytest
1213
import tensorflow as tf
13-
import keras
1414

1515
import tiledb
1616
from tiledb.ml import __version__ as tiledb_ml_version
1717
from tiledb.ml.models import SHORT_PREVIEW_LIMIT
1818
from tiledb.ml.models.tensorflow_keras import TensorflowKerasTileDBModel
1919

2020
# Detect Keras version for conditional test behavior
21-
keras_version = tuple(int(x) for x in keras.__version__.split('.')[:2])
21+
keras_version = tuple(int(x) for x in keras.__version__.split(".")[:2])
2222
KERAS_3_OR_HIGHER = keras_version[0] >= 3
2323

2424
try:
@@ -35,10 +35,7 @@
3535
except ImportError:
3636
try:
3737
# For Keras 3.x (TensorFlow >=2.16), use the public API
38-
from keras.testing import (
39-
get_small_functional_mlp,
40-
get_small_sequential_mlp,
41-
)
38+
from keras.testing import get_small_functional_mlp, get_small_sequential_mlp
4239
except ImportError:
4340
try:
4441
# Fallback for older versions with keras.src
@@ -50,21 +47,26 @@
5047
# Final fallback: implement these functions ourselves for Keras 3.x
5148
def get_small_sequential_mlp(num_hidden, num_classes, input_dim):
5249
"""Create a small sequential MLP for testing."""
53-
model = tf.keras.Sequential([
54-
tf.keras.layers.Input(shape=(input_dim,)),
55-
tf.keras.layers.Dense(num_hidden, activation='relu'),
56-
tf.keras.layers.Dense(num_classes, activation='softmax')
57-
])
50+
model = tf.keras.Sequential(
51+
[
52+
tf.keras.layers.Input(shape=(input_dim,)),
53+
tf.keras.layers.Dense(num_hidden, activation="relu"),
54+
tf.keras.layers.Dense(num_classes, activation="softmax"),
55+
]
56+
)
5857
return model
59-
58+
6059
def get_small_functional_mlp(num_hidden, num_classes, input_dim):
6160
"""Create a small functional MLP for testing."""
6261
inputs = tf.keras.Input(shape=(input_dim,))
63-
x = tf.keras.layers.Dense(num_hidden, activation='relu')(inputs)
64-
outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
62+
x = tf.keras.layers.Dense(num_hidden, activation="relu")(inputs)
63+
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(
64+
x
65+
)
6566
model = tf.keras.Model(inputs=inputs, outputs=outputs)
6667
return model
6768

69+
6870
# Suppress all Tensorflow messages
6971
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7072
batch_get_value = tf.keras.backend.batch_get_value
@@ -372,8 +374,10 @@ def test_save_load_with_sequence_features(self, tmpdir, loss, optimizer, metrics
372374

373375
def test_functional_model_save_load_with_custom_loss_and_metric(self, tmpdir):
374376
if KERAS_3_OR_HIGHER:
375-
pytest.skip("Custom loss/metric with Lambda layers not fully supported in Keras 3.x serialization")
376-
377+
pytest.skip(
378+
"Custom loss/metric with Lambda layers not fully supported in Keras 3.x serialization"
379+
)
380+
377381
inputs = tf.keras.layers.Input(shape=(4,))
378382
x = tf.keras.layers.Dense(8, activation="relu")(inputs)
379383
outputs = tf.keras.layers.Dense(3, activation="softmax")(x)
@@ -441,7 +445,7 @@ def test_sequential_model_save_load_without_input_shape(self, tmpdir):
441445
model.add(tf.keras.layers.Dense(2))
442446
model.add(tf.keras.layers.RepeatVector(3))
443447
model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)))
444-
448+
445449
# sample_weight_mode was removed in Keras 3.x
446450
if KERAS_3_OR_HIGHER:
447451
model.compile(

tiledb/ml/models/_cloud_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from typing import Mapping, Optional
32

43
try:
@@ -37,7 +36,7 @@
3736
# s3_path = profile.default_s3_path
3837
# else:
3938
# raise
40-
39+
4140
# return os.path.join(s3_path, CLOUD_MODELS) if s3_path is not None else None
4241

4342
# # if namespace is None:

0 commit comments

Comments
 (0)