Skip to content

Commit 002f4a7

Browse files
Diffusion clean up (#605)
* Diffusion model conditioned on autoencoders --------- Co-authored-by: Samuel Friedman, Rakesh Rathod <[email protected]>
1 parent 9079a17 commit 002f4a7

File tree

13 files changed

+308
-92
lines changed

13 files changed

+308
-92
lines changed

ml4h/TensorMap.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def __init__(
181181
# Infer loss from interpretation
182182
if self.loss is None and self.is_categorical() and self.shape[0] == 1:
183183
self.loss = 'sparse_categorical_crossentropy'
184+
elif self.loss is None and self.is_categorical() and self.shape[0] == 2:
185+
self.loss = 'binary_crossentropy'
184186
elif self.loss is None and self.is_categorical():
185187
self.loss = 'categorical_crossentropy'
186188
elif self.loss is None and self.is_continuous() and self.sentinel is not None:

ml4h/arguments.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,15 @@ def parse_args():
252252
'--supervision_scalar', default=0.01, type=float,
253253
help='For `train_diffusion_supervise` mode, this weights the supervision loss from phenotype prediction on denoised data.',
254254
)
255+
parser.add_argument('--encoder_file', help='Diffusion model encoder path for DiffAE training.')
256+
parser.add_argument('--interpolate_min', type=float, default =-2.0,
257+
help='Diffusion model synthetic interpolation minimum continuous condition')
258+
parser.add_argument('--interpolate_max', type=float, default =2.01,
259+
help='Diffusion model synthetic interpolation maximum continuous condition')
260+
parser.add_argument('--interpolate_step', type=float, default =1.0,
261+
help='Diffusion model synthetic interpolation step size continuous condition')
262+
263+
255264
parser.add_argument(
256265
'--transformer_size', default=32, type=int,
257266
help='Number of output neurons in Transformer encoders and decoders, '
@@ -437,7 +446,8 @@ def parse_args():
437446
#Parent Sort enable or disable
438447
parser.add_argument('--parent_sort', default=True, type=lambda x: x.lower() == 'true', help='disable or enable parent_sort on output tmaps')
439448
#Dictionary outputs
440-
parser.add_argument('--named_outputs', default=False, type=lambda x: x.lower() == 'true', help='pass output tmaps as dictionaries if true else pass as list')
449+
parser.add_argument('--named_outputs', default=True, type=lambda x: x.lower() == 'true', help='pass output tmaps as dictionaries if true else pass as list')
450+
441451
args = parser.parse_args()
442452
_process_args(args)
443453
return args

ml4h/ml4ht_integration/tensor_map.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ def __call__(self, path: str) -> Batch:
8888
dependents = {dep.name: dep for dep in tm.dependent_map}
8989
else:
9090
dependents = {tm.dependent_map.name: tm.dependent_map}
91-
out_batch[tm.output_name()] = tm.postprocess_tensor(
91+
if tm in dependents:
92+
out_batch[tm.output_name()] = dependents[tm]
93+
else:
94+
out_batch[tm.output_name()] = tm.postprocess_tensor(
9295
tm.tensor_from_file(tm, hd5, dependents),
9396
augment=self.augment, hd5=hd5,
9497
)

ml4h/models/diffusion_blocks.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import keras
1111
from keras import layers
12+
from keras.models import load_model
1213
from keras.saving import register_keras_serializable
1314

1415
from ml4h.defines import IMAGE_EXT
@@ -762,16 +763,22 @@ def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_c
762763
@register_keras_serializable()
763764
class DiffusionController(keras.Model):
764765
def __init__(
765-
self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size,
766-
attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy,
767-
inspect_model, supervisor = None, supervision_scalar = 0.01,
766+
self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size,
767+
attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy,
768+
inspect_model, supervisor=None, supervision_scalar=0.01, encoder_file=None,
768769
):
769770
super().__init__()
770771

771772
self.input_map = tensor_map
772773
self.batch_size = batch_size
773774
self.output_maps = output_maps
774-
self.control_embed_model = get_control_embed_model(self.output_maps, control_size)
775+
if encoder_file:
776+
self.autoencoder_control = True
777+
self.control_embed_model = load_model(encoder_file, compile=False)
778+
logging.info(f'loaded encoder for DiffAE at: {encoder_file}')
779+
else:
780+
self.autoencoder_control = False
781+
self.control_embed_model = get_control_embed_model(self.output_maps, control_size)
775782
self.normalizer = layers.Normalization()
776783
self.network = get_control_network(self.input_map.shape, widths, block_depth, conv_x, control_size,
777784
attention_start, attention_heads, attention_modulo, condition_strategy)
@@ -780,7 +787,7 @@ def __init__(
780787
self.beta = sigmoid_beta
781788
self.supervisor = supervisor
782789
self.supervision_scalar = supervision_scalar
783-
self.inspect_model = False# inspect_model
790+
self.inspect_model = False # inspect_model
784791

785792
def get_config(self):
786793
config = super().get_config().copy()
@@ -796,7 +803,7 @@ def compile(self, **kwargs):
796803
if self.supervisor is not None:
797804
self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss")
798805
if self.input_map.axes() == 3 and self.inspect_model:
799-
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.input_map.shape, kernel_image_size=299)
806+
self.kid = KernelInceptionDistance(name="kid", input_shape=self.input_map.shape, kernel_image_size=299)
800807
self.ms_ssim = MultiScaleSSIM()
801808

802809
@property
@@ -895,10 +902,13 @@ def generate_from_noise(self, control_embed, num_images, diffusion_steps, initia
895902
def train_step(self, batch):
896903
# normalize images to have standard deviation of 1, like the noises
897904
images = batch[0][self.input_map.input_name()]
898-
#self.normalizer.adapt(images)
905+
# self.normalizer.adapt(images)
899906
images = self.normalizer(images, training=True)
900907

901-
control_embed = self.control_embed_model(batch[1])
908+
if self.autoencoder_control:
909+
control_embed = self.control_embed_model(batch[0])
910+
else:
911+
control_embed = self.control_embed_model(batch[1])
902912

903913
noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape)
904914

@@ -960,10 +970,13 @@ def train_step(self, batch):
960970
def test_step(self, batch):
961971
# normalize images to have standard deviation of 1, like the noises
962972
images = batch[0][self.input_map.input_name()]
963-
#self.normalizer.adapt(images)
973+
# self.normalizer.adapt(images)
964974
images = self.normalizer(images, training=False)
965975

966-
control_embed = self.control_embed_model(batch[1])
976+
if self.autoencoder_control:
977+
control_embed = self.control_embed_model(batch[0])
978+
else:
979+
control_embed = self.control_embed_model(batch[1])
967980

968981
noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape)
969982

@@ -999,7 +1012,7 @@ def test_step(self, batch):
9991012
supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds)
10001013
self.supervised_loss_tracker.update_state(supervised_loss)
10011014
# Combine losses: add noise_loss and supervised_loss
1002-
noise_loss += self.supervision_scalar*supervised_loss
1015+
noise_loss += self.supervision_scalar * supervised_loss
10031016

10041017
self.image_loss_tracker.update_state(image_loss)
10051018
self.noise_loss_tracker.update_state(noise_loss)
@@ -1011,8 +1024,8 @@ def test_step(self, batch):
10111024
if self.input_map.axes() == 3 and self.inspect_model:
10121025
images = self.denormalize(images)
10131026
generated_images = self.generate(control_embed,
1014-
num_images=self.batch_size, diffusion_steps=20
1015-
)
1027+
num_images=self.batch_size, diffusion_steps=20
1028+
)
10161029
self.kid.update_state(images, generated_images)
10171030
self.ms_ssim.update_state(images, generated_images, 255)
10181031

