Skip to content

Commit

Permalink
Merge branch 'fs-bzlmod' of github.com:phaedon/pybind11_protobuf into…
Browse files Browse the repository at this point in the history
… fs-bzlmod
  • Loading branch information
Fædon Jóhannes Sinis committed Jan 8, 2024
2 parents 28b3684 + 12aab41 commit b820f0d
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 115 deletions.
48 changes: 29 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,18 @@ protobuf extensions
are involved, a well-known pitfall is that extensions are silently moved
to the `proto2::UnknownFieldSet` when a message is deserialized in C++,
but the `cc_proto_library` for the extensions is not linked in. The root
cause is an asymmetry in the handling of Python protos vs C++ protos: when
a Python proto is deserialized, both the Python descriptor pool and the C++
descriptor pool are inspected, but when a C++ proto is deserialized, only
cause is an asymmetry in the handling of Python protos vs C++
protos:
when a Python proto is deserialized, both the Python descriptor pool and the
C++ descriptor pool are inspected, but when a C++ proto is deserialized, only
the C++ descriptor pool is inspected. Until this asymmetry is resolved, the
`cc_proto_library` for all extensions involved must be added to the `deps` of
the relevant `pybind_library` or `pybind_extension`, but this is sufficiently
unobvious to be a setup for regular accidents, potentially with critical
the relevant `pybind_library` or `pybind_extension`, or if this is impractial,
`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse`
or `pybind11_protobuf::AllowUnknownFieldsFor` can be used.

The pitfall is sufficiently unobvious to be a setup for regular accidents,
potentially with critical
consequences.

To guard against the most common type of accident, native_proto_caster.h
Expand All @@ -97,14 +102,20 @@ in certain situations:
* and the `proto2::UnknownFieldSet` for the message or any of its submessages
is not empty.

`pybind11_protobuf::AllowUnknownFieldsFor` is an escape hatch for situations in
which
`pybind11_protobuf::check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::WeakEnableFallbackToSerializeParse`
is a **global** escape hatch trading off convenience and runtime overhead: the
convenience is that it is not necessary to determine what `cc_proto_library`
dependencies need to be added, the runtime overhead is that
`SerializePartialToString`/`ParseFromString` is used for messages with unknown
fields, instead of the much faster `CopyFrom`.

* unknown fields existed before the safety mechanism was
introduced.
* unknown fields are needed in the future.
Another escape hatch is `pybind11_protobuf::AllowUnknownFieldsFor`, which
simply disables the safety mechanism for **specific message types**, without
a runtime overhead. This is useful for situations in which unknown fields
are acceptable.

An example of a full error message (with lines breaks here for readability):
An example of a full error message generated by the safety mechanism
(with lines breaks here for readability):

```
Proto Message of type pybind11.test.NestRepeated has an Unknown Field with
Expand All @@ -117,14 +128,13 @@ Only if there is no alternative to suppressing this error, use
(Warning: suppressions may mask critical bugs.)
```

The current implementation is a compromise solution, trading off simplicity
of implementation, runtime performance, and precision. Generally, the runtime
overhead is expected to be very small, but fields flagged as unknown may not
necessarily be in extensions.
Alerting developers of new code to unknown fields is assumed to be generally
helpful, but the unknown fields detection is limited to messages with
extensions, to avoid the runtime overhead for the presumably much more common
case that no extensions are involved.
Note that the current implementation of the safety mechanism is a compromise
solution, trading off simplicity of implementation, runtime performance,
and precision. Alerting developers of new code to unknown fields is assumed
to be generally helpful, but the unknown fields detection is limited to
messages with extensions, to avoid the runtime overhead for the presumably
much more common case that no extensions are involved. Because of this,
the runtime overhead for the safety mechanism is expected to be very small.

### Enumerations

Expand Down
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
@@ -1 +1 @@
workspace(name = "com_google_pybind11_protobuf")
workspace(name = "com_google_pybind11_protobuf")
12 changes: 12 additions & 0 deletions pybind11_protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,15 @@ cc_library(
"@com_google_protobuf//python:proto_api",
],
)

cc_library(
name = "disallow_extensions_with_unknown_fields",
srcs = ["disallow_extensions_with_unknown_fields.cc"],
visibility = [
"//visibility:public",
],
deps = [
":check_unknown_fields",
],
alwayslink = 1,
)
7 changes: 5 additions & 2 deletions pybind11_protobuf/check_unknown_fields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
unknown_field_parent_message_fqn));
}

std::optional<std::string> CheckAndBuildErrorMessageIfAny(
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* message) {
const ::google::protobuf::Message* message, bool build_error_message_if_any) {
const auto* root_descriptor = message->GetDescriptor();
HasUnknownFields search{py_proto_api, root_descriptor};
if (!search.FindUnknownFieldsRecursive(message, 0u)) {
Expand All @@ -193,6 +193,9 @@ std::optional<std::string> CheckAndBuildErrorMessageIfAny(
search.FieldFQN())) != 0) {
return std::nullopt;
}
if (!build_error_message_if_any) {
return ""; // This indicates that an unknown field was found.
}
return search.BuildErrorMessage();
}

