Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make stl.h list|set|map_caster more user friendly. #30045

Merged
merged 16 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 176 additions & 32 deletions include/pybind11/stl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "detail/common.h"

#include <deque>
#include <initializer_list>
#include <list>
#include <map>
#include <ostream>
Expand All @@ -36,6 +37,89 @@
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail)

//
// Begin: Equivalent of
// https://github.com/google/clif/blob/ae4eee1de07cdf115c0c9bf9fec9ff28efce6f6c/clif/python/runtime.cc#L388-L438
/*
The three `PyObjectTypeIsConvertibleTo*()` functions below are
the result of converging the behaviors of pybind11 and PyCLIF
(http://github.com/google/clif).

Originally PyCLIF was extremely far on the permissive side of the spectrum,
while pybind11 was very far on the strict side. Originally PyCLIF accepted any
Python iterable as input for a C++ `vector`/`set`/`map` argument, as long as
the elements were convertible. The obvious (in hindsight) problem was that
any empty Python iterable could be passed to any of these C++ types, e.g. `{}`
was accpeted for C++ `vector`/`set` arguments, or `[]` for C++ `map` arguments.

The functions below strike a practical permissive-vs-strict compromise,
informed by tens of thousands of use cases in the wild. A main objective is
to prevent accidents and improve readability:

- Python literals must match the C++ types.

- For C++ `set`: The potentially reducing conversion from a Python sequence
(e.g. Python `list` or `tuple`) to a C++ `set` must be explicit, by going
through a Python `set`.

- However, a Python `set` can still be passed to a C++ `vector`. The rationale
is that this conversion is not reducing. Implicit conversions of this kind
are also fairly commonly used, therefore enforcing explicit conversions
would have an unfavorable cost : benefit ratio; more sloppily speaking,
such an enforcement would be more annoying than helpful.
*/

inline bool PyObjectIsInstanceWithOneOfTpNames(PyObject *obj,
std::initializer_list<const char *> tp_names) {
if (PyType_Check(obj)) {
return false;
}
const char *obj_tp_name = Py_TYPE(obj)->tp_name;
for (const auto *tp_name : tp_names) {
if (std::strcmp(obj_tp_name, tp_name) == 0) {
return true;
}
}
return false;
}

inline bool PyObjectTypeIsConvertibleToStdVector(PyObject *obj) {
if (PySequence_Check(obj) != 0) {
return !PyUnicode_Check(obj) && !PyBytes_Check(obj);
}
return (PyGen_Check(obj) != 0) || (PyAnySet_Check(obj) != 0)
|| PyObjectIsInstanceWithOneOfTpNames(
obj, {"dict_keys", "dict_values", "dict_items", "map", "zip"});
}

inline bool PyObjectTypeIsConvertibleToStdSet(PyObject *obj) {
return (PyAnySet_Check(obj) != 0) || PyObjectIsInstanceWithOneOfTpNames(obj, {"dict_keys"});
}

inline bool PyObjectTypeIsConvertibleToStdMap(PyObject *obj) {
if (PyDict_Check(obj)) {
return true;
}
// Implicit requirement in the conditions below:
// A type with `.__getitem__()` & `.items()` methods must implement these
// to be compatible with https://docs.python.org/3/c-api/mapping.html
if (PyMapping_Check(obj) == 0) {
return false;
}
PyObject *items = PyObject_GetAttrString(obj, "items");
if (items == nullptr) {
PyErr_Clear();
return false;
}
bool is_convertible = (PyCallable_Check(items) != 0);
Py_DECREF(items);
return is_convertible;
}

//
// End: Equivalent of clif/python/runtime.cc
//

/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
template <typename T, typename U>
Expand Down Expand Up @@ -67,24 +151,40 @@ struct set_caster {
}
void reserve_maybe(const anyset &, void *) {}

