Skip to content

Commit b2ea16c

Browse files
all_paths_to_bindings, with_endings (#27)
* not ready * with_endings * fix * impl * fix * fix * fix * not ready * update * fix --------- Co-authored-by: TANG ZHIXIONG <zhixiong.tang@momenta.ai>
1 parent fc5ef0a commit b2ea16c

File tree

4 files changed

+242
-18
lines changed

4 files changed

+242
-18
lines changed

docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@
5858
# built documents.
5959
#
6060
# The short X.Y version.
61-
version = '0.2.3'
61+
version = '0.2.4'
6262
# The full version, including alpha/beta/rc tags.
63-
release = '0.2.3'
63+
release = '0.2.4'
6464

6565
# The language for content autogenerated by Sphinx. Refer to documentation
6666
# for a list of supported languages.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build"
55

66
[project]
77
name = "networkx_graph"
8-
version = "0.2.3"
8+
version = "0.2.4"
99
url = "https://github.com/cubao/networkx-graph"
1010
description = "Some customized graph algorithms"
1111
readme = "README.md"

src/main.cpp

Lines changed: 120 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,23 @@ struct Sequences
138138
}
139139
};
140140

141+
inline bool starts_with(const std::vector<int64_t> &nodes,
142+
const std::vector<int64_t> &prefix)
143+
{
144+
return !prefix.empty() && //
145+
prefix.size() <= nodes.size() && //
146+
std::equal(prefix.begin(), prefix.end(), nodes.begin());
147+
}
148+
149+
inline bool ends_with(const std::vector<int64_t> &nodes,
150+
const std::vector<int64_t> &suffix)
151+
{
152+
return !suffix.empty() && //
153+
suffix.size() <= nodes.size() && //
154+
std::equal(suffix.begin(), suffix.end(),
155+
&nodes[nodes.size() - suffix.size()]);
156+
}
157+
141158
inline std::array<double, 2> cheap_ruler_k(double latitude)
142159
{
143160
// https://github.com/cubao/headers/blob/8ed287a7a1e2a5cd221271b19611ba4a3f33d15c/include/cubao/crs_transform.hpp#L212
@@ -887,7 +904,8 @@ struct DiGraph
887904
const Bindings &bindings, //
888905
std::optional<double> offset = {}, //
889906
int direction = 0, //
890-
const Sinks *sinks = nullptr) const
907+
const Sinks *sinks = nullptr, //
908+
bool with_endings = false) const
891909
{
892910
if (bindings.graph != this) {
893911
return {};
@@ -908,13 +926,15 @@ struct DiGraph
908926
}
909927
std::vector<Path> forwards;
910928
if (direction >= 0) {
911-
forwards = __all_path_to_bindings(*src_idx, offset, length->second,
912-
cutoff, bindings, sinks);
929+
forwards =
930+
__all_path_to_bindings(*src_idx, offset, length->second, cutoff,
931+
bindings, sinks, false, with_endings);
913932
}
914933
std::vector<Path> backwards;
915934
if (direction <= 0) {
916-
backwards = __all_path_to_bindings(*src_idx, offset, length->second,
917-
cutoff, bindings, sinks, true);
935+
backwards =
936+
__all_path_to_bindings(*src_idx, offset, length->second, cutoff,
937+
bindings, sinks, true, with_endings);
918938
}
919939
if (round_scale_) {
920940
for (auto &r : forwards) {
@@ -1738,13 +1758,13 @@ struct DiGraph
17381758
}
17391759

17401760
std::vector<Path>
1741-
__all_path_to_bindings(int64_t source, //
1742-
std::optional<double> source_offset, //
1743-
double source_length,
1744-
double cutoff, //
1745-
const Bindings &bindings, //
1746-
const Sinks *sinks = nullptr, //
1747-
bool reverse = false) const
1761+
__all_path_to_bindings__(int64_t source, //
1762+
std::optional<double> source_offset, //
1763+
double source_length, //
1764+
double cutoff, //
1765+
const Bindings &bindings, //
1766+
const Sinks *sinks, //
1767+
bool reverse) const
17481768
{
17491769
auto &node2bindings = bindings.node2bindings;
17501770
if (source_offset) {
@@ -1886,6 +1906,91 @@ struct DiGraph
18861906
[](const auto &p1, const auto &p2) { return p1.dist < p2.dist; });
18871907
return paths;
18881908
}
1909+
1910+
std::vector<Path>
1911+
__all_path_to_bindings(int64_t source, //
1912+
std::optional<double> source_offset, //
1913+
double source_length, //
1914+
double cutoff, //
1915+
const Bindings &bindings, //
1916+
const Sinks *sinks, //
1917+
bool reverse, //
1918+
bool with_endings) const
1919+
{
1920+
auto paths =
1921+
__all_path_to_bindings__(source, source_offset, source_length, //
1922+
cutoff, bindings, sinks, reverse);
1923+
if (!with_endings) {
1924+
return paths;
1925+
}
1926+
std::vector<Path> ending_paths;
1927+
if (!reverse) {
1928+
auto all_paths = __all_paths(source, cutoff, source_offset,
1929+
lengths_, nexts_, sinks);
1930+
for (auto &path : all_paths) {
1931+
bool keep = true;
1932+
for (auto &p : paths) {
1933+
// keep if path not starts with any of paths
1934+
if (starts_with(path.nodes, p.nodes)) {
1935+
keep = false;
1936+
break;
1937+
}
1938+
}
1939+
if (keep) {
1940+
if (round_scale_) {
1941+
path.round(*round_scale_);
1942+
}
1943+
int64_t tail = path.nodes.back();
1944+
double off = *path.end_offset;
1945+
py::object obj = py::none();
1946+
path.binding =
1947+
std::make_tuple(tail, std::make_tuple(off, off, obj));
1948+
ending_paths.push_back(path);
1949+
}
1950+
}
1951+
} else {
1952+
if (source_offset) {
1953+
source_offset = CLIP(0.0, *source_offset, source_length);
1954+
source_offset = source_length - *source_offset;
1955+
}
1956+
auto all_paths = __all_paths(source, cutoff, source_offset,
1957+
lengths_, prevs_, sinks);
1958+
for (auto &p : all_paths) {
1959+
if (p.start_offset) {
1960+
p.start_offset =
1961+
lengths_.at(p.nodes.front()) - *p.start_offset;
1962+
}
1963+
if (p.end_offset) {
1964+
p.end_offset = lengths_.at(p.nodes.back()) - *p.end_offset;
1965+
}
1966+
std::reverse(p.nodes.begin(), p.nodes.end());
1967+
std::swap(p.start_offset, p.end_offset);
1968+
}
1969+
for (auto &path : all_paths) {
1970+
bool keep = true;
1971+
for (auto &p : paths) {
1972+
// keep if path not ends with any of paths
1973+
if (ends_with(path.nodes, p.nodes)) {
1974+
keep = false;
1975+
break;
1976+
}
1977+
}
1978+
if (keep) {
1979+
if (round_scale_) {
1980+
path.round(*round_scale_);
1981+
}
1982+
int64_t head = path.nodes.front();
1983+
double off = *path.start_offset;
1984+
py::object obj = py::none();
1985+
path.binding =
1986+
std::make_tuple(head, std::make_tuple(off, off, obj));
1987+
ending_paths.push_back(path);
1988+
}
1989+
}
1990+
}
1991+
paths.insert(paths.end(), ending_paths.begin(), ending_paths.end());
1992+
return paths;
1993+
}
18891994
};
18901995

18911996
struct ShortestPathWithUbodt
@@ -3170,8 +3275,9 @@ PYBIND11_MODULE(_core, m)
31703275
"cutoff"_a, //
31713276
"bindings"_a, //
31723277
"offset"_a = std::nullopt, //
3173-
"direction"_a = 0,
3174-
"sinks"_a = nullptr, //
3278+
"direction"_a = 0, //
3279+
"sinks"_a = nullptr, //
3280+
"with_endings"_a = false, //
31753281
py::call_guard<py::gil_scoped_release>())
31763282
.def("build_ubodt",
31773283
py::overload_cast<double, int, int>(&DiGraph::build_ubodt,

tests/test_basic.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def calculate_md5(filename, block_size=4096):
3232

3333

3434
def test_version():
35-
assert m.__version__ == "0.2.3"
35+
assert m.__version__ == "0.2.4"
3636

3737

3838
def test_add():
@@ -1134,6 +1134,124 @@ def test_all_paths_to_bindings():
11341134
assert len(backwards) == 2
11351135
assert len(forwards) == 0
11361136

1137+
_, forwards = G.all_paths_to_bindings(
1138+
"w1",
1139+
cutoff=4.0,
1140+
offset=9.0,
1141+
bindings=bindings,
1142+
)
1143+
assert len(forwards) == 1
1144+
_, forwards = G.all_paths_to_bindings(
1145+
"w1",
1146+
cutoff=4.0,
1147+
offset=9.0,
1148+
bindings=bindings,
1149+
with_endings=True,
1150+
)
1151+
assert len(forwards) == 2
1152+
assert forwards[0].to_dict() == {
1153+
"dist": 2.0,
1154+
"nodes": ["w1", "w3"],
1155+
"start": ("w1", 9.0),
1156+
"end": ("w3", 1.0),
1157+
"binding": ("w3", (1.0, 3.0, "obj31")),
1158+
}
1159+
assert forwards[1].to_dict() == {
1160+
"dist": 4.0,
1161+
"nodes": ["w1", "w2"],
1162+
"start": ("w1", 9.0),
1163+
"end": ("w2", 3.0),
1164+
"binding": ("w2", (3.0, 3.0, None)),
1165+
}
1166+
backwards, forwards = G.all_paths_to_bindings(
1167+
"w3",
1168+
cutoff=5.0,
1169+
offset=0.5,
1170+
bindings=bindings,
1171+
with_endings=True,
1172+
)
1173+
assert len(backwards) == 1
1174+
assert backwards[0].to_dict() == {
1175+
"dist": 5.0,
1176+
"nodes": ["w1", "w3"],
1177+
"start": ("w1", 5.5),
1178+
"end": ("w3", 0.5),
1179+
"binding": ("w1", (5.5, 5.5, None)),
1180+
}
1181+
assert len(forwards) == 1
1182+
assert forwards[0].to_dict() == {
1183+
"dist": 0.5,
1184+
"nodes": ["w3"],
1185+
"start": ("w3", 0.5),
1186+
"end": ("w3", 1.0),
1187+
"binding": ("w3", (1.0, 3.0, "obj31")),
1188+
}
1189+
1190+
backwards, forwards = G.all_paths_to_bindings(
1191+
"w3",
1192+
cutoff=5.0,
1193+
offset=2.5,
1194+
bindings=bindings,
1195+
with_endings=True,
1196+
)
1197+
assert len(backwards) == 1
1198+
assert backwards[0].to_dict() == {
1199+
"dist": 5.0,
1200+
"nodes": ["w1", "w3"],
1201+
"start": ("w1", 7.5),
1202+
"end": ("w3", 2.5),
1203+
"binding": ("w1", (7.5, 7.5, None)),
1204+
}
1205+
assert len(forwards) == 1
1206+
assert forwards[0].to_dict() == {
1207+
"dist": 2.5,
1208+
"nodes": ["w3"],
1209+
"start": ("w3", 2.5),
1210+
"end": ("w3", 5.0),
1211+
"binding": ("w3", (5.0, 6.0, "obj32")),
1212+
}
1213+
1214+
assert G.all_paths_to_bindings("w3", cutoff=5.0, offset=1, bindings=bindings)[1][
1215+
0
1216+
].binding == ("w3", (1.0, 3.0, "obj31"))
1217+
assert G.all_paths_to_bindings(
1218+
"w3", cutoff=5.0, offset=1 + 1e-15, bindings=bindings
1219+
)[1][0].binding == ("w3", (5.0, 6.0, "obj32"))
1220+
1221+
expected = {
1222+
"dist": 24.0,
1223+
"nodes": ["w3", "w4", "w6", "w7"],
1224+
"start": ("w3", 10.0),
1225+
"end": ("w7", 1.0),
1226+
"binding": ("w3", (9.0, 10.0, "obj33")),
1227+
}
1228+
backwards, forwards = G.all_paths_to_bindings(
1229+
"w7",
1230+
cutoff=30.0,
1231+
offset=1.0,
1232+
bindings=bindings,
1233+
)
1234+
assert len(backwards) == 1
1235+
assert backwards[0].to_dict() == expected
1236+
assert len(forwards) == 1
1237+
backwards, forwards = G.all_paths_to_bindings(
1238+
"w7",
1239+
cutoff=30.0,
1240+
offset=1.0,
1241+
bindings=bindings,
1242+
with_endings=True,
1243+
)
1244+
assert len(backwards) == 2
1245+
assert backwards[0].to_dict() == expected
1246+
assert backwards[1].to_dict() == {
1247+
"dist": 30.0,
1248+
"nodes": ["w2", "w5", "w7"],
1249+
"start": ("w2", 1.0),
1250+
"end": ("w7", 1.0),
1251+
"binding": ("w2", (1.0, 1.0, None)),
1252+
}
1253+
assert len(forwards) == 1
1254+
11371255

11381256
def test_shortest_zigzag_path():
11391257
G = graph1()

0 commit comments

Comments
 (0)