diff --git a/README.md b/README.md index 91a376d..b84b35b 100644 --- a/README.md +++ b/README.md @@ -78,13 +78,18 @@ protobuf extensions are involved, a well-known pitfall is that extensions are silently moved to the `proto2::UnknownFieldSet` when a message is deserialized in C++, but the `cc_proto_library` for the extensions is not linked in. The root -cause is an asymmetry in the handling of Python protos vs C++ protos: when -a Python proto is deserialized, both the Python descriptor pool and the C++ -descriptor pool are inspected, but when a C++ proto is deserialized, only +cause is an asymmetry in the handling of Python protos vs C++ +protos: +when a Python proto is deserialized, both the Python descriptor pool and the +C++ descriptor pool are inspected, but when a C++ proto is deserialized, only the C++ descriptor pool is inspected. Until this asymmetry is resolved, the `cc_proto_library` for all extensions involved must be added to the `deps` of -the relevant `pybind_library` or `pybind_extension`, but this is sufficiently -unobvious to be a setup for regular accidents, potentially with critical +the relevant `pybind_library` or `pybind_extension`, or if this is impractial, +`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse` +or `pybind11_protobuf::AllowUnknownFieldsFor` can be used. + +The pitfall is sufficiently unobvious to be a setup for regular accidents, +potentially with critical consequences. To guard against the most common type of accident, native_proto_caster.h @@ -97,14 +102,20 @@ in certain situations: * and the `proto2::UnknownFieldSet` for the message or any of its submessages is not empty. -`pybind11_protobuf::AllowUnknownFieldsFor` is an escape hatch for situations in -which +`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse` +is a **global** escape hatch trading off convenience and runtime overhead: the +convenience is that it is not necessary to determine what `cc_proto_library` +dependencies need to be added, the runtime overhead is that +`SerializePartialToString`/`ParseFromString` is used for messages with unknown +fields, instead of the much faster `CopyFrom`. -* unknown fields existed before the safety mechanism was - introduced. -* unknown fields are needed in the future. +Another escape hatch is `pybind11_protobuf::AllowUnknownFieldsFor`, which +simply disables the safety mechanism for **specific message types**, without +a runtime overhead. This is useful for situations in which unknown fields +are acceptable. -An example of a full error message (with lines breaks here for readability): +An example of a full error message generated by the safety mechanism +(with lines breaks here for readability): ``` Proto Message of type pybind11.test.NestRepeated has an Unknown Field with @@ -117,14 +128,13 @@ Only if there is no alternative to suppressing this error, use (Warning: suppressions may mask critical bugs.) ``` -The current implementation is a compromise solution, trading off simplicity -of implementation, runtime performance, and precision. Generally, the runtime -overhead is expected to be very small, but fields flagged as unknown may not -necessarily be in extensions. -Alerting developers of new code to unknown fields is assumed to be generally -helpful, but the unknown fields detection is limited to messages with -extensions, to avoid the runtime overhead for the presumably much more common -case that no extensions are involved. +Note that the current implementation of the safety mechanism is a compromise +solution, trading off simplicity of implementation, runtime performance, +and precision. Alerting developers of new code to unknown fields is assumed +to be generally helpful, but the unknown fields detection is limited to +messages with extensions, to avoid the runtime overhead for the presumably +much more common case that no extensions are involved. Because of this, +the runtime overhead for the safety mechanism is expected to be very small. ### Enumerations diff --git a/WORKSPACE b/WORKSPACE index 05c291a..900f9bc 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1 +1 @@ -workspace(name = "com_google_pybind11_protobuf") \ No newline at end of file +workspace(name = "com_google_pybind11_protobuf") diff --git a/pybind11_protobuf/BUILD b/pybind11_protobuf/BUILD index ea6f4f7..4a6d743 100644 --- a/pybind11_protobuf/BUILD +++ b/pybind11_protobuf/BUILD @@ -100,3 +100,15 @@ cc_library( "@com_google_protobuf//python:proto_api", ], ) + +cc_library( + name = "disallow_extensions_with_unknown_fields", + srcs = ["disallow_extensions_with_unknown_fields.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":check_unknown_fields", + ], + alwayslink = 1, +) diff --git a/pybind11_protobuf/check_unknown_fields.cc b/pybind11_protobuf/check_unknown_fields.cc index 5b5e2c6..bb67d69 100644 --- a/pybind11_protobuf/check_unknown_fields.cc +++ b/pybind11_protobuf/check_unknown_fields.cc @@ -181,9 +181,9 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, unknown_field_parent_message_fqn)); } -std::optional CheckAndBuildErrorMessageIfAny( +std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* message) { + const ::google::protobuf::Message* message, bool build_error_message_if_any) { const auto* root_descriptor = message->GetDescriptor(); HasUnknownFields search{py_proto_api, root_descriptor}; if (!search.FindUnknownFieldsRecursive(message, 0u)) { @@ -193,6 +193,9 @@ std::optional CheckAndBuildErrorMessageIfAny( search.FieldFQN())) != 0) { return std::nullopt; } + if (!build_error_message_if_any) { + return ""; // This indicates that an unknown field was found. + } return search.BuildErrorMessage(); } diff --git a/pybind11_protobuf/check_unknown_fields.h b/pybind11_protobuf/check_unknown_fields.h index 4fc9771..e37adc7 100644 --- a/pybind11_protobuf/check_unknown_fields.h +++ b/pybind11_protobuf/check_unknown_fields.h @@ -9,12 +9,45 @@ namespace pybind11_protobuf::check_unknown_fields { +class ExtensionsWithUnknownFieldsPolicy { + enum State { + // Initial state. + kWeakDisallow, + + // Primary use case: PyCLIF extensions might set this when being imported. + kWeakEnableFallbackToSerializeParse, + + // Primary use case: `:disallow_extensions_with_unknown_fields` in `deps` + // of a binary (or test). + kStrongDisallow + }; + + static State& GetStateSingleton() { + static State singleton = kWeakDisallow; + return singleton; + } + + public: + static void WeakEnableFallbackToSerializeParse() { + State& policy = GetStateSingleton(); + if (policy == kWeakDisallow) { + policy = kWeakEnableFallbackToSerializeParse; + } + } + + static void StrongSetDisallow() { GetStateSingleton() = kStrongDisallow; } + + static bool UnknownFieldsAreDisallowed() { + return GetStateSingleton() != kWeakEnableFallbackToSerializeParse; + } +}; + void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, absl::string_view unknown_field_parent_message_fqn); -std::optional CheckAndBuildErrorMessageIfAny( +std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, - const ::google::protobuf::Message* top_message); + const ::google::protobuf::Message* top_message, bool build_error_message_if_any); } // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc b/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc new file mode 100644 index 0000000..5889e21 --- /dev/null +++ b/pybind11_protobuf/disallow_extensions_with_unknown_fields.cc @@ -0,0 +1,12 @@ +#include "pybind11_protobuf/check_unknown_fields.h" + +namespace pybind11_protobuf::check_unknown_fields { +namespace { + +static int kSetConfigDone = []() { + ExtensionsWithUnknownFieldsPolicy::StrongSetDisallow(); + return 0; +}(); + +} // namespace +} // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/proto_cast_util.cc b/pybind11_protobuf/proto_cast_util.cc index 1c8727f..0f7ba29 100644 --- a/pybind11_protobuf/proto_cast_util.cc +++ b/pybind11_protobuf/proto_cast_util.cc @@ -9,27 +9,33 @@ #include #include #include -#include #include +#include #include "google/protobuf/descriptor.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor_database.h" #include "google/protobuf/dynamic_message.h" -#include "google/protobuf/message.h" #include "python/google/protobuf/proto_api.h" -#include "absl/container/flat_hash_map.h" +#include "pybind11_protobuf/check_unknown_fields.h" + +#if defined(GOOGLE_PROTOBUF_VERSION) #include "absl/strings/numbers.h" -#include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" -#include "absl/types/optional.h" -#include "pybind11_protobuf/check_unknown_fields.h" +#endif namespace py = pybind11; using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorDatabase; using ::google::protobuf::DescriptorPool; -using ::google::protobuf::DescriptorProto; using ::google::protobuf::DynamicMessageFactory; using ::google::protobuf::FileDescriptor; using ::google::protobuf::FileDescriptorProto; @@ -534,10 +540,8 @@ class PythonDescriptorPoolWrapper { } } - py::object wire = py_file_descriptor.attr("serialized_pb"); - const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr()); - return output->ParsePartialFromArray(bytes, - PYBIND11_BYTES_SIZE(wire.ptr())); + return output->ParsePartialFromString( + PyBytesAsStringView(py_file_descriptor.attr("serialized_pb"))); } py::object pool_; // never dereferenced. @@ -549,6 +553,11 @@ class PythonDescriptorPoolWrapper { } // namespace +absl::string_view PyBytesAsStringView(py::bytes py_bytes) { + return absl::string_view(PyBytes_AsString(py_bytes.ptr()), + PyBytes_Size(py_bytes.ptr())); +} + void InitializePybindProtoCastUtil() { assert(PyGILState_Check()); GlobalState::instance(); @@ -593,7 +602,7 @@ const Message* PyProtoGetCppMessagePointer(py::handle src) { #endif } -absl::optional PyProtoDescriptorName(py::handle py_proto) { +absl::optional PyProtoDescriptorFullName(py::handle py_proto) { assert(PyGILState_Check()); auto py_full_name = ResolveAttrs(py_proto, {"DESCRIPTOR", "full_name"}); if (py_full_name) { @@ -602,66 +611,42 @@ absl::optional PyProtoDescriptorName(py::handle py_proto) { return absl::nullopt; } -bool PyProtoIsCompatible(py::handle py_proto, const Descriptor* descriptor) { - assert(PyGILState_Check()); - if (descriptor->file()->pool() != DescriptorPool::generated_pool()) { - /// This indicates that the C++ descriptor does not come from the C++ - /// DescriptorPool. This may happen if the C++ code has the same proto - /// in different descriptor pools, perhaps from different shared objects, - /// and could be result in undefined behavior. - return false; - } - - auto py_descriptor = ResolveAttrs(py_proto, {"DESCRIPTOR"}); - if (!py_descriptor) { - // Not a valid protobuf -- missing DESCRIPTOR. - return false; - } - - // Test full_name equivalence. - { - auto py_full_name = ResolveAttrs(*py_descriptor, {"full_name"}); - if (!py_full_name) { - // Not a valid protobuf -- missing DESCRIPTOR.full_name - return false; - } - auto full_name = CastToOptionalString(*py_full_name); - if (!full_name || *full_name != descriptor->full_name()) { - // Name mismatch. - return false; - } - } - - // The C++ descriptor is compiled in (see above assert), so the py_proto - // is expected to be from the global pool, i.e. the DESCRIPTOR.file.pool - // instance is the global python pool, and not a custom pool. - auto py_pool = ResolveAttrs(*py_descriptor, {"file", "pool"}); - if (py_pool) { - return py_pool->is(GlobalState::instance()->global_pool()); - } - - // The py_proto is missing a DESCRIPTOR.file.pool, but the name matches. - // This will not happen with a native python implementation, but does - // occur with the deprecated :proto_casters, and could happen with other - // mocks. Returning true allows the caster to call PyProtoCopyToCProto. - return true; +bool PyProtoHasMatchingFullName(py::handle py_proto, + const Descriptor* descriptor) { + auto full_name = PyProtoDescriptorFullName(py_proto); + return full_name && *full_name == descriptor->full_name(); } -bool PyProtoCopyToCProto(py::handle py_proto, Message* message) { - assert(PyGILState_Check()); - auto serialize_fn = ResolveAttrMRO(py_proto, "SerializePartialToString"); +py::bytes PyProtoSerializePartialToString(py::handle py_proto, + bool raise_if_error) { + static const char* serialize_fn_name = "SerializePartialToString"; + auto serialize_fn = ResolveAttrMRO(py_proto, serialize_fn_name); if (!serialize_fn) { - throw py::type_error( - "SerializePartialToString method not found; is this a " + - message->GetDescriptor()->full_name()); - } - auto wire = (*serialize_fn)(); - const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr()); - if (!bytes) { - throw py::type_error("SerializePartialToString failed; is this a " + - message->GetDescriptor()->full_name()); + return py::object(); + } + auto serialized_bytes = py::reinterpret_steal( + PyObject_CallObject(serialize_fn->ptr(), nullptr)); + if (!serialized_bytes) { + if (raise_if_error) { + std::string msg = py::repr(py_proto).cast() + "." + + serialize_fn_name + "() function call FAILED"; + py::raise_from(PyExc_TypeError, msg.c_str()); + throw py::error_already_set(); + } + return py::object(); + } + if (!PyBytes_Check(serialized_bytes.ptr())) { + if (raise_if_error) { + std::string msg = py::repr(py_proto).cast() + "." + + serialize_fn_name + + "() function call is expected to return bytes, but the " + "returned value is " + + py::repr(serialized_bytes).cast(); + throw py::type_error(msg); + } + return py::object(); } - return message->ParsePartialFromArray(bytes, PYBIND11_BYTES_SIZE(wire.ptr())); + return serialized_bytes; } void CProtoCopyToPyProto(Message* message, py::handle py_proto) { @@ -686,7 +671,8 @@ std::unique_ptr AllocateCProtoFromPythonSymbolDatabase( assert(PyGILState_Check()); auto pool = ResolveAttrs(src, {"DESCRIPTOR", "file", "pool"}); if (!pool) { - throw py::type_error("Object is not a valid protobuf"); + throw py::type_error(py::repr(src).cast() + + " object is not a valid protobuf"); } auto pool_data = @@ -833,18 +819,24 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy, // 1. The binary does not have a py_proto_api instance, or // 2. a) the proto is from the default pool and // b) the binary is not using fast_cpp_protos. - if ((GlobalState::instance()->py_proto_api() == nullptr) || + if (GlobalState::instance()->py_proto_api() == nullptr || (src->GetDescriptor()->file()->pool() == DescriptorPool::generated_pool() && !GlobalState::instance()->using_fast_cpp())) { return GenericPyProtoCast(src, policy, parent, is_const); } - std::optional emsg = - check_unknown_fields::CheckAndBuildErrorMessageIfAny( - GlobalState::instance()->py_proto_api(), src); - if (emsg) { - throw py::value_error(*emsg); + std::optional unknown_field_message = + check_unknown_fields::CheckRecursively( + GlobalState::instance()->py_proto_api(), src, + check_unknown_fields::ExtensionsWithUnknownFieldsPolicy:: + UnknownFieldsAreDisallowed()); + if (unknown_field_message) { + if (!unknown_field_message->empty()) { + throw py::value_error(*unknown_field_message); + } + // Fall back to serialize/parse. + return GenericPyProtoCast(src, policy, parent, is_const); } // If this is a dynamically generated proto, then we're going to need to diff --git a/pybind11_protobuf/proto_cast_util.h b/pybind11_protobuf/proto_cast_util.h index 4821e94..954115e 100644 --- a/pybind11_protobuf/proto_cast_util.h +++ b/pybind11_protobuf/proto_cast_util.h @@ -14,6 +14,7 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" // PYBIND11_PROTOBUF_ASSUME_FULL_ABI_COMPATIBILITY can be defined by users @@ -28,6 +29,10 @@ namespace pybind11_protobuf { +// Simple helper. Caller has to ensure that the py_bytes argument outlives the +// returned string_view. +absl::string_view PyBytesAsStringView(pybind11::bytes py_bytes); + // Initialize internal proto cast dependencies, which includes importing // various protobuf-related modules. void InitializePybindProtoCastUtil(); @@ -39,14 +44,16 @@ void ImportProtoDescriptorModule(const ::google::protobuf::Descriptor *); const ::google::protobuf::Message *PyProtoGetCppMessagePointer(pybind11::handle src); // Returns the protocol buffer's py_proto.DESCRIPTOR.full_name attribute. -absl::optional PyProtoDescriptorName(pybind11::handle py_proto); +absl::optional PyProtoDescriptorFullName( + pybind11::handle py_proto); + +// Returns true if py_proto full name matches descriptor full name. +bool PyProtoHasMatchingFullName(pybind11::handle py_proto, + const ::google::protobuf::Descriptor *descriptor); -// Return whether py_proto is compatible with the C++ descriptor. -// The py_proto name must match the C++ Descriptor::full_name(), and is -// expected to originate from the python default pool, which means that -// this method will return false for dynamic protos. -bool PyProtoIsCompatible(pybind11::handle py_proto, - const ::google::protobuf::Descriptor *descriptor); +// Caller should enforce any type identity that is required. +pybind11::bytes PyProtoSerializePartialToString(pybind11::handle py_proto, + bool raise_if_error); // Allocates a C++ protocol buffer for a given name. std::unique_ptr<::google::protobuf::Message> AllocateCProtoFromPythonSymbolDatabase( @@ -54,7 +61,6 @@ std::unique_ptr<::google::protobuf::Message> AllocateCProtoFromPythonSymbolDatab // Serialize the py_proto and deserialize it into the provided message. // Caller should enforce any type identity that is required. -bool PyProtoCopyToCProto(pybind11::handle py_proto, ::google::protobuf::Message *message); void CProtoCopyToPyProto(::google::protobuf::Message *message, pybind11::handle py_proto); // Returns a handle to a python protobuf suitably diff --git a/pybind11_protobuf/proto_caster_impl.h b/pybind11_protobuf/proto_caster_impl.h index f54b175..50eb3bc 100644 --- a/pybind11_protobuf/proto_caster_impl.h +++ b/pybind11_protobuf/proto_caster_impl.h @@ -63,16 +63,19 @@ struct proto_caster_load_impl { } } - // The incoming object is not a compatible fast_cpp_proto, so check whether - // it is otherwise compatible, then serialize it and deserialize into a - // native C++ proto type. - if (!pybind11_protobuf::PyProtoIsCompatible(src, - ProtoType::GetDescriptor())) { + if (!PyProtoHasMatchingFullName(src, ProtoType::GetDescriptor())) { return false; } + pybind11::bytes serialized_bytes = + PyProtoSerializePartialToString(src, convert); + if (!serialized_bytes) { + return false; + } + owned = std::unique_ptr(new ProtoType()); value = owned.get(); - return pybind11_protobuf::PyProtoCopyToCProto(src, owned.get()); + return owned.get()->ParsePartialFromString( + PyBytesAsStringView(serialized_bytes)); } // ensure_owned ensures that the owned member contains a copy of the @@ -108,16 +111,23 @@ struct proto_caster_load_impl<::google::protobuf::Message> { // `src` is not a C++ proto instance from the generated_pool, // so create a compatible native C++ proto. - auto descriptor_name = pybind11_protobuf::PyProtoDescriptorName(src); + auto descriptor_name = pybind11_protobuf::PyProtoDescriptorFullName(src); if (!descriptor_name) { return false; } + pybind11::bytes serialized_bytes = + PyProtoSerializePartialToString(src, convert); + if (!serialized_bytes) { + return false; + } + owned.reset(static_cast( pybind11_protobuf::AllocateCProtoFromPythonSymbolDatabase( src, *descriptor_name) .release())); value = owned.get(); - return pybind11_protobuf::PyProtoCopyToCProto(src, owned.get()); + return owned.get()->ParsePartialFromString( + PyBytesAsStringView(serialized_bytes)); } // ensure_owned ensures that the owned member contains a copy of the diff --git a/pybind11_protobuf/tests/pass_by_test.py b/pybind11_protobuf/tests/pass_by_test.py index 66a75c9..56a69fc 100644 --- a/pybind11_protobuf/tests/pass_by_test.py +++ b/pybind11_protobuf/tests/pass_by_test.py @@ -170,6 +170,34 @@ def test_pass_fake2(self, check_method): def test_overload_fn(self, message_fn, expected): self.assertEqual(expected, m.fn_overload(message_fn())) + def test_bad_serialize_partial_function_calls(self): + class FakeDescr: + full_name = 'fake_full_name' + + class FakeProto: + DESCRIPTOR = FakeDescr() + + def __init__(self, serialize_fn_return_value=None): + self.serialize_fn_return_value = serialize_fn_return_value + + def SerializePartialToString(self): # pylint: disable=invalid-name + if self.serialize_fn_return_value is None: + raise RuntimeError('Broken serialize_fn.') + return self.serialize_fn_return_value + + with self.assertRaisesRegex( + TypeError, r'\.SerializePartialToString\(\) function call FAILED$' + ): + m.fn_overload(FakeProto()) + with self.assertRaisesRegex( + TypeError, + r'\.SerializePartialToString\(\) function call is expected to return' + r' bytes, but the returned value is \[\]$', + ): + m.fn_overload(FakeProto([])) + with self.assertRaisesRegex(TypeError, r' object is not a valid protobuf$'): + m.fn_overload(FakeProto(b'')) + if __name__ == '__main__': absltest.main()