-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented Coca architecture #2371
Open
VarunS1997
wants to merge
13
commits into
master
Choose a base branch
from
model-impl/CoCa
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
20cdf41
Implemented Coca architecture
VarunS1997 b8c0ba4
Minor clean-up
VarunS1997 bbe17c4
Fixed depth of decoders
VarunS1997 202526f
Updated config to match args
VarunS1997 367dd39
Moved layer definitions to build and added build calls for each layer
VarunS1997 80ea7d3
Unabbreviated 'contrastive' and 'captioning'
VarunS1997 3feacb6
Improved documentation and added output sizing to call(), also built …
VarunS1997 f15408f
Lowercased coca model directory and added to kokoro build
VarunS1997 960873f
Addressed comments by Matt; reformatted as well
VarunS1997 33cff54
Addressed comments related to attn pooling size, attn pooling name
VarunS1997 145d7b5
Wrote a test for coca saving and loading, which prompted some model c…
VarunS1997 e8623a9
Updated to functional model
VarunS1997 c9e1ec1
added size inputs for functional model
VarunS1997 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_cv.models.feature_extractor.coca.coca_model import CoCa | ||
from keras_cv.models.feature_extractor.coca.coca_layers import CoCaAttentionPooling |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from keras import layers | ||
|
||
|
||
class CoCaAttentionPooling(layers.Layer): | ||
"""Implements the Pooled Attention Layer used in "coca": Contrastive Captioners are Image-Text Foundation Models" | ||
(https://arxiv.org/pdf/2205.01917.pdf), consisting of a Multiheaded Attention followed by Layer Normalization. | ||
|
||
Args: | ||
head_dim: The dimensions of the attention heads | ||
num_heads: The number of attention heads in the multi-headed attention layer | ||
""" | ||
|
||
def __init__(self, head_dim, num_heads, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.head_dim = head_dim | ||
self.num_heads = num_heads | ||
|
||
self.multi_head_attn = layers.MultiHeadAttention( | ||
self.num_heads, self.head_dim | ||
) | ||
|
||
self.layer_norm = layers.LayerNormalization() | ||
|
||
def build(self, input_shape): | ||
# super().build(input_shape) | ||
|
||
if(len(input_shape) < 2): | ||
raise ValueError("Building CoCa Attention Pooling requires input shape of shape (query_shape, value_shape)") | ||
|
||
query_shape = input_shape[0] | ||
value_shape = input_shape[1] | ||
|
||
self.multi_head_attn._build_from_signature(query_shape, value_shape) | ||
self.layer_norm.build(query_shape) | ||
|
||
def call(self, query, value): | ||
x = self.multi_head_attn(query, value) | ||
return self.layer_norm(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
# Copyright 2024 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import keras | ||
from keras import Sequential | ||
from keras_nlp.layers import RotaryEmbedding | ||
from keras_nlp.layers import TransformerDecoder | ||
|
||
from keras_cv.api_export import keras_cv_export | ||
from keras_cv.backend import ops | ||
from keras_cv.layers import TransformerEncoder as CVTransformerEncoder | ||
from keras_cv.models.feature_extractor.coca.coca_layers import CoCaAttentionPooling | ||
from keras_cv.layers.vit_layers import PatchingAndEmbedding | ||
from keras_cv.models.task import Task | ||
|
||
|
||
@keras_cv_export(["keras_cv.models.coca"]) | ||
class CoCa(Task): | ||
"""Contrastive Captioner foundational model implementation. | ||
|
||
This model implements the "Contrastive Captioners are image-Text Foundational Models" by Yu, et al. | ||
(https://arxiv.org/pdf/2205.01917.pdf). In short, the coca model combines the ideas of Contrastive techniques | ||
such as CLIP, with Generative Captioning approaches such as SimVLM. | ||
|
||
The architecture of clip can be described as an Image Visual Transformer Encoder in parallel to self-attention-only | ||
Text Transformer Decoder, the outputs of both of which are passed into a multimodal Transformer Decoder. The | ||
contrastive loss from the ViT and the uni-modal Text Decoder is combined with a captioning loss from the multi-modal | ||
Decoder in order to produce the combined total loss. | ||
|
||
Basic Usage: | ||
```python | ||
|
||
images = ... # [batch_size, height, width, channel] | ||
text = ... # [batch_size, text_dim, sequence_length] | ||
|
||
coca = coca() | ||
|
||
# [batch_size, sequence_length, captioning_query_length] | ||
output = coca(images, text) | ||
``` | ||
|
||
All default arguments should be consistent with the original paper's details. | ||
|
||
Args: | ||
img_shape: The shape of a single image, typically expressed as [height, weight, channels] | ||
caption_shape: The shape of a single caption, typically expressed as [sequence_length, text_dim] | ||
img_patch_size: N of each NxN patch generated from linearization of the input images | ||
encoder_depth: number of image encoder blocks | ||
encoder_heads: number of attention heads used in each image encoder block | ||
encoder_intermediate_dim: dimensionality of the encoder blocks' intermediate representation (MLP dimensionality) | ||
encoder_width: dimensionality of the encoder's projection, consistent with wording used in coca paper. | ||
unimodal_decoder_depth: number of decoder blocks used for text self-attention/embedding | ||
multimodal_decoder_depth: number of decoder blocks used for image-text cross-attention and captioning | ||
decoder_intermediate_dim: dimensionality of the decoder blocks' MLPs | ||
unimodal_decoder_heads: number of attention heads in the unimodal decoder | ||
multimodal_decoder_heads: number of attention heads in the multimodal decoder | ||
contrastive_query_length: number of tokens to use to represent contrastive query | ||
captioning_query_length: number of tokens to use to represent captioning query | ||
contrastive_attn_heads: number of attention heads used for the contrastive attention pooling | ||
captioning_attn_heads: number of attention heads used for the captioning attention pooling | ||
contrastive_loss_weight: weighting of contrastive loss | ||
captioning_loss_weight: weighting of captioning loss | ||
""" | ||
|
||
def __init__( | ||
self, | ||
img_shape=(512, 512, 3), | ||
caption_shape = (10, 48), | ||
img_patch_size=18, | ||
encoder_depth=40, | ||
encoder_heads=16, | ||
encoder_intermediate_dim=6144, | ||
encoder_width=1408, | ||
unimodal_decoder_depth=18, | ||
multimodal_decoder_depth=18, | ||
decoder_intermediate_dim=5632, | ||
unimodal_decoder_heads=16, | ||
multimodal_decoder_heads=16, | ||
contrastive_query_length=1, | ||
captioning_query_length=256, | ||
contrastive_attn_heads=16, | ||
captioning_attn_heads=16, | ||
contrastive_loss_weight=0.5, | ||
captioning_loss_weight=0.5, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
|
||
# | ||
# Save Details | ||
# | ||
self.img_shape = img_shape | ||
self.caption_shape = caption_shape | ||
|
||
self.img_patch_size = img_patch_size | ||
|
||
self.encoder_depth = encoder_depth | ||
self.encoder_heads = encoder_heads | ||
self.encoder_width = encoder_width | ||
self.encoder_intermediate_dim = encoder_intermediate_dim | ||
|
||
self.unimodal_decoder_depth = unimodal_decoder_depth | ||
self.multimodal_decoder_depth = multimodal_decoder_depth | ||
self.decoder_intermediate_dim = decoder_intermediate_dim | ||
self.unimodal_decoder_heads = unimodal_decoder_heads | ||
self.multimodal_decoder_heads = multimodal_decoder_heads | ||
|
||
self.contrastive_query_length = contrastive_query_length | ||
self.contrastive_attn_heads = contrastive_attn_heads | ||
self.contrastive_loss_weight = contrastive_loss_weight | ||
|
||
self.captioning_query_length = captioning_query_length | ||
self.captioning_attn_heads = captioning_attn_heads | ||
self.captioning_loss_weight = captioning_loss_weight | ||
|
||
# | ||
# Layer Definitions | ||
# | ||
self.image_patching = PatchingAndEmbedding( | ||
self.encoder_width, self.img_patch_size | ||
) | ||
self.image_encoder = Sequential( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sequential might not work, the model will not build properly. |
||
[ | ||
CVTransformerEncoder( | ||
self.encoder_width, | ||
self.encoder_heads, | ||
self.encoder_intermediate_dim, | ||
) | ||
for _ in range(self.encoder_depth) | ||
] | ||
) | ||
|
||
self.text_embedding = RotaryEmbedding() | ||
self.unimodal_text_decoder = Sequential( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, sequential might not work well. Please double check. |
||
[ | ||
TransformerDecoder( | ||
self.decoder_intermediate_dim, self.unimodal_decoder_heads | ||
) | ||
for _ in range(self.unimodal_decoder_depth) | ||
] | ||
) | ||
self.multimodal_text_decoders = [ | ||
TransformerDecoder( | ||
self.decoder_intermediate_dim, self.multimodal_decoder_heads | ||
) | ||
for _ in range(self.multimodal_decoder_depth) | ||
] | ||
|
||
self.contrastive_attn_pooling = CoCaAttentionPooling( | ||
self.encoder_width, self.contrastive_attn_heads | ||
) | ||
self.captioning_attn_pooling = CoCaAttentionPooling( | ||
self.encoder_width, self.captioning_attn_heads | ||
) | ||
|
||
# These are learnable weights defined in build as per Keras recommendations | ||
self.contrastive_query = None | ||
self.captioning_query = None | ||
|
||
# | ||
# Functional Model | ||
# | ||
images = keras.Input(shape=self.img_shape, name="images") | ||
captions = keras.Input(shape=self.caption_shape, name="caption") | ||
|
||
img_encoding = self.image_patching( | ||
images | ||
) # [batch_size, img_patches_len+1, encoder_width] | ||
img_encoding = self.image_encoder( | ||
img_encoding | ||
) # [batch_size, img_patches_len+1, encoder_width] | ||
|
||
# Learnable Weights | ||
self.contrastive_query = self.add_weight( | ||
shape=( | ||
None, | ||
self.contrastive_query_length, | ||
self.encoder_width, | ||
), | ||
trainable=True, | ||
) | ||
self.captioning_query = self.add_weight( | ||
shape=( | ||
None, | ||
self.captioning_query_length, | ||
self.encoder_width, | ||
), | ||
trainable=True, | ||
) | ||
|
||
# This is for contrastive loss; [batch_size, contrastive_query_length, encoder_width] | ||
contrastive_feature = self.con_attn_pooling(self.contrastive_query, img_encoding) | ||
|
||
# [batch_size, captioning_query_length, encoder_width] | ||
captioning_feature = self.captioning_attn_pooling( | ||
self.captioning_query, img_encoding | ||
) | ||
|
||
# Learnable CLs Token | ||
self.cls_token = self.add_weight( | ||
shape=(None, 1, self.caption_shape[-1]), name="cls_token", trainable=True | ||
) | ||
|
||
# [batch_size, sequence_length+1, text_dim] | ||
text_tokens = ops.concatenate(captions, self.cls_token) | ||
mask = ops.concatenate( | ||
(ops.ones_like(captions), ops.zeros_like(self.cls_token)) | ||
) | ||
|
||
# [batch_size, sequence_length+1, text_dim] | ||
embed_text = self.text_embedding(text_tokens) | ||
unimodal_out = self.unimodal_text_decoder( | ||
embed_text, attention_mask=mask | ||
) | ||
|
||
# [batch_size, sequence_length, captioning_query_length], notice we remove the CLs token | ||
multimodal_out = unimodal_out[:, :-1, :] | ||
for decoder in self.multimodal_text_decoders: | ||
multimodal_out = decoder( | ||
multimodal_out, | ||
encoder_sequence=captioning_feature, | ||
decoder_attention_mask=mask | ||
) | ||
|
||
super().__init__( | ||
inputs={ | ||
"images": images, | ||
"captions": captions, | ||
}, | ||
outputs={ | ||
"multimodal_out": multimodal_out, | ||
"contrastive_feature": contrastive_feature | ||
}, | ||
) | ||
|
||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"img_shape": self.img_shape, | ||
"caption_shape": self.caption_shape, | ||
"img_patch_size": self.img_patch_size, | ||
"encoder_depth": self.encoder_depth, | ||
"encoder_heads": self.encoder_heads, | ||
"encoder_width": self.encoder_width, | ||
"encoder_intermediate_dim": self.encoder_intermediate_dim, | ||
"unimodal_decoder_depth": self.unimodal_decoder_depth, | ||
"multimodal_decoder_depth": self.multimodal_decoder_depth, | ||
"decoder_intermediate_dim": self.decoder_intermediate_dim, | ||
"unimodal_decoder_heads": self.unimodal_decoder_heads, | ||
"multimodal_decoder_heads": self.multimodal_decoder_heads, | ||
"contrastive_query_length": self.contrastive_query_length, | ||
"contrastive_attn_heads": self.contrastive_attn_heads, | ||
"contrastive_loss_weight": self.contrastive_loss_weight, | ||
"captioning_query_length": self.captioning_query_length, | ||
"captioning_attn_heads": self.captioning_attn_heads, | ||
"captioning_loss_weight": self.captioning_loss_weight, | ||
} | ||
) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
return cls(**config) | ||
|
||
def load_own_variables(self, store): | ||
print(store) | ||
super().load_own_variables(store) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import keras.saving | ||
import numpy as np | ||
import pytest | ||
import os | ||
|
||
from keras_cv.models.feature_extractor.coca import CoCa | ||
from keras_cv.tests.test_case import TestCase | ||
|
||
class CoCaTest(TestCase): | ||
|
||
@pytest.mark.large | ||
def test_coca_model_save(self): | ||
# TODO: Transformer encoder breaks if you have project dim < num heads | ||
model = CoCa() | ||
|
||
save_path = os.path.join(self.get_temp_dir(), "coca.keras") | ||
model.save(save_path) | ||
|
||
restored_model = keras.models.load_model(save_path, custom_objects={"CoCa": CoCa}) | ||
|
||
self.assertIsInstance(restored_model, CoCa) | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has to be changed to
Example:
since we follow onlyExample
orExamples:
as a standard format.