diff --git a/synapse/storage/chunk_ordered_table.py b/synapse/storage/chunk_ordered_table.py
index 79d0ca44ec..87a57f87b3 100644
--- a/synapse/storage/chunk_ordered_table.py
+++ b/synapse/storage/chunk_ordered_table.py
@@ -281,10 +281,32 @@ class ChunkDBOrderedListStore(OrderedListStore):
# We pick the interval to try and minimise the number of decimal
# places, i.e. we round to nearest float with `rebalance_digits` and
# use that as one side of the interval
+
order = self._get_order(node_id)
+ rebalance_digits = self.rebalance_digits
a = round(order, self.rebalance_digits)
- min_order = a - 10 ** -self.rebalance_digits
- max_order = a + 10 ** -self.rebalance_digits
+ diff = 10 ** - self.rebalance_digits
+
+ while True:
+ min_order = a - diff
+ max_order = a + diff
+
+ sql = """
+ SELECT count(chunk_id) FROM chunk_linearized
+ WHERE ordering >= ? AND ordering <= ? AND room_id = ?
+ """
+ self.txn.execute(sql, (
+ min_order - self.min_difference,
+ max_order + self.min_difference,
+ self.room_id,
+ ))
+
+ cnt, = self.txn.fetchone()
+ step = (max_order - min_order) / cnt
+ if step > 1 / self.min_difference:
+ break
+
+ diff *= 2
# Now we get all the nodes in the range. We add the minimum difference
# to the bounds to ensure that we don't accidentally move a node to be
@@ -292,6 +314,7 @@ class ChunkDBOrderedListStore(OrderedListStore):
sql = """
SELECT chunk_id FROM chunk_linearized
WHERE ordering >= ? AND ordering <= ? AND room_id = ?
+ ORDER BY ordering ASC
"""
self.txn.execute(sql, (
min_order - self.min_difference,
diff --git a/synapse/util/katriel_bodlaender.py b/synapse/util/katriel_bodlaender.py
index 16126ec936..d030e37013 100644
--- a/synapse/util/katriel_bodlaender.py
+++ b/synapse/util/katriel_bodlaender.py
@@ -112,6 +112,12 @@ class OrderedListStore(object):
pe_s = self.get_nodes_with_edges_to(s)
fe_t = self.get_nodes_with_edges_from(t)
+ for n, _ in pe_s:
+ assert n not in to_s
+
+ for n, _ in fe_t:
+ assert n not in from_t
+
l_s = len(pe_s)
l_t = len(fe_t)
@@ -145,15 +151,19 @@ class OrderedListStore(object):
if t is None:
t = self.get_next(source)
+ for node_id in to_s:
+ self._delete_ordering(node_id)
+
while to_s:
s1 = to_s.pop()
- self._delete_ordering(s1)
self._insert_after(s1, s)
s = s1
+ for node_id in from_t:
+ self._delete_ordering(node_id)
+
while from_t:
t1 = from_t.pop()
- self._delete_ordering(t1)
self._insert_before(t1, t)
t = t1
diff --git a/tests/storage/test_chunk_linearizer_table.py b/tests/storage/test_chunk_linearizer_table.py
index beb1ac9a42..9cac62061b 100644
--- a/tests/storage/test_chunk_linearizer_table.py
+++ b/tests/storage/test_chunk_linearizer_table.py
@@ -48,6 +48,7 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase):
table.add_node("A")
table._insert_after("B", "A")
table._insert_before("C", "A")
+ table._insert_after("D", "A")
sql = """
SELECT chunk_id FROM chunk_linearized
@@ -58,7 +59,7 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase):
ordered = [r for r, in txn]
- self.assertEqual(["C", "A", "B"], ordered)
+ self.assertEqual(["C", "A", "D", "B"], ordered)
yield self.store.runInteraction("test", test_txn)
@@ -183,3 +184,44 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase):
self.assertEqual(expected, ordered)
yield self.store.runInteraction("test", test_txn)
+
+ @defer.inlineCallbacks
+ def test_get_edges_to(self):
+ room_id = "foo_room4"
+
+ def test_txn(txn):
+ table = ChunkDBOrderedListStore(
+ txn, room_id, self.clock, 1, 100,
+ )
+
+ table.add_node("A")
+ table._insert_after("B", "A")
+ table._add_edge_to_graph("A", "B")
+ table._insert_before("C", "A")
+ table._add_edge_to_graph("C", "A")
+
+ nodes = table.get_nodes_with_edges_from("A")
+ self.assertEqual([n for _, n in nodes], ["B"])
+
+ nodes = table.get_nodes_with_edges_to("A")
+ self.assertEqual([n for _, n in nodes], ["C"])
+
+ yield self.store.runInteraction("test", test_txn)
+
+ @defer.inlineCallbacks
+ def test_get_next_and_prev(self):
+ room_id = "foo_room5"
+
+ def test_txn(txn):
+ table = ChunkDBOrderedListStore(
+ txn, room_id, self.clock, 1, 100,
+ )
+
+ table.add_node("A")
+ table._insert_after("B", "A")
+ table._insert_before("C", "A")
+
+ self.assertEqual(table.get_next("A"), "B")
+ self.assertEqual(table.get_prev("A"), "C")
+
+ yield self.store.runInteraction("test", test_txn)
diff --git a/tests/util/test_katriel_bodlaender.py b/tests/util/test_katriel_bodlaender.py
index 5768408604..72126bdea9 100644
--- a/tests/util/test_katriel_bodlaender.py
+++ b/tests/util/test_katriel_bodlaender.py
@@ -56,3 +56,29 @@ class KatrielBodlaenderTests(unittest.TestCase):
store.add_edge("node_4", "node_3")
self.assertEqual(list(reversed(nodes)), store.list)
+
+ def test_divergent_graph(self):
+ store = InMemoryOrderedListStore()
+
+ nodes = [
+ "node_1",
+ "node_2",
+ "node_3",
+ "node_4",
+ "node_5",
+ "node_6",
+ ]
+
+ for node in reversed(nodes):
+ store.add_node(node)
+
+ store.add_edge("node_2", "node_3")
+ store.add_edge("node_2", "node_5")
+ store.add_edge("node_1", "node_2")
+ store.add_edge("node_3", "node_4")
+ store.add_edge("node_1", "node_3")
+ store.add_edge("node_4", "node_5")
+ store.add_edge("node_5", "node_6")
+ store.add_edge("node_4", "node_6")
+
+ self.assertEqual(nodes, store.list)
|