Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions aitlas/datasets/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ class So2SatDataset(BaseDataset):
def __init__(self, config):
super().__init__(config)

self.file_path = self.config.h5_file
self.data = h5py.File(
self.file_path
) # TODO: we should close this file eventually
self.h5_file = self.config.h5_file
self.data = self.load_dataset(self.h5_file)

def __getitem__(self, index):
label = self.data["label"][index]
Expand All @@ -72,6 +70,11 @@ def __getitem__(self, index):
label = self.target_transform(label)

return img, np.where(label == 1.0)[0][0]

def load_dataset(self, h5_file, csv_file=None):
self.data = h5py.File(h5_file, 'r')

return self.data

def __len__(self):
return self.data["label"].shape[0]
Expand Down Expand Up @@ -126,4 +129,4 @@ def data_distribution_barchart(self):
ax.set_title(
"Labels distribution for {}".format(self.get_name()), pad=20, fontsize=18
)
return fig
return fig
166 changes: 123 additions & 43 deletions aitlas/datasets/spacenet6.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@
from skimage.morphology import dilation, erosion, square
from skimage.segmentation import watershed
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns

from ..base import BaseDataset
from ..datasets.schemas import SpaceNet6DatasetSchema
from ..utils import parse_img_id
from aitlas.base import BaseDataset
from aitlas.datasets.schemas import SpaceNet6DatasetSchema
from aitlas.utils import parse_img_id
from aitlas.utils import image_loader


