Skip to content
Merged

Add Esm #2244

Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,22 @@
from keras_hub.src.models.electra.electra_tokenizer import (
ElectraTokenizer as ElectraTokenizer,
)
from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone
from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESMBackbone
from keras_hub.src.models.esm.esm_classifier import (
ESMProteinClassifier as ESMProteinClassifier,
)
from keras_hub.src.models.esm.esm_classifier_preprocessor import (
ESMProteinClassifierPreprocessor as ESMProteinClassifierPreprocessor,
)
from keras_hub.src.models.esm.esm_masked_plm import (
ESMMaskedPLM as ESM2MaskedPLM,
)
from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import of ESMMaskedPLM is redundant. The name ESMMaskedPLM is already available from the import on lines 200-202. Removing this line will improve code clarity. Even though this file is autogenerated, it's good practice to address such issues in the source generator if possible.

from keras_hub.src.models.esm.esm_masked_plm_preprocessor import (
ESMMaskedPLMPreprocessor as ESMMaskedPLMPreprocessor,
)
from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer
from keras_hub.src.models.f_net.f_net_backbone import (
FNetBackbone as FNetBackbone,
)
Expand Down
1 change: 1 addition & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from keras_hub.src.models.electra.electra_tokenizer import (
ElectraTokenizer as ElectraTokenizer,
)
from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer
from keras_hub.src.models.f_net.f_net_tokenizer import (
FNetTokenizer as FNetTokenizer,
)
Expand Down
Empty file.
94 changes: 94 additions & 0 deletions keras_hub/src/models/esm/esm_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import keras
from keras import ops

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.models.roformer_v2.roformer_v2_attention import (
RoformerAttention,
)


class ESMRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_embedding(self, x, position=1):
dim = x.shape[-1]
inv_freq = self.scaling_factor / (
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
)
t = ops.arange(x.shape[position], dtype=x.dtype)
freqs = ops.outer(t, inv_freq)
emb = ops.concatenate((freqs, freqs), axis=-1)

cos_emb = ops.cos(emb)[None, :, None, :]
sin_emb = ops.sin(emb)[None, :, None, :]
return cos_emb, sin_emb

def call(self, q, k, position=1):
cos_emb, sin_emb = self._compute_cos_sin_embedding(q, position)

return (
self.apply_rotary_pos_emb(q, cos_emb, sin_emb),
self.apply_rotary_pos_emb(k, cos_emb, sin_emb),
)

def rotate_half(self, x):
x1, x2 = ops.split(x, 2, -1)
return ops.concatenate((-x2, x1), axis=-1)

def apply_rotary_pos_emb(self, x, cos, sin):
cos = cos[:, : x.shape[1], :, :]
sin = sin[:, : x.shape[1], :, :]

return (x * cos) + (self.rotate_half(x) * sin)


class EsmSelfAttention(RoformerAttention):
"""MultiHeadAttention by ESM2

Referred to the implementation of HuggingFace.
In fact, this part of the calculation is exactly the same as RoFormer.
Only the calculation of the rotary part is different.
"""

def __init__(self, use_rotary=True, **kwargs):
super().__init__(**kwargs)
self.use_rotary = use_rotary

def build(self, input_shape):
super().build(input_shape)
if self.use_rotary:
self.rotary_embedding_layer = ESMRotaryEmbedding(
max_wavelength=self.max_wavelength, dtype=self.dtype_policy
)
self.rotary_embedding_layer.build([])

def call(self, x, attention_mask=None):
qw = self.q_dense(x)
kw = self.k_dense(x)
vw = self.v_dense(x)

b, s = ops.shape(qw)[:2]
qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
vw = ops.reshape(vw, (b, s, self.heads, self.head_size))

