Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the SAM end-to-end model predict test #2385

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import keras
import tensorflow as tf
import tree

if hasattr(keras, "src"):
keras_backend = keras.src.backend
else:
keras_backend = keras.backend

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import config
Expand Down
109 changes: 60 additions & 49 deletions keras_cv/models/segmentation/segment_anything/sam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,55 +233,66 @@ def test_end_to_end_model_predict(self, dtype_policy):
with threading.Lock():
# We are changing the global dtype policy here but don't want any
# other tests to use that policy, so compute under a lock until
# we reset the global policy.
old_policy = getattr(
keras.mixed_precision, "dtype_policy", lambda: "float32"
)()
keras.mixed_precision.set_global_policy(dtype_policy)
model = SegmentAnythingModel(
backbone=self.image_encoder,
prompt_encoder=self.prompt_encoder,
mask_decoder=self.mask_decoder,
)

# We use box-only prompting for this test.
mask_prompts = self.get_prompts(1, "boxes")
inputs = {
"images": np.ones((1, 1024, 1024, 3)),
}
inputs.update(mask_prompts)

# Check the number of parameters
num_parameters = np.sum([np.prod(x.shape) for x in model.weights])
self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340)

# Forward pass through the model
outputs = model.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]

# Check the output is equal to the one we expect if we
# run each component separately. This is to confirm that
# the graph is getting compiled correctly i.e. the jitted
# execution is equivalent to the eager execution.
features = self.image_encoder(inputs["images"])
outputs_ex = self.prompt_encoder(
{k: v for k, v in inputs.items() if k != "images"}
)
outputs_ex = self.mask_decoder(
{
"image_embeddings": features,
"image_pe": outputs_ex["dense_positional_embeddings"],
"sparse_prompt_embeddings": outputs_ex["sparse_embeddings"],
"dense_prompt_embeddings": outputs_ex["dense_embeddings"],
},
)
masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"]

self.assertAllClose(masks, masks_ex, atol=1e-4)
self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4)

# Reset the global policy
keras.mixed_precision.set_global_policy(old_policy)
# we reset the global policy. We also want to make sure even if
# the test fails, other tests remain unaffected.
try:
old_policy = getattr(
keras.mixed_precision, "dtype_policy", lambda: "float32"
)()
keras.mixed_precision.set_global_policy(dtype_policy)
model = SegmentAnythingModel(
backbone=self.image_encoder,
prompt_encoder=self.prompt_encoder,
mask_decoder=self.mask_decoder,
)

# We use box-only prompting for this test.
mask_prompts = self.get_prompts(1, "boxes")
inputs = {
"images": np.ones((1, 1024, 1024, 3)),
}
inputs.update(mask_prompts)

# Check the number of parameters
num_parameters = np.sum(
[np.prod(x.shape) for x in model.weights]
)
self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340)

# Forward pass through the model
outputs = model.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]

# Check the output is equal to the one we expect if we
# run each component separately. This is to confirm that
# the graph is getting compiled correctly i.e. the jitted
# execution is equivalent to the eager execution.
features = self.image_encoder(inputs["images"])
outputs_ex = self.prompt_encoder(
{k: v for k, v in inputs.items() if k != "images"}
)
outputs_ex = self.mask_decoder(
{
"image_embeddings": features,
"image_pe": outputs_ex["dense_positional_embeddings"],
"sparse_prompt_embeddings": outputs_ex[
"sparse_embeddings"
],
"dense_prompt_embeddings": outputs_ex[
"dense_embeddings"
],
},
)
masks_ex, iou_pred_ex = (
outputs_ex["masks"],
outputs_ex["iou_pred"],
)

self.assertAllClose(masks, masks_ex, atol=1e-4)
self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4)
finally:
# Reset the global policy
keras.mixed_precision.set_global_policy(old_policy)

@pytest.mark.extra_large
def test_end_to_end_model_save(self):
Expand Down
Loading