diff --git a/pybind11_protobuf/BUILD b/pybind11_protobuf/BUILD index 4a6d743..eb19832 100644 --- a/pybind11_protobuf/BUILD +++ b/pybind11_protobuf/BUILD @@ -87,9 +87,6 @@ cc_library( name = "check_unknown_fields", srcs = ["check_unknown_fields.cc"], hdrs = ["check_unknown_fields.h"], - visibility = [ - "//visibility:private", - ], deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/pybind11_protobuf/check_unknown_fields.cc b/pybind11_protobuf/check_unknown_fields.cc index bb67d69..0639d09 100644 --- a/pybind11_protobuf/check_unknown_fields.cc +++ b/pybind11_protobuf/check_unknown_fields.cc @@ -183,7 +183,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* message, bool build_error_message_if_any) { + const ::google::protobuf::Message* message) { const auto* root_descriptor = message->GetDescriptor(); HasUnknownFields search{py_proto_api, root_descriptor}; if (!search.FindUnknownFieldsRecursive(message, 0u)) { @@ -193,9 +193,6 @@ std::optional CheckRecursively( search.FieldFQN())) != 0) { return std::nullopt; } - if (!build_error_message_if_any) { - return ""; // This indicates that an unknown field was found. - } return search.BuildErrorMessage(); } diff --git a/pybind11_protobuf/check_unknown_fields.h b/pybind11_protobuf/check_unknown_fields.h index e37adc7..79ac001 100644 --- a/pybind11_protobuf/check_unknown_fields.h +++ b/pybind11_protobuf/check_unknown_fields.h @@ -47,7 +47,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* top_message, bool build_error_message_if_any); + const ::google::protobuf::Message* top_message); } // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/proto_cast_util.cc b/pybind11_protobuf/proto_cast_util.cc index 0f7ba29..4b506d9 100644 --- a/pybind11_protobuf/proto_cast_util.cc +++ b/pybind11_protobuf/proto_cast_util.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -828,14 +829,18 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy, std::optional unknown_field_message = check_unknown_fields::CheckRecursively( - GlobalState::instance()->py_proto_api(), src, - check_unknown_fields::ExtensionsWithUnknownFieldsPolicy:: - UnknownFieldsAreDisallowed()); + GlobalState::instance()->py_proto_api(), src); if (unknown_field_message) { - if (!unknown_field_message->empty()) { + if (check_unknown_fields::ExtensionsWithUnknownFieldsPolicy:: + UnknownFieldsAreDisallowed()) { throw py::value_error(*unknown_field_message); } - // Fall back to serialize/parse. + // Emit one LOG(WARNING) per unique unknown_field_message: + static auto fall_back_log_shown = new std::unordered_set(); + if (fall_back_log_shown->insert(*unknown_field_message).second) { + LOG(WARNING) << "FALL BACK TO PROTOBUF SERIALIZE/PARSE: " + << *unknown_field_message; + } return GenericPyProtoCast(src, policy, parent, is_const); } diff --git a/pybind11_protobuf/tests/BUILD b/pybind11_protobuf/tests/BUILD index 41e3ea0..0ac166c 100644 --- a/pybind11_protobuf/tests/BUILD +++ b/pybind11_protobuf/tests/BUILD @@ -160,6 +160,16 @@ pybind_extension( ], ) +EXTENSION_TEST_DEPS_COMMON = [ + ":extension_in_other_file_in_deps_py_pb2", + ":extension_in_other_file_py_pb2", + ":extension_nest_repeated_py_pb2", + ":extension_py_pb2", + ":test_py_pb2", # fixdeps: keep - Direct dependency needed in open-source version, see https://github.com/grpc/grpc/issues/22811 + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", +] + py_test( name = "extension_test", srcs = ["extension_test.py"], @@ -169,14 +179,20 @@ py_test( ], python_version = "PY3", srcs_version = "PY3", - deps = [ - ":extension_in_other_file_in_deps_py_pb2", - ":extension_in_other_file_py_pb2", - ":extension_nest_repeated_py_pb2", - ":extension_py_pb2", - ":test_py_pb2", # fixdeps: keep - Direct dependency needed in open-source version, see https://github.com/grpc/grpc/issues/22811 - "@com_google_absl_py//absl/testing:absltest", - "@com_google_absl_py//absl/testing:parameterized", + deps = EXTENSION_TEST_DEPS_COMMON + ["@com_google_protobuf//:protobuf_python"], +) + +py_test( + name = "extension_disallow_unknown_fields_test", + srcs = ["extension_test.py"], + data = [ + ":extension_module.so", + ":proto_enum_module.so", + ], + main = "extension_test.py", + python_version = "PY3", + srcs_version = "PY3", + deps = EXTENSION_TEST_DEPS_COMMON + [ "@com_google_protobuf//:protobuf_python", ], ) diff --git a/pybind11_protobuf/tests/extension_module.cc b/pybind11_protobuf/tests/extension_module.cc index ee5f132..6f76976 100644 --- a/pybind11_protobuf/tests/extension_module.cc +++ b/pybind11_protobuf/tests/extension_module.cc @@ -48,6 +48,11 @@ void DefReserialize(py::module_& m, const char* py_name) { PYBIND11_MODULE(extension_module, m) { pybind11_protobuf::ImportNativeProtoCasters(); + m.def("extensions_with_unknown_fields_are_disallowed", []() { + return pybind11_protobuf::check_unknown_fields:: + ExtensionsWithUnknownFieldsPolicy::UnknownFieldsAreDisallowed(); + }); + m.def("get_base_message", []() -> BaseMessage { return {}; }); m.def( diff --git a/pybind11_protobuf/tests/extension_test.py b/pybind11_protobuf/tests/extension_test.py index 0340882..6c9e297 100644 --- a/pybind11_protobuf/tests/extension_test.py +++ b/pybind11_protobuf/tests/extension_test.py @@ -20,6 +20,13 @@ from pybind11_protobuf.tests import extension_pb2 +def unknown_field_exception_is_expected(): + return ( + api_implementation.Type() == 'cpp' + and m.extensions_with_unknown_fields_are_disallowed() + ) + + def get_py_message(value=5, in_other_file_in_deps_value=None, in_other_file_value=None): @@ -103,7 +110,7 @@ def test_extension_in_other_file_roundtrip(self): def test_reserialize_base_message(self): a = get_py_message(in_other_file_value=63) - if api_implementation.Type() == 'cpp': + if unknown_field_exception_is_expected(): with self.assertRaises(ValueError) as ctx: m.reserialize_base_message(a) self.assertStartsWith( @@ -127,7 +134,7 @@ def test_reserialize_nest_level2(self): a = extension_pb2.NestLevel2( nest_lvl1=extension_pb2.NestLevel1( base_msg=get_py_message(in_other_file_value=52))) - if api_implementation.Type() == 'cpp': + if unknown_field_exception_is_expected(): with self.assertRaises(ValueError) as ctx: m.reserialize_nest_level2(a) self.assertStartsWith( @@ -154,7 +161,7 @@ def test_reserialize_nest_repeated(self): get_py_message(in_other_file_value=74), get_py_message(in_other_file_value=85) ]) - if api_implementation.Type() == 'cpp': + if unknown_field_exception_is_expected(): with self.assertRaises(ValueError) as ctx: m.reserialize_nest_repeated(a) self.assertStartsWith(