Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions InstructorEmbedding/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import os
from collections import OrderedDict
from typing import Union

import copy
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer
from torch import Tensor, nn
from torch import Tensor, device, nn
from tqdm.autonotebook import trange
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel, T5Config
from transformers.models.t5.modeling_t5 import T5Stack
from typing import Tuple


def batch_to_device(batch, target_device: str):
Expand All @@ -20,6 +22,73 @@ def batch_to_device(batch, target_device: str):
batch[key] = batch[key].to(target_device)
return batch

class T5CustomStack(T5Stack):
def __init__(self, config, embed_tokens=None):
super().__init__(config, embed_tokens)

def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.

Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
input_shape, attention_mask, device
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.repeat(1,1,extended_attention_mask.shape[-1],1) # fp16 compatibility
extended_attention_mask[:, :, 0, 0] = 0
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask

class T5CustomEncoderModel(T5EncoderModel):
def __init__(self, config: T5Config):
super().__init__(config)
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5CustomStack(encoder_config, self.shared)

class T5Custom(T5CustomEncoderModel):
def __init__(self, config: T5Config):
super().__init__(config)


class InstructorPooling(nn.Module):
"""Performs pooling (max or mean) on the token embeddings.
Expand Down Expand Up @@ -340,6 +409,7 @@ def forward(self, features):
instruction_mask = features["instruction_mask"]
output_states = self.auto_model(**input_features, return_dict=False)
output_tokens = output_states[0]
# torch.save((output_states[1][-1]),"/Users/ashok-4983/GitHub/instructor-embedding/attn.bin")
attention_mask = features["attention_mask"]
instruction_mask = features["instruction_mask"]
attention_mask = attention_mask * instruction_mask
Expand Down Expand Up @@ -418,20 +488,22 @@ def tokenize(self, texts):

input_features = self.tokenize(instruction_prepended_input_texts)
instruction_features = self.tokenize(instructions)
input_features = Instructor.prepare_input_features(
input_features = self.prepare_input_features(
input_features, instruction_features
)
else:
raise ValueError("not support other modes")

output.update(input_features)
return output

def _load_t5_model(self, model_name_or_path, config, cache_dir):
"""Loads the encoder model from T5"""
T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
self.auto_model = T5Custom.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir)


class Instructor(SentenceTransformer):
@staticmethod
def prepare_input_features(
input_features, instruction_features, return_data_type: str = "pt"
self, input_features, instruction_features, return_data_type: str = "pt"
):
if return_data_type == "np":
input_features["attention_mask"] = torch.from_numpy(
Expand All @@ -443,6 +515,18 @@ def prepare_input_features(

input_attention_mask_shape = input_features["attention_mask"].shape
instruction_attention_mask = instruction_features["attention_mask"]
input_ids = input_features["input_ids"]
attention_mask = input_features["attention_mask"]

input_ids = input_ids[:, 1:]
attention_mask = attention_mask[:, 1:]
selfless_token_id = self.tokenizer.convert_tokens_to_ids(["<extra_id_0>"])[0]

selfless_token_id_column = torch.full((input_ids.size(0), 1), selfless_token_id)
selfless_token_attention_column = torch.full((input_ids.size(0), 1), 1)

input_ids = torch.cat((selfless_token_id_column, input_ids), dim=1)
attention_mask = torch.cat((selfless_token_attention_column, attention_mask), dim=1)

# reducing the attention length by 1 in order to omit the attention corresponding to the end_token
instruction_attention_mask = instruction_attention_mask[:, 1:]
Expand All @@ -469,6 +553,8 @@ def prepare_input_features(
# and not the instruction. This is achieved by inverting the
# attention_mask corresponding to the instruction.
expanded_instruction_attention_mask = 1 - expanded_instruction_attention_mask
input_features["input_ids"] = input_ids
input_features["attention_mask"] = attention_mask
input_features["instruction_mask"] = expanded_instruction_attention_mask
if return_data_type == "np":
input_features["attention_mask"] = input_features["attention_mask"].numpy()
Expand All @@ -477,6 +563,7 @@ def prepare_input_features(
].numpy()
return input_features

class Instructor(SentenceTransformer):
def smart_batching_collate(self, batch):
num_texts = len(batch[0].texts)
texts = [[] for _ in range(num_texts)]
Expand Down Expand Up @@ -508,7 +595,7 @@ def smart_batching_collate(self, batch):

input_features = self.tokenize(instruction_prepended_input_texts)
instruction_features = self.tokenize(instructions)
input_features = Instructor.prepare_input_features(
input_features = self._first_module().prepare_input_features(
input_features, instruction_features
)
batched_input_features.append(input_features)
Expand Down