Skip to content

Commit

Permalink
Import #161
Browse files Browse the repository at this point in the history
Manual import. The Google-internal repo is the source of truth for pybind11_protobuf. Sorry we didn't get to automating imports from GitHub PRs.

PiperOrigin-RevId: 644215761
  • Loading branch information
StefanBruens authored and copybara-github committed Jun 18, 2024
1 parent af8ee51 commit 2199159
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 13 deletions.
38 changes: 26 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# Options

option(BUILD_TESTS "Build tests." OFF)
option(ENABLE_PYPROTO_API "Enable usage of proto_api." OFF)

# ============================================================================
# Find Python
Expand Down Expand Up @@ -60,18 +61,26 @@ target_include_directories(
# ============================================================================
# pybind11_native_proto_caster shared library
add_library(
pybind11_native_proto_caster SHARED
pybind11_native_proto_caster STATIC
# bazel: pybind_library: native_proto_caster
pybind11_protobuf/native_proto_caster.h
# bazel: pybind_library: enum_type_caster
pybind11_protobuf/enum_type_caster.h
# bazel: pybind_library: proto_cast_util
pybind11_protobuf/proto_cast_util.cc
pybind11_protobuf/proto_cast_util.h
pybind11_protobuf/proto_caster_impl.h
pybind11_protobuf/proto_caster_impl.h)

target_sources(
# bazel: cc_library::check_unknown_fields
pybind11_protobuf/check_unknown_fields.cc
pybind11_protobuf/check_unknown_fields.h)
pybind11_native_proto_caster #
PRIVATE pybind11_protobuf/check_unknown_fields.cc
pybind11_protobuf/check_unknown_fields.h)

if(ENABLE_PYPROTO_API)
target_compile_definitions(pybind11_native_proto_caster
PRIVATE PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
endif()

target_link_libraries(
pybind11_native_proto_caster
Expand All @@ -92,16 +101,24 @@ target_include_directories(
# ============================================================================
# pybind11_wrapped_proto_caster shared library
add_library(
pybind11_wrapped_proto_caster SHARED
pybind11_wrapped_proto_caster STATIC
# bazel: pybind_library: wrapped_proto_caster
pybind11_protobuf/wrapped_proto_caster.h
# bazel: pybind_library: proto_cast_util
pybind11_protobuf/proto_cast_util.cc
pybind11_protobuf/proto_cast_util.h
pybind11_protobuf/proto_caster_impl.h
# bazel: cc_library: check_unknown_fields
pybind11_protobuf/check_unknown_fields.cc
pybind11_protobuf/check_unknown_fields.h)
pybind11_protobuf/proto_caster_impl.h)

target_sources(
# bazel: cc_library::check_unknown_fields
pybind11_wrapped_proto_caster
PRIVATE pybind11_protobuf/check_unknown_fields.cc
pybind11_protobuf/check_unknown_fields.h)

if(ENABLE_PYPROTO_API)
target_compile_definitions(pybind11_wrapped_proto_caster
PRIVATE PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
endif()

target_link_libraries(
pybind11_wrapped_proto_caster
Expand All @@ -119,9 +136,6 @@ target_include_directories(
PRIVATE ${PROJECT_SOURCE_DIR} ${protobuf_INCLUDE_DIRS} ${protobuf_SOURCE_DIR}
${pybind11_INCLUDE_DIRS})

# TODO set defines PYBIND11_PROTOBUF_ENABLE_PYPROTO_API see: bazel:
# pybind_library: proto_cast_util

# bazel equivs. checklist
#
# bazel: pybind_library: enum_type_caster - enum_type_caster.h
Expand Down
6 changes: 6 additions & 0 deletions pybind11_protobuf/check_unknown_fields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ std::string MakeAllowListKey(
unknown_field_parent_message_fqn);
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)

/// Recurses through the message Descriptor class looking for valid extensions.
/// Stores the result to `memoized`.
bool MessageMayContainExtensionsRecursive(const ::google::protobuf::Descriptor* descriptor,
Expand Down Expand Up @@ -173,6 +175,8 @@ std::string HasUnknownFields::BuildErrorMessage() const {
return emsg;
}

#endif

} // namespace

void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
Expand All @@ -181,6 +185,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
unknown_field_parent_message_fqn));
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* message) {
Expand All @@ -195,5 +200,6 @@ std::optional<std::string> CheckRecursively(
}
return search.BuildErrorMessage();
}
#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API

} // namespace pybind11_protobuf::check_unknown_fields
5 changes: 5 additions & 0 deletions pybind11_protobuf/check_unknown_fields.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
#include <optional>

