diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 039453af7..02ab91c38 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -133,11 +133,15 @@ def _encode_image(self, images, normalize=True): image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, tokens_embs - def _encode_text(self, text, normalize=True, embed_cls=True): + def _encode_text(self, text, normalize=True, embed_cls=True, cache=None): text = text[:, :-1] if embed_cls else text # make space for CLS token - text_latent, token_emb = self.text(text) + if cache is not None: + text_latent, token_emb, attentions = self.text(text, cache) + else: + text_latent, token_emb = self.text(text) + attentions = None text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - return text_latent, token_emb + return text_latent, token_emb, attentions def encode_image(self, images, normalize=True): image_latent, _ = self._encode_image(images, normalize=normalize) @@ -147,19 +151,22 @@ def encode_text(self, text, normalize=True, embed_cls=True): text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) return text_latent - def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): - text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None, cache=None): + #TODO: Fix encoder caching + text_latent, token_embs, text_attentions = self._encode_text(text, embed_cls=embed_cls, cache=None) if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] - - logits = self.text_decoder(image_embs, token_embs) + logits, attentions, cross_attentions = self.text_decoder(image_embs, token_embs, cache=cache["dec"]) return { "image_features": image_latent, "text_features": text_latent, "logits": logits, + "text_attentions": text_attentions, + "attentions": attentions, + "cross_attentions": cross_attentions, "labels": labels, "logit_scale": self.logit_scale.exp() } @@ -182,7 +189,8 @@ def generate( min_seq_len=5, stopping_criteria=None, repetition_penalty=1.0, - fixed_output_length=False # if True output.shape == (batch_size, seq_len) + fixed_output_length=False, # if True output.shape == (batch_size, seq_len) + caching=False, # cache previously computed attentions ): # taking many ideas and components from HuggingFace GenerationMixin # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation @@ -220,6 +228,7 @@ def generate( min_seq_len=min_seq_len, stopping_criteria=stopping_criteria, logit_processor=logit_processor, + caching=caching ) if fixed_output_length and output.shape[1] < seq_len: return torch.cat( @@ -252,11 +261,24 @@ def generate( cur_len = text.shape[1] self.eval() out = text + + if caching: + cache = {"enc": [], "dec": {"self": [], "cross": []}} + else: + cache = {"enc": None, "dec": {"self": None, "cross": None}} while True: x = out[:, -max_seq_len:] cur_len = x.shape[1] - logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + + outputs = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False, cache=cache) + + if caching: + cache["enc"] = outputs["text_attentions"] + cache["dec"]["self"] = outputs["attentions"] + cache["dec"]["cross"] = outputs["cross_attentions"] + + logits = outputs["logits"][:, -1] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id @@ -299,6 +321,7 @@ def _generate_beamsearch( stopping_criteria=None, logit_processor=None, logit_warper=None, + caching=True, ): device = image_inputs.device batch_size = image_inputs.shape[0] @@ -338,6 +361,11 @@ def _generate_beamsearch( beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) + if caching: + cache = {"enc": [], "dec": {"self": [], "cross": []}} + else: + cache = {"enc": None, "dec": {"self": None, "cross": None}} + while True: # predicted tokens in cur_len step @@ -353,9 +381,15 @@ def _generate_beamsearch( model_inputs['text'], embed_cls=False, image_latent=image_latent, - image_embs=image_embs + image_embs=image_embs, + cache=cache ) + if caching: + cache["enc"] = outputs["text_attentions"] + cache["dec"]["self"] = outputs["attentions"] + cache["dec"]["cross"] = outputs["cross_attentions"] + for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 4e0151017..ef1daf1a7 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -1,6 +1,6 @@ from collections import OrderedDict import math -from typing import Callable, Optional, Sequence, Tuple +from typing import Callable, Optional, Sequence, Tuple, List, Union import torch from torch import nn @@ -220,28 +220,47 @@ def attention( k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, + cache: Optional[torch.Tensor] = None, ): + k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x + if cache is not None: + q_x = q_x[-1:] + if attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask[-1:] + elif attn_mask.dim() == 3: + attn_mask = attn_mask[:,-1:] + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None - return self.attn( + + out = self.attn( q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask )[0] + if cache is not None: + out = torch.cat((cache, out), dim=0) + + return out + def forward( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, + cache: Optional[torch.Tensor] = None, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None - - x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + + # print(q_x) + attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask, cache=cache) + x = q_x + self.ls_1(attn) x = x + self.ls_2(self.mlp(self.ln_2(x))) - return x + return x, attn class CustomResidualAttentionBlock(nn.Module): @@ -310,13 +329,21 @@ def __init__( def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - for r in self.resblocks: + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, cache: Optional[List[torch.Tensor]] = None): + attentions = [] + for i, r, in enumerate(self.resblocks): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 - x = checkpoint(r, x, None, None, attn_mask) + x, _ = checkpoint(r, x, None, None, attn_mask, None) + elif cache is None or len(cache) != len(self.resblocks): + x, _ = r(x, attn_mask=attn_mask, cache=None) else: - x = r(x, attn_mask=attn_mask) + x, attn = r(x, attn_mask=attn_mask, cache=cache[i]) + attentions.append(attn) + + if cache is not None: + return x, attentions + return x @@ -497,7 +524,7 @@ def forward(self, x: torch.Tensor): if self.output_tokens: return pooled, tokens - + return pooled @@ -594,7 +621,7 @@ def build_cls_mask(self, text, cast_dtype: torch.dtype): def _repeat(self, t, N: int): return t.reshape(1, 1, -1).repeat(N, 1, 1) - def forward(self, text): + def forward(self, text, cache=None): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] @@ -608,7 +635,10 @@ def forward(self, text): x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=attn_mask) + if cache is not None: + x, attentions = self.transformer(x, attn_mask=attn_mask, cache=cache) + else: + x = self.transformer(x, attn_mask=attn_mask) x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] @@ -623,8 +653,12 @@ def forward(self, text): if self.text_projection is not None: pooled = pooled @ self.text_projection - if self.output_tokens: + if self.output_tokens and cache is None: return pooled, tokens + elif self.output_tokens and cache is not None: + return pooled, tokens, attentions + elif cache is not None: + return pooled, attentions return pooled @@ -697,25 +731,42 @@ def build_attention_mask(self): mask.triu_(1) # zero out the lower diagonal return mask - def forward(self, image_embs, text_embs): + def forward(self, image_embs, text_embs, cache=None): text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq image_embs = image_embs.permute(1, 0, 2) # NLD -> LND seq_len = text_embs.shape[0] + attentions = [] + cross_attentions = [] + + valid_cache = True + if cache is None or cache["self"] is None or cache["self"] == []: + valid_cache = False - for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + for i, (resblock, cross_attn) in enumerate(zip(self.resblocks, self.cross_attn)): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) else: - text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) - text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + # Could allow for only caching one or the other, but unsure when that + # would be beneficial over the alternatives + if not valid_cache: + text_embs, attn = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len], cache=None) + text_embs, x_attn = cross_attn(text_embs, k_x=image_embs, v_x=image_embs, cache=None) + else: + text_embs, attn = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len], cache=cache["self"][i]) + text_embs, x_attn = cross_attn(text_embs, k_x=image_embs, v_x=image_embs, cache=cache["cross"][i]) + attentions.append(attn) + cross_attentions.append(x_attn) x = text_embs.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) if self.text_projection is not None: x = x @ self.text_projection + + if cache is not None: + return x, attentions, cross_attentions return x