18
18
from ..tag import Tag
19
19
from ..store import Store
20
20
from ..store import StoreItem
21
+ from ..utils import validate_labels
22
+ from ..utils import validate_metadata
21
23
from ...exceptions import NotFound
22
24
from ...exceptions import BentoMLException
23
25
from ..configuration import BENTOML_VERSION
@@ -42,8 +44,7 @@ class Model(StoreItem):
42
44
_info : "ModelInfo"
43
45
_custom_objects : t .Optional [t .Dict [str , t .Any ]] = None
44
46
45
- _info_flushed = False
46
- _custom_objects_flushed = False
47
+ _flushed : bool = False
47
48
48
49
@staticmethod
49
50
def _export_ext () -> str :
@@ -59,24 +60,20 @@ def _fs(self) -> "FS":
59
60
60
61
@property
61
62
def info (self ) -> "ModelInfo" :
62
- self ._info_flushed = False
63
63
return self ._info
64
64
65
- @info .setter
66
- def info (self , new_info : "ModelInfo" ):
67
- self ._info_flushed = False
68
- self ._info = new_info
69
-
70
65
@property
71
66
def custom_objects (self ) -> t .Dict [str , t .Any ]:
72
- self ._custom_objects_flushed = False
73
-
74
67
if self ._custom_objects is None :
75
68
if self ._fs .isfile (CUSTOM_OBJECTS_FILENAME ):
76
69
with self ._fs .open (CUSTOM_OBJECTS_FILENAME , "rb" ) as cofile :
77
- self ._custom_objects = cloudpickle .load (cofile )
70
+ self ._custom_objects : "t.Optional[t.Dict[str, t.Any]]" = (
71
+ cloudpickle .load (cofile )
72
+ )
73
+ if not isinstance (self ._custom_objects , dict ):
74
+ raise ValueError ("Invalid custom objects found." )
78
75
else :
79
- self ._custom_objects = {}
76
+ self ._custom_objects : "t.Optional[t.Dict[str, t.Any]]" = {}
80
77
81
78
return self ._custom_objects
82
79
@@ -128,6 +125,9 @@ def create(
128
125
context ["bentoml_version" ] = BENTOML_VERSION
129
126
context ["python_version" ] = PYTHON_VERSION
130
127
128
+ validate_labels (labels )
129
+ validate_metadata (metadata )
130
+
131
131
model_fs = fs .open_fs (f"temp://bentoml_model_{ name } " )
132
132
133
133
res = Model (
@@ -155,8 +155,7 @@ def save(
155
155
return self
156
156
157
157
def _save (self , model_store : "ModelStore" ) -> "Model" :
158
- self .flush_info ()
159
- self .flush_custom_objects ()
158
+ self .flush ()
160
159
161
160
if not self .validate ():
162
161
logger .warning (f"Failed to create Model for { self .tag } , not saving." )
@@ -187,8 +186,7 @@ def from_fs(cls, item_fs: FS) -> "Model":
187
186
f"Failed to load bento model because it contains an invalid '{ MODEL_YAML_FILENAME } '"
188
187
)
189
188
190
- res ._info_flushed = True
191
- res ._custom_objects_flushed = True
189
+ res ._flushed = True
192
190
return res
193
191
194
192
@property
@@ -198,26 +196,22 @@ def path(self) -> str:
198
196
def path_of (self , item : str ) -> str :
199
197
return self ._fs .getsyspath (item )
200
198
201
- def flush_info (self ):
202
- if self ._info_flushed :
203
- return
199
+ def flush (self ):
200
+ if not self ._flushed :
201
+ self ._flush_info ()
202
+ self ._flush_custom_objects ()
203
+ self ._flushed = True
204
204
205
+ def _flush_info (self ):
205
206
with self ._fs .open (MODEL_YAML_FILENAME , "w" ) as model_yaml :
206
207
self .info .dump (model_yaml )
207
208
208
- self ._info_flushed = True
209
-
210
- def flush_custom_objects (self ):
211
- if self ._custom_objects_flushed :
212
- return
213
-
209
+ def _flush_custom_objects (self ):
214
210
# pickle custom_objects if it is not None and not empty
215
211
if self .custom_objects :
216
212
with self ._fs .open (CUSTOM_OBJECTS_FILENAME , "wb" ) as cofile :
217
213
cloudpickle .dump (self .custom_objects , cofile )
218
214
219
- self ._custom_objects_flushed = True
220
-
221
215
@property
222
216
def creation_time (self ) -> datetime :
223
217
return self .info .creation_time
@@ -233,8 +227,8 @@ def export(
233
227
params : t .Optional [t .Dict [str , str ]] = None ,
234
228
subpath : t .Optional [str ] = None ,
235
229
) -> str :
236
- self .flush_info ()
237
- self . flush_custom_objects ()
230
+ self .flush ()
231
+
238
232
return super ().export (
239
233
path ,
240
234
output_format ,
@@ -264,14 +258,30 @@ def __init__(self, base_path: "t.Union[PathType, FS]"):
264
258
class ModelInfo :
265
259
tag : Tag
266
260
module : str
267
- labels : t .Dict [str , t . Any ]
261
+ labels : t .Dict [str , str ]
268
262
options : t .Dict [str , t .Any ]
269
263
metadata : t .Dict [str , t .Any ]
270
264
context : t .Dict [str , t .Any ]
271
265
bentoml_version : str = BENTOML_VERSION
272
266
api_version : str = "v1"
273
267
creation_time : datetime = attr .field (factory = lambda : datetime .now (timezone .utc ))
274
268
269
+ def __eq__ (self , other : "t.Union[ModelInfo, FrozenModelInfo]" ):
270
+ if not isinstance (other , (ModelInfo , FrozenModelInfo )):
271
+ return False
272
+
273
+ return (
274
+ self .tag == other .tag
275
+ and self .module == other .module
276
+ and self .labels == other .labels
277
+ and self .options == other .options
278
+ and self .metadata == other .metadata
279
+ and self .context == other .context
280
+ and self .bentoml_version == other .bentoml_version
281
+ and self .api_version == other .api_version
282
+ and self .creation_time == other .creation_time
283
+ )
284
+
275
285
def __attrs_post_init__ (self ):
276
286
self .validate ()
277
287
@@ -292,8 +302,8 @@ def to_dict(self) -> t.Dict[str, t.Any]:
292
302
def dump (self , stream : t .IO [t .Any ]):
293
303
return yaml .dump (self , stream , sort_keys = False )
294
304
295
- @classmethod
296
- def from_yaml_file (cls , stream : t .IO [t .Any ]):
305
+ @staticmethod
306
+ def from_yaml_file (stream : t .IO [t .Any ]):
297
307
try :
298
308
yaml_content = yaml .safe_load (stream )
299
309
except yaml .YAMLError as exc : # pragma: no cover - simple error handling
@@ -311,7 +321,7 @@ def from_yaml_file(cls, stream: t.IO[t.Any]):
311
321
del yaml_content ["version" ]
312
322
313
323
try :
314
- model_info = cls (** yaml_content ) # type: ignore
324
+ model_info = FrozenModelInfo (** yaml_content ) # type: ignore
315
325
except TypeError : # pragma: no cover - simple error handling
316
326
raise BentoMLException (f"unexpected field in { MODEL_YAML_FILENAME } " )
317
327
return model_info
@@ -321,6 +331,15 @@ def validate(self):
321
331
# add tests when implemented
322
332
...
323
333
334
+ def freeze (self ) -> "ModelInfo" :
335
+ self .__class__ = FrozenModelInfo
336
+ return self
337
+
338
+
339
+ @attr .define (repr = False , frozen = True ) # type: ignore (pyright doesn't allow for a frozen subclass)
340
+ class FrozenModelInfo (ModelInfo ):
341
+ pass
342
+
324
343
325
344
def copy_model (
326
345
model_tag : t .Union [Tag , str ],
@@ -346,3 +365,4 @@ def _ModelInfo_dumper(dumper: yaml.Dumper, info: ModelInfo) -> yaml.Node:
346
365
347
366
348
367
yaml .add_representer (ModelInfo , _ModelInfo_dumper )
368
+ yaml .add_representer (FrozenModelInfo , _ModelInfo_dumper )
0 commit comments