summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/chunk_ordered_table.py298
-rw-r--r--synapse/storage/events.py9
-rw-r--r--synapse/storage/schema/delta/49/event_chunks.sql49
-rw-r--r--synapse/util/katriel_bodlaender.py298
-rw-r--r--tests/storage/test_chunk_linearizer_table.py181
-rw-r--r--tests/util/test_katriel_bodlaender.py58
6 files changed, 893 insertions, 0 deletions
diff --git a/synapse/storage/chunk_ordered_table.py b/synapse/storage/chunk_ordered_table.py
new file mode 100644
index 0000000000..33089c2c60
--- /dev/null
+++ b/synapse/storage/chunk_ordered_table.py
@@ -0,0 +1,298 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import logging
+
+from synapse.storage._base import SQLBaseStore
+from synapse.util.katriel_bodlaender import OrderedListStore
+from synapse.util.metrics import Measure
+
+import synapse.metrics
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+rebalance_counter = metrics.register_counter("rebalances")
+
+
+logger = logging.getLogger(__name__)
+
+
+class ChunkDBOrderedListStore(OrderedListStore):
+    """Used as the list store for room chunks, efficiently maintaining them in
+    topological order on updates.
+
+    The class is designed for use inside transactions and so takes a
+    transaction object in the constructor. This means that it needs to be
+    re-instantiated in each transaction, so all state needs to be stored
+    in the database.
+
+    Internally the ordering is implemented using floats, and the average is
+    taken when a node is inserted inbetween other nodes. To avoid presicion
+    errors a minimum difference between sucessive orderings is attempted to be
+    kept; whenever the difference is too small we attempt to rebalance. See
+    the `_rebalance` function for implementation details.
+
+    Note that OrderedListStore orders nodes such that source of an edge
+    comes before the target. This is counter intuitive when edges represent
+    causality, so for the purposes of ordering algorithm we invert the edge
+    directions, i.e. if chunk A has a prev chunk of B then we say that the
+    edge is from B to A. This ensures that newer chunks get inserted at the
+    end (rather than the start).
+
+    Args:
+        txn
+        room_id (str)
+        clock
+        rebalance_digits (int): When a rebalance is triggered we rebalance
+            in a range around the node, where the bounds are rounded to this
+            number of digits.
+        min_difference (int): A rebalance is triggered when the difference
+            between two successive orderings are less than the reverse of
+            this.
+    """
+    def __init__(self,
+                 txn, room_id, clock,
+                 rebalance_digits=3,
+                 min_difference=1000000):
+        self.txn = txn
+        self.room_id = room_id
+        self.clock = clock
+
+        self.rebalance_digits = rebalance_digits
+        self.min_difference = 1. / min_difference
+
+    def is_before(self, a, b):
+        """Implements OrderedListStore"""
+        return self._get_order(a) < self._get_order(b)
+
+    def get_prev(self, node_id):
+        """Implements OrderedListStore"""
+        order = self._get_order(node_id)
+
+        sql = """
+            SELECT chunk_id FROM chunk_linearized
+            WHERE ordering < ? AND room_id = ?
+            ORDER BY ordering DESC
+            LIMIT 1
+        """
+
+        self.txn.execute(sql, (order, self.room_id,))
+
+        row = self.txn.fetchone()
+        if row:
+            return row[0]
+        return None
+
+    def get_next(self, node_id):
+        """Implements OrderedListStore"""
+        order = self._get_order(node_id)
+
+        sql = """
+            SELECT chunk_id FROM chunk_linearized
+            WHERE ordering > ? AND room_id = ?
+            ORDER BY ordering ASC
+            LIMIT 1
+        """
+
+        self.txn.execute(sql, (order, self.room_id,))
+
+        row = self.txn.fetchone()
+        if row:
+            return row[0]
+        return None
+
+    def insert_before(self, node_id, target_id):
+        """Implements OrderedListStore"""
+
+        rebalance = False  # Set to true if we need to trigger a rebalance
+
+        if target_id:
+            target_order = self._get_order(target_id)
+            before_id = self.get_prev(target_id)
+
+            if before_id:
+                before_order = self._get_order(before_id)
+                new_order = (target_order + before_order) / 2.
+
+                rebalance = math.fabs(target_order - before_order) < self.min_difference
+            else:
+                new_order = math.floor(target_order) - 1
+        else:
+            # If target_id is None then we insert at the end.
+            self.txn.execute("""
+                SELECT COALESCE(MAX(ordering), 0) + 1
+                FROM chunk_linearized
+                WHERE room_id = ?
+            """, (self.room_id,))
+
+            new_order, = self.txn.fetchone()
+
+        self._insert(node_id, new_order)
+
+        if rebalance:
+            self._rebalance(node_id)
+
+    def insert_after(self, node_id, target_id):
+        """Implements OrderedListStore"""
+
+        rebalance = False  # Set to true if we need to trigger a rebalance
+
+        if target_id:
+            target_order = self._get_order(target_id)
+            after_id = self.get_next(target_id)
+            if after_id:
+                after_order = self._get_order(after_id)
+                new_order = (target_order + after_order) / 2.
+
+                rebalance = math.fabs(target_order - after_order) < self.min_difference
+            else:
+                new_order = math.ceil(target_order) + 1
+        else:
+            # If target_id is None then we insert at the start.
+            self.txn.execute("""
+                SELECT COALESCE(MIN(ordering), 0) - 1
+                FROM chunk_linearized
+                WHERE room_id = ?
+            """, (self.room_id,))
+
+            new_order, = self.txn.fetchone()
+
+        self._insert(node_id, new_order)
+
+        if rebalance:
+            self._rebalance(node_id)
+
+    def get_nodes_with_edges_to(self, node_id):
+        """Implements OrderedListStore"""
+
+        # Note that we use the inverse relation here
+        sql = """
+            SELECT l.ordering, l.chunk_id FROM chunk_graph AS g
+            INNER JOIN chunk_linearized AS l ON g.prev_id = l.chunk_id
+            WHERE g.chunk_id = ?
+        """
+        self.txn.execute(sql, (node_id,))
+        return self.txn.fetchall()
+
+    def get_nodes_with_edges_from(self, node_id):
+        """Implements OrderedListStore"""
+
+        # Note that we use the inverse relation here
+        sql = """
+            SELECT l.ordering, l.chunk_id FROM chunk_graph AS g
+            INNER JOIN chunk_linearized AS l ON g.chunk_id = l.chunk_id
+            WHERE g.prev_id = ?
+        """
+        self.txn.execute(sql, (node_id,))
+        return self.txn.fetchall()
+
+    def _delete_ordering(self, node_id):
+        """Implements OrderedListStore"""
+
+        SQLBaseStore._simple_delete_txn(
+            self.txn,
+            table="chunk_linearized",
+            keyvalues={"chunk_id": node_id},
+        )
+
+    def _add_edge_to_graph(self, source_id, target_id):
+        """Implements OrderedListStore"""
+
+        # Note that we use the inverse relation
+        SQLBaseStore._simple_insert_txn(
+            self.txn,
+            table="chunk_graph",
+            values={"chunk_id": target_id, "prev_id": source_id}
+        )
+
+    def _insert(self, node_id, order):
+        """Inserts the node with the given ordering.
+        """
+        SQLBaseStore._simple_insert_txn(
+            self.txn,
+            table="chunk_linearized",
+            values={
+                "chunk_id": node_id,
+                "room_id": self.room_id,
+                "ordering": order,
+            }
+        )
+
+    def _get_order(self, node_id):
+        """Get the ordering of the given node.
+        """
+
+        return SQLBaseStore._simple_select_one_onecol_txn(
+            self.txn,
+            table="chunk_linearized",
+            keyvalues={"chunk_id": node_id},
+            retcol="ordering"
+        )
+
+    def _rebalance(self, node_id):
+        """Rebalances the list around the given node to ensure that the
+        ordering floats don't get too small.
+
+        This works by finding a range that includes the given node, and
+        recalculating the ordering floats such that they're equidistant in
+        that range.
+        """
+
+        logger.info("Rebalancing room %s, chunk %s", self.room_id, node_id)
+
+        with Measure(self.clock, "chunk_rebalance"):
+            # We pick the interval to try and minimise the number of decimal
+            # places, i.e. we round to nearest float with `rebalance_digits` and
+            # use that as the middle of the interval
+            order = self._get_order(node_id)
+            a = round(order, self.rebalance_digits)
+            if order > a:
+                min_order = a
+                max_order = a + 10 ** -self.rebalance_digits
+            else:
+                min_order = a - 10 ** -self.rebalance_digits
+                max_order = a
+
+            # Now we get all the nodes in the range. We add the minimum difference
+            # to the bounds to ensure that we don't accidentally move a node to be
+            # within the minimum difference of a node outside the range.
+            sql = """
+                SELECT chunk_id FROM chunk_linearized
+                WHERE ordering >= ? AND ordering <= ? AND room_id = ?
+            """
+            self.txn.execute(sql, (
+                min_order - self.min_difference,
+                max_order + self.min_difference,
+                self.room_id,
+            ))
+
+            chunk_ids = [c for c, in self.txn]
+
+            sql = """
+                UPDATE chunk_linearized
+                SET ordering = ?
+                WHERE chunk_id = ?
+            """
+
+            step = (max_order - min_order) / len(chunk_ids)
+            self.txn.executemany(
+                sql,
+                (
+                    ((idx * step + min_order), chunk_id)
+                    for idx, chunk_id in enumerate(chunk_ids)
+                )
+            )
+
+            rebalance_counter.inc()
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 05cde96afc..70b9041eee 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -232,6 +232,15 @@ class EventsStore(EventsWorkerStore):
             psql_only=True,
         )
 
