From e6cc3b1f24274c2a9b6a54e79b0ceef4909e9bd4 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 22 Nov 2023 11:58:29 -0800 Subject: [PATCH] Add `pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy`. PiperOrigin-RevId: 584687154 --- README.md | 48 +++++++++++-------- pybind11_protobuf/BUILD | 12 +++++ pybind11_protobuf/check_unknown_fields.cc | 7 ++- pybind11_protobuf/check_unknown_fields.h | 37 +++++++++++++- ...disallow_extensions_with_unknown_fields.cc | 12 +++++ pybind11_protobuf/proto_cast_util.cc | 18 ++++--- 6 files changed, 105 insertions(+), 29 deletions(-) create mode 100644 pybind11_protobuf/disallow_extensions_with_unknown_fields.cc diff --git a/README.md b/README.md index 91a376d..b84b35b 100644 --- a/README.md +++ b/README.md @@ -78,13 +78,18 @@ protobuf extensions are involved, a well-known pitfall is that extensions are silently moved to the `proto2::UnknownFieldSet` when a message is deserialized in C++, but the `cc_proto_library` for the extensions is not linked in. The root -cause is an asymmetry in the handling of Python protos vs C++ protos: when -a Python proto is deserialized, both the Python descriptor pool and the C++ -descriptor pool are inspected, but when a C++ proto is deserialized, only +cause is an asymmetry in the handling of Python protos vs C++ +protos: +when a Python proto is deserialized, both the Python descriptor pool and the +C++ descriptor pool are inspected, but when a C++ proto is deserialized, only the C++ descriptor pool is inspected. Until this asymmetry is resolved, the `cc_proto_library` for all extensions involved must be added to the `deps` of -the relevant `pybind_library` or `pybind_extension`, but this is sufficiently -unobvious to be a setup for regular accidents, potentially with critical +the relevant `pybind_library` or `pybind_extension`, or if this is impractial, +`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse` +or `pybind11_protobuf::AllowUnknownFieldsFor` can be used. + +The pitfall is sufficiently unobvious to be a setup for regular accidents, +potentially with critical consequences. To guard against the most common type of accident, native_proto_caster.h @@ -97,14 +102,20 @@ in certain situations: * and the `proto2::UnknownFieldSet` for the message or any of its submessages is not empty. -`pybind11_protobuf::AllowUnknownFieldsFor` is an escape hatch for situations in -which +`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse` +is a **global** escape hatch trading off convenience and runtime overhead: the +convenience is that it is not necessary to determine what `cc_proto_library` +dependencies need to be added, the runtime overhead is that +`SerializePartialToString`/`ParseFromString` is used for messages with unknown +fields, instead of the much faster `CopyFrom`. -* unknown fields existed before the safety mechanism was - introduced. -* unknown fields are needed in the future. +Another escape hatch is `pybind11_protobuf::AllowUnknownFieldsFor`, which +simply disables the safety mechanism for **specific message types**, without +a runtime overhead. This is useful for situations in which unknown fields +are acceptable. -An example of a full error message (with lines breaks here for readability): +An example of a full error message generated by the safety mechanism +(with lines breaks here for readability): ``` Proto Message of type pybind11.test.NestRepeated has an Unknown Field with @@ -117,14 +128,13 @@ Only if there is no alternative to suppressing this error, use (Warning: suppressions may mask critical bugs.) ``` -The current implementation is a compromise solution, trading off simplicity -of implementation, runtime performance, and precision. Generally, the runtime -overhead is expected to be very small, but fields flagged as unknown may not -necessarily be in extensions. -Alerting developers of new code to unknown fields is assumed to be generally -helpful, but the unknown fields detection is limited to messages with -extensions, to avoid the runtime overhead for the presumably much more common -case that no extensions are involved. +Note that the current implementation of the safety mechanism is a compromise +solution, trading off simplicity of implementation, runtime performance, +and precision. Alerting developers of new code to unknown fields is assumed +to be generally helpful, but the unknown fields detection is limited to +messages with extensions, to avoid the runtime overhead for the presumably +much more common case that no extensions are involved. Because of this, +the runtime overhead for the safety mechanism is expected to be very small. ### Enumerations diff --git a/pybind11_protobuf/BUILD b/pybind11_protobuf/BUILD index a03fab4..deba58a 100644 --- a/pybind11_protobuf/BUILD +++ b/pybind11_protobuf/BUILD @@ -100,3 +100,15 @@ cc_library( "@com_google_protobuf//python:proto_api", ], ) + +cc_library( + name = "disallow_extensions_with_unknown_fields", + srcs = ["disallow_extensions_with_unknown_fields.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":check_unknown_fields", + ], + alwayslink = 1, +) diff --git a/pybind11_protobuf/check_unknown_fields.cc b/pybind11_protobuf/check_unknown_fields.cc index 5b5e2c6..bb67d69 100644 --- a/pybind11_protobuf/check_unknown_fields.cc +++ b/pybind11_protobuf/check_unknown_fields.cc @@ -181,9 +181,9 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, unknown_field_parent_message_fqn)); } -std::optional CheckAndBuildErrorMessageIfAny( +std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* message) { + const ::google::protobuf::Message* message, bool build_error_message_if_any) { const auto* root_descriptor = message->GetDescriptor(); HasUnknownFields search{py_proto_api, root_descriptor}; if (!search.FindUnknownFieldsRecursive(message, 0u)) { @@ -193,6 +193,9 @@ std::optional CheckAndBuildErrorMessageIfAny( search.FieldFQN())) != 0) { return std::nullopt; } + if (!build_error_message_if_any) { + return ""; // This indicates that an unknown field was found. + } return search.BuildErrorMessage(); } diff --git a/pybind11_protobuf/check_unknown_fields.h b/pybind11_protobuf/check_unknown_fields.h index 4fc9771..e37adc7 100644 --- a/pybind11_protobuf/check_unknown_fields.h +++ b/pybind11_protobuf/check_unknown_fields.h @@ -9,12 +9,45 @@ namespace pybind11_protobuf::check_unknown_fields { +class ExtensionsWithUnknownFieldsPolicy { + enum State { + // Initial state. + kWeakDisallow, + + // Primary use case: PyCLIF extensions might set this when being imported. + kWeakEnableFallbackToSerializeParse, + + // Primary use case: `:disallow_extensions_with_unknown_fields` in `deps` + // of a binary (or test). + kStrongDisallow + }; + + static State& GetStateSingleton() { + static State singleton = kWeakDisallow; + return singleton; + } + + public: + static void WeakEnableFallbackToSerializeParse() { + State& policy = GetStateSingleton(); + if (policy == kWeakDisallow) { + policy = kWeakEnableFallbackToSerializeParse; + } + } + + static void StrongSetDisallow() { GetStateSingleton() = kStrongDisallow; } + + static bool UnknownFieldsAreDisallowed() { + return GetStateSingleton() != kWeakEnableFallbackToSerializeParse; + } +}; + void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, absl::string_view unknown_field_parent_message_fqn); -std::optional CheckAndBuildErrorMessageIfAny( +std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* top_message); + const ::google::protobuf::Message* top_message, bool build_error_message_if_any); } // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc b/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc new file mode 100644 index 0000000..5889e21 --- /dev/null +++ b/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc @@ -0,0 +1,12 @@ +#include "pybind11_protobuf/check_unknown_fields.h" + +namespace pybind11_protobuf::check_unknown_fields { +namespace { + +static int kSetConfigDone = []() { + ExtensionsWithUnknownFieldsPolicy::StrongSetDisallow(); + return 0; +}(); + +} // namespace +} // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/proto_cast_util.cc b/pybind11_protobuf/proto_cast_util.cc index f05936a..0f7ba29 100644 --- a/pybind11_protobuf/proto_cast_util.cc +++ b/pybind11_protobuf/proto_cast_util.cc @@ -819,18 +819,24 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy, // 1. The binary does not have a py_proto_api instance, or // 2. a) the proto is from the default pool and // b) the binary is not using fast_cpp_protos. - if ((GlobalState::instance()->py_proto_api() == nullptr) || + if (GlobalState::instance()->py_proto_api() == nullptr || (src->GetDescriptor()->file()->pool() == DescriptorPool::generated_pool() && !GlobalState::instance()->using_fast_cpp())) { return GenericPyProtoCast(src, policy, parent, is_const); } - std::optional emsg = - check_unknown_fields::CheckAndBuildErrorMessageIfAny( - GlobalState::instance()->py_proto_api(), src); - if (emsg) { - throw py::value_error(*emsg); + std::optional unknown_field_message = + check_unknown_fields::CheckRecursively( + GlobalState::instance()->py_proto_api(), src, + check_unknown_fields::ExtensionsWithUnknownFieldsPolicy:: + UnknownFieldsAreDisallowed()); + if (unknown_field_message) { + if (!unknown_field_message->empty()) { + throw py::value_error(*unknown_field_message); + } + // Fall back to serialize/parse. + return GenericPyProtoCast(src, policy, parent, is_const); } // If this is a dynamically generated proto, then we're going to need to