diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py index 167da7ad0b..a1f8e521f1 100644 --- a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py @@ -12,15 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras import tensorflow as tf import tree -if hasattr(keras, "src"): - keras_backend = keras.src.backend -else: - keras_backend = keras.backend - from keras_cv import bounding_box from keras_cv.api_export import keras_cv_export from keras_cv.backend import config diff --git a/keras_cv/models/segmentation/segment_anything/sam_test.py b/keras_cv/models/segmentation/segment_anything/sam_test.py index 295355a716..8d5a163cb9 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_test.py +++ b/keras_cv/models/segmentation/segment_anything/sam_test.py @@ -233,55 +233,66 @@ def test_end_to_end_model_predict(self, dtype_policy): with threading.Lock(): # We are changing the global dtype policy here but don't want any # other tests to use that policy, so compute under a lock until - # we reset the global policy. - old_policy = getattr( - keras.mixed_precision, "dtype_policy", lambda: "float32" - )() - keras.mixed_precision.set_global_policy(dtype_policy) - model = SegmentAnythingModel( - backbone=self.image_encoder, - prompt_encoder=self.prompt_encoder, - mask_decoder=self.mask_decoder, - ) - - # We use box-only prompting for this test. - mask_prompts = self.get_prompts(1, "boxes") - inputs = { - "images": np.ones((1, 1024, 1024, 3)), - } - inputs.update(mask_prompts) - - # Check the number of parameters - num_parameters = np.sum([np.prod(x.shape) for x in model.weights]) - self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340) - - # Forward pass through the model - outputs = model.predict(inputs) - masks, iou_pred = outputs["masks"], outputs["iou_pred"] - - # Check the output is equal to the one we expect if we - # run each component separately. This is to confirm that - # the graph is getting compiled correctly i.e. the jitted - # execution is equivalent to the eager execution. - features = self.image_encoder(inputs["images"]) - outputs_ex = self.prompt_encoder( - {k: v for k, v in inputs.items() if k != "images"} - ) - outputs_ex = self.mask_decoder( - { - "image_embeddings": features, - "image_pe": outputs_ex["dense_positional_embeddings"], - "sparse_prompt_embeddings": outputs_ex["sparse_embeddings"], - "dense_prompt_embeddings": outputs_ex["dense_embeddings"], - }, - ) - masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"] - - self.assertAllClose(masks, masks_ex, atol=1e-4) - self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4) - - # Reset the global policy - keras.mixed_precision.set_global_policy(old_policy) + # we reset the global policy. We also want to make sure even if + # the test fails, other tests remain unaffected. + try: + old_policy = getattr( + keras.mixed_precision, "dtype_policy", lambda: "float32" + )() + keras.mixed_precision.set_global_policy(dtype_policy) + model = SegmentAnythingModel( + backbone=self.image_encoder, + prompt_encoder=self.prompt_encoder, + mask_decoder=self.mask_decoder, + ) + + # We use box-only prompting for this test. + mask_prompts = self.get_prompts(1, "boxes") + inputs = { + "images": np.ones((1, 1024, 1024, 3)), + } + inputs.update(mask_prompts) + + # Check the number of parameters + num_parameters = np.sum( + [np.prod(x.shape) for x in model.weights] + ) + self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340) + + # Forward pass through the model + outputs = model.predict(inputs) + masks, iou_pred = outputs["masks"], outputs["iou_pred"] + + # Check the output is equal to the one we expect if we + # run each component separately. This is to confirm that + # the graph is getting compiled correctly i.e. the jitted + # execution is equivalent to the eager execution. + features = self.image_encoder(inputs["images"]) + outputs_ex = self.prompt_encoder( + {k: v for k, v in inputs.items() if k != "images"} + ) + outputs_ex = self.mask_decoder( + { + "image_embeddings": features, + "image_pe": outputs_ex["dense_positional_embeddings"], + "sparse_prompt_embeddings": outputs_ex[ + "sparse_embeddings" + ], + "dense_prompt_embeddings": outputs_ex[ + "dense_embeddings" + ], + }, + ) + masks_ex, iou_pred_ex = ( + outputs_ex["masks"], + outputs_ex["iou_pred"], + ) + + self.assertAllClose(masks, masks_ex, atol=1e-4) + self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4) + finally: + # Reset the global policy + keras.mixed_precision.set_global_policy(old_policy) @pytest.mark.extra_large def test_end_to_end_model_save(self):