33from typing import Any , Dict , Iterator , List , Optional , Tuple
44
55import numpy as np
6- import orjson as json
76from PIL import Image as PILImage
87
98from darwin .dataset .utils import get_classes , get_release_path , load_pil_image
@@ -64,20 +63,6 @@ def __init__(
6463 split_type : str = "random" ,
6564 release_name : Optional [str ] = None ,
6665 ):
67- assert dataset_path is not None
68- release_path = get_release_path (dataset_path , release_name )
69- annotations_dir = release_path / "annotations"
70- assert annotations_dir .exists ()
71- images_dir = dataset_path / "images"
72- assert images_dir .exists ()
73-
74- if partition not in ["train" , "val" , "test" , None ]:
75- raise ValueError ("partition should be either 'train', 'val', or 'test'" )
76- if split_type not in ["random" , "stratified" ]:
77- raise ValueError ("split_type should be either 'random', 'stratified'" )
78- if annotation_type not in ["tag" , "polygon" , "bounding_box" ]:
79- raise ValueError ("annotation_type should be either 'tag', 'bounding_box', or 'polygon'" )
80-
8166 self .dataset_path = dataset_path
8267 self .annotation_type = annotation_type
8368 self .images_path : List [Path ] = []
@@ -86,15 +71,64 @@ def __init__(
8671 self .original_images_path : Optional [List [Path ]] = None
8772 self .original_annotations_path : Optional [List [Path ]] = None
8873
74+ release_path , annotations_dir , images_dir = self ._initial_setup (
75+ dataset_path , release_name
76+ )
77+ self ._validate_inputs (partition , split_type , annotation_type )
8978 # Get the list of classes
79+
80+ annotation_types = [self .annotation_type ]
81+ # We fetch bounding_boxes annotations from selected polygons as well
82+ if self .annotation_type == "bounding_box" :
83+ annotation_types .append ("polygon" )
9084 self .classes = get_classes (
91- self .dataset_path , release_name , annotation_type = self .annotation_type , remove_background = True
85+ self .dataset_path ,
86+ release_name ,
87+ annotation_type = annotation_types ,
88+ remove_background = True ,
9289 )
9390 self .num_classes = len (self .classes )
91+ self ._setup_annotations_and_images (
92+ release_path ,
93+ annotations_dir ,
94+ images_dir ,
95+ annotation_type ,
96+ split ,
97+ partition ,
98+ split_type ,
99+ )
100+
101+ if len (self .images_path ) == 0 :
102+ raise ValueError (
103+ f"Could not find any { SUPPORTED_IMAGE_EXTENSIONS } file" ,
104+ f" in { images_dir } " ,
105+ )
106+
107+ assert len (self .images_path ) == len (self .annotations_path )
94108
95- stems = build_stems (release_path , annotations_dir , annotation_type , split , partition , split_type )
109+ def _validate_inputs (self , partition , split_type , annotation_type ):
110+ if partition not in ["train" , "val" , "test" , None ]:
111+ raise ValueError ("partition should be either 'train', 'val', or 'test'" )
112+ if split_type not in ["random" , "stratified" ]:
113+ raise ValueError ("split_type should be either 'random', 'stratified'" )
114+ if annotation_type not in ["tag" , "polygon" , "bounding_box" ]:
115+ raise ValueError (
116+ "annotation_type should be either 'tag', 'bounding_box', or 'polygon'"
117+ )
96118
97- # Find all the annotations and their corresponding images
119+ def _setup_annotations_and_images (
120+ self ,
121+ release_path ,
122+ annotations_dir ,
123+ images_dir ,
124+ annotation_type ,
125+ split ,
126+ partition ,
127+ split_type ,
128+ ):
129+ stems = build_stems (
130+ release_path , annotations_dir , annotation_type , split , partition , split_type
131+ )
98132 for stem in stems :
99133 annotation_path = annotations_dir / f"{ stem } .json"
100134 images = []
@@ -107,16 +141,24 @@ def __init__(
107141 if image_path .exists ():
108142 images .append (image_path )
109143 if len (images ) < 1 :
110- raise ValueError (f"Annotation ({ annotation_path } ) does not have a corresponding image" )
144+ raise ValueError (
145+ f"Annotation ({ annotation_path } ) does not have a corresponding image"
146+ )
111147 if len (images ) > 1 :
112- raise ValueError (f"Image ({ stem } ) is present with multiple extensions. This is forbidden." )
148+ raise ValueError (
149+ f"Image ({ stem } ) is present with multiple extensions. This is forbidden."
150+ )
113151 self .images_path .append (images [0 ])
114152 self .annotations_path .append (annotation_path )
115153
116- if len (self .images_path ) == 0 :
117- raise ValueError (f"Could not find any { SUPPORTED_IMAGE_EXTENSIONS } file" , f" in { images_dir } " )
118-
119- assert len (self .images_path ) == len (self .annotations_path )
154+ def _initial_setup (self , dataset_path , release_name ):
155+ assert dataset_path is not None
156+ release_path = get_release_path (dataset_path , release_name )
157+ annotations_dir = release_path / "annotations"
158+ assert annotations_dir .exists ()
159+ images_dir = dataset_path / "images"
160+ assert images_dir .exists ()
161+ return release_path , annotations_dir , images_dir
120162
121163 def get_img_info (self , index : int ) -> Dict [str , Any ]:
122164 """
@@ -166,7 +208,9 @@ def get_height_and_width(self, index: int) -> Tuple[float, float]:
166208 parsed = parse_darwin_json (self .annotations_path [index ], index )
167209 return parsed .image_height , parsed .image_width
168210
169- def extend (self , dataset : "LocalDataset" , extend_classes : bool = False ) -> "LocalDataset" :
211+ def extend (
212+ self , dataset : "LocalDataset" , extend_classes : bool = False
213+ ) -> "LocalDataset" :
170214 """
171215 Extends the current dataset with another one.
172216
@@ -261,7 +305,10 @@ def parse_json(self, index: int) -> Dict[str, Any]:
261305 # Filter out unused classes and annotations of a different type
262306 if self .classes is not None :
263307 annotations = [
264- a for a in annotations if a .annotation_class .name in self .classes and self .annotation_type_supported (a )
308+ a
309+ for a in annotations
310+ if a .annotation_class .name in self .classes
311+ and self .annotation_type_supported (a )
265312 ]
266313 return {
267314 "image_id" : index ,
@@ -278,15 +325,20 @@ def annotation_type_supported(self, annotation) -> bool:
278325 elif self .annotation_type == "bounding_box" :
279326 is_bounding_box = annotation_type == "bounding_box"
280327 is_supported_polygon = (
281- annotation_type in ["polygon" , "complex_polygon" ] and "bounding_box" in annotation .data
328+ annotation_type in ["polygon" , "complex_polygon" ]
329+ and "bounding_box" in annotation .data
282330 )
283331 return is_bounding_box or is_supported_polygon
284332 elif self .annotation_type == "polygon" :
285333 return annotation_type in ["polygon" , "complex_polygon" ]
286334 else :
287- raise ValueError ("annotation_type should be either 'tag', 'bounding_box', or 'polygon'" )
335+ raise ValueError (
336+ "annotation_type should be either 'tag', 'bounding_box', or 'polygon'"
337+ )
288338
289- def measure_mean_std (self , multi_threaded : bool = True ) -> Tuple [np .ndarray , np .ndarray ]:
339+ def measure_mean_std (
340+ self , multi_threaded : bool = True
341+ ) -> Tuple [np .ndarray , np .ndarray ]:
290342 """
291343 Computes mean and std of trained images, given the train loader.
292344
@@ -309,7 +361,9 @@ def measure_mean_std(self, multi_threaded: bool = True) -> Tuple[np.ndarray, np.
309361 results = pool .map (self ._return_mean , self .images_path )
310362 mean = np .sum (np .array (results ), axis = 0 ) / len (self .images_path )
311363 # Online image_classification deviation
312- results = pool .starmap (self ._return_std , [[item , mean ] for item in self .images_path ])
364+ results = pool .starmap (
365+ self ._return_std , [[item , mean ] for item in self .images_path ]
366+ )
313367 std_sum = np .sum (np .array ([item [0 ] for item in results ]), axis = 0 )
314368 total_pixel_count = np .sum (np .array ([item [1 ] for item in results ]))
315369 std = np .sqrt (std_sum / total_pixel_count )
@@ -355,14 +409,20 @@ def _compute_weights(labels: List[int]) -> np.ndarray:
355409 @staticmethod
356410 def _return_mean (image_path : Path ) -> np .ndarray :
357411 img = np .array (load_pil_image (image_path ))
358- mean = np .array ([np .mean (img [:, :, 0 ]), np .mean (img [:, :, 1 ]), np .mean (img [:, :, 2 ])])
412+ mean = np .array (
413+ [np .mean (img [:, :, 0 ]), np .mean (img [:, :, 1 ]), np .mean (img [:, :, 2 ])]
414+ )
359415 return mean / 255.0
360416
361417 # Loads an image with OpenCV and returns the channel wise std of the image.
362418 @staticmethod
363419 def _return_std (image_path : Path , mean : np .ndarray ) -> Tuple [np .ndarray , float ]:
364420 img = np .array (load_pil_image (image_path )) / 255.0
365- m2 = np .square (np .array ([img [:, :, 0 ] - mean [0 ], img [:, :, 1 ] - mean [1 ], img [:, :, 2 ] - mean [2 ]]))
421+ m2 = np .square (
422+ np .array (
423+ [img [:, :, 0 ] - mean [0 ], img [:, :, 1 ] - mean [1 ], img [:, :, 2 ] - mean [2 ]]
424+ )
425+ )
366426 return np .sum (np .sum (m2 , axis = 1 ), 1 ), m2 .size / 3.0
367427
368428 def __getitem__ (self , index : int ):
@@ -432,7 +492,10 @@ def build_stems(
432492 """
433493
434494 if partition is None :
435- return (str (e .relative_to (annotations_dir ).parent / e .stem ) for e in sorted (annotations_dir .glob ("**/*.json" )))
495+ return (
496+ str (e .relative_to (annotations_dir ).parent / e .stem )
497+ for e in sorted (annotations_dir .glob ("**/*.json" ))
498+ )
436499
437500 if split_type == "random" :
438501 split_filename = f"{ split_type } _{ partition } .txt"
0 commit comments