summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-07-07 12:11:35 +0100
committerGitHub <noreply@github.com>2020-07-07 12:11:35 +0100
commit67d7756fcfb43c2b01a83da10b4f36635fa7b441 (patch)
treef193847474b4ef687e522add3d011f66b9407b16
parentAdd libwebp dependency to Dockerfile (#7791) (diff)
downloadsynapse-67d7756fcfb43c2b01a83da10b4f36635fa7b441.tar.xz
Refactor getting replication updates from database v2. (#7740)
-rw-r--r--changelog.d/7740.misc1
-rw-r--r--synapse/handlers/typing.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py56
-rw-r--r--synapse/storage/data_stores/main/cache.py36
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py54
-rw-r--r--synapse/storage/data_stores/main/devices.py70
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py65
-rw-r--r--synapse/storage/data_stores/main/group_server.py52
-rw-r--r--synapse/storage/data_stores/main/pusher.py108
-rw-r--r--synapse/storage/data_stores/main/room.py41
-rw-r--r--synapse/storage/data_stores/main/tags.py45
11 files changed, 336 insertions, 195 deletions
diff --git a/changelog.d/7740.misc b/changelog.d/7740.misc
new file mode 100644
index 0000000000..f93149502e
--- /dev/null
+++ b/changelog.d/7740.misc
@@ -0,0 +1 @@
+Refactor getting replication updates from database.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6c7abaa578..879c4c07c6 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -294,6 +294,9 @@ class TypingHandler(object):
         rows.sort()
 
         limited = False
+        # We, unusually, use a strict limit here as we have all the rows in
+        # memory rather than pulling them out of the database with a `LIMIT ?`
+        # clause.
         if len(rows) > limit:
             rows = rows[:limit]
             current_id = rows[-1][0]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f196eff072..9076bbe9f1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -198,26 +198,6 @@ def current_token_without_instance(
     return lambda instance_name: current_token()
 
 
-def db_query_to_update_function(
-    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
-) -> UpdateFunction:
-    """Wraps a db query function which returns a list of rows to make it
-    suitable for use as an `update_function` for the Stream class
-    """
-
-    async def update_function(instance_name, from_token, upto_token, limit):
-        rows = await query_function(from_token, upto_token, limit)
-        updates = [(row[0], row[1:]) for row in rows]
-        limited = False
-        if len(updates) >= limit:
-            upto_token = updates[-1][0]
-            limited = True
-
-        return updates, upto_token, limited
-
-    return update_function
-
-
 def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
     """Makes a suitable function for use as an `update_function` that queries
     the master process for updates.
@@ -393,7 +373,7 @@ class PushersStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_pushers_stream_token),
-            db_query_to_update_function(store.get_all_updated_pushers_rows),
+            store.get_all_updated_pushers_rows,
         )
 
 
@@ -421,26 +401,12 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        self.store = hs.get_datastore()
+        store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self.store.get_cache_stream_token,
-            self._update_function,
-        )
-
-    async def _update_function(
-        self, instance_name: str, from_token: int, upto_token: int, limit: int
-    ):
-        rows = await self.store.get_all_updated_caches(
-            instance_name, from_token, upto_token, limit
+            store.get_cache_stream_token,
+            store.get_all_updated_caches,
         )
-        updates = [(row[0], row[1:]) for row in rows]
-        limited = False
-        if len(updates) >= limit:
-            upto_token = updates[-1][0]
-            limited = True
-
-        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
@@ -465,7 +431,7 @@ class PublicRoomsStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_current_public_room_stream_id),
-            db_query_to_update_function(store.get_all_new_public_rooms),
+            store.get_all_new_public_rooms,
         )
 
 
@@ -486,7 +452,7 @@ class DeviceListsStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_device_stream_token),
-            db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+            store.get_all_device_list_changes_for_remotes,
         )
 
 
@@ -504,7 +470,7 @@ class ToDeviceStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_to_device_stream_token),
-            db_query_to_update_function(store.get_all_new_device_messages),
+            store.get_all_new_device_messages,
         )
 
 
@@ -524,7 +490,7 @@ class TagAccountDataStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_max_account_data_stream_id),
-            db_query_to_update_function(store.get_all_updated_tags),
+            store.get_all_updated_tags,
         )
 
 
@@ -612,7 +578,7 @@ class GroupServerStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_group_stream_token),
-            db_query_to_update_function(store.get_all_groups_changes),
+            store.get_all_groups_changes,
         )
 
 
@@ -630,7 +596,5 @@ class UserSignatureStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_device_stream_token),
-            db_query_to_update_function(
-                store.get_all_user_signature_changes_for_remotes
-            ),
+            store.get_all_user_signature_changes_for_remotes,
         )
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d30766e543..f39f556c20 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,7 +16,7 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, Optional, Tuple
+from typing import Any, Iterable, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes
 from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -46,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ):
-        """Fetches cache invalidation rows between the two given IDs written
-        by the given instance. Returns at most `limit` rows.
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for caches replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
         if last_id == current_id:
-            return []
+            return [], current_id, False
 
         def get_all_updated_caches_txn(txn):
             # We purposefully don't bound by the current token, as we want to
@@ -66,7 +83,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
             txn.execute(sql, (last_id, instance_name, limit))
-            return txn.fetchall()
+            updates = [(row[0], row[1:]) for row in txn]
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
 
         return await self.db.runInteraction(
             "get_all_updated_caches", get_all_updated_caches_txn
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 9a1178fb39..d313b9705f 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
-    def get_all_new_device_messages(self, last_pos, current_pos, limit):
-        """
+    async def get_all_new_device_messages(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for to device replication stream.
+
         Args:
-            last_pos(int):
-            current_pos(int):
-            limit(int):
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
         Returns:
-            A deferred list of rows from the device inbox
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
-        if last_pos == current_pos:
-            return defer.succeed([])
+
+        if last_id == current_id:
+            return [], current_id, False
 
         def get_all_new_device_messages_txn(txn):
             # We limit like this as we might have multiple rows per stream_id, and
             # we want to make sure we always get all entries for any stream_id
             # we return.
-            upper_pos = min(current_pos, last_pos + limit)
+            upper_pos = min(current_id, last_id + limit)
             sql = (
                 "SELECT max(stream_id), user_id"
                 " FROM device_inbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
                 " GROUP BY user_id"
             )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows = txn.fetchall()
+            txn.execute(sql, (last_id, upper_pos))
+            updates = [(row[0], row[1:]) for row in txn]
 
             sql = (
                 "SELECT max(stream_id), destination"
@@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 " WHERE ? < stream_id AND stream_id <= ?"
                 " GROUP BY destination"
             )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn)
+            txn.execute(sql, (last_id, upper_pos))
+            updates.extend((row[0], row[1:]) for row in txn)
 
             # Order by ascending stream ordering
-            rows.sort()
+            updates.sort()
 
-            return rows
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
 
-        return self.db.runInteraction(
+            return updates, upto_token, limited
+
+        return await self.db.runInteraction(
             "get_all_new_device_messages", get_all_new_device_messages_txn
         )
 
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 0ff0542453..343cf9a2d5 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -582,32 +582,58 @@ class DeviceWorkerStore(SQLBaseStore):
             return set()
 
     async def get_all_device_list_changes_for_remotes(
-        self, from_key: int, to_key: int, limit: int,
-    ) -> List[Tuple[int, str]]:
-        """Return a list of `(stream_id, entity)` which is the combined list of
-        changes to devices and which destinations need to be poked. Entity is
-        either a user ID (starting with '@') or a remote destination.
-        """
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for device lists replication stream.
 
-        # This query Does The Right Thing where it'll correctly apply the
-        # bounds to the inner queries.
-        sql = """
-            SELECT stream_id, entity FROM (
-                SELECT stream_id, user_id AS entity FROM device_lists_stream
-                UNION ALL
-                SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
-            ) AS e
-            WHERE ? < stream_id AND stream_id <= ?
-            LIMIT ?
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
-        return await self.db.execute(
+        if last_id == current_id:
+            return [], current_id, False
+
+        def _get_all_device_list_changes_for_remotes(txn):
+            # This query Does The Right Thing where it'll correctly apply the
+            # bounds to the inner queries.
+            sql = """
+                SELECT stream_id, entity FROM (
+                    SELECT stream_id, user_id AS entity FROM device_lists_stream
+                    UNION ALL
+                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+                ) AS e
+                WHERE ? < stream_id AND stream_id <= ?
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_id, current_id, limit))
+            updates = [(row[0], row[1:]) for row in txn]
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
+
+        return await self.db.runInteraction(
             "get_all_device_list_changes_for_remotes",
-            None,
-            sql,
-            from_key,
-            to_key,
-            limit,
+            _get_all_device_list_changes_for_remotes,
         )
 
     @cached(max_entries=10000)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 1a0842d4b0..6c3cff82e1 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,7 +14,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, List
+from typing import Dict, List, Tuple
 
 from canonicaljson import encode_canonical_json, json
 
@@ -479,34 +479,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         return result
 
-    def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
-        """Return a list of changes from the user signature stream to notify remotes.
+    async def get_all_user_signature_changes_for_remotes(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for groups replication stream.
+
         Note that the user signature stream represents when a user signs their
         device with their user-signing key, which is not published to other
         users or servers, so no `destination` is needed in the returned
         list. However, this is needed to poke workers.
 
         Args:
-            from_key (int): the stream ID to start at (exclusive)
-            to_key (int): the stream ID to end at (inclusive)
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
 
         Returns:
-            Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
-        """
-        sql = """
-            SELECT stream_id, from_user_id AS user_id
-            FROM user_signature_stream
-            WHERE ? < stream_id AND stream_id <= ?
-            ORDER BY stream_id ASC
-            LIMIT ?
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
-        return self.db.execute(
+
+        if last_id == current_id:
+            return [], current_id, False
+
+        def _get_all_user_signature_changes_for_remotes_txn(txn):
+            sql = """
+                SELECT stream_id, from_user_id AS user_id
+                FROM user_signature_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, current_id, limit))
+
+            updates = [(row[0], (row[1:])) for row in txn]
+
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
+
+        return await self.db.runInteraction(
             "get_all_user_signature_changes_for_remotes",
-            None,
-            sql,
-            from_key,
-            to_key,
-            limit,
+            _get_all_user_signature_changes_for_remotes_txn,
         )
 
 
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index fb1361f1c1..4fb9f9850c 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import List, Tuple
+
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_groups_changes_for_user", _get_groups_changes_for_user_txn
         )
 
-    def get_all_groups_changes(self, from_token, to_token, limit):
-        from_token = int(from_token)
-        has_changed = self._group_updates_stream_cache.has_any_entity_changed(
-            from_token
-        )
+    async def get_all_groups_changes(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for groups replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
+
+        last_id = int(last_id)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+
         if not has_changed:
-            return defer.succeed([])
+            return [], current_id, False
 
         def _get_all_groups_changes_txn(txn):
             sql = """
@@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore):
                 WHERE ? < stream_id AND stream_id <= ?
                 LIMIT ?
             """
-            txn.execute(sql, (from_token, to_token, limit))
-            return [
-                (stream_id, group_id, user_id, gtype, json.loads(content_json))
+            txn.execute(sql, (last_id, current_id, limit))
+            updates = [
+                (stream_id, (group_id, user_id, gtype, json.loads(content_json)))
                 for stream_id, group_id, user_id, gtype, content_json in txn
             ]
 
-        return self.db.runInteraction(
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
+
+        return await self.db.runInteraction(
             "get_all_groups_changes", _get_all_groups_changes_txn
         )
 
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 547b9d69cb..5461016240 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, Iterator
+from typing import Iterable, Iterator, List, Tuple
 
 from canonicaljson import encode_canonical_json, json
 
@@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore):
         rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
         return rows
 
-    def get_all_updated_pushers(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed(([], []))
-
-        def get_all_updated_pushers_txn(txn):
-            sql = (
-                "SELECT id, user_name, access_token, profile_tag, kind,"
-                " app_id, app_display_name, device_display_name, pushkey, ts,"
-                " lang, data"
-                " FROM pushers"
-                " WHERE ? < id AND id <= ?"
-                " ORDER BY id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            updated = txn.fetchall()
-
-            sql = (
-                "SELECT stream_id, user_id, app_id, pushkey"
-                " FROM deleted_pushers"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            deleted = txn.fetchall()
+    async def get_all_updated_pushers_rows(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for pushers replication stream.
 
-            return updated, deleted
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
 
-        return self.db.runInteraction(
-            "get_all_updated_pushers", get_all_updated_pushers_txn
-        )
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
 
-    def get_all_updated_pushers_rows(self, last_id, current_id, limit):
-        """Get all the pushers that have changed between the given tokens.
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
 
-        Returns:
-            Deferred(list(tuple)): each tuple consists of:
-                stream_id (str)
-                user_id (str)
-                app_id (str)
-                pushkey (str)
-                was_deleted (bool): whether the pusher was added/updated (False)
-                    or deleted (True)
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
         if last_id == current_id:
-            return defer.succeed([])
+            return [], current_id, False
 
         def get_all_updated_pushers_rows_txn(txn):
-            sql = (
-                "SELECT id, user_name, app_id, pushkey"
-                " FROM pushers"
-                " WHERE ? < id AND id <= ?"
-                " ORDER BY id ASC LIMIT ?"
-            )
+            sql = """
+                SELECT id, user_name, app_id, pushkey
+                FROM pushers
+                WHERE ? < id AND id <= ?
+                ORDER BY id ASC LIMIT ?
+            """
             txn.execute(sql, (last_id, current_id, limit))
-            results = [list(row) + [False] for row in txn]
-
-            sql = (
-                "SELECT stream_id, user_id, app_id, pushkey"
-                " FROM deleted_pushers"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC LIMIT ?"
-            )
+            updates = [
+                (stream_id, (user_name, app_id, pushkey, False))
+                for stream_id, user_name, app_id, pushkey in txn
+            ]
+
+            sql = """
+                SELECT stream_id, user_id, app_id, pushkey
+                FROM deleted_pushers
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC LIMIT ?
+            """
             txn.execute(sql, (last_id, current_id, limit))
+            updates.extend(
+                (stream_id, (user_name, app_id, pushkey, True))
+                for stream_id, user_name, app_id, pushkey in txn
+            )
+
+            updates.sort()  # Sort so that they're ordered by stream id
 
-            results.extend(list(row) + [True] for row in txn)
-            results.sort()  # Sort so that they're ordered by stream id
+            limited = False
+            upper_bound = current_id
+            if len(updates) >= limit:
+                limited = True
+                upper_bound = updates[-1][0]
 
-            return results
+            return updates, upper_bound, limited
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 13e366536a..c473cf158f 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -803,7 +803,32 @@ class RoomWorkerStore(SQLBaseStore):
 
         return total_media_quarantined
 
-    def get_all_new_public_rooms(self, prev_id, current_id, limit):
+    async def get_all_new_public_rooms(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for public rooms replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
+        if last_id == current_id:
+            return [], current_id, False
+
         def get_all_new_public_rooms(txn):
             sql = """
                 SELECT stream_id, room_id, visibility, appservice_id, network_id
@@ -813,13 +838,17 @@ class RoomWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
 
-            txn.execute(sql, (prev_id, current_id, limit))
-            return txn.fetchall()
+            txn.execute(sql, (last_id, current_id, limit))
+            updates = [(row[0], row[1:]) for row in txn]
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
 
-        if prev_id == current_id:
-            return defer.succeed([])
+            return updates, upto_token, limited
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
 
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index f8c776be3f..290317fd94 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
         return deferred
 
-    @defer.inlineCallbacks
-    def get_all_updated_tags(self, last_id, current_id, limit):
-        """Get all the client tags that have changed on the server
+    async def get_all_updated_tags(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for tags replication stream.
+
         Args:
-            last_id(int): The position to fetch from.
-            current_id(int): The position to fetch up to.
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
         Returns:
-            A deferred list of tuples of stream_id int, user_id string,
-            room_id string, tag string and content string.
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
+
         if last_id == current_id:
-            return []
+            return [], current_id, False
 
         def get_all_updated_tags_txn(txn):
             sql = (
@@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        tag_ids = yield self.db.runInteraction(
+        tag_ids = await self.db.runInteraction(
             "get_all_updated_tags", get_all_updated_tags_txn
         )
 
@@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
                 for tag, content in txn:
                     tags.append(json.dumps(tag) + ":" + content)
                 tag_json = "{" + ",".join(tags) + "}"
-                results.append((stream_id, user_id, room_id, tag_json))
+                results.append((stream_id, (user_id, room_id, tag_json)))
 
             return results
 
         batch_size = 50
         results = []
         for i in range(0, len(tag_ids), batch_size):
-            tags = yield self.db.runInteraction(
+            tags = await self.db.runInteraction(
                 "get_all_updated_tag_content",
                 get_tag_content,
                 tag_ids[i : i + batch_size],
             )
             results.extend(tags)
 
-        return results
+        limited = False
+        upto_token = current_id
+        if len(results) >= limit:
+            upto_token = results[-1][0]
+            limited = True
+
+        return results, upto_token, limited
 
     @defer.inlineCallbacks
     def get_updated_tags(self, user_id, stream_id):