+        self.register_background_index_update(
+            "events_chunk_index",
+            index_name="events_chunk_index",
+            table="events",
+            columns=["room_id", "chunk_id", "topological_ordering", "stream_ordering"],
+            unique=True,
+            psql_only=True,
+        )
+
         self._event_persist_queue = _EventPeristenceQueue()
 
         self._state_resolution_handler = hs.get_state_resolution_handler()
diff --git a/synapse/storage/schema/delta/49/event_chunks.sql b/synapse/storage/schema/delta/49/event_chunks.sql
new file mode 100644
index 0000000000..6b428b4ef8
--- /dev/null
+++ b/synapse/storage/schema/delta/49/event_chunks.sql
@@ -0,0 +1,49 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN chunk_id BIGINT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+    ('events_chunk_index', '{}');
+
+-- Stores how chunks of graph relate to each other
+CREATE TABLE chunk_graph (
+    chunk_id BIGINT NOT NULL,
+    prev_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX chunk_graph_id ON chunk_graph (chunk_id, prev_id);
+CREATE INDEX chunk_graph_prev_id ON chunk_graph (prev_id);
+
+-- The extremities in each chunk. Note that these are pointing to events that
+-- we don't have, rather than boundary between chunks.
+CREATE TABLE chunk_backwards_extremities (
+    chunk_id BIGINT NOT NULL,
+    event_id TEXT NOT NULL
+);
+
+CREATE INDEX chunk_backwards_extremities_id ON chunk_backwards_extremities(chunk_id, event_id);
+CREATE INDEX chunk_backwards_extremities_event_id ON chunk_backwards_extremities(event_id);
+
+-- Maintains an absolute ordering of chunks. Gets updated when we see new
+-- edges between chunks.
+CREATE TABLE chunk_linearized (
+    chunk_id BIGINT NOT NULL,
+    room_id TEXT NOT NULL,
+    ordering DOUBLE PRECISION NOT NULL
+);
+
+CREATE UNIQUE INDEX chunk_linearized_id ON chunk_linearized (chunk_id);
+CREATE INDEX chunk_linearized_ordering ON chunk_linearized (room_id, ordering);
diff --git a/synapse/util/katriel_bodlaender.py b/synapse/util/katriel_bodlaender.py
new file mode 100644
index 0000000000..b924a4cfdf
--- /dev/null
+++ b/synapse/util/katriel_bodlaender.py
@@ -0,0 +1,298 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module contains an implementation of the Katriel-Bodlaender algorithm,
+which is used to do online topological ordering of graphs.
+
+Note that the ordering derived from the graph has the first node one with no
+incoming edges at the start, and the last node one with no outgoing edges.
+
+This ordering is therefore opposite to what one might expect when considering
+the room DAG, as newer messages would be added to the start rather than the
+end.
+
+***We therefore invert the direction of edges when using the algorithm***
+
+See https://www.sciencedirect.com/science/article/pii/S0304397507006573
+"""
+
+from abc import ABCMeta, abstractmethod
+
+
+class OrderedListStore(object):
+    """An abstract base class that is used to store a topological ordering of
+    a graph. Suitable for use with the Katriel-Bodlaender algorithm.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def is_before(self, first_node, second_node):
+        """Returns whether the first node is before the second node.
+
+        Args:
+            first_node (str)
+            second_node (str)
+
+        Returns:
+            bool: True if first_node is before second_node
+        """
+        pass
+
+    @abstractmethod
+    def get_prev(self, node_id):
+        """Gets the node immediately before the given node
+
+        Args:
+            node_id (str)
+
+        Returns:
+            str|None: A node ID or None if no preceding node exists
+        """
+        pass
+
+    @abstractmethod
+    def get_next(self, node_id):
+        """Gets the node immediately after the given node
+
+        Args:
+            node_id (str)
+
+        Returns:
+            str|None: A node ID or None if no proceding node exists
+        """
+        pass
+
+    @abstractmethod
+    def insert_before(self, node_id, target_id):
+        """Inserts node immediately before target node.
+
+        If target_id is None then the node is inserted at the end of the list
+
+        Args:
+            node_id (str)
+            target_id (str|None)
+        """
+        pass
+
+    @abstractmethod
+    def insert_after(self, node_id, target_id):
+        """Inserts node immediately after target node.
+
+        If target_id is None then the node is inserted at the start of the list
+
+        Args:
+            node_id (str)
+            target_id (str|None)
+        """
+        pass
+
+    @abstractmethod
+    def get_nodes_with_edges_to(self, node_id):
+        """Get all nodes with edges to the given node
+
+        Args:
+            node_id (str)
+
+        Returns:
+            list[tuple[float, str]]: Returns a list of tuple of an ordering
+            term and the node ID. The ordering term can be used to sort the
+            returned list.
+            The ordering is valid until subsequent calls to insert_* functions
+        """
+        pass
+
+    @abstractmethod
+    def get_nodes_with_edges_from(self, node_id):
+        """Get all nodes with edges from the given node
+
+        Args:
+            node_id (str)
+
+        Returns:
+            list[tuple[float, str]]: Returns a list of tuple of an ordering
+            term and the node ID. The ordering term can be used to sort the
+            returned list.
+            The ordering is valid until subsequent calls to insert_* functions
+        """
+        pass
+
+    @abstractmethod
+    def _delete_ordering(self, node_id):
+        """Deletes the given node from the ordered list (but not the graph).
+
+        Used when we want to reinsert it into a different position
+
+        Args:
+            node_id (str)
+        """
+        pass
+
+    @abstractmethod
+    def _add_edge_to_graph(self, source_id, target_id):
+        """Adds an edge to the graph from source to target.
+
+        Does not update ordering.
+
+        Args:
+            source_id (str)
+            target_id (str)
+        """
+        pass
+
+    def add_node(self, node_id):
+        """Adds a node to the graph.
+
+        Args:
+            node_id (str)
+        """
+        self.insert_before(node_id, None)
+
+    def add_edge(self, source, target):
+        """Adds a new edge is added to the graph and updates the ordering.
+
+        See module level docs.
+
+        Note that both the source and target nodes must have been inserted into
+        the store (at an arbitrary position) already.
+
+        Args:
+            source (str): The source node of the new edge
+            target (str): The target node of the new edge
+        """
+
+        # The following is the Katriel-Bodlaender algorithm.
+
+        to_s = []
+        from_t = []
+        to_s_neighbours = []
+        from_t_neighbours = []
+        to_s_indegree = 0
+        from_t_outdegree = 0
+        s = source
+        t = target
+
+        while s and t and not self.is_before(s, t):
+            m_s = to_s_indegree
+            m_t = from_t_outdegree
+
+            pe_s = self.get_nodes_with_edges_to(s)
+            fe_t = self.get_nodes_with_edges_from(t)
+
+            l_s = len(pe_s)
+            l_t = len(fe_t)
+
+            if m_s + l_s <= m_t + l_t:
+                to_s.append(s)
+                to_s_neighbours.extend(pe_s)
+                to_s_indegree += l_s
+
+                if to_s_neighbours:
+                    to_s_neighbours.sort()
+                    _, s = to_s_neighbours.pop()
+                else:
+                    s = None
+
+            if m_s + l_s >= m_t + l_t:
+                from_t.append(t)
+                from_t_neighbours.extend(fe_t)
+                from_t_outdegree += l_t
+
+                if from_t_neighbours:
+                    from_t_neighbours.sort(reverse=True)
+                    _, t = from_t_neighbours.pop()
+                else:
+                    t = None
+
+        if s is None:
+            s = self.get_prev(target)
+
+        if t is None:
+            t = self.get_next(source)
+
+        while to_s:
+            s1 = to_s.pop()
+            self._delete_ordering(s1)
+            self.insert_after(s1, s)
+            s = s1
+
+        while from_t:
+            t1 = from_t.pop()
+            self._delete_ordering(t1)
+            self.insert_before(t1, t)
+            t = t1
+
+        self._add_edge_to_graph(source, target)
+
+
+class InMemoryOrderedListStore(OrderedListStore):
+    """An in memory OrderedListStore
+    """
+
+    def __init__(self):
+        # The ordered list of nodes
+        self.list = []
+
+        # Map from node to set of nodes that it references
+        self.edges_from = {}
+
+        # Map from node to set of nodes that it is referenced by
+        self.edges_to = {}
+
+    def is_before(self, first_node, second_node):
+        return self.list.index(first_node) < self.list.index(second_node)
+
+    def get_prev(self, node_id):
+        idx = self.list.index(node_id) - 1
+        if idx >= 0:
+            return self.list[idx]
+        else:
+            return None
+
+    def get_next(self, node_id):
+        idx = self.list.index(node_id) + 1
+        if idx < len(self.list):
+            return self.list[idx]
+        else:
+            return None
+
+    def insert_before(self, node_id, target_id):
+        if target_id is not None:
+            idx = self.list.index(target_id)
+            self.list.insert(idx, node_id)
+        else:
+            self.list.append(node_id)
+
+    def insert_after(self, node_id, target_id):
+        if target_id is not None:
+            idx = self.list.index(target_id) + 1
+            self.list.insert(idx, node_id)
+        else:
+            self.list.insert(0, node_id)
+
+    def _delete_ordering(self, node_id):
+        self.list.remove(node_id)
+
+    def get_nodes_with_edges_to(self, node_id):
+        to_nodes = self.edges_to.get(node_id, [])
+        return [(self.list.index(nid), nid) for nid in to_nodes]
+
+    def get_nodes_with_edges_from(self, node_id):
+        from_nodes = self.edges_from.get(node_id, [])
+        return [(self.list.index(nid), nid) for nid in from_nodes]
+
+    def _add_edge_to_graph(self, source_id, target_id):
+        self.edges_from.setdefault(source_id, set()).add(target_id)
+        self.edges_to.setdefault(target_id, set()).add(source_id)
diff --git a/tests/storage/test_chunk_linearizer_table.py b/tests/storage/test_chunk_linearizer_table.py
new file mode 100644
index 0000000000..5ca436c555
--- /dev/null
+++ b/tests/storage/test_chunk_linearizer_table.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import random
+import tests.unittest
+import tests.utils
+
+from synapse.storage.chunk_ordered_table import ChunkDBOrderedListStore
+
+
+class ChunkLinearizerStoreTestCase(tests.unittest.TestCase):
+    def __init__(self, *args, **kwargs):
+        super(ChunkLinearizerStoreTestCase, self).__init__(*args, **kwargs)
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        hs = yield tests.utils.setup_test_homeserver()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+
+    @defer.inlineCallbacks
+    def test_simple_insert_fetch(self):
+        room_id = "foo_room1"
+
+        def test_txn(txn):
+            table = ChunkDBOrderedListStore(
+                txn, room_id, self.clock, 1, 100,
+            )
+
+            table.add_node("A")
+            table.insert_after("B", "A")
+            table.insert_before("C", "A")
+
+            sql = """
+                SELECT chunk_id FROM chunk_linearized
+                WHERE room_id = ?
+                ORDER BY ordering ASC
+            """
+            txn.execute(sql, (room_id,))
+
+            ordered = [r for r, in txn]
+
+            self.assertEqual(["C", "A", "B"], ordered)
+
+        yield self.store.runInteraction("test", test_txn)
+
+    @defer.inlineCallbacks
+    def test_many_insert_fetch(self):
+        room_id = "foo_room2"
+
+        def test_txn(txn):
+            table = ChunkDBOrderedListStore(
+                txn, room_id, self.clock, 1, 20,
+            )
+
+            nodes = [(i, "node_%d" % (i,)) for i in xrange(1, 1000)]
+            expected = [n for _, n in nodes]
+
+            already_inserted = []
+
+            random.shuffle(nodes)
+            while nodes:
+                i, node_id = nodes.pop()
+                if not already_inserted:
+                    table.add_node(node_id)
+                else:
+                    for j, target_id in already_inserted:
+                        if j > i:
+                            break
+
+                    if j < i:
+                        table.insert_after(node_id, target_id)
+                    else:
+                        table.insert_before(node_id, target_id)
+
+                already_inserted.append((i, node_id))
+                already_inserted.sort()
+
+            sql = """
+                SELECT chunk_id FROM chunk_linearized
+                WHERE room_id = ?
+                ORDER BY ordering ASC
+            """
+            txn.execute(sql, (room_id,))
+
+            ordered = [r for r, in txn]
+
+            self.assertEqual(expected, ordered)
+
+        yield self.store.runInteraction("test", test_txn)
+
+    @defer.inlineCallbacks
+    def test_prepend_and_append(self):
+        room_id = "foo_room3"
+
+        def test_txn(txn):
+            table = ChunkDBOrderedListStore(
+                txn, room_id, self.clock, 1, 20,
+            )
+
+            table.add_node("a")
+
+            expected = ["a"]
+
+            for i in xrange(1, 1000):
+                node_id = "node_id_before_%d" % i
+                table.insert_before(node_id, expected[0])
+                expected.insert(0, node_id)
+
+            for i in xrange(1, 1000):
+                node_id = "node_id_after_%d" % i
+                table.insert_after(node_id, expected[-1])
+                expected.append(node_id)
+
+            sql = """
+                SELECT chunk_id FROM chunk_linearized
+                WHERE room_id = ?
+                ORDER BY ordering ASC
+            """
+            txn.execute(sql, (room_id,))
+
+            ordered = [r for r, in txn]
+
+            self.assertEqual(expected, ordered)
+
+        yield self.store.runInteraction("test", test_txn)
+
+    @defer.inlineCallbacks
+    def test_worst_case(self):
+        room_id = "foo_room3"
+
+        def test_txn(txn):
+            table = ChunkDBOrderedListStore(
+                txn, room_id, self.clock, 1, 100,
+            )
+
+            table.add_node("a")
+
+            prev_node = "a"
+
+            expected_prefix = ["a"]
+            expected_suffix = []
+
+            for i in xrange(1, 100):
+                node_id = "node_id_%d" % i
+                if i % 2 == 0:
+                    table.insert_before(node_id, prev_node)
+                    expected_prefix.append(node_id)
+                else:
+                    table.insert_after(node_id, prev_node)
+                    expected_suffix.append(node_id)
+                prev_node = node_id
+
+            sql = """
+                SELECT chunk_id FROM chunk_linearized
+                WHERE room_id = ?
+                ORDER BY ordering ASC
+            """
+            txn.execute(sql, (room_id,))
+
+            ordered = [r for r, in txn]
+
+            expected = expected_prefix + list(reversed(expected_suffix))
+
+            self.assertEqual(expected, ordered)
+
+        yield self.store.runInteraction("test", test_txn)
diff --git a/tests/util/test_katriel_bodlaender.py b/tests/util/test_katriel_bodlaender.py
new file mode 100644
index 0000000000..5768408604
--- /dev/null
+++ b/tests/util/test_katriel_bodlaender.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.katriel_bodlaender import InMemoryOrderedListStore
+
+from tests import unittest
+
+
+class KatrielBodlaenderTests(unittest.TestCase):
+    def test_simple_graph(self):
+        store = InMemoryOrderedListStore()
+
+        nodes = [
+            "node_1",
+            "node_2",
+            "node_3",
+            "node_4",
+        ]
+
+        for node in nodes:
+            store.add_node(node)
+
+        store.add_edge("node_2", "node_3")
+        store.add_edge("node_1", "node_2")
+        store.add_edge("node_3", "node_4")
+
+        self.assertEqual(nodes, store.list)
+
+    def test_reverse_graph(self):
+        store = InMemoryOrderedListStore()
+
+        nodes = [
+            "node_1",
+            "node_2",
+            "node_3",
+            "node_4",
+        ]
+
+        for node in nodes:
+            store.add_node(node)
+
+        store.add_edge("node_3", "node_2")
+        store.add_edge("node_2", "node_1")
+        store.add_edge("node_4", "node_3")
+
+        self.assertEqual(list(reversed(nodes)), store.list)