if self.use_rotary:
qw, kw = self.rotary_embedding_layer(qw, kw)
if keras.__version__ < "3.6":
raise ("Please make sure your Keras version is >=3.6.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Raising a string or a tuple does not work as intended in Python 3 and will result in a TypeError. You should raise an instance of an exception class, such as ValueError.

raise ValueError("Please make sure your Keras version is >=3.6.")

flash_attention = keras.config.is_flash_attention_enabled()
attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
if keras.config.backend() == "torch":
attention_mask = ops.repeat(attention_mask, s, -1)
attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
o = ops.dot_product_attention(
qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
)
return self.o_dense(ops.reshape(o, [b, s, -1]))

def get_config(self):
config = super().get_config()
config.update(
{
"use_rotary": self.use_rotary,
}
)
return config
221 changes: 221 additions & 0 deletions keras_hub/src/models/esm/esm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import keras
from keras import activations

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.esm.esm_encoder import ESMEncoder


def esm2_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)


@keras_hub_export(
["keras_hub.models.ESM2Backbone", "keras_hub.models.ESMBackbone"]
)
class ESMBackbone(Backbone):
"""A ESM2 and ESM encoder network.

This class implements a bi-directional Transformer-based encoder as
described in ["Roformer"](https://github.com/facebookresearch/esm).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring mentions "Roformer" but links to the ESM repository. To avoid confusion, the link text should be updated to "ESM" to match the model being implemented.

Suggested change
described in ["Roformer"](https://github.com/facebookresearch/esm).
described in ["ESM"](https://github.com/facebookresearch/esm).


The default constructor gives a fully customizable, randomly initialized
ESM2 encoder with any number of layers, heads, and embed dim.To
load preset architectures and weights, use the `from_preset()` constructor.


Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Defaults to in the arg description wherever you're using default values.
max_wavelength arg detail is missing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still activation and max_wavelength description is missing!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add arg description for pad_token_id as well

vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults are missing for the activation, max_wavelength, and pad_token_id arguments.

num_heads: int. The number of attention heads for each transformer.
The hidden size must be divisible by the number of attention heads.
hidden_dim: int. The size of the transformer encoding and pooler layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
Defaults to 0.1
layer_norm_eps:bool.If true, then layer norm will be used before
entering the transformer block.
Since it's pre-norm, the default is false.
max_sequence_length: int. The maximum sequence length that this encoder
can consume. If None, `max_sequence_length` uses the value from
sequence length. This determines the variable shape for positional
embeddings.
position_embedding_type:esm1 use abs position embeding,esm2 use rope.
so this parameter is only except for absolute and rotary.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to: position_embedding_type: str. The position embedding type to use. One of "absolute" and
"rotary". Use "absolute" for ESM1. Use "rotary" for ESM2. Defaults to "rotary".

Copy link
Collaborator

@sachinprasadhs sachinprasadhs Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still needs to be changed to:

position_embedding_type: str. The position embedding type to use. One of "absolute" and
"rotary". Use "absolute" for ESM1. Use "rotary" for ESM2. Defaults to "rotary".

dtype: None or str or .keras.mixed_precision.DTypePolicy. The dtype to
use for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.

Examples:
```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
}

# Pretrained ESM2 encoder.
model = keras_hub.models.ESM2Backbone.from_preset('hf://facebook/esm2_t6_8M_UR50D')
model(input_data)

# Randomly initialized ESM2 encoder with a custom config.
model = keras_hub.models.ESM2Backbone(
vocabulary_size=30552,
num_layers=4,
num_heads=4,
hidden_dim=256,
intermediate_dim=512,
head_size = 64,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update this example. head_size is not part of examples.
have these examples been tested? do they run successfully?

)
model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
num_layers,
num_heads,
hidden_dim,
intermediate_dim,
use_bias=True,
activation="gelu",
dropout=0.1,
dtype=None,
max_sequence_length=1024,
max_wavelength=10000,
layer_norm_eps=1e-12,
emb_layer_norm_before=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of emb_layer_norm_before use something like use_pre_layer_norm

Copy link
Collaborator

@sachinprasadhs sachinprasadhs Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pending change, instead emb_layer_norm_before --> use_pre_layer_norm

position_embedding_type="rotary",
pad_token_id=0,
**kwargs,
):
if position_embedding_type not in (
"rotary",
"absolute",
):
raise ValueError(
'`position_embedding_type` must be either `"rotary"`, or '
'`"absolute"`. Received '
"position_embedding_type={position_embedding_type}."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string in the ValueError is intended to be an f-string to include the value of position_embedding_type, but it's missing the f prefix. This will result in the literal string {position_embedding_type} being part of the error message.

Suggested change
"position_embedding_type={position_embedding_type}."
f"position_embedding_type={position_embedding_type}."

)
head_size = hidden_dim // num_heads
# === Layers ===
self.token_embedding = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=esm2_kernel_initializer(),
dtype=dtype,
name="token_embedding",
)
if position_embedding_type == "absolute":
self.position_embedding = PositionEmbedding(
initializer=esm2_kernel_initializer(),
sequence_length=max_sequence_length,
dtype=dtype,
name="position_embedding",
)
self.embeddings_add = keras.layers.Add(
dtype=dtype,
name="embeddings_add",
)

self.output_layer_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_eps,
dtype=dtype,
name="output_layer_norm",
)
if emb_layer_norm_before:
self.emb_layer_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_eps,
dtype=dtype,
name="emb_layer_norm",
)
self.transformer_layers = []
for i in range(num_layers):
layer = ESMEncoder(
heads=num_heads,
head_size=head_size,
intermediate_size=intermediate_dim,
use_bias=use_bias,
max_wavelength=max_wavelength,
dropout=dropout,
activation=activation,
kernel_initializer=esm2_kernel_initializer(),
layer_norm_eps=layer_norm_eps,
dtype=dtype,
use_rotary=position_embedding_type == "rotary",
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)

# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)

attention_mask = keras.ops.not_equal(token_id_input, pad_token_id)

token_vector = self.token_embedding(token_id_input)
if position_embedding_type == "absolute":
position_vector = self.position_embedding(
token_vector, start_index=pad_token_id
)
x = self.embeddings_add([token_vector, position_vector])
else:
x = token_vector
if emb_layer_norm_before:
x = self.emb_layer_norm(x)
for transformer_layer in self.transformer_layers:
x = transformer_layer(x, attention_mask=attention_mask)
output = self.output_layer_norm(x)
super().__init__(
inputs={
"token_ids": token_id_input,
},
outputs=output,
dtype=dtype,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_wavelength = max_wavelength
self.head_size = head_size
self.dropout = dropout

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The self.dropout attribute is assigned twice in the __init__ method (lines 195 and 198). The second assignment is redundant and can be removed.

self.activation = activations.get(activation)
self.use_bias = use_bias
self.start_token_index = 0
self.layer_norm_eps = layer_norm_eps
self.max_sequence_length = max_sequence_length
self.emb_layer_norm_before = emb_layer_norm_before
self.position_embedding_type = position_embedding_type
self.pad_token_id = pad_token_id

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_wavelength": self.max_wavelength,
"use_bias": self.use_bias,
"activation": activations.serialize(self.activation),
"layer_norm_eps": self.layer_norm_eps,
"emb_layer_norm_before": self.emb_layer_norm_before,
"position_embedding_type": self.position_embedding_type,
"max_sequence_length": self.max_sequence_length,
"pad_token_id": self.pad_token_id,
}
)
return config
30 changes: 30 additions & 0 deletions keras_hub/src/models/esm/esm_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import keras
from keras import ops

from keras_hub.src.models.esm.esm_backbone import ESMBackbone
from keras_hub.src.tests.test_case import TestCase


class ESMBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 10,
"num_layers": 2,
"num_heads": 1,
"hidden_dim": 2,
"intermediate_dim": 4,
}
self.input_data = {
"token_ids": ops.ones((2, 5), dtype="int32"),
"segment_ids": ops.zeros((2, 5), dtype="int32"),
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ESMBackbone model only accepts "token_ids" as input. The "segment_ids" key in self.input_data is not used by the model and should be removed to make the test accurately reflect the model's usage.

Suggested change
"segment_ids": ops.zeros((2, 5), dtype="int32"),
}
"token_ids": ops.ones((2, 5), dtype="int32")


def test_backbone_basics(self):
if keras.__version__ < "3.6":
self.skipTest("Failing on keras lower version")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the tests failing due to some bug which was addressed in 3.6 release?

self.run_backbone_test(
cls=ESMBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
)
Loading
Loading