public:
bool load(handle src, bool convert) {
if (!isinstance<anyset>(src)) {
return false;
}
auto s = reinterpret_borrow<anyset>(src);
value.clear();
reserve_maybe(s, &value);
for (auto entry : s) {
bool convert_iterable(const iterable &itbl, bool convert) {
for (auto it : itbl) {
key_conv conv;
if (!conv.load(entry, convert)) {
if (!conv.load(it, convert)) {
return false;
}
value.insert(cast_op<Key &&>(std::move(conv)));
}
return true;
}

bool convert_anyset(anyset s, bool convert) {
value.clear();
reserve_maybe(s, &value);
return convert_iterable(s, convert);
}

public:
bool load(handle src, bool convert) {
if (!PyObjectTypeIsConvertibleToStdSet(src.ptr())) {
return false;
}
if (isinstance<anyset>(src)) {
value.clear();
return convert_anyset(reinterpret_borrow<anyset>(src), convert);
}
if (!convert) {
return false;
}
assert(isinstance<iterable>(src));
value.clear();
return convert_iterable(reinterpret_borrow<iterable>(src), convert);
}

template <typename T>
static handle cast(T &&src, const return_value_policy_pack &rvpp, handle parent) {
return_value_policy_pack rvpp_local = rvpp;
Expand Down Expand Up @@ -117,12 +217,7 @@ struct map_caster {
}
void reserve_maybe(const dict &, void *) {}

public:
bool load(handle src, bool convert) {
if (!isinstance<dict>(src)) {
return false;
}
auto d = reinterpret_borrow<dict>(src);
bool convert_elements(const dict &d, bool convert) {
value.clear();
reserve_maybe(d, &value);
for (auto it : d) {
Expand All @@ -136,6 +231,25 @@ struct map_caster {
return true;
}

public:
bool load(handle src, bool convert) {
if (!PyObjectTypeIsConvertibleToStdMap(src.ptr())) {
return false;
}
if (isinstance<dict>(src)) {
return convert_elements(reinterpret_borrow<dict>(src), convert);
}
if (!convert) {
return false;
}
auto items = reinterpret_steal<object>(PyMapping_Items(src.ptr()));
if (!items) {
throw error_already_set();
}
assert(isinstance<iterable>(items));
return convert_elements(dict(reinterpret_borrow<iterable>(items)), convert);
}

template <typename T>
static handle cast(T &&src, const return_value_policy_pack &rvpp, handle parent) {
dict d;
Expand Down Expand Up @@ -168,13 +282,35 @@ struct list_caster {
using value_conv = make_caster<Value>;

bool load(handle src, bool convert) {
if (!isinstance<sequence>(src) || isinstance<bytes>(src) || isinstance<str>(src)) {
if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) {
return false;
}
if (isinstance<sequence>(src)) {
return convert_elements(src, convert);
}
if (!convert) {
return false;
}
auto s = reinterpret_borrow<sequence>(src);
// Designed to be behavior-equivalent to passing tuple(src) from Python:
// The conversion to a tuple will first exhaust the generator object, to ensure that
// the generator is not left in an unpredictable (to the caller) partially-consumed
// state.
assert(isinstance<iterable>(src));
return convert_elements(tuple(reinterpret_borrow<iterable>(src)), convert);
}

private:
template <typename T = Type, enable_if_t<has_reserve_method<T>::value, int> = 0>
void reserve_maybe(const sequence &s, Type *) {
value.reserve(s.size());
}
void reserve_maybe(const sequence &, void *) {}

bool convert_elements(handle seq, bool convert) {
auto s = reinterpret_borrow<sequence>(seq);
value.clear();
reserve_maybe(s, &value);
for (auto it : s) {
for (auto it : seq) {
value_conv conv;
if (!conv.load(it, convert)) {
return false;
Expand All @@ -184,13 +320,6 @@ struct list_caster {
return true;
}

private:
template <typename T = Type, enable_if_t<has_reserve_method<T>::value, int> = 0>
void reserve_maybe(const sequence &s, Type *) {
value.reserve(s.size());
}
void reserve_maybe(const sequence &, void *) {}

public:
template <typename T>
static handle cast(T &&src, const return_value_policy_pack &rvpp, handle parent) {
Expand Down Expand Up @@ -240,12 +369,8 @@ struct array_caster {
return size == Size;
}

public:
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src)) {
return false;
}
auto l = reinterpret_borrow<sequence>(src);
bool convert_elements(handle seq, bool convert) {
auto l = reinterpret_borrow<sequence>(seq);
if (!require_size(l.size())) {
return false;
}
Expand All @@ -260,6 +385,25 @@ struct array_caster {
return true;
}

public:
bool load(handle src, bool convert) {
if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) {
return false;
}
if (isinstance<sequence>(src)) {
return convert_elements(src, convert);
}
if (!convert) {
return false;
}
// Designed to be behavior-equivalent to passing tuple(src) from Python:
// The conversion to a tuple will first exhaust the generator object, to ensure that
// the generator is not left in an unpredictable (to the caller) partially-consumed
// state.
assert(isinstance<iterable>(src));
return convert_elements(tuple(reinterpret_borrow<iterable>(src)), convert);
}

template <typename T>
static handle cast(T &&src, const return_value_policy_pack &rvpp, handle parent) {
list l(src.size());
Expand Down
34 changes: 34 additions & 0 deletions tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ struct type_caster<ReferenceSensitiveOptional<T>>
} // namespace detail
} // namespace PYBIND11_NAMESPACE

int pass_std_vector_int(const std::vector<int> &v) {
int zum = 100;
for (const int i : v) {
zum += 2 * i;
}
return zum;
}

TEST_SUBMODULE(stl, m) {
// test_vector
m.def("cast_vector", []() { return std::vector<int>{1}; });
Expand Down Expand Up @@ -549,6 +557,32 @@ TEST_SUBMODULE(stl, m) {
// Without explicitly specifying `take_ownership`, this function leaks.
py::return_value_policy::take_ownership);

m.def("pass_std_vector_int", pass_std_vector_int);
m.def("pass_std_vector_pair_int", [](const std::vector<std::pair<int, int>> &v) {
int zum = 0;
for (const auto &ij : v) {
zum += ij.first * 100 + ij.second;
}
return zum;
});
m.def("pass_std_array_int_2", [](const std::array<int, 2> &a) {
return pass_std_vector_int(std::vector<int>(a.begin(), a.end())) + 1;
});
m.def("pass_std_set_int", [](const std::set<int> &s) {
int zum = 200;
for (const int i : s) {
zum += 3 * i;
}
return zum;
});
m.def("pass_std_map_int", [](const std::map<int, int> &m) {
int zum = 500;
for (const auto &p : m) {
zum += p.first * 1000 + p.second;
}
return zum;
});

// test return_value_policy::_return_as_bytes
m.def(
"invalid_utf8_string_array_as_bytes",
Expand Down
Loading