Skip to content

Commit d8ae659

Browse files
authored
Merge pull request #241 from theislab/fix/tokenattention
fix TokenAttention
2 parents 70eec9a + f3c840c commit d8ae659

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ optional-dependencies.docs = [
6363
"ipython",
6464
"myst-nb>=1.1",
6565
"pandas",
66-
"scvi-tools",
66+
"scvi-tools>=1.3.1",
6767
"setuptools", # Until pybtex >0.23.0 releases: https://bitbucket.org/pybtex-devs/pybtex/issues/169/
6868
"sphinx>=8",
6969
"sphinx-autodoc-typehints",
@@ -78,7 +78,7 @@ optional-dependencies.embedding = [
7878
"transformers",
7979
]
8080
optional-dependencies.external = [
81-
"scvi-tools",
81+
"scvi-tools>=1.3.1",
8282
]
8383
optional-dependencies.pp = [
8484
"pertpy",

src/cellflow/networks/_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,10 +464,11 @@ def __call__(
464464
token_shape = (len(x), 1)
465465
class_token = nn.Embed(num_embeddings=1, features=x.shape[-1])(jnp.int32(jnp.zeros(token_shape)))
466466
z = jnp.concatenate((class_token, x), axis=-2)
467-
token_mask = jnp.zeros((x.shape[0], 1, x.shape[1] + 1, x.shape[1] + 1))
468-
token_mask = token_mask.at[:, :, 0, :].set(1)
469-
token_mask = token_mask.at[:, :, :, 0].set(1)
467+
token_mask = jnp.ones((x.shape[0], 1, x.shape[1] + 1, x.shape[1] + 1))
470468
token_mask = token_mask.at[:, :, 1:, 1:].set(mask)
469+
cls_token_to_data = mask[0, 0, :, :].sum(axis=0) > 0
470+
token_mask = token_mask.at[:, :, 0, 1:].set(cls_token_to_data)
471+
token_mask = token_mask.at[:, :, 1:, 0].set(cls_token_to_data)
471472

472473
# attention
473474
attention = nn.MultiHeadDotProductAttention(

tests/networks/test_aggregators.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import pytest
4+
5+
from cellflow.networks._set_encoders import ConditionEncoder
6+
from cellflow.networks._utils import SeedAttentionPooling, TokenAttentionPooling
7+
8+
9+
class TestAggregator:
10+
@pytest.mark.parametrize("agg", [TokenAttentionPooling, SeedAttentionPooling])
11+
def test_mask_impact_on_TokenAttentionPooling(self, agg):
12+
rng = jax.random.PRNGKey(0)
13+
init_rng, mask_rng = jax.random.split(rng, 2)
14+
condition = jax.random.normal(rng, (2, 3, 7))
15+
condition = jnp.concatenate((condition, jnp.zeros((2, 1, 7))), axis=1)
16+
cond_encoder = ConditionEncoder(32)
17+
_, attn_mask = cond_encoder._get_masks({"conditions": condition})
18+
random_mask = jax.random.bernoulli(mask_rng, 0.5, attn_mask.shape).astype(jnp.int32)
19+
agg = agg()
20+
variables = agg.init(init_rng, condition, random_mask, training=True)
21+
out = agg.apply(variables, condition, attn_mask, training=True)
22+
out_rand = agg.apply(variables, condition, random_mask, training=True)
23+
# output dim = input dim for TokenAttentionPooling, output dim = 64 by default in SeedAttentionPooling
24+
assert out.shape[0] == 2
25+
assert out.shape[1] == 7 if isinstance(agg, TokenAttentionPooling) else 64
26+
assert out_rand.shape[0] == 2
27+
assert out_rand.shape[1] == 7 if isinstance(agg, TokenAttentionPooling) else 64
28+
assert not jnp.allclose(out[0], out_rand[0], atol=1e-6)
29+
assert not jnp.allclose(out[1], out_rand[1], atol=1e-6)

0 commit comments

Comments
 (0)