|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +import pytest |
| 4 | + |
| 5 | +from cellflow.networks._set_encoders import ConditionEncoder |
| 6 | +from cellflow.networks._utils import SeedAttentionPooling, TokenAttentionPooling |
| 7 | + |
| 8 | + |
| 9 | +class TestAggregator: |
| 10 | + @pytest.mark.parametrize("agg", [TokenAttentionPooling, SeedAttentionPooling]) |
| 11 | + def test_mask_impact_on_TokenAttentionPooling(self, agg): |
| 12 | + rng = jax.random.PRNGKey(0) |
| 13 | + init_rng, mask_rng = jax.random.split(rng, 2) |
| 14 | + condition = jax.random.normal(rng, (2, 3, 7)) |
| 15 | + condition = jnp.concatenate((condition, jnp.zeros((2, 1, 7))), axis=1) |
| 16 | + cond_encoder = ConditionEncoder(32) |
| 17 | + _, attn_mask = cond_encoder._get_masks({"conditions": condition}) |
| 18 | + random_mask = jax.random.bernoulli(mask_rng, 0.5, attn_mask.shape).astype(jnp.int32) |
| 19 | + agg = agg() |
| 20 | + variables = agg.init(init_rng, condition, random_mask, training=True) |
| 21 | + out = agg.apply(variables, condition, attn_mask, training=True) |
| 22 | + out_rand = agg.apply(variables, condition, random_mask, training=True) |
| 23 | + # output dim = input dim for TokenAttentionPooling, output dim = 64 by default in SeedAttentionPooling |
| 24 | + assert out.shape[0] == 2 |
| 25 | + assert out.shape[1] == 7 if isinstance(agg, TokenAttentionPooling) else 64 |
| 26 | + assert out_rand.shape[0] == 2 |
| 27 | + assert out_rand.shape[1] == 7 if isinstance(agg, TokenAttentionPooling) else 64 |
| 28 | + assert not jnp.allclose(out[0], out_rand[0], atol=1e-6) |
| 29 | + assert not jnp.allclose(out[1], out_rand[1], atol=1e-6) |
0 commit comments