diff --git a/src/crested/tl/_explainer_tf.py b/src/crested/tl/_explainer_tf.py index 4dc132a9..95e266d6 100644 --- a/src/crested/tl/_explainer_tf.py +++ b/src/crested/tl/_explainer_tf.py @@ -4,6 +4,8 @@ Adapted from: https://github.com/p-koo/tfomics/blob/master/tfomics/ """ +from __future__ import annotations + import numpy as np import tensorflow as tf @@ -97,18 +99,38 @@ def set_baseline(self, x, baseline, num_samples): return baseline -def saliency_map(X, model, class_index=None, func=tf.math.reduce_mean): +def saliency_map( + X, model, class_index: int | list[int] | None = None, func=tf.math.reduce_mean +): """Fast function to generate saliency maps.""" if not tf.is_tensor(X): X = tf.Variable(X) - with tf.GradientTape() as tape: + # use persistent tape so gradient can be calculated for each output in class_index, in case + # class_index is a list of indexes. + with tf.GradientTape(persistent=True) as tape: tape.watch(X) - if class_index is not None: - outputs = model(X)[:, class_index] + output = model(X) + if isinstance(class_index, int): + # get output for class (C) + outputs_C = [output[:, class_index]] + elif isinstance(class_index, list): + # get output for multiple classes + outputs_C = [output[:, c] for c in class_index] + elif class_index is None: + # legacy mode -- not sure if `func` is even needed here?? + outputs_C = [func(model(X))] else: - outputs = func(model(X)) - return tape.gradient(outputs, X) + raise ValueError( + f"class_index should be either an integer a list of integers or None, not: {class_index}." + ) + grads = np.empty((len(outputs_C), *X.shape)) + for i in range(len(outputs_C)): + grads[i] = tape.gradient(outputs_C[i], X) + # explicitly delete the tape, needed because persistent is True + del tape + # squeeze grads so first dimension is dropped in case class_index is a single int. + return grads.squeeze() @tf.function @@ -148,7 +170,12 @@ def smoothgrad( def integrated_grad( - x, model, baseline, num_steps=25, class_index=None, func=tf.math.reduce_mean + x, + model, + baseline, + num_steps=25, + class_index: int | list[int] | None = None, + func=tf.math.reduce_mean, ): """Calculate integrated gradients for a given sequence.""" @@ -167,8 +194,17 @@ def interpolate_data(baseline, x, steps): steps = tf.linspace(start=0.0, stop=1.0, num=num_steps + 1) x_interp = interpolate_data(baseline, x, steps) grad = saliency_map(x_interp, model, class_index=class_index, func=func) + # at this point the shape of grad is either: + # - (num_steps + 1, *x.shape) in case class_index is None or a single int. + # - (len(class_index), num_steps + 1, *x.shape) in case class_idnex is a list of int. + if len(grad.shape) == 4: + # second case, put num_steps + 1 on first axis + grad = grad.swapaxes(0, 1) avg_grad = integral_approximation(grad) - avg_grad = np.expand_dims(avg_grad, axis=0) + if len(avg_grad.shape) != 3: + # first case, in this case the dimension should be expanded. + # in the second case they are already expanded. + avg_grad = np.expand_dims(avg_grad, axis=0) return avg_grad diff --git a/tests/test_refactor.py b/tests/test_refactor.py index 87aa6d98..5ce29fea 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -1,6 +1,5 @@ """Test that ensures the outputs after the functional refactor are the same as before.""" -import os import keras import numpy as np @@ -19,11 +18,6 @@ np.random.seed(42) keras.utils.set_random_seed(42) -if os.environ["KERAS_BACKEND"] == "tensorflow": - import tensorflow as tf - - tf.config.experimental.enable_op_determinism() - @pytest.fixture(scope="module") def crested_object(keras_model, adata, genome):