Skip to content

Commit 9680e85

Browse files
author
Braden Dubois
committed
Added 'bind_set(...)' to std_binds
Based off github.com/beatmax's solution.
1 parent a09cf61 commit 9680e85

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

include/pybind11/stl_bind.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,82 @@ class_<Vector, holder_type> bind_vector(handle scope, std::string const &name, A
569569
return cl;
570570
}
571571

572+
//
573+
// std::set
574+
//
575+
template <typename Set, typename holder_type = std::unique_ptr<Set>, typename... Args>
576+
class_<Set, holder_type> bind_set(handle scope, std::string const &name, Args &&...args) {
577+
using Class_ = class_<Set, holder_type>;
578+
using T = typename Set::value_type;
579+
using ItType = typename Set::iterator;
580+
581+
auto vtype_info = detail::get_type_info(typeid(T));
582+
bool local = !vtype_info || vtype_info->module_local;
583+
584+
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
585+
cl.def(init<>());
586+
cl.def(init<const Set &>(), "Copy constructor");
587+
cl.def(init([](iterable it) {
588+
auto s = std::unique_ptr<Set>(new Set());
589+
for (handle h : it)
590+
s->insert(h.cast<T>());
591+
return s.release();
592+
}));
593+
cl.def(self == self);
594+
cl.def(self != self);
595+
cl.def(
596+
"remove",
597+
[](Set &s, const T &x) {
598+
auto p = s.find(x);
599+
if (p != s.end())
600+
s.erase(p);
601+
else
602+
throw value_error();
603+
},
604+
arg("x"),
605+
"Remove the item from the set whose value is x. "
606+
"It is an error if there is no such item.");
607+
cl.def(
608+
"__contains__",
609+
[](const Set &s, const T &x) { return s.find(x) != s.end(); },
610+
arg("x"),
611+
"Return true if the container contains ``x``.");
612+
cl.def(
613+
"add",
614+
[](Set &s, const T &value) { s.insert(value); },
615+
arg("x"),
616+
"Add an item to the set.");
617+
cl.def("clear", [](Set &s) { s.clear(); }, "Clear the contents.");
618+
cl.def(
619+
"__iter__",
620+
[](Set &s) {
621+
return make_iterator<return_value_policy::copy, ItType, ItType, T>(s.begin(), s.end());
622+
},
623+
keep_alive<0, 1>() /* Essential: keep set alive while iterator exists */
624+
);
625+
cl.def(
626+
"__repr__",
627+
[name](Set &s) {
628+
std::ostringstream os;
629+
os << name << '{';
630+
for (auto it = s.begin(); it != s.end(); ++it) {
631+
if (it != s.begin())
632+
os << ", ";
633+
os << *it;
634+
}
635+
os << '}';
636+
return os.str();
637+
},
638+
"Return the canonical string representation of this set.");
639+
cl.def(
640+
"__bool__",
641+
[](const Set &s) -> bool { return !s.empty(); },
642+
"Check whether the set is nonempty");
643+
cl.def("__len__", &Set::size);
644+
645+
return cl;
646+
}
647+
572648
//
573649
// std::map, std::unordered_map
574650
//

tests/test_stl_binders.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <deque>
1616
#include <map>
17+
#include <set>
1718
#include <unordered_map>
1819
#include <vector>
1920

@@ -183,6 +184,9 @@ TEST_SUBMODULE(stl_binders, m) {
183184
py::bind_vector<std::vector<El>>(m, "VectorEl");
184185
py::bind_vector<std::vector<std::vector<El>>>(m, "VectorVectorEl");
185186

187+
// test_set_int
188+
py::bind_set<std::set<int>>(m, "SetInt");
189+
186190
// test_map_string_double
187191
py::bind_map<std::map<std::string, double>>(m, "MapStringDouble");
188192
py::bind_map<std::unordered_map<std::string, double>>(m, "UnorderedMapStringDouble");

tests/test_stl_binders.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,29 @@ def test_vector_custom():
149149
assert str(vv_b) == "VectorEl[El{1}, El{2}]"
150150

151151

152+
def test_set_int():
153+
s_a = m.SetInt()
154+
s_b = m.SetInt()
155+
156+
assert len(s_a) == 0
157+
assert s_a == s_b
158+
159+
s_a.add(1)
160+
161+
assert 1 in s_a
162+
assert str(s_a) == "SetInt{1}"
163+
assert s_a != s_b
164+
165+
for i in range(5):
166+
s_a.add(i)
167+
168+
assert sorted(s_a) == [0, 1, 2, 3, 4]
169+
170+
s_a.clear()
171+
assert len(s_a) == 0
172+
assert str(s_a) == "SetInt{}"
173+
174+
152175
def test_map_string_double():
153176
mm = m.MapStringDouble()
154177
mm["a"] = 1

0 commit comments

Comments
 (0)