diff --git a/taming/data/custom.py b/taming/data/custom.py index 33f302a4..543966ab 100644 --- a/taming/data/custom.py +++ b/taming/data/custom.py @@ -23,16 +23,46 @@ def __getitem__(self, i): class CustomTrain(CustomBase): def __init__(self, size, training_images_list_file): super().__init__() - with open(training_images_list_file, "r") as f: - paths = f.read().splitlines() + + isFile = os.path.isfile(training_images_list_file) + isDirectory = os.path.isdir(training_images_list_file) + + if isFile: + with open(training_images_list_file, "r") as f: + paths = f.read().splitlines() + + if isDirectory: + paths = [] + for images in os.listdir(training_images_list_file): + + # check if the image ends with png or jpg or jpeg + if (images.endswith(".png") or images.endswith(".jpg")\ + or images.endswith(".jpeg")): + paths.append(os.path.join(training_images_list_file, images)) + self.data = ImagePaths(paths=paths, size=size, random_crop=False) class CustomTest(CustomBase): def __init__(self, size, test_images_list_file): super().__init__() - with open(test_images_list_file, "r") as f: - paths = f.read().splitlines() + + isFile = os.path.isfile(test_images_list_file) + isDirectory = os.path.isdir(test_images_list_file) + + if isFile: + with open(test_images_list_file, "r") as f: + paths = f.read().splitlines() + + if isDirectory: + paths = [] + for images in os.listdir(test_images_list_file): + + # check if the image ends with png or jpg or jpeg + if (images.endswith(".png") or images.endswith(".jpg")\ + or images.endswith(".jpeg")): + paths.append(os.path.join(test_images_list_file, images)) + self.data = ImagePaths(paths=paths, size=size, random_crop=False)