#include "google/protobuf/message.h"
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
#include "python/google/protobuf/proto_api.h"
#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API

#include "absl/strings/string_view.h"

namespace pybind11_protobuf::check_unknown_fields {
Expand Down Expand Up @@ -45,9 +48,11 @@ class ExtensionsWithUnknownFieldsPolicy {
void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
absl::string_view unknown_field_parent_message_fqn);

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* top_message);
#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API

} // namespace pybind11_protobuf::check_unknown_fields

Expand Down
18 changes: 17 additions & 1 deletion pybind11_protobuf/proto_cast_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/dynamic_message.h"
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
#include "python/google/protobuf/proto_api.h"
#else
namespace google::protobuf::python {
struct PyProto_API;
}
#endif
#include "pybind11_protobuf/check_unknown_fields.h"

#if defined(GOOGLE_PROTOBUF_VERSION)
Expand All @@ -46,7 +52,6 @@ using ::google::protobuf::FileDescriptorProto;
using ::google::protobuf::Message;
using ::google::protobuf::MessageFactory;
using ::google::protobuf::python::PyProto_API;
using ::google::protobuf::python::PyProtoAPICapsuleName;

namespace pybind11_protobuf {

Expand Down Expand Up @@ -266,6 +271,7 @@ GlobalState::GlobalState() {
//
// By default (3) is used, however if the define is set *and* the version
// matches, then pybind11_protobuf will assume that this will work.
using ::google::protobuf::python::PyProtoAPICapsuleName;
py_proto_api_ =
static_cast<PyProto_API*>(PyCapsule_Import(PyProtoAPICapsuleName(), 0));
if (py_proto_api_ == nullptr) {
Expand Down Expand Up @@ -355,6 +361,7 @@ py::object GlobalState::PyMessageInstance(const Descriptor* descriptor) {
module_name + "?");
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
std::pair<py::object, Message*> GlobalState::PyFastCppProtoMessageInstance(
const Descriptor* descriptor) {
assert(descriptor != nullptr);
Expand Down Expand Up @@ -395,6 +402,7 @@ std::pair<py::object, Message*> GlobalState::PyFastCppProtoMessageInstance(
}
return {std::move(result), message};
}
#endif

// Create C++ DescriptorPools based on Python DescriptorPools.
// The Python pool will provide message definitions when they are needed.
Expand Down Expand Up @@ -534,6 +542,7 @@ class PythonDescriptorPoolWrapper {
private:
bool CopyToFileDescriptorProto(py::handle py_file_descriptor,
FileDescriptorProto* output) {
#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
if (GlobalState::instance()->py_proto_api()) {
try {
py::object c_proto = py::reinterpret_steal<py::object>(
Expand All @@ -552,6 +561,7 @@ class PythonDescriptorPoolWrapper {
PyErr_Print();
}
}
#endif

return output->ParsePartialFromString(
PyBytesAsStringView(py_file_descriptor.attr("serialized_pb")));
Expand Down Expand Up @@ -750,6 +760,7 @@ py::handle GenericPyProtoCast(Message* src, py::return_value_policy policy,
return py_proto.release();
}

#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
py::handle GenericFastCppProtoCast(Message* src, py::return_value_policy policy,
py::handle parent, bool is_const) {
assert(policy != pybind11::return_value_policy::automatic);
Expand Down Expand Up @@ -823,6 +834,7 @@ py::handle GenericFastCppProtoCast(Message* src, py::return_value_policy policy,
throw py::cast_error(message + ReturnValuePolicyName(policy));
}
}
#endif

py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
py::handle parent, bool is_const) {
Expand All @@ -833,6 +845,9 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
// 1. The binary does not have a py_proto_api instance, or
// 2. a) the proto is from the default pool and
// b) the binary is not using fast_cpp_protos.
#if !defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API)
return GenericPyProtoCast(src, policy, parent, is_const);
#else
if (GlobalState::instance()->py_proto_api() == nullptr ||
(src->GetDescriptor()->file()->pool() ==
DescriptorPool::generated_pool() &&
Expand Down Expand Up @@ -861,6 +876,7 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
// construct a mapping between C++ pool() and python pool(), and then
// use the PyProto_API to make it work.
return GenericFastCppProtoCast(src, policy, parent, is_const);
#endif
}

} // namespace pybind11_protobuf

0 comments on commit 2199159

Please sign in to comment.