Expand Down
37 changes: 35 additions & 2 deletions pybind11_protobuf/check_unknown_fields.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,45 @@

namespace pybind11_protobuf::check_unknown_fields {

class ExtensionsWithUnknownFieldsPolicy {
enum State {
// Initial state.
kWeakDisallow,

// Primary use case: PyCLIF extensions might set this when being imported.
kWeakEnableFallbackToSerializeParse,

// Primary use case: `:disallow_extensions_with_unknown_fields` in `deps`
// of a binary (or test).
kStrongDisallow
};

static State& GetStateSingleton() {
static State singleton = kWeakDisallow;
return singleton;
}

public:
static void WeakEnableFallbackToSerializeParse() {
State& policy = GetStateSingleton();
if (policy == kWeakDisallow) {
policy = kWeakEnableFallbackToSerializeParse;
}
}

static void StrongSetDisallow() { GetStateSingleton() = kStrongDisallow; }

static bool UnknownFieldsAreDisallowed() {
return GetStateSingleton() != kWeakEnableFallbackToSerializeParse;
}
};

void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
absl::string_view unknown_field_parent_message_fqn);

std::optional<std::string> CheckAndBuildErrorMessageIfAny(
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* top_message);
const ::google::protobuf::Message* top_message, bool build_error_message_if_any);

} // namespace pybind11_protobuf::check_unknown_fields

Expand Down
12 changes: 12 additions & 0 deletions pybind11_protobuf/disallow_extensions_with_unknown_fields.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "pybind11_protobuf/check_unknown_fields.h"

namespace pybind11_protobuf::check_unknown_fields {
namespace {

static int kSetConfigDone = []() {
ExtensionsWithUnknownFieldsPolicy::StrongSetDisallow();
return 0;
}();

} // namespace
} // namespace pybind11_protobuf::check_unknown_fields
142 changes: 67 additions & 75 deletions pybind11_protobuf/proto_cast_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,33 @@
#include <iostream>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

#include "google/protobuf/descriptor.pb.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/message.h"
#include "python/google/protobuf/proto_api.h"
#include "absl/container/flat_hash_map.h"
#include "pybind11_protobuf/check_unknown_fields.h"

#if defined(GOOGLE_PROTOBUF_VERSION)
#include "absl/strings/numbers.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "pybind11_protobuf/check_unknown_fields.h"
#endif

namespace py = pybind11;

using ::google::protobuf::Descriptor;
using ::google::protobuf::DescriptorDatabase;
using ::google::protobuf::DescriptorPool;
using ::google::protobuf::DescriptorProto;
using ::google::protobuf::DynamicMessageFactory;
using ::google::protobuf::FileDescriptor;
using ::google::protobuf::FileDescriptorProto;
Expand Down Expand Up @@ -534,10 +540,8 @@ class PythonDescriptorPoolWrapper {
}
}

py::object wire = py_file_descriptor.attr("serialized_pb");
const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr());
return output->ParsePartialFromArray(bytes,
PYBIND11_BYTES_SIZE(wire.ptr()));
return output->ParsePartialFromString(
PyBytesAsStringView(py_file_descriptor.attr("serialized_pb")));
}

py::object pool_; // never dereferenced.
Expand All @@ -549,6 +553,11 @@ class PythonDescriptorPoolWrapper {

} // namespace

absl::string_view PyBytesAsStringView(py::bytes py_bytes) {
return absl::string_view(PyBytes_AsString(py_bytes.ptr()),
PyBytes_Size(py_bytes.ptr()));
}

