| diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 54b54dcc6a..a8c303e11e 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -493,15 +493,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """Gets the stream ordering corresponding to a given timestamp.
 
         Specifically, finds the stream_ordering of the first event that was
-        received after the timestamp. This is done by a binary search on the
-        events table, since there is no index on received_ts, so is
+        received on or after the timestamp. This is done by a binary search on
+        the events table, since there is no index on received_ts, so is
         relatively slow.
 
         Args:
             ts (int): timestamp in millis
 
         Returns:
-            Deferred[int]: stream ordering of the first event received after
+            Deferred[int]: stream ordering of the first event received on/after
                 the timestamp
         """
         return self.runInteraction(
@@ -510,16 +510,24 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             ts,
         )
 
-    def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
+    @staticmethod
+    def _find_first_stream_ordering_after_ts_txn(txn, ts):
         """
-        Find the stream_ordering of the first event that was received after
-        a given timestamp. This is relatively slow as there is no index on
-        received_ts but we can then use this to delete push actions before
+        Find the stream_ordering of the first event that was received on or
+        after a given timestamp. This is relatively slow as there is no index
+        on received_ts but we can then use this to delete push actions before
         this.
 
         received_ts must necessarily be in the same order as stream_ordering
         and stream_ordering is indexed, so we manually binary search using
         stream_ordering
+
+        Args:
+            txn (twisted.enterprise.adbapi.Transaction):
+            ts (int): timestamp to search for
+
+        Returns:
+            int: stream ordering
         """
         txn.execute("SELECT MAX(stream_ordering) FROM events")
         max_stream_ordering = txn.fetchone()[0]
@@ -527,23 +535,53 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         if max_stream_ordering is None:
             return 0
 
+        # We want the first stream_ordering in which received_ts is greater
+        # than or equal to ts. Call this point X.
+        #
+        # We maintain the invariants:
+        #
+        #   range_start <= X <= range_end
+        #
         range_start = 0
-        range_end = max_stream_ordering
-
+        range_end = max_stream_ordering + 1
+
+        # Given a stream_ordering, look up the timestamp at that
+        # stream_ordering.
+        #
+        # The array may be sparse (we may be missing some stream_orderings).
+        # We treat the gaps as the same as having the same value as the
+        # preceding entry, because we will pick the lowest stream_ordering
+        # which satisfies our requirement of received_ts >= ts.
+        #
+        # For example, if our array of events indexed by stream_ordering is
+        # [10, <none>, 20], we should treat this as being equivalent to
+        # [10, 10, 20].
+        #
         sql = (
             "SELECT received_ts FROM events"
-            " WHERE stream_ordering > ?"
-            " ORDER BY stream_ordering"
+            " WHERE stream_ordering <= ?"
+            " ORDER BY stream_ordering DESC"
             " LIMIT 1"
         )
 
-        while range_end - range_start > 1:
-            middle = int((range_end + range_start) / 2)
+        while range_end - range_start > 0:
+            middle = (range_end + range_start) // 2
             txn.execute(sql, (middle,))
-            middle_ts = txn.fetchone()[0]
+            row = txn.fetchone()
+            if row is None:
+                # no rows with stream_ordering<=middle
+                range_start = middle + 1
+                continue
+
+            middle_ts = row[0]
             if ts > middle_ts:
-                range_start = middle
+                # we got a timestamp lower than the one we were looking for.
+                # definitely need to look higher: X > middle.
+                range_start = middle + 1
             else:
+                # we got a timestamp higher than (or the same as) the one we
+                # were looking for. We aren't yet sure about the point we
+                # looked up, but we can be sure that X <= middle.
                 range_end = middle
 
         return range_end
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
 index 6c1aad149b..dbaaa12e23 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -127,3 +127,70 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _assert_counts(1, 1)
         yield _rotate(10)
         yield _assert_counts(1, 1)
+
+    @tests.unittest.DEBUG
+    @defer.inlineCallbacks
+    def test_find_first_stream_ordering_after_ts(self):
+        def add_event(so, ts):
+            return self.store._simple_insert("events", {
+                "stream_ordering": so,
+                "received_ts": ts,
+                "event_id": "event%i" % so,
+                "type": "",
+                "room_id": "",
+                "content": "",
+                "processed": True,
+                "outlier": False,
+                "topological_ordering": 0,
+                "depth": 0,
+            })
+
+        # start with the base case where there are no events in the table
+        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        self.assertEqual(r, 0)
+
+        # now with one event
+        yield add_event(2, 10)
+        r = yield self.store.find_first_stream_ordering_after_ts(9)
+        self.assertEqual(r, 2)
+        r = yield self.store.find_first_stream_ordering_after_ts(10)
+        self.assertEqual(r, 2)
+        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        self.assertEqual(r, 3)
+
+        # add a bunch of dummy events to the events table
+        for (stream_ordering, ts) in (
+                (3, 110),
+                (4, 120),
+                (5, 120),
+                (10, 130),
+                (20, 140),
+        ):
+            yield add_event(stream_ordering, ts)
+
+        r = yield self.store.find_first_stream_ordering_after_ts(110)
+        self.assertEqual(r, 3,
+                         "First event after 110ms should be 3, was %i" % r)
+
+        # 4 and 5 are both after 12: we want 4 rather than 5
+        r = yield self.store.find_first_stream_ordering_after_ts(120)
+        self.assertEqual(r, 4,
+                         "First event after 120ms should be 4, was %i" % r)
+
+        r = yield self.store.find_first_stream_ordering_after_ts(129)
+        self.assertEqual(r, 10,
+                         "First event after 129ms should be 10, was %i" % r)
+
+        # check we can get the last event
+        r = yield self.store.find_first_stream_ordering_after_ts(140)
+        self.assertEqual(r, 20,
+                         "First event after 14ms should be 20, was %i" % r)
+
+        # off the end
+        r = yield self.store.find_first_stream_ordering_after_ts(160)
+        self.assertEqual(r, 21)
+
+        # check we can find an event at ordering zero
+        yield add_event(0, 5)
+        r = yield self.store.find_first_stream_ordering_after_ts(1)
+        self.assertEqual(r, 0)
 |