Skip to content

Commit 160fcd3

Browse files
authored
freeze model info and validate metadata entries (bentoml#2363)
* fix(internal): prevent updating Bento.info * feat(internal): add validate_labels and validate_metadata * feat(internal): allow LazyType to take tuples of types * feat(internal): freeze model info on save * chore(internal): format code * chore(internal): fix typing errors
1 parent a9904e7 commit 160fcd3

File tree

7 files changed

+311
-82
lines changed

7 files changed

+311
-82
lines changed

bentoml/_internal/bento/bento.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,11 @@ class Bento(StoreItem):
9696
_model_store: ModelStore
9797
_doc: t.Optional[str] = None
9898

99-
_flushed: bool = False
100-
10199
@staticmethod
102100
def _export_ext() -> str:
103101
return "bento"
104102

105-
@__fs.validator # type: ignore
103+
@__fs.validator # type:ignore (attrs validators not supported by pyright)
106104
def check_fs(self, _attr: t.Any, new_fs: "FS"):
107105
try:
108106
new_fs.makedir("models", recreate=True)
@@ -116,7 +114,7 @@ def __init__(self, tag: Tag, bento_fs: "FS", info: "BentoInfo"):
116114
self._tag = tag
117115
self.__fs = bento_fs
118116
self.check_fs(None, bento_fs)
119-
self.info = info
117+
self._info = info
120118
self.validate()
121119

122120
@property
@@ -129,14 +127,8 @@ def _fs(self) -> "FS":
129127

130128
@property
131129
def info(self) -> "BentoInfo":
132-
self._flushed = False
133130
return self._info
134131

135-
@info.setter
136-
def info(self, new_info: "BentoInfo"):
137-
self._flushed = False
138-
self._info = new_info
139-
140132
@classmethod
141133
@inject
142134
def create(
@@ -300,7 +292,6 @@ def from_fs(cls, item_fs: "FS") -> "Bento":
300292
f"Failed to create bento because it contains an invalid '{BENTO_YAML_FILENAME}'"
301293
)
302294

303-
res._flushed = True
304295
return res
305296

306297
@property
@@ -311,14 +302,9 @@ def path_of(self, item: str) -> str:
311302
return self._fs.getsyspath(item)
312303

313304
def flush_info(self):
314-
if self._flushed:
315-
return
316-
317305
with self._fs.open(BENTO_YAML_FILENAME, "w") as bento_yaml:
318306
self.info.dump(bento_yaml)
319307

320-
self._flushed = True
321-
322308
@property
323309
def doc(self) -> str:
324310
if self._doc is not None:
@@ -337,8 +323,6 @@ def save(
337323
self,
338324
bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store],
339325
) -> "Bento":
340-
self.flush_info()
341-
342326
if not self.validate():
343327
logger.warning(f"Failed to create Bento for {self.tag}, not saving.")
344328
raise BentoMLException("Failed to save Bento because it was invalid.")
@@ -351,28 +335,6 @@ def save(
351335

352336
return self
353337

354-
def export(
355-
self,
356-
path: str,
357-
output_format: t.Optional[str] = None,
358-
*,
359-
protocol: t.Optional[str] = None,
360-
user: t.Optional[str] = None,
361-
passwd: t.Optional[str] = None,
362-
params: t.Optional[t.Dict[str, str]] = None,
363-
subpath: t.Optional[str] = None,
364-
) -> str:
365-
self.flush_info()
366-
return super().export(
367-
path,
368-
output_format,
369-
protocol=protocol,
370-
user=user,
371-
passwd=passwd,
372-
params=params,
373-
subpath=subpath,
374-
)
375-
376338
def validate(self):
377339
return self._fs.isfile(BENTO_YAML_FILENAME)
378340

@@ -445,8 +407,6 @@ class BentoInfo:
445407
runners: t.List[BentoRunnerInfo] = attr.field(factory=list)
446408
apis: t.List[BentoApiInfo] = attr.field(factory=list)
447409

448-
_flushed: bool = False
449-
450410
def __attrs_post_init__(self):
451411
# Direct set is not available when frozen=True
452412
object.__setattr__(self, "name", self.tag.name)
@@ -501,11 +461,8 @@ def validate(self):
501461

502462
bentoml_cattr.register_unstructure_hook(
503463
BentoInfo,
504-
# Ignore internal private state "_flushed"
505464
# Ignore tag, tag is saved via the name and version field
506-
make_dict_unstructure_fn(
507-
BentoInfo, bentoml_cattr, _flushed=override(omit=True), tag=override(omit=True)
508-
),
465+
make_dict_unstructure_fn(BentoInfo, bentoml_cattr, tag=override(omit=True)),
509466
)
510467

511468

bentoml/_internal/models/model.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from ..tag import Tag
1919
from ..store import Store
2020
from ..store import StoreItem
21+
from ..utils import validate_labels
22+
from ..utils import validate_metadata
2123
from ...exceptions import NotFound
2224
from ...exceptions import BentoMLException
2325
from ..configuration import BENTOML_VERSION
@@ -42,8 +44,7 @@ class Model(StoreItem):
4244
_info: "ModelInfo"
4345
_custom_objects: t.Optional[t.Dict[str, t.Any]] = None
4446

45-
_info_flushed = False
46-
_custom_objects_flushed = False
47+
_flushed: bool = False
4748

4849
@staticmethod
4950
def _export_ext() -> str:
@@ -59,24 +60,20 @@ def _fs(self) -> "FS":
5960

6061
@property
6162
def info(self) -> "ModelInfo":
62-
self._info_flushed = False
6363
return self._info
6464

65-
@info.setter
66-
def info(self, new_info: "ModelInfo"):
67-
self._info_flushed = False
68-
self._info = new_info
69-
7065
@property
7166
def custom_objects(self) -> t.Dict[str, t.Any]:
72-
self._custom_objects_flushed = False
73-
7467
if self._custom_objects is None:
7568
if self._fs.isfile(CUSTOM_OBJECTS_FILENAME):
7669
with self._fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile:
77-
self._custom_objects = cloudpickle.load(cofile)
70+
self._custom_objects: "t.Optional[t.Dict[str, t.Any]]" = (
71+
cloudpickle.load(cofile)
72+
)
73+
if not isinstance(self._custom_objects, dict):
74+
raise ValueError("Invalid custom objects found.")
7875
else:
79-
self._custom_objects = {}
76+
self._custom_objects: "t.Optional[t.Dict[str, t.Any]]" = {}
8077

8178
return self._custom_objects
8279

@@ -128,6 +125,9 @@ def create(
128125
context["bentoml_version"] = BENTOML_VERSION
129126
context["python_version"] = PYTHON_VERSION
130127

128+
validate_labels(labels)
129+
validate_metadata(metadata)
130+
131131
model_fs = fs.open_fs(f"temp://bentoml_model_{name}")
132132

133133
res = Model(
@@ -155,8 +155,7 @@ def save(
155155
return self
156156

157157
def _save(self, model_store: "ModelStore") -> "Model":
158-
self.flush_info()
159-
self.flush_custom_objects()
158+
self.flush()
160159

161160
if not self.validate():
162161
logger.warning(f"Failed to create Model for {self.tag}, not saving.")
@@ -187,8 +186,7 @@ def from_fs(cls, item_fs: FS) -> "Model":
187186
f"Failed to load bento model because it contains an invalid '{MODEL_YAML_FILENAME}'"
188187
)
189188

190-
res._info_flushed = True
191-
res._custom_objects_flushed = True
189+
res._flushed = True
192190
return res
193191

194192
@property
@@ -198,26 +196,22 @@ def path(self) -> str:
198196
def path_of(self, item: str) -> str:
199197
return self._fs.getsyspath(item)
200198

201-
def flush_info(self):
202-
if self._info_flushed:
203-
return
199+
def flush(self):
200+
if not self._flushed:
201+
self._flush_info()
202+
self._flush_custom_objects()
203+
self._flushed = True
204204

205+
def _flush_info(self):
205206
with self._fs.open(MODEL_YAML_FILENAME, "w") as model_yaml:
206207
self.info.dump(model_yaml)
207208

208-
self._info_flushed = True
209-
210-
def flush_custom_objects(self):
211-
if self._custom_objects_flushed:
212-
return
213-
209+
def _flush_custom_objects(self):
214210
# pickle custom_objects if it is not None and not empty
215211
if self.custom_objects:
216212
with self._fs.open(CUSTOM_OBJECTS_FILENAME, "wb") as cofile:
217213
cloudpickle.dump(self.custom_objects, cofile)
218214

219-
self._custom_objects_flushed = True
220-
221215
@property
222216
def creation_time(self) -> datetime:
223217
return self.info.creation_time
@@ -233,8 +227,8 @@ def export(
233227
params: t.Optional[t.Dict[str, str]] = None,
234228
subpath: t.Optional[str] = None,
235229
) -> str:
236-
self.flush_info()
237-
self.flush_custom_objects()
230+
self.flush()
231+
238232
return super().export(
239233
path,
240234
output_format,
@@ -264,14 +258,30 @@ def __init__(self, base_path: "t.Union[PathType, FS]"):
264258
class ModelInfo:
265259
tag: Tag
266260
module: str
267-
labels: t.Dict[str, t.Any]
261+
labels: t.Dict[str, str]
268262
options: t.Dict[str, t.Any]
269263
metadata: t.Dict[str, t.Any]
270264
context: t.Dict[str, t.Any]
271265
bentoml_version: str = BENTOML_VERSION
272266
api_version: str = "v1"
273267
creation_time: datetime = attr.field(factory=lambda: datetime.now(timezone.utc))
274268

269+
def __eq__(self, other: "t.Union[ModelInfo, FrozenModelInfo]"):
270+
if not isinstance(other, (ModelInfo, FrozenModelInfo)):
271+
return False
272+
273+
return (
274+
self.tag == other.tag
275+
and self.module == other.module
276+
and self.labels == other.labels
277+
and self.options == other.options
278+
and self.metadata == other.metadata
279+
and self.context == other.context
280+
and self.bentoml_version == other.bentoml_version
281+
and self.api_version == other.api_version
282+
and self.creation_time == other.creation_time
283+
)
284+
275285
def __attrs_post_init__(self):
276286
self.validate()
277287

@@ -292,8 +302,8 @@ def to_dict(self) -> t.Dict[str, t.Any]:
292302
def dump(self, stream: t.IO[t.Any]):
293303
return yaml.dump(self, stream, sort_keys=False)
294304

295-
@classmethod
296-
def from_yaml_file(cls, stream: t.IO[t.Any]):
305+
@staticmethod
306+
def from_yaml_file(stream: t.IO[t.Any]):
297307
try:
298308
yaml_content = yaml.safe_load(stream)
299309
except yaml.YAMLError as exc: # pragma: no cover - simple error handling
@@ -311,7 +321,7 @@ def from_yaml_file(cls, stream: t.IO[t.Any]):
311321
del yaml_content["version"]
312322

313323
try:
314-
model_info = cls(**yaml_content) # type: ignore
324+
model_info = FrozenModelInfo(**yaml_content) # type: ignore
315325
except TypeError: # pragma: no cover - simple error handling
316326
raise BentoMLException(f"unexpected field in {MODEL_YAML_FILENAME}")
317327
return model_info
@@ -321,6 +331,15 @@ def validate(self):
321331
# add tests when implemented
322332
...
323333

334+
def freeze(self) -> "ModelInfo":
335+
self.__class__ = FrozenModelInfo
336+
return self
337+
338+
339+
@attr.define(repr=False, frozen=True) # type: ignore (pyright doesn't allow for a frozen subclass)
340+
class FrozenModelInfo(ModelInfo):
341+
pass
342+
324343

325344
def copy_model(
326345
model_tag: t.Union[Tag, str],
@@ -346,3 +365,4 @@ def _ModelInfo_dumper(dumper: yaml.Dumper, info: ModelInfo) -> yaml.Node:
346365

347366

348367
yaml.add_representer(ModelInfo, _ModelInfo_dumper)
368+
yaml.add_representer(FrozenModelInfo, _ModelInfo_dumper)

bentoml/_internal/types.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import urllib.parse
88
import urllib.request
99
from typing import TYPE_CHECKING
10+
from datetime import date
11+
from datetime import time
12+
from datetime import datetime
13+
from datetime import timedelta
1014
from dataclasses import dataclass
1115

1216
from .utils.dataclasses import json_serializer
@@ -31,6 +35,24 @@
3135
else:
3236
PathType = t.Union[str, os.PathLike]
3337

38+
MetadataType = t.Union[
39+
str,
40+
bytes,
41+
bool,
42+
int,
43+
float,
44+
complex,
45+
datetime,
46+
date,
47+
time,
48+
timedelta,
49+
t.List["MetadataType"],
50+
t.Tuple["MetadataType"],
51+
t.Dict[str, "MetadataType"],
52+
]
53+
54+
MetadataDict = t.Dict[str, MetadataType]
55+
3456
JSONSerializable = t.NewType("JSONSerializable", object)
3557

3658

@@ -135,7 +157,9 @@ def __hash__(self) -> int:
135157
def __repr__(self) -> str:
136158
return f'LazyType("{self.module}", "{self.qualname}")'
137159

138-
def get_class(self, import_module: bool = True) -> "t.Type[T]":
160+
def get_class(
161+
self, import_module: bool = True
162+
) -> "t.Union[t.Type[T], t.Tuple[t.Type[T]]]":
139163
if self._runtime_class is None:
140164
try:
141165
m = sys.modules[self.module]
@@ -147,6 +171,10 @@ def get_class(self, import_module: bool = True) -> "t.Type[T]":
147171
else:
148172
raise ValueError(f"Module {self.module} not imported")
149173

174+
if isinstance(self.qualname, tuple):
175+
self._runtime_class = tuple(
176+
(t.cast("t.Type[T]", getattr(m, x)) for x in self.qualname)
177+
)
150178
self._runtime_class = t.cast("t.Type[T]", getattr(m, self.qualname))
151179

152180
return self._runtime_class

0 commit comments

Comments
 (0)