# Ignore the "low-contrast" warnings
Expand Down Expand Up @@ -123,13 +127,15 @@ def process_image(
class SpaceNet6Dataset(BaseDataset):
"""SpaceNet6 dataset."""

name = "SpaceNet 6"
schema = SpaceNet6DatasetSchema
labels = ["background","building boundary","building"]
color_mapping = [[0,0,0],[255,255,0],[255,0,0]]

def __init__(self, config):
super().__init__(config)
self.image_paths = list()
self.mask_paths = list()
self.orients = pd.read_csv(config.orients, index_col=0)
self.image_paths, self.mask_paths = self.load_directory()
self.orients = pd.read_csv(self.config.orients, index_col=0)
self.orients["val"] = list(range(len(self.orients.index)))

def __getitem__(self, index):
Expand All @@ -143,52 +149,67 @@ def __getitem__(self, index):
# Get image paths
image_path = self.image_paths[index]
# Read image
image = io.imread(image_path)
mask = None # placeholder, ignores the "might be referenced before assignment" warning
image = io.imread(image_path)[:,:,[0, 3, 1]]
# Calculate min/max x/y for the black parts
m = np.where((image.sum(axis=2) > 0).any(1))
y_min, y_max = np.amin(m), np.amax(m) + 1
m = np.where((image.sum(axis=2) > 0).any(0))
x_min, x_max = np.amin(m), np.amax(m) + 1
# Remove black parts
image = image[y_min:y_max, x_min:x_max]
# Apply transformations, (should be available only for training data)
if self.config.transforms:
# Get mask path
mask_path = self.mask_paths[index]
# Read mask
mask = io.imread(mask_path)
# Remove black parts
mask = mask[y_min:y_max, x_min:x_max]
image, mask = self.transform({"image": image, "mask": mask})

# Get mask path
mask_path = self.mask_paths[index]
# Read mask
mask = io.imread(mask_path)
# Remove black parts
mask = mask[y_min:y_max, x_min:x_max]
# Convert 2D mask with class IDs to one-hot encoded 3D mask
single_band_mask = np.zeros([mask.shape[0], mask.shape[1]], np.uint8)

for i in np.arange(mask.shape[0]):
for j in np.arange(mask.shape[1]):
if mask[i,j,0] == 0 and mask[i,j,1] == 0 and mask[i,j,2] == 0:
single_band_mask[i,j] = 0
if mask[i,j,0] == 255 and mask[i,j,1] == 255 and mask[i,j,2] == 0:
single_band_mask[i,j] = 1
if mask[i,j,0] == 255 and mask[i,j,1] == 0 and mask[i,j,2] == 0:
single_band_mask[i,j] = 2

masks = [(single_band_mask == v) for v, label in enumerate(self.labels)]
mask = np.stack(masks, axis=-1).astype("float32")

# Extract direction, strip and coordinates from image
direction, strip, coordinate = parse_img_id(image_path, self.orients)
if direction.item():
image = np.fliplr(np.flipud(image))
if self.config.transforms:
mask = np.fliplr(np.flipud(mask))
image = (
image - np.array([28.62501827, 36.09922463, 33.84483687, 26.21196667])
) / np.array([8.41487376, 8.26645475, 8.32328472, 8.63668993])

#normalize image for better visualization
img_min, img_max = image.min(), image.max()
if img_max - img_min > 1e-6:
image = (image - img_min) / (img_max - img_min + 1e-8)
# Transpose image
image = torch.from_numpy(image.transpose((2, 0, 1)).copy()).float()
# Reorder bands
image = image[[0, 3, 1, 2]]
if self.config.transforms:
weights = np.ones_like(mask[:, :, :1], dtype=float)
region_labels, region_count = measure.label(
mask[:, :, 0], background=0, connectivity=1, return_num=True
)
region_properties = measure.regionprops(region_labels)
for bl in range(region_count):
weights[region_labels == bl + 1] = 1024.0 / region_properties[bl].area
mask[:, :, :3] = (mask[:, :, :3] > 1) * 1
weights = torch.from_numpy(weights.transpose((2, 0, 1)).copy()).float()
mask = torch.from_numpy(mask.transpose((2, 0, 1)).copy()).float()
rgb = torch.Tensor([0])
else:
mask = rgb = weights = region_count = torch.Tensor([0])
return {
image = torch.from_numpy(image.transpose((2, 0, 1)))#.copy()).float()

weights = np.ones_like(mask[:, :, :1], dtype=float)
region_labels, region_count = measure.label(
mask[:, :, 0], background=0, connectivity=1, return_num=True
)
region_properties = measure.regionprops(region_labels)
for bl in range(region_count):
weights[region_labels == bl + 1] = 1024.0 / region_properties[bl].area
weights = torch.from_numpy(weights.transpose((2, 0, 1)).copy()).float()
mask = torch.from_numpy(mask.transpose((2, 0, 1)).copy()).float()
rgb = torch.Tensor([0])

if self.transform:
image = self.transform(image)
if self.target_transform:
mask = self.target_transform(mask)

output = {
"image": image,
"mask": mask,
"rgb": rgb,
Expand All @@ -202,13 +223,20 @@ def __getitem__(self, index):
"weights": weights,
}

return image, mask

def __len__(self):
return len(self.image_paths)

def get_labels(self):
return self.labels

def load_directory(self):
"""Loads the *.tif images from the specified directory."""
self.image_paths = glob.glob(os.path.join(self.config.test_directory, "*.tif"))
self.mask_paths = None
self.image_paths = glob.glob(os.path.join(self.config.root_directory, "*.tif"))
self.mask_paths = glob.glob(os.path.join(self.config.segmentation_directory, "*.tif"))

return self.image_paths, self.mask_paths

def load_other_folds(self, fold):
"""Loads all images (and masks) except the ones from this fold."""
Expand Down Expand Up @@ -243,9 +271,6 @@ def load_fold(self, fold):
]
self.mask_paths = None

def labels(self):
pass

def prepare(self):
"""
Prepares the SpaceNet6 data set for model training and validation by:
Expand Down Expand Up @@ -345,3 +370,58 @@ def prepare(self):
(orientations["mean_y"] - 5746153.106161971) / 11000
) + 0.2
orientations.to_csv(self.config.orients_output, index=True)

def show_image(self, index, show_title=False):
img, mask = self[index]
img = img.permute(1,2,0)
mask = mask.permute(1,2,0)
img_mask = np.zeros([mask.shape[0], mask.shape[1], 3], np.uint8)
legend_elements = []
for i, label in enumerate(self.labels):
legend_elements.append(
Patch(
facecolor=tuple([x / 255 for x in self.color_mapping[i]]),
label=self.labels[i],
)
)
img_mask[np.where(mask[:, :, i] == 1)] = self.color_mapping[i]

fig = plt.figure(figsize=(10, 8))
height_factor = math.ceil(len(self.labels)/3)
if height_factor == 4:
height_factor = 0.73
elif height_factor == 2:
height_factor = 0.80
else:
height_factor = 0.81
fig.legend(handles=legend_elements, bbox_to_anchor=(0.2, height_factor, 0.6, 0.2), ncol=3, mode='expand',
loc='lower left', prop={'size': 12})
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(img_mask)
plt.axis("off")
fig.tight_layout()
plt.show()
return fig

def data_distribution_table(self):
label_dist = {key: 0 for key in self.labels}
for i in np.arange(len(self.image_paths)):
_, mask = self.__getitem__(i)
for index, label in enumerate(self.labels):
label_dist[self.labels[index]] += mask[:, :, index].sum()
label_count = pd.DataFrame.from_dict(label_dist, orient='index')
label_count.columns = ["Number of pixels"]
return label_count

def data_distribution_barchart(self, show_title=True):
label_count = self.data_distribution_table()
fig, ax = plt.subplots(figsize=(12, 10))
sns.barplot(data=label_count, x=label_count.index, y='Number of pixels', ax=ax)
if show_title:
ax.set_title(
"Labels distribution for {}".format(self.get_name()), pad=20, fontsize=18
)
return fig

Large diffs are not rendered by default.

1,009 changes: 1,009 additions & 0 deletions examples/multiclass_classification_example_optimal_31.ipynb

Large diffs are not rendered by default.

1,212 changes: 1,212 additions & 0 deletions examples/multiclass_classification_example_so2sat.ipynb

Large diffs are not rendered by default.

519 changes: 519 additions & 0 deletions examples/semantic_segmentation_example_spacenet6.ipynb

Large diffs are not rendered by default.