From b713501f1da56d9b76c42f89efd00b97c26c9eac Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 29 Nov 2023 10:57:30 -0800 Subject: [PATCH 1/4] Update pybind11_bazel to latest git commit. PiperOrigin-RevId: 586393896 --- WORKSPACE | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 49b6a07..810c71a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -39,9 +39,9 @@ http_archive( # https://github.com/pybind/pybind11_bazel http_archive( name = "pybind11_bazel", - strip_prefix = "pybind11_bazel-ff261d2e9190955d0830040b20ea59ab9dbe66c8", - sha256 = "c68230f540ae99e6acdec9a79f351d003e2dccefa029c3ce8d25060a6e05dc43", - urls = ["https://github.com/pybind/pybind11_bazel/archive/ff261d2e9190955d0830040b20ea59ab9dbe66c8.tar.gz"], + strip_prefix = "pybind11_bazel-23926b00e2b2eb2fc46b17e587cf0c0cfd2f2c4b", + sha256 = "f58c0d5bfd125b08075224c319a02a901c3bce11ff2cf8310c024d40f4af823e", + urls = ["https://github.com/pybind/pybind11_bazel/archive/23926b00e2b2eb2fc46b17e587cf0c0cfd2f2c4b.tar.gz"], ) ## `pybind11` (FLOATING) From 8359a091a9b0bc7deb0233de986c06c885a3ff2d Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Mon, 11 Dec 2023 12:03:17 -0800 Subject: [PATCH 2/4] Remove two pool-membership conditions guarding the C++ equivalent of `obj.SerializePartialToString()` The main change in this CL is to remove two conditions in `PyProtoIsCompatible()`: 1. ``` if (descriptor->file()->pool() != DescriptorPool::generated_pool()) { ``` 2. ``` return py_pool->is(GlobalState::instance()->global_pool()); ``` Rationale for removing these conditions: * All that matters for protobuf compatibility is that the `full_name` is the same. (Thanks @kmoffett for that insight!) * Cross-extension-module ABI compatibility is not a concern because only the Python API is used in the relevant code paths serializing Python protobuf objects to Python `bytes` (equivalent to calling `obj.SerializePartialToString()` from Python). All other changes in this CL are secondary: small-scale refactoring, slight naming changes, additional tests for error conditions. PiperOrigin-RevId: 589898116 --- pybind11_protobuf/proto_cast_util.cc | 105 ++++++++++-------------- pybind11_protobuf/proto_cast_util.h | 22 +++-- pybind11_protobuf/proto_caster_impl.h | 26 ++++-- pybind11_protobuf/tests/pass_by_test.py | 28 +++++++ 4 files changed, 103 insertions(+), 78 deletions(-) diff --git a/pybind11_protobuf/proto_cast_util.cc b/pybind11_protobuf/proto_cast_util.cc index 1c8727f..a634429 100644 --- a/pybind11_protobuf/proto_cast_util.cc +++ b/pybind11_protobuf/proto_cast_util.cc @@ -21,6 +21,7 @@ #include "absl/strings/numbers.h" #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "pybind11_protobuf/check_unknown_fields.h" @@ -534,10 +535,8 @@ class PythonDescriptorPoolWrapper { } } - py::object wire = py_file_descriptor.attr("serialized_pb"); - const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr()); - return output->ParsePartialFromArray(bytes, - PYBIND11_BYTES_SIZE(wire.ptr())); + return output->ParsePartialFromString( + PyBytesAsStringView(py_file_descriptor.attr("serialized_pb"))); } py::object pool_; // never dereferenced. @@ -549,6 +548,11 @@ class PythonDescriptorPoolWrapper { } // namespace +absl::string_view PyBytesAsStringView(py::bytes py_bytes) { + return absl::string_view(PyBytes_AsString(py_bytes.ptr()), + PyBytes_Size(py_bytes.ptr())); +} + void InitializePybindProtoCastUtil() { assert(PyGILState_Check()); GlobalState::instance(); @@ -593,7 +597,7 @@ const Message* PyProtoGetCppMessagePointer(py::handle src) { #endif } -absl::optional PyProtoDescriptorName(py::handle py_proto) { +absl::optional PyProtoDescriptorFullName(py::handle py_proto) { assert(PyGILState_Check()); auto py_full_name = ResolveAttrs(py_proto, {"DESCRIPTOR", "full_name"}); if (py_full_name) { @@ -602,66 +606,42 @@ absl::optional PyProtoDescriptorName(py::handle py_proto) { return absl::nullopt; } -bool PyProtoIsCompatible(py::handle py_proto, const Descriptor* descriptor) { - assert(PyGILState_Check()); - if (descriptor->file()->pool() != DescriptorPool::generated_pool()) { - /// This indicates that the C++ descriptor does not come from the C++ - /// DescriptorPool. This may happen if the C++ code has the same proto - /// in different descriptor pools, perhaps from different shared objects, - /// and could be result in undefined behavior. - return false; - } - - auto py_descriptor = ResolveAttrs(py_proto, {"DESCRIPTOR"}); - if (!py_descriptor) { - // Not a valid protobuf -- missing DESCRIPTOR. - return false; - } - - // Test full_name equivalence. - { - auto py_full_name = ResolveAttrs(*py_descriptor, {"full_name"}); - if (!py_full_name) { - // Not a valid protobuf -- missing DESCRIPTOR.full_name - return false; - } - auto full_name = CastToOptionalString(*py_full_name); - if (!full_name || *full_name != descriptor->full_name()) { - // Name mismatch. - return false; - } - } - - // The C++ descriptor is compiled in (see above assert), so the py_proto - // is expected to be from the global pool, i.e. the DESCRIPTOR.file.pool - // instance is the global python pool, and not a custom pool. - auto py_pool = ResolveAttrs(*py_descriptor, {"file", "pool"}); - if (py_pool) { - return py_pool->is(GlobalState::instance()->global_pool()); - } - - // The py_proto is missing a DESCRIPTOR.file.pool, but the name matches. - // This will not happen with a native python implementation, but does - // occur with the deprecated :proto_casters, and could happen with other - // mocks. Returning true allows the caster to call PyProtoCopyToCProto. - return true; +bool PyProtoHasMatchingFullName(py::handle py_proto, + const Descriptor* descriptor) { + auto full_name = PyProtoDescriptorFullName(py_proto); + return full_name && *full_name == descriptor->full_name(); } -bool PyProtoCopyToCProto(py::handle py_proto, Message* message) { - assert(PyGILState_Check()); - auto serialize_fn = ResolveAttrMRO(py_proto, "SerializePartialToString"); +py::bytes PyProtoSerializePartialToString(py::handle py_proto, + bool raise_if_error) { + static const char* serialize_fn_name = "SerializePartialToString"; + auto serialize_fn = ResolveAttrMRO(py_proto, serialize_fn_name); if (!serialize_fn) { - throw py::type_error( - "SerializePartialToString method not found; is this a " + - message->GetDescriptor()->full_name()); - } - auto wire = (*serialize_fn)(); - const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr()); - if (!bytes) { - throw py::type_error("SerializePartialToString failed; is this a " + - message->GetDescriptor()->full_name()); + return py::object(); + } + auto serialized_bytes = py::reinterpret_steal( + PyObject_CallObject(serialize_fn->ptr(), nullptr)); + if (!serialized_bytes) { + if (raise_if_error) { + std::string msg = py::repr(py_proto).cast() + "." + + serialize_fn_name + "() function call FAILED"; + py::raise_from(PyExc_TypeError, msg.c_str()); + throw py::error_already_set(); + } + return py::object(); + } + if (!PyBytes_Check(serialized_bytes.ptr())) { + if (raise_if_error) { + std::string msg = py::repr(py_proto).cast() + "." + + serialize_fn_name + + "() function call is expected to return bytes, but the " + "returned value is " + + py::repr(serialized_bytes).cast(); + throw py::type_error(msg); + } + return py::object(); } - return message->ParsePartialFromArray(bytes, PYBIND11_BYTES_SIZE(wire.ptr())); + return serialized_bytes; } void CProtoCopyToPyProto(Message* message, py::handle py_proto) { @@ -686,7 +666,8 @@ std::unique_ptr AllocateCProtoFromPythonSymbolDatabase( assert(PyGILState_Check()); auto pool = ResolveAttrs(src, {"DESCRIPTOR", "file", "pool"}); if (!pool) { - throw py::type_error("Object is not a valid protobuf"); + throw py::type_error(py::repr(src).cast() + + " object is not a valid protobuf"); } auto pool_data = diff --git a/pybind11_protobuf/proto_cast_util.h b/pybind11_protobuf/proto_cast_util.h index 4821e94..954115e 100644 --- a/pybind11_protobuf/proto_cast_util.h +++ b/pybind11_protobuf/proto_cast_util.h @@ -14,6 +14,7 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" // PYBIND11_PROTOBUF_ASSUME_FULL_ABI_COMPATIBILITY can be defined by users @@ -28,6 +29,10 @@ namespace pybind11_protobuf { +// Simple helper. Caller has to ensure that the py_bytes argument outlives the +// returned string_view. +absl::string_view PyBytesAsStringView(pybind11::bytes py_bytes); + // Initialize internal proto cast dependencies, which includes importing // various protobuf-related modules. void InitializePybindProtoCastUtil(); @@ -39,14 +44,16 @@ void ImportProtoDescriptorModule(const ::google::protobuf::Descriptor *); const ::google::protobuf::Message *PyProtoGetCppMessagePointer(pybind11::handle src); // Returns the protocol buffer's py_proto.DESCRIPTOR.full_name attribute. -absl::optional PyProtoDescriptorName(pybind11::handle py_proto); +absl::optional PyProtoDescriptorFullName( + pybind11::handle py_proto); + +// Returns true if py_proto full name matches descriptor full name. +bool PyProtoHasMatchingFullName(pybind11::handle py_proto, + const ::google::protobuf::Descriptor *descriptor); -// Return whether py_proto is compatible with the C++ descriptor. -// The py_proto name must match the C++ Descriptor::full_name(), and is -// expected to originate from the python default pool, which means that -// this method will return false for dynamic protos. -bool PyProtoIsCompatible(pybind11::handle py_proto, - const ::google::protobuf::Descriptor *descriptor); +// Caller should enforce any type identity that is required. +pybind11::bytes PyProtoSerializePartialToString(pybind11::handle py_proto, + bool raise_if_error); // Allocates a C++ protocol buffer for a given name. std::unique_ptr<::google::protobuf::Message> AllocateCProtoFromPythonSymbolDatabase( @@ -54,7 +61,6 @@ std::unique_ptr<::google::protobuf::Message> AllocateCProtoFromPythonSymbolDatab // Serialize the py_proto and deserialize it into the provided message. // Caller should enforce any type identity that is required. -bool PyProtoCopyToCProto(pybind11::handle py_proto, ::google::protobuf::Message *message); void CProtoCopyToPyProto(::google::protobuf::Message *message, pybind11::handle py_proto); // Returns a handle to a python protobuf suitably diff --git a/pybind11_protobuf/proto_caster_impl.h b/pybind11_protobuf/proto_caster_impl.h index f54b175..50eb3bc 100644 --- a/pybind11_protobuf/proto_caster_impl.h +++ b/pybind11_protobuf/proto_caster_impl.h @@ -63,16 +63,19 @@ struct proto_caster_load_impl { } } - // The incoming object is not a compatible fast_cpp_proto, so check whether - // it is otherwise compatible, then serialize it and deserialize into a - // native C++ proto type. - if (!pybind11_protobuf::PyProtoIsCompatible(src, - ProtoType::GetDescriptor())) { + if (!PyProtoHasMatchingFullName(src, ProtoType::GetDescriptor())) { return false; } + pybind11::bytes serialized_bytes = + PyProtoSerializePartialToString(src, convert); + if (!serialized_bytes) { + return false; + } + owned = std::unique_ptr(new ProtoType()); value = owned.get(); - return pybind11_protobuf::PyProtoCopyToCProto(src, owned.get()); + return owned.get()->ParsePartialFromString( + PyBytesAsStringView(serialized_bytes)); } // ensure_owned ensures that the owned member contains a copy of the @@ -108,16 +111,23 @@ struct proto_caster_load_impl<::google::protobuf::Message> { // `src` is not a C++ proto instance from the generated_pool, // so create a compatible native C++ proto. - auto descriptor_name = pybind11_protobuf::PyProtoDescriptorName(src); + auto descriptor_name = pybind11_protobuf::PyProtoDescriptorFullName(src); if (!descriptor_name) { return false; } + pybind11::bytes serialized_bytes = + PyProtoSerializePartialToString(src, convert); + if (!serialized_bytes) { + return false; + } + owned.reset(static_cast( pybind11_protobuf::AllocateCProtoFromPythonSymbolDatabase( src, *descriptor_name) .release())); value = owned.get(); - return pybind11_protobuf::PyProtoCopyToCProto(src, owned.get()); + return owned.get()->ParsePartialFromString( + PyBytesAsStringView(serialized_bytes)); } // ensure_owned ensures that the owned member contains a copy of the diff --git a/pybind11_protobuf/tests/pass_by_test.py b/pybind11_protobuf/tests/pass_by_test.py index 66a75c9..56a69fc 100644 --- a/pybind11_protobuf/tests/pass_by_test.py +++ b/pybind11_protobuf/tests/pass_by_test.py @@ -170,6 +170,34 @@ def test_pass_fake2(self, check_method): def test_overload_fn(self, message_fn, expected): self.assertEqual(expected, m.fn_overload(message_fn())) + def test_bad_serialize_partial_function_calls(self): + class FakeDescr: + full_name = 'fake_full_name' + + class FakeProto: + DESCRIPTOR = FakeDescr() + + def __init__(self, serialize_fn_return_value=None): + self.serialize_fn_return_value = serialize_fn_return_value + + def SerializePartialToString(self): # pylint: disable=invalid-name + if self.serialize_fn_return_value is None: + raise RuntimeError('Broken serialize_fn.') + return self.serialize_fn_return_value + + with self.assertRaisesRegex( + TypeError, r'\.SerializePartialToString\(\) function call FAILED$' + ): + m.fn_overload(FakeProto()) + with self.assertRaisesRegex( + TypeError, + r'\.SerializePartialToString\(\) function call is expected to return' + r' bytes, but the returned value is \[\]$', + ): + m.fn_overload(FakeProto([])) + with self.assertRaisesRegex(TypeError, r' object is not a valid protobuf$'): + m.fn_overload(FakeProto(b'')) + if __name__ == '__main__': absltest.main() From 2f613cae594254d0328d1fad0dbaa357b308836e Mon Sep 17 00:00:00 2001 From: pybind11_protobuf authors Date: Mon, 1 Jan 2024 18:27:28 -0800 Subject: [PATCH 3/4] Fix library dependency list PiperOrigin-RevId: 594993469 --- pybind11_protobuf/BUILD | 3 +++ pybind11_protobuf/proto_cast_util.cc | 21 +++++++++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pybind11_protobuf/BUILD b/pybind11_protobuf/BUILD index 528db9e..a03fab4 100644 --- a/pybind11_protobuf/BUILD +++ b/pybind11_protobuf/BUILD @@ -59,6 +59,9 @@ pybind_library( deps = [ ":check_unknown_fields", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", diff --git a/pybind11_protobuf/proto_cast_util.cc b/pybind11_protobuf/proto_cast_util.cc index a634429..f05936a 100644 --- a/pybind11_protobuf/proto_cast_util.cc +++ b/pybind11_protobuf/proto_cast_util.cc @@ -9,28 +9,33 @@ #include #include #include -#include #include +#include #include "google/protobuf/descriptor.pb.h" -#include "google/protobuf/descriptor.h" -#include "google/protobuf/dynamic_message.h" -#include "google/protobuf/message.h" -#include "python/google/protobuf/proto_api.h" #include "absl/container/flat_hash_map.h" -#include "absl/strings/numbers.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/strings/str_replace.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor_database.h" +#include "google/protobuf/dynamic_message.h" +#include "python/google/protobuf/proto_api.h" #include "pybind11_protobuf/check_unknown_fields.h" +#if defined(GOOGLE_PROTOBUF_VERSION) +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#endif + namespace py = pybind11; using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorDatabase; using ::google::protobuf::DescriptorPool; -using ::google::protobuf::DescriptorProto; using ::google::protobuf::DynamicMessageFactory; using ::google::protobuf::FileDescriptor; using ::google::protobuf::FileDescriptorProto; From 3b11990a99dea5101799e61d98a82c4737d240cc Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 3 Jan 2024 18:43:22 -0800 Subject: [PATCH 4/4] Add `pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy`. PiperOrigin-RevId: 595557800 --- README.md | 48 +++++++++++-------- pybind11_protobuf/BUILD | 12 +++++ pybind11_protobuf/check_unknown_fields.cc | 7 ++- pybind11_protobuf/check_unknown_fields.h | 37 +++++++++++++- ...disallow_extensions_with_unknown_fields.cc | 12 +++++ pybind11_protobuf/proto_cast_util.cc | 18 ++++--- 6 files changed, 105 insertions(+), 29 deletions(-) create mode 100644 pybind11_protobuf/disallow_extensions_with_unknown_fields.cc diff --git a/README.md b/README.md index 91a376d..b84b35b 100644 --- a/README.md +++ b/README.md @@ -78,13 +78,18 @@ protobuf extensions are involved, a well-known pitfall is that extensions are silently moved to the `proto2::UnknownFieldSet` when a message is deserialized in C++, but the `cc_proto_library` for the extensions is not linked in. The root -cause is an asymmetry in the handling of Python protos vs C++ protos: when -a Python proto is deserialized, both the Python descriptor pool and the C++ -descriptor pool are inspected, but when a C++ proto is deserialized, only +cause is an asymmetry in the handling of Python protos vs C++ +protos: +when a Python proto is deserialized, both the Python descriptor pool and the +C++ descriptor pool are inspected, but when a C++ proto is deserialized, only the C++ descriptor pool is inspected. Until this asymmetry is resolved, the `cc_proto_library` for all extensions involved must be added to the `deps` of -the relevant `pybind_library` or `pybind_extension`, but this is sufficiently -unobvious to be a setup for regular accidents, potentially with critical +the relevant `pybind_library` or `pybind_extension`, or if this is impractial, +`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse` +or `pybind11_protobuf::AllowUnknownFieldsFor` can be used. + +The pitfall is sufficiently unobvious to be a setup for regular accidents, +potentially with critical consequences. To guard against the most common type of accident, native_proto_caster.h @@ -97,14 +102,20 @@ in certain situations: * and the `proto2::UnknownFieldSet` for the message or any of its submessages is not empty. -`pybind11_protobuf::AllowUnknownFieldsFor` is an escape hatch for situations in -which +`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse` +is a **global** escape hatch trading off convenience and runtime overhead: the +convenience is that it is not necessary to determine what `cc_proto_library` +dependencies need to be added, the runtime overhead is that +`SerializePartialToString`/`ParseFromString` is used for messages with unknown +fields, instead of the much faster `CopyFrom`. -* unknown fields existed before the safety mechanism was - introduced. -* unknown fields are needed in the future. +Another escape hatch is `pybind11_protobuf::AllowUnknownFieldsFor`, which +simply disables the safety mechanism for **specific message types**, without +a runtime overhead. This is useful for situations in which unknown fields +are acceptable. -An example of a full error message (with lines breaks here for readability): +An example of a full error message generated by the safety mechanism +(with lines breaks here for readability): ``` Proto Message of type pybind11.test.NestRepeated has an Unknown Field with @@ -117,14 +128,13 @@ Only if there is no alternative to suppressing this error, use (Warning: suppressions may mask critical bugs.) ``` -The current implementation is a compromise solution, trading off simplicity -of implementation, runtime performance, and precision. Generally, the runtime -overhead is expected to be very small, but fields flagged as unknown may not -necessarily be in extensions. -Alerting developers of new code to unknown fields is assumed to be generally -helpful, but the unknown fields detection is limited to messages with -extensions, to avoid the runtime overhead for the presumably much more common -case that no extensions are involved. +Note that the current implementation of the safety mechanism is a compromise +solution, trading off simplicity of implementation, runtime performance, +and precision. Alerting developers of new code to unknown fields is assumed +to be generally helpful, but the unknown fields detection is limited to +messages with extensions, to avoid the runtime overhead for the presumably +much more common case that no extensions are involved. Because of this, +the runtime overhead for the safety mechanism is expected to be very small. ### Enumerations diff --git a/pybind11_protobuf/BUILD b/pybind11_protobuf/BUILD index a03fab4..deba58a 100644 --- a/pybind11_protobuf/BUILD +++ b/pybind11_protobuf/BUILD @@ -100,3 +100,15 @@ cc_library( "@com_google_protobuf//python:proto_api", ], ) + +cc_library( + name = "disallow_extensions_with_unknown_fields", + srcs = ["disallow_extensions_with_unknown_fields.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":check_unknown_fields", + ], + alwayslink = 1, +) diff --git a/pybind11_protobuf/check_unknown_fields.cc b/pybind11_protobuf/check_unknown_fields.cc index 5b5e2c6..bb67d69 100644 --- a/pybind11_protobuf/check_unknown_fields.cc +++ b/pybind11_protobuf/check_unknown_fields.cc @@ -181,9 +181,9 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, unknown_field_parent_message_fqn)); } -std::optional CheckAndBuildErrorMessageIfAny( +std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* message) { + const ::google::protobuf::Message* message, bool build_error_message_if_any) { const auto* root_descriptor = message->GetDescriptor(); HasUnknownFields search{py_proto_api, root_descriptor}; if (!search.FindUnknownFieldsRecursive(message, 0u)) { @@ -193,6 +193,9 @@ std::optional CheckAndBuildErrorMessageIfAny( search.FieldFQN())) != 0) { return std::nullopt; } + if (!build_error_message_if_any) { + return ""; // This indicates that an unknown field was found. + } return search.BuildErrorMessage(); } diff --git a/pybind11_protobuf/check_unknown_fields.h b/pybind11_protobuf/check_unknown_fields.h index 4fc9771..e37adc7 100644 --- a/pybind11_protobuf/check_unknown_fields.h +++ b/pybind11_protobuf/check_unknown_fields.h @@ -9,12 +9,45 @@ namespace pybind11_protobuf::check_unknown_fields { +class ExtensionsWithUnknownFieldsPolicy { + enum State { + // Initial state. + kWeakDisallow, + + // Primary use case: PyCLIF extensions might set this when being imported. + kWeakEnableFallbackToSerializeParse, + + // Primary use case: `:disallow_extensions_with_unknown_fields` in `deps` + // of a binary (or test). + kStrongDisallow + }; + + static State& GetStateSingleton() { + static State singleton = kWeakDisallow; + return singleton; + } + + public: + static void WeakEnableFallbackToSerializeParse() { + State& policy = GetStateSingleton(); + if (policy == kWeakDisallow) { + policy = kWeakEnableFallbackToSerializeParse; + } + } + + static void StrongSetDisallow() { GetStateSingleton() = kStrongDisallow; } + + static bool UnknownFieldsAreDisallowed() { + return GetStateSingleton() != kWeakEnableFallbackToSerializeParse; + } +}; + void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, absl::string_view unknown_field_parent_message_fqn); -std::optional CheckAndBuildErrorMessageIfAny( +std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* top_message); + const ::google::protobuf::Message* top_message, bool build_error_message_if_any); } // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc b/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc new file mode 100644 index 0000000..5889e21 --- /dev/null +++ b/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc @@ -0,0 +1,12 @@ +#include "pybind11_protobuf/check_unknown_fields.h" + +namespace pybind11_protobuf::check_unknown_fields { +namespace { + +static int kSetConfigDone = []() { + ExtensionsWithUnknownFieldsPolicy::StrongSetDisallow(); + return 0; +}(); + +} // namespace +} // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/proto_cast_util.cc b/pybind11_protobuf/proto_cast_util.cc index f05936a..0f7ba29 100644 --- a/pybind11_protobuf/proto_cast_util.cc +++ b/pybind11_protobuf/proto_cast_util.cc @@ -819,18 +819,24 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy, // 1. The binary does not have a py_proto_api instance, or // 2. a) the proto is from the default pool and // b) the binary is not using fast_cpp_protos. - if ((GlobalState::instance()->py_proto_api() == nullptr) || + if (GlobalState::instance()->py_proto_api() == nullptr || (src->GetDescriptor()->file()->pool() == DescriptorPool::generated_pool() && !GlobalState::instance()->using_fast_cpp())) { return GenericPyProtoCast(src, policy, parent, is_const); } - std::optional emsg = - check_unknown_fields::CheckAndBuildErrorMessageIfAny( - GlobalState::instance()->py_proto_api(), src); - if (emsg) { - throw py::value_error(*emsg); + std::optional unknown_field_message = + check_unknown_fields::CheckRecursively( + GlobalState::instance()->py_proto_api(), src, + check_unknown_fields::ExtensionsWithUnknownFieldsPolicy:: + UnknownFieldsAreDisallowed()); + if (unknown_field_message) { + if (!unknown_field_message->empty()) { + throw py::value_error(*unknown_field_message); + } + // Fall back to serialize/parse. + return GenericPyProtoCast(src, policy, parent, is_const); } // If this is a dynamically generated proto, then we're going to need to