99
1010import keras
1111from keras import layers
12+ from keras .models import load_model
1213from keras .saving import register_keras_serializable
1314
1415from ml4h .defines import IMAGE_EXT
@@ -762,16 +763,22 @@ def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_c
762763@register_keras_serializable ()
763764class DiffusionController (keras .Model ):
764765 def __init__ (
765- self , tensor_map , output_maps , batch_size , widths , block_depth , conv_x , control_size ,
766- attention_start , attention_heads , attention_modulo , diffusion_loss , sigmoid_beta , condition_strategy ,
767- inspect_model , supervisor = None , supervision_scalar = 0.01 ,
766+ self , tensor_map , output_maps , batch_size , widths , block_depth , conv_x , control_size ,
767+ attention_start , attention_heads , attention_modulo , diffusion_loss , sigmoid_beta , condition_strategy ,
768+ inspect_model , supervisor = None , supervision_scalar = 0.01 , encoder_file = None ,
768769 ):
769770 super ().__init__ ()
770771
771772 self .input_map = tensor_map
772773 self .batch_size = batch_size
773774 self .output_maps = output_maps
774- self .control_embed_model = get_control_embed_model (self .output_maps , control_size )
775+ if encoder_file :
776+ self .autoencoder_control = True
777+ self .control_embed_model = load_model (encoder_file , compile = False )
778+ logging .info (f'loaded encoder for DiffAE at: { encoder_file } ' )
779+ else :
780+ self .autoencoder_control = False
781+ self .control_embed_model = get_control_embed_model (self .output_maps , control_size )
775782 self .normalizer = layers .Normalization ()
776783 self .network = get_control_network (self .input_map .shape , widths , block_depth , conv_x , control_size ,
777784 attention_start , attention_heads , attention_modulo , condition_strategy )
@@ -780,7 +787,7 @@ def __init__(
780787 self .beta = sigmoid_beta
781788 self .supervisor = supervisor
782789 self .supervision_scalar = supervision_scalar
783- self .inspect_model = False # inspect_model
790+ self .inspect_model = False # inspect_model
784791
785792 def get_config (self ):
786793 config = super ().get_config ().copy ()
@@ -796,7 +803,7 @@ def compile(self, **kwargs):
796803 if self .supervisor is not None :
797804 self .supervised_loss_tracker = keras .metrics .Mean (name = "supervised_loss" )
798805 if self .input_map .axes () == 3 and self .inspect_model :
799- self .kid = KernelInceptionDistance (name = "kid" , input_shape = self .input_map .shape , kernel_image_size = 299 )
806+ self .kid = KernelInceptionDistance (name = "kid" , input_shape = self .input_map .shape , kernel_image_size = 299 )
800807 self .ms_ssim = MultiScaleSSIM ()
801808
802809 @property
@@ -895,10 +902,13 @@ def generate_from_noise(self, control_embed, num_images, diffusion_steps, initia
895902 def train_step (self , batch ):
896903 # normalize images to have standard deviation of 1, like the noises
897904 images = batch [0 ][self .input_map .input_name ()]
898- #self.normalizer.adapt(images)
905+ # self.normalizer.adapt(images)
899906 images = self .normalizer (images , training = True )
900907
901- control_embed = self .control_embed_model (batch [1 ])
908+ if self .autoencoder_control :
909+ control_embed = self .control_embed_model (batch [0 ])
910+ else :
911+ control_embed = self .control_embed_model (batch [1 ])
902912
903913 noises = tf .random .normal (shape = (self .batch_size ,) + self .input_map .shape )
904914
@@ -960,10 +970,13 @@ def train_step(self, batch):
960970 def test_step (self , batch ):
961971 # normalize images to have standard deviation of 1, like the noises
962972 images = batch [0 ][self .input_map .input_name ()]
963- #self.normalizer.adapt(images)
973+ # self.normalizer.adapt(images)
964974 images = self .normalizer (images , training = False )
965975
966- control_embed = self .control_embed_model (batch [1 ])
976+ if self .autoencoder_control :
977+ control_embed = self .control_embed_model (batch [0 ])
978+ else :
979+ control_embed = self .control_embed_model (batch [1 ])
967980
968981 noises = tf .random .normal (shape = (self .batch_size ,) + self .input_map .shape )
969982
@@ -999,7 +1012,7 @@ def test_step(self, batch):
9991012 supervised_loss = loss_fn (batch [1 ][self .output_maps [0 ].output_name ()], supervised_preds )
10001013 self .supervised_loss_tracker .update_state (supervised_loss )
10011014 # Combine losses: add noise_loss and supervised_loss
1002- noise_loss += self .supervision_scalar * supervised_loss
1015+ noise_loss += self .supervision_scalar * supervised_loss
10031016
10041017 self .image_loss_tracker .update_state (image_loss )
10051018 self .noise_loss_tracker .update_state (noise_loss )
@@ -1011,8 +1024,8 @@ def test_step(self, batch):
10111024 if self .input_map .axes () == 3 and self .inspect_model :
10121025 images = self .denormalize (images )
10131026 generated_images = self .generate (control_embed ,
1014- num_images = self .batch_size , diffusion_steps = 20
1015- )
1027+ num_images = self .batch_size , diffusion_steps = 20
1028+ )
10161029 self .kid .update_state (images , generated_images )
10171030 self .ms_ssim .update_state (images , generated_images , 255 )
10181031
@@ -1025,7 +1038,10 @@ def call(self, batch, training=False):
10251038 2. You can use model((noisy_images, noise_rates)) for inference
10261039 """
10271040 noisy_images , noise_rates = batch [0 ]
1028- control_embed = self .control_embed_model (batch [1 ])
1041+ if self .autoencoder_control :
1042+ control_embed = self .control_embed_model (noisy_images )
1043+ else :
1044+ control_embed = self .control_embed_model (batch [1 ])
10291045 # re-compute signal_rates
10301046 signal_rates = tf .sqrt (1.0 - tf .square (noise_rates ))
10311047 # this returns (pred_noises, pred_images)
@@ -1063,8 +1079,8 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None
10631079 plt .close ()
10641080
10651081 def plot_reconstructions (
1066- self , batch , diffusion_amount = 0 ,
1067- epoch = None , logs = None , num_rows = 4 , num_cols = 4 , prefix = './figures/' ,
1082+ self , batch , diffusion_amount = 0 ,
1083+ epoch = None , logs = None , num_rows = 4 , num_cols = 4 , prefix = './figures/' ,
10681084 ):
10691085 images = batch [0 ][self .input_map .input_name ()]
10701086 self .normalizer .adapt (images )
@@ -1075,7 +1091,10 @@ def plot_reconstructions(
10751091 # mix the images with noises accordingly
10761092 noisy_images = signal_rates * images + noise_rates * noises
10771093
1078- control_embed = self .control_embed_model (batch [1 ])
1094+ if self .autoencoder_control :
1095+ control_embed = self .control_embed_model (batch [0 ])
1096+ else :
1097+ control_embed = self .control_embed_model (batch [1 ])
10791098
10801099 # use the network to separate noisy images to their components
10811100 pred_noises , generated_images = self .denoise (
@@ -1111,10 +1130,9 @@ def plot_reconstructions(
11111130 plt .close ()
11121131 return generated_images
11131132
1114-
11151133 def control_plot_images (
1116- self , control_batch , epoch = None , logs = None , num_rows = 2 , num_cols = 8 , reseed = None ,
1117- renoise = None ,
1134+ self , control_batch , epoch = None , logs = None , num_rows = 2 , num_cols = 8 , reseed = None ,
1135+ renoise = None ,
11181136 ):
11191137 control_embed = self .control_embed_model (control_batch )
11201138 # plot random generated images for visual evaluation of generation quality
@@ -1139,6 +1157,31 @@ def control_plot_images(
11391157
11401158 return generated_images
11411159
1160+ def control_plot_images_embed (
1161+ self , control_embed , epoch = None , logs = None , num_rows = 2 , num_cols = 8 , reseed = None ,
1162+ renoise = None ,
1163+ ):
1164+ generated_images = self .generate (
1165+ control_embed ,
1166+ num_images = max (self .batch_size , num_rows * num_cols ),
1167+ diffusion_steps = plot_diffusion_steps ,
1168+ reseed = reseed ,
1169+ renoise = renoise ,
1170+ )
1171+
1172+ plt .figure (figsize = (num_cols * 2.0 , num_rows * 2.0 ), dpi = 300 )
1173+ for row in range (num_rows ):
1174+ for col in range (num_cols ):
1175+ index = row * num_cols + col
1176+ plt .subplot (num_rows , num_cols , index + 1 )
1177+ plt .imshow (generated_images [index ], cmap = 'gray' )
1178+ plt .axis ("off" )
1179+ plt .tight_layout ()
1180+ plt .show ()
1181+ plt .close ()
1182+
1183+ return generated_images
1184+
11421185 def control_plot_images_noise (self , control_batch , initial_noise , epoch = None , logs = None , num_rows = 2 , num_cols = 8 ):
11431186 control_embed = self .control_embed_model (control_batch )
11441187 # plot random generated images for visual evaluation of generation quality
@@ -1163,8 +1206,8 @@ def control_plot_images_noise(self, control_batch, initial_noise, epoch=None, lo
11631206 return generated_images
11641207
11651208 def control_plot_ecgs (
1166- self , control_batch , epoch = None , logs = None , num_rows = 2 , num_cols = 8 , reseed = None ,
1167- renoise = None ,
1209+ self , control_batch , epoch = None , logs = None , num_rows = 2 , num_cols = 8 , reseed = None ,
1210+ renoise = None ,
11681211 ):
11691212 control_embed = self .control_embed_model (control_batch )
11701213 # plot random generated images for visual evaluation of generation quality
0 commit comments