Skip to content

Add Esm #2244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open

Add Esm #2244

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions esm2_t6_8M/assets/tokenizer/vocabulary.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<cls>
<pad>
<eos>
<unk>
L
A
G
V
S
E
R
T
I
D
P
K
Q
N
F
Y
M
H
W
C
X
B
U
Z
O
.
-
<null_1>
<mask>
23 changes: 23 additions & 0 deletions esm2_t6_8M/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"module": "keras_hub.src.models.esm.esm_backbone",
"class_name": "ESMBackbone",
"config": {
"name": "esm_backbone",
"trainable": true,
"vocabulary_size": 33,
"num_layers": 6,
"num_heads": 20,
"hidden_dim": 320,
"intermediate_dim": 1280,
"dropout": 0.0,
"max_wavelength": 10000,
"use_bias": true,
"activation": "gelu",
"layer_norm_eps": 1e-05,
"use_pre_layer_norm": false,
"position_embedding_type": "rotary",
"max_sequence_length": 1026,
"pad_token_id": 1
},
"registered_name": "keras_hub>ESMBackbone"
}
10 changes: 10 additions & 0 deletions esm2_t6_8M/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"keras_version": "3.10.0",
"keras_hub_version": "0.21.0.dev0",
"parameter_count": 7408960,
"date_saved": "2025-06-17@15:50:50",
"tasks": [
"MaskedLM",
"TextClassifier"
]
}
Binary file added esm2_t6_8M/model.weights.h5
Binary file not shown.
51 changes: 51 additions & 0 deletions esm2_t6_8M/preprocessor.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"module": "keras_hub.src.models.esm.esm_masked_plm_preprocessor",
"class_name": "ESMMaskedPLMPreprocessor",
"config": {
"name": "esm_masked_plm_preprocessor_1",
"trainable": true,
"dtype": {
"module": "keras",
"class_name": "DTypePolicy",
"config": {
"name": "float32"
},
"registered_name": null
},
"tokenizer": {
"module": "keras_hub.src.models.esm.esm_tokenizer",
"class_name": "ESMTokenizer",
"config": {
"name": "esm_tokenizer",
"trainable": true,
"dtype": {
"module": "keras",
"class_name": "DTypePolicy",
"config": {
"name": "int32"
},
"registered_name": null
},
"config_file": "tokenizer.json",
"vocabulary": null,
"sequence_length": null,
"lowercase": false,
"strip_accents": false,
"split": true,
"suffix_indicator": "##",
"oov_token": "<unk>",
"special_tokens": null,
"special_tokens_in_strings": false
},
"registered_name": "keras_hub>ESMTokenizer"
},
"config_file": "preprocessor.json",
"sequence_length": 512,
"truncate": "round_robin",
"mask_selection_rate": 0.15,
"mask_selection_length": 96,
"mask_token_rate": 0.8,
"random_token_rate": 0.1
},
"registered_name": "keras_hub>ESMMaskedPLMPreprocessor"
}
82 changes: 82 additions & 0 deletions esm2_t6_8M/task.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
{
"module": "keras_hub.src.models.esm.esm_masked_plm",
"class_name": "ESMMaskedPLM",
"config": {
"backbone": {
"module": "keras_hub.src.models.esm.esm_backbone",
"class_name": "ESMBackbone",
"config": {
"name": "esm_backbone",
"trainable": true,
"vocabulary_size": 33,
"num_layers": 6,
"num_heads": 20,
"hidden_dim": 320,
"intermediate_dim": 1280,
"dropout": 0.0,
"max_wavelength": 10000,
"use_bias": true,
"activation": "gelu",
"layer_norm_eps": 1e-05,
"use_pre_layer_norm": false,
"position_embedding_type": "rotary",
"max_sequence_length": 1026,
"pad_token_id": 1
},
"registered_name": "keras_hub>ESMBackbone"
},
"preprocessor": {
"module": "keras_hub.src.models.esm.esm_masked_plm_preprocessor",
"class_name": "ESMMaskedPLMPreprocessor",
"config": {
"name": "esm_masked_plm_preprocessor_1",
"trainable": true,
"dtype": {
"module": "keras",
"class_name": "DTypePolicy",
"config": {
"name": "float32"
},
"registered_name": null
},
"tokenizer": {
"module": "keras_hub.src.models.esm.esm_tokenizer",
"class_name": "ESMTokenizer",
"config": {
"name": "esm_tokenizer",
"trainable": true,
"dtype": {
"module": "keras",
"class_name": "DTypePolicy",
"config": {
"name": "int32"
},
"registered_name": null
},
"config_file": "tokenizer.json",
"vocabulary": null,
"sequence_length": null,
"lowercase": false,
"strip_accents": false,
"split": true,
"suffix_indicator": "##",
"oov_token": "<unk>",
"special_tokens": null,
"special_tokens_in_strings": false
},
"registered_name": "keras_hub>ESMTokenizer"
},
"config_file": "preprocessor.json",
"sequence_length": 512,
"truncate": "round_robin",
"mask_selection_rate": 0.15,
"mask_selection_length": 96,
"mask_token_rate": 0.8,
"random_token_rate": 0.1
},
"registered_name": "keras_hub>ESMMaskedPLMPreprocessor"
},
"name": "esm_masked_plm"
},
"registered_name": "keras_hub>ESMMaskedPLM"
}
Binary file added esm2_t6_8M/task.weights.h5
Binary file not shown.
27 changes: 27 additions & 0 deletions esm2_t6_8M/tokenizer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"module": "keras_hub.src.models.esm.esm_tokenizer",
"class_name": "ESMTokenizer",
"config": {
"name": "esm_tokenizer",
"trainable": true,
"dtype": {
"module": "keras",
"class_name": "DTypePolicy",
"config": {
"name": "int32"
},
"registered_name": null
},
"config_file": "tokenizer.json",
"vocabulary": null,
"sequence_length": null,
"lowercase": false,
"strip_accents": false,
"split": true,
"suffix_indicator": "##",
"oov_token": "<unk>",
"special_tokens": null,
"special_tokens_in_strings": false
},
"registered_name": "keras_hub>ESMTokenizer"
}
16 changes: 16 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,22 @@
from keras_hub.src.models.electra.electra_tokenizer import (
ElectraTokenizer as ElectraTokenizer,
)
from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone
from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESMBackbone
from keras_hub.src.models.esm.esm_classifier import (
ESMProteinClassifier as ESMProteinClassifier,
)
from keras_hub.src.models.esm.esm_classifier_preprocessor import (
ESMProteinClassifierPreprocessor as ESMProteinClassifierPreprocessor,
)
from keras_hub.src.models.esm.esm_masked_plm import (
ESMMaskedPLM as ESM2MaskedPLM,
)
from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM
from keras_hub.src.models.esm.esm_masked_plm_preprocessor import (
ESMMaskedPLMPreprocessor as ESMMaskedPLMPreprocessor,
)
from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer
from keras_hub.src.models.f_net.f_net_backbone import (
FNetBackbone as FNetBackbone,
)
Expand Down
1 change: 1 addition & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from keras_hub.src.models.electra.electra_tokenizer import (
ElectraTokenizer as ElectraTokenizer,
)
from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer
from keras_hub.src.models.f_net.f_net_tokenizer import (
FNetTokenizer as FNetTokenizer,
)
Expand Down
Empty file.
95 changes: 95 additions & 0 deletions keras_hub/src/models/esm/esm_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import keras
from keras import ops
from packaging import version

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.models.roformer_v2.roformer_v2_attention import (
RoformerAttention,
)


class ESMRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_embedding(self, x, position=1):
dim = x.shape[-1]
inv_freq = self.scaling_factor / (
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
)
t = ops.arange(x.shape[position], dtype=x.dtype)
freqs = ops.outer(t, inv_freq)
emb = ops.concatenate((freqs, freqs), axis=-1)

cos_emb = ops.cos(emb)[None, :, None, :]
sin_emb = ops.sin(emb)[None, :, None, :]
return cos_emb, sin_emb

def call(self, q, k, position=1):
cos_emb, sin_emb = self._compute_cos_sin_embedding(q, position)

return (
self.apply_rotary_pos_emb(q, cos_emb, sin_emb),
self.apply_rotary_pos_emb(k, cos_emb, sin_emb),
)

def rotate_half(self, x):
x1, x2 = ops.split(x, 2, -1)
return ops.concatenate((-x2, x1), axis=-1)

def apply_rotary_pos_emb(self, x, cos, sin):
cos = cos[:, : x.shape[1], :, :]
sin = sin[:, : x.shape[1], :, :]

return (x * cos) + (self.rotate_half(x) * sin)


class EsmSelfAttention(RoformerAttention):
"""MultiHeadAttention by ESM2

Referred to the implementation of HuggingFace.
In fact, this part of the calculation is exactly the same as RoFormer.
Only the calculation of the rotary part is different.
"""

def __init__(self, use_rotary=True, **kwargs):
super().__init__(**kwargs)
self.use_rotary = use_rotary

def build(self, input_shape):
super().build(input_shape)
if self.use_rotary:
self.rotary_embedding_layer = ESMRotaryEmbedding(
max_wavelength=self.max_wavelength, dtype=self.dtype_policy
)
self.rotary_embedding_layer.build([])

def call(self, x, attention_mask=None):
qw = self.q_dense(x)
kw = self.k_dense(x)
vw = self.v_dense(x)

b, s = ops.shape(qw)[:2]
qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
vw = ops.reshape(vw, (b, s, self.heads, self.head_size))

if self.use_rotary:
qw, kw = self.rotary_embedding_layer(qw, kw)
if version.parse(keras.__version__) < version.parse("3.6"):
raise ("Please make sure your Keras version is >=3.6.")
flash_attention = keras.config.is_flash_attention_enabled()
attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
if keras.config.backend() == "torch":
attention_mask = ops.repeat(attention_mask, s, -1)
attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
o = ops.dot_product_attention(
qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
)
return self.o_dense(ops.reshape(o, [b, s, -1]))

def get_config(self):
config = super().get_config()
config.update(
{
"use_rotary": self.use_rotary,
}
)
return config
Loading
Loading