@@ -1025,7 +1038,10 @@ def call(self, batch, training=False):
10251038
2. You can use model((noisy_images, noise_rates)) for inference
10261039
"""
10271040
noisy_images, noise_rates = batch[0]
1028-
control_embed = self.control_embed_model(batch[1])
1041+
if self.autoencoder_control:
1042+
control_embed = self.control_embed_model(noisy_images)
1043+
else:
1044+
control_embed = self.control_embed_model(batch[1])
10291045
# re-compute signal_rates
10301046
signal_rates = tf.sqrt(1.0 - tf.square(noise_rates))
10311047
# this returns (pred_noises, pred_images)
@@ -1063,8 +1079,8 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None
10631079
plt.close()
10641080

10651081
def plot_reconstructions(
1066-
self, batch, diffusion_amount=0,
1067-
epoch=None, logs=None, num_rows=4, num_cols=4, prefix='./figures/',
1082+
self, batch, diffusion_amount=0,
1083+
epoch=None, logs=None, num_rows=4, num_cols=4, prefix='./figures/',
10681084
):
10691085
images = batch[0][self.input_map.input_name()]
10701086
self.normalizer.adapt(images)
@@ -1075,7 +1091,10 @@ def plot_reconstructions(
10751091
# mix the images with noises accordingly
10761092
noisy_images = signal_rates * images + noise_rates * noises
10771093

1078-
control_embed = self.control_embed_model(batch[1])
1094+
if self.autoencoder_control:
1095+
control_embed = self.control_embed_model(batch[0])
1096+
else:
1097+
control_embed = self.control_embed_model(batch[1])
10791098

10801099
# use the network to separate noisy images to their components
10811100
pred_noises, generated_images = self.denoise(
@@ -1111,10 +1130,9 @@ def plot_reconstructions(
11111130
plt.close()
11121131
return generated_images
11131132

1114-
11151133
def control_plot_images(
1116-
self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,
1117-
renoise=None,
1134+
self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,
1135+
renoise=None,
11181136
):
11191137
control_embed = self.control_embed_model(control_batch)
11201138
# plot random generated images for visual evaluation of generation quality
@@ -1139,6 +1157,31 @@ def control_plot_images(
11391157

11401158
return generated_images
11411159

1160+
def control_plot_images_embed(
1161+
self, control_embed, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,
1162+
renoise=None,
1163+
):
1164+
generated_images = self.generate(
1165+
control_embed,
1166+
num_images=max(self.batch_size, num_rows * num_cols),
1167+
diffusion_steps=plot_diffusion_steps,
1168+
reseed=reseed,
1169+
renoise=renoise,
1170+
)
1171+
1172+
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300)
1173+
for row in range(num_rows):
1174+
for col in range(num_cols):
1175+
index = row * num_cols + col
1176+
plt.subplot(num_rows, num_cols, index + 1)
1177+
plt.imshow(generated_images[index], cmap='gray')
1178+
plt.axis("off")
1179+
plt.tight_layout()
1180+
plt.show()
1181+
plt.close()
1182+
1183+
return generated_images
1184+
11421185
def control_plot_images_noise(self, control_batch, initial_noise, epoch=None, logs=None, num_rows=2, num_cols=8):
11431186
control_embed = self.control_embed_model(control_batch)
11441187
# plot random generated images for visual evaluation of generation quality
@@ -1163,8 +1206,8 @@ def control_plot_images_noise(self, control_batch, initial_noise, epoch=None, lo
11631206
return generated_images
11641207

11651208
def control_plot_ecgs(
1166-
self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,
1167-
renoise=None,
1209+
self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,
1210+
renoise=None,
11681211
):
11691212
control_embed = self.control_embed_model(control_batch)
11701213
# plot random generated images for visual evaluation of generation quality

0 commit comments

Comments
 (0)