void InitializePybindProtoCastUtil() {
assert(PyGILState_Check());
GlobalState::instance();
Expand Down Expand Up @@ -593,7 +602,7 @@ const Message* PyProtoGetCppMessagePointer(py::handle src) {
#endif
}

absl::optional<std::string> PyProtoDescriptorName(py::handle py_proto) {
absl::optional<std::string> PyProtoDescriptorFullName(py::handle py_proto) {
assert(PyGILState_Check());
auto py_full_name = ResolveAttrs(py_proto, {"DESCRIPTOR", "full_name"});
if (py_full_name) {
Expand All @@ -602,66 +611,42 @@ absl::optional<std::string> PyProtoDescriptorName(py::handle py_proto) {
return absl::nullopt;
}

bool PyProtoIsCompatible(py::handle py_proto, const Descriptor* descriptor) {
assert(PyGILState_Check());
if (descriptor->file()->pool() != DescriptorPool::generated_pool()) {
/// This indicates that the C++ descriptor does not come from the C++
/// DescriptorPool. This may happen if the C++ code has the same proto
/// in different descriptor pools, perhaps from different shared objects,
/// and could be result in undefined behavior.
return false;
}

auto py_descriptor = ResolveAttrs(py_proto, {"DESCRIPTOR"});
if (!py_descriptor) {
// Not a valid protobuf -- missing DESCRIPTOR.
return false;
}

// Test full_name equivalence.
{
auto py_full_name = ResolveAttrs(*py_descriptor, {"full_name"});
if (!py_full_name) {
// Not a valid protobuf -- missing DESCRIPTOR.full_name
return false;
}
auto full_name = CastToOptionalString(*py_full_name);
if (!full_name || *full_name != descriptor->full_name()) {
// Name mismatch.
return false;
}
}

// The C++ descriptor is compiled in (see above assert), so the py_proto
// is expected to be from the global pool, i.e. the DESCRIPTOR.file.pool
// instance is the global python pool, and not a custom pool.
auto py_pool = ResolveAttrs(*py_descriptor, {"file", "pool"});
if (py_pool) {
return py_pool->is(GlobalState::instance()->global_pool());
}

// The py_proto is missing a DESCRIPTOR.file.pool, but the name matches.
// This will not happen with a native python implementation, but does
// occur with the deprecated :proto_casters, and could happen with other
// mocks. Returning true allows the caster to call PyProtoCopyToCProto.
return true;
bool PyProtoHasMatchingFullName(py::handle py_proto,
const Descriptor* descriptor) {
auto full_name = PyProtoDescriptorFullName(py_proto);
return full_name && *full_name == descriptor->full_name();
}

bool PyProtoCopyToCProto(py::handle py_proto, Message* message) {
assert(PyGILState_Check());
auto serialize_fn = ResolveAttrMRO(py_proto, "SerializePartialToString");
py::bytes PyProtoSerializePartialToString(py::handle py_proto,
bool raise_if_error) {
static const char* serialize_fn_name = "SerializePartialToString";
auto serialize_fn = ResolveAttrMRO(py_proto, serialize_fn_name);
if (!serialize_fn) {
throw py::type_error(
"SerializePartialToString method not found; is this a " +
message->GetDescriptor()->full_name());
}
auto wire = (*serialize_fn)();
const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr());
if (!bytes) {
throw py::type_error("SerializePartialToString failed; is this a " +
message->GetDescriptor()->full_name());
return py::object();
}
auto serialized_bytes = py::reinterpret_steal<py::object>(
PyObject_CallObject(serialize_fn->ptr(), nullptr));
if (!serialized_bytes) {
if (raise_if_error) {
std::string msg = py::repr(py_proto).cast<std::string>() + "." +
serialize_fn_name + "() function call FAILED";
py::raise_from(PyExc_TypeError, msg.c_str());
throw py::error_already_set();
}
return py::object();
}
if (!PyBytes_Check(serialized_bytes.ptr())) {
if (raise_if_error) {
std::string msg = py::repr(py_proto).cast<std::string>() + "." +
serialize_fn_name +
"() function call is expected to return bytes, but the "
"returned value is " +
py::repr(serialized_bytes).cast<std::string>();
throw py::type_error(msg);
}
return py::object();
}
return message->ParsePartialFromArray(bytes, PYBIND11_BYTES_SIZE(wire.ptr()));
return serialized_bytes;
}

void CProtoCopyToPyProto(Message* message, py::handle py_proto) {
Expand All @@ -686,7 +671,8 @@ std::unique_ptr<Message> AllocateCProtoFromPythonSymbolDatabase(
assert(PyGILState_Check());
auto pool = ResolveAttrs(src, {"DESCRIPTOR", "file", "pool"});
if (!pool) {
throw py::type_error("Object is not a valid protobuf");
throw py::type_error(py::repr(src).cast<std::string>() +
" object is not a valid protobuf");
}

auto pool_data =
Expand Down Expand Up @@ -833,18 +819,24 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
// 1. The binary does not have a py_proto_api instance, or
// 2. a) the proto is from the default pool and
// b) the binary is not using fast_cpp_protos.
if ((GlobalState::instance()->py_proto_api() == nullptr) ||
if (GlobalState::instance()->py_proto_api() == nullptr ||
(src->GetDescriptor()->file()->pool() ==
DescriptorPool::generated_pool() &&
!GlobalState::instance()->using_fast_cpp())) {
return GenericPyProtoCast(src, policy, parent, is_const);
}

std::optional<std::string> emsg =
check_unknown_fields::CheckAndBuildErrorMessageIfAny(
GlobalState::instance()->py_proto_api(), src);
if (emsg) {
throw py::value_error(*emsg);
std::optional<std::string> unknown_field_message =
check_unknown_fields::CheckRecursively(
GlobalState::instance()->py_proto_api(), src,
check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::
UnknownFieldsAreDisallowed());
if (unknown_field_message) {
if (!unknown_field_message->empty()) {
throw py::value_error(*unknown_field_message);
}
// Fall back to serialize/parse.
return GenericPyProtoCast(src, policy, parent, is_const);
}

// If this is a dynamically generated proto, then we're going to need to
Expand Down
Loading

0 comments on commit b820f0d

Please sign in to comment.