summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r--synapse/replication/tcp/streams/_base.py629
1 files changed, 419 insertions, 210 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f03111c259..4acefc8a96 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,102 +14,84 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-import itertools
+import heapq
 import logging
 from collections import namedtuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+)
+
+import attr
+
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
-from twisted.internet import defer
+if TYPE_CHECKING:
+    import synapse.server
 
 logger = logging.getLogger(__name__)
 
+# the number of rows to request from an update_function.
+_STREAM_UPDATE_TARGET_ROW_COUNT = 100
 
-MAX_EVENTS_BEHIND = 10000
 
-BackfillStreamRow = namedtuple(
-    "BackfillStreamRow",
-    (
-        "event_id",  # str
-        "room_id",  # str
-        "type",  # str
-        "state_key",  # str, optional
-        "redacts",  # str, optional
-        "relates_to",  # str, optional
-    ),
-)
-PresenceStreamRow = namedtuple(
-    "PresenceStreamRow",
-    (
-        "user_id",  # str
-        "state",  # str
-        "last_active_ts",  # int
-        "last_federation_update_ts",  # int
-        "last_user_sync_ts",  # int
-        "status_msg",  # str
-        "currently_active",  # bool
-    ),
-)
-TypingStreamRow = namedtuple(
-    "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
-)
-ReceiptsStreamRow = namedtuple(
-    "ReceiptsStreamRow",
-    (
-        "room_id",  # str
-        "receipt_type",  # str
-        "user_id",  # str
-        "event_id",  # str
-        "data",  # dict
-    ),
-)
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
-PushersStreamRow = namedtuple(
-    "PushersStreamRow",
-    ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
-)
-CachesStreamRow = namedtuple(
-    "CachesStreamRow",
-    ("cache_func", "keys", "invalidation_ts"),  # str  # list(str)  # int
-)
-PublicRoomsStreamRow = namedtuple(
-    "PublicRoomsStreamRow",
-    (
-        "room_id",  # str
-        "visibility",  # str
-        "appservice_id",  # str, optional
-        "network_id",  # str, optional
-    ),
-)
-DeviceListsStreamRow = namedtuple(
-    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
-)
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
-TagAccountDataStreamRow = namedtuple(
-    "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
-)
-AccountDataStreamRow = namedtuple(
-    "AccountDataStream",
-    ("user_id", "room_id", "data_type", "data"),  # str  # str  # str  # dict
-)
-GroupsStreamRow = namedtuple(
-    "GroupsStreamRow",
-    ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
-)
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = TypeVar("StreamRow", bound=Tuple)
+
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+#   * `updates` is a list of `(token, row)` entries.
+#   * `new_last_token` is the new position in stream.
+#   * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
+
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+#  * instance_name: the writer of the stream
+#  * from_token: the previous stream token: the starting point for fetching the
+#    updates
+#  * to_token: the new stream token: the point to get updates up to
+#  * target_row_count: a target for the number of rows to be returned.
+#
+# The update_function is expected to return up to _approximately_ target_row_count rows.
+# If there are more updates available, it should set `limited` in the result, and
+# it will be called again to get the next batch.
+#
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
 class Stream(object):
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
-    time it was called up until the point `advance_current_token` was called.
+    time it was called.
     """
 
-    NAME = None  # The name of the stream
-    ROW_TYPE = None  # The type of the row. Used by the default impl of parse_row.
-    _LIMITED = True  # Whether the update function takes a limit
+    NAME = None  # type: str  # The name of the stream
+    # The type of the row. Used by the default impl of parse_row.
+    ROW_TYPE = None  # type: Any
 
     @classmethod
-    def parse_row(cls, row):
+    def parse_row(cls, row: StreamRow):
         """Parse a row received over replication
 
         By default, assumes that the row data is an array object and passes its contents
@@ -123,102 +105,138 @@ class Stream(object):
         """
         return cls.ROW_TYPE(*row)
 
-    def __init__(self, hs):
-        # The token from which we last asked for updates
-        self.last_token = self.current_token()
-
-        # The token that we will get updates up to
-        self.upto_token = self.current_token()
+    def __init__(
+        self,
+        local_instance_name: str,
+        current_token_function: Callable[[str], Token],
+        update_function: UpdateFunction,
+    ):
+        """Instantiate a Stream
+
+        `current_token_function` and `update_function` are callbacks which
+        should be implemented by subclasses.
+
+        `current_token_function` takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process). On the writer process this is where
+        the writer has successfully written up to, whereas on other processes
+        this is the position which we have received updates up to over
+        replication. (Note that most streams have a single writer and so their
+        implementations ignore the instance name passed in).
+
+        `update_function` is called to get updates for this stream between a
+        pair of stream tokens. See the `UpdateFunction` type definition for more
+        info.
 
-    def advance_current_token(self):
-        """Updates `upto_token` to "now", which updates up until which point
-        get_updates[_since] will fetch rows till.
+        Args:
+            local_instance_name: The instance name of the current process
+            current_token_function: callback to get the current token, as above
+            update_function: callback go get stream updates, as above
         """
-        self.upto_token = self.current_token()
+        self.local_instance_name = local_instance_name
+        self.current_token = current_token_function
+        self.update_function = update_function
+
+        # The token from which we last asked for updates
+        self.last_token = self.current_token(self.local_instance_name)
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.upto_token = self.current_token()
-        self.last_token = self.upto_token
+        self.last_token = self.current_token(self.local_instance_name)
 
-    @defer.inlineCallbacks
-    def get_updates(self):
+    async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
-        since the stream was constructed if it hadn't been called before),
-        until the `upto_token`
+        since the stream was constructed if it hadn't been called before).
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        updates, current_token = yield self.get_updates_since(self.last_token)
+        current_token = self.current_token(self.local_instance_name)
+        updates, current_token, limited = await self.get_updates_since(
+            self.local_instance_name, self.last_token, current_token
+        )
         self.last_token = current_token
 
-        return updates, current_token
+        return updates, current_token, limited
 
-    @defer.inlineCallbacks
-    def get_updates_since(self, from_token):
+    async def get_updates_since(
+        self, instance_name: str, from_token: Token, upto_token: Token
+    ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        if from_token in ("NOW", "now"):
-            return [], self.upto_token
-
-        current_token = self.upto_token
 
         from_token = int(from_token)
 
-        if from_token == current_token:
-            return [], current_token
+        if from_token == upto_token:
+            return [], upto_token, False
 
-        if self._LIMITED:
-            rows = yield self.update_function(
-                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-            )
+        updates, upto_token, limited = await self.update_function(
+            instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+        )
+        return updates, upto_token, limited
 
-            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-        else:
-            rows = yield self.update_function(from_token, current_token)
 
+def current_token_without_instance(
+    current_token: Callable[[], int]
+) -> Callable[[str], int]:
+    """Takes a current token callback function for a single writer stream
+    that doesn't take an instance name parameter and wraps it in a function that
+    does accept an instance name parameter but ignores it.
+    """
+    return lambda instance_name: current_token()
+
+
+def db_query_to_update_function(
+    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> UpdateFunction:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(instance_name, from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
         updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
 
-        # check we didn't get more rows than the limit.
-        # doing it like this allows the update_function to be a generator.
-        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
-            raise Exception("stream %s has fallen behind" % (self.NAME))
+        return updates, upto_token, limited
 
-        return updates, current_token
+    return update_function
 
-    def current_token(self):
-        """Gets the current token of the underlying streams. Should be provided
-        by the sub classes
 
-        Returns:
-            int
-        """
-        raise NotImplementedError()
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
+    """Makes a suitable function for use as an `update_function` that queries
+    the master process for updates.
+    """
 
-    def update_function(self, from_token, current_token, limit=None):
-        """Get updates between from_token and to_token. If Stream._LIMITED is
-        True then limit is provided, otherwise it's not.
+    client = ReplicationGetStreamUpdates.make_client(hs)
 
-        Returns:
-            Deferred(list(tuple)): the first entry in the tuple is the token for
-                that update, and the rest of the tuple gets used to construct
-                a ``ROW_TYPE`` instance
-        """
-        raise NotImplementedError()
+    async def update_function(
+        instance_name: str, from_token: int, upto_token: int, limit: int
+    ) -> StreamUpdateResult:
+        result = await client(
+            instance_name=instance_name,
+            stream_name=stream_name,
+            from_token=from_token,
+            upto_token=upto_token,
+        )
+        return result["updates"], result["upto_token"], result["limited"]
+
+    return update_function
 
 
 class BackfillStream(Stream):
@@ -226,94 +244,170 @@ class BackfillStream(Stream):
     or it went from being an outlier to not.
     """
 
+    BackfillStreamRow = namedtuple(
+        "BackfillStreamRow",
+        (
+            "event_id",  # str
+            "room_id",  # str
+            "type",  # str
+            "state_key",  # str, optional
+            "redacts",  # str, optional
+            "relates_to",  # str, optional
+        ),
+    )
+
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        self.current_token = store.get_current_backfill_token
-        self.update_function = store.get_all_new_backfill_event_rows
-
-        super(BackfillStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_backfill_token),
+            db_query_to_update_function(store.get_all_new_backfill_event_rows),
+        )
 
 
 class PresenceStream(Stream):
+    PresenceStreamRow = namedtuple(
+        "PresenceStreamRow",
+        (
+            "user_id",  # str
+            "state",  # str
+            "last_active_ts",  # int
+            "last_federation_update_ts",  # int
+            "last_user_sync_ts",  # int
+            "status_msg",  # str
+            "currently_active",  # bool
+        ),
+    )
+
     NAME = "presence"
-    _LIMITED = False
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        presence_handler = hs.get_presence_handler()
 
-        self.current_token = store.get_current_presence_token
-        self.update_function = presence_handler.get_all_presence_updates
+        if hs.config.worker_app is None:
+            # on the master, query the presence handler
+            presence_handler = hs.get_presence_handler()
+            update_function = db_query_to_update_function(
+                presence_handler.get_all_presence_updates
+            )
+        else:
+            # Query master process
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(PresenceStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_presence_token),
+            update_function,
+        )
 
 
 class TypingStream(Stream):
+    TypingStreamRow = namedtuple(
+        "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
+    )
+
     NAME = "typing"
-    _LIMITED = False
     ROW_TYPE = TypingStreamRow
 
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
-        self.current_token = typing_handler.get_current_token
-        self.update_function = typing_handler.get_all_typing_updates
+        if hs.config.worker_app is None:
+            # on the master, query the typing handler
+            update_function = db_query_to_update_function(
+                typing_handler.get_all_typing_updates
+            )
+        else:
+            # Query master process
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(TypingStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(typing_handler.get_current_token),
+            update_function,
+        )
 
 
 class ReceiptsStream(Stream):
+    ReceiptsStreamRow = namedtuple(
+        "ReceiptsStreamRow",
+        (
+            "room_id",  # str
+            "receipt_type",  # str
+            "user_id",  # str
+            "event_id",  # str
+            "data",  # dict
+        ),
+    )
+
     NAME = "receipts"
     ROW_TYPE = ReceiptsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_receipt_stream_id
-        self.update_function = store.get_all_updated_receipts
-
-        super(ReceiptsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_max_receipt_stream_id),
+            db_query_to_update_function(store.get_all_updated_receipts),
+        )
 
 
 class PushRulesStream(Stream):
     """A user has changed their push rules
     """
 
+    PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        super(PushRulesStream, self).__init__(hs)
+        super(PushRulesStream, self).__init__(
+            hs.get_instance_name(), self._current_token, self._update_function
+        )
 
-    def current_token(self):
+    def _current_token(self, instance_name: str) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+    async def _update_function(
+        self, instance_name: str, from_token: Token, to_token: Token, limit: int
+    ):
+        rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], (row[2],)) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher
     """
 
+    PushersStreamRow = namedtuple(
+        "PushersStreamRow",
+        ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
+    )
+
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_pushers_stream_token
-        self.update_function = store.get_all_updated_pushers_rows
-
-        super(PushersStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_pushers_stream_token),
+            db_query_to_update_function(store.get_all_updated_pushers_rows),
+        )
 
 
 class CachesStream(Stream):
@@ -321,120 +415,235 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
+    @attr.s
+    class CachesStreamRow:
+        """Stream to inform workers they should invalidate their cache.
+
+        Attributes:
+            cache_func: Name of the cached function.
+            keys: The entry in the cache to invalidate. If None then will
+                invalidate all.
+            invalidation_ts: Timestamp of when the invalidation took place.
+        """
+
+        cache_func = attr.ib(type=str)
+        keys = attr.ib(type=Optional[List[Any]])
+        invalidation_ts = attr.ib(type=int)
+
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            self.store.get_cache_stream_token,
+            self._update_function,
+        )
 
-        self.current_token = store.get_cache_stream_token
-        self.update_function = store.get_all_updated_caches
+    async def _update_function(
+        self, instance_name: str, from_token: int, upto_token: int, limit: int
+    ):
+        rows = await self.store.get_all_updated_caches(
+            instance_name, from_token, upto_token, limit
+        )
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
 
-        super(CachesStream, self).__init__(hs)
+        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
     """The public rooms list changed
     """
 
+    PublicRoomsStreamRow = namedtuple(
+        "PublicRoomsStreamRow",
+        (
+            "room_id",  # str
+            "visibility",  # str
+            "appservice_id",  # str, optional
+            "network_id",  # str, optional
+        ),
+    )
+
     NAME = "public_rooms"
     ROW_TYPE = PublicRoomsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_current_public_room_stream_id
-        self.update_function = store.get_all_new_public_rooms
-
-        super(PublicRoomsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_public_room_stream_id),
+            db_query_to_update_function(store.get_all_new_public_rooms),
+        )
 
 
 class DeviceListsStream(Stream):
-    """Someone added/changed/removed a device
+    """Either a user has updated their devices or a remote server needs to be
+    told about a device update.
     """
 
+    @attr.s
+    class DeviceListsStreamRow:
+        entity = attr.ib(type=str)
+
     NAME = "device_lists"
-    _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token
-        self.update_function = store.get_all_device_list_changes_for_remotes
-
-        super(DeviceListsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_device_stream_token),
+            db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+        )
 
 
 class ToDeviceStream(Stream):
     """New to_device messages for a client
     """
 
+    ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_to_device_stream_token
-        self.update_function = store.get_all_new_device_messages
-
-        super(ToDeviceStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_to_device_stream_token),
+            db_query_to_update_function(store.get_all_new_device_messages),
+        )
 
 
 class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room
     """
 
+    TagAccountDataStreamRow = namedtuple(
+        "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
+    )
+
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_account_data_stream_id
-        self.update_function = store.get_all_updated_tags
-
-        super(TagAccountDataStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_max_account_data_stream_id),
+            db_query_to_update_function(store.get_all_updated_tags),
+        )
 
 
 class AccountDataStream(Stream):
     """Global or per room account data was changed
     """
 
+    AccountDataStreamRow = namedtuple(
+        "AccountDataStream",
+        ("user_id", "room_id", "data_type"),  # str  # Optional[str]  # str
+    )
+
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(self.store.get_max_account_data_stream_id),
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, to_token: int, limit: int
+    ) -> StreamUpdateResult:
+        limited = False
+        global_results = await self.store.get_updated_global_account_data(
+            from_token, to_token, limit
+        )
 
-        self.current_token = self.store.get_max_account_data_stream_id
+        # if the global results hit the limit, we'll need to limit the room results to
+        # the same stream token.
+        if len(global_results) >= limit:
+            to_token = global_results[-1][0]
+            limited = True
 
-        super(AccountDataStream, self).__init__(hs)
+        room_results = await self.store.get_updated_room_account_data(
+            from_token, to_token, limit
+        )
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        global_results, room_results = yield self.store.get_all_updated_account_data(
-            from_token, from_token, to_token, limit
+        # likewise, if the room results hit the limit, limit the global results to
+        # the same stream token.
+        if len(room_results) >= limit:
+            to_token = room_results[-1][0]
+            limited = True
+
+        # convert the global results to the right format, and limit them to the to_token
+        # at the same time
+        global_rows = (
+            (stream_id, (user_id, None, account_data_type))
+            for stream_id, user_id, account_data_type in global_results
+            if stream_id <= to_token
         )
 
-        results = list(room_results)
-        results.extend(
-            (stream_id, user_id, None, account_data_type, content)
-            for stream_id, user_id, account_data_type, content in global_results
+        # we know that the room_results are already limited to `to_token` so no need
+        # for a check on `stream_id` here.
+        room_rows = (
+            (stream_id, (user_id, room_id, account_data_type))
+            for stream_id, user_id, room_id, account_data_type in room_results
         )
 
-        return results
+        # We need to return a sorted list, so merge them together.
+        #
+        # Note: We order only by the stream ID to work around a bug where the
+        # same stream ID could appear in both `global_rows` and `room_rows`,
+        # leading to a comparison between the data tuples. The comparison could
+        # fail due to attempting to compare the `room_id` which results in a
+        # `TypeError` from comparing a `str` vs `None`.
+        updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+        return updates, to_token, limited
 
 
 class GroupServerStream(Stream):
+    GroupsStreamRow = namedtuple(
+        "GroupsStreamRow",
+        ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
+    )
+
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_group_stream_token),
+            db_query_to_update_function(store.get_all_groups_changes),
+        )
 
-        self.current_token = store.get_group_stream_token
-        self.update_function = store.get_all_groups_changes
 
-        super(GroupServerStream, self).__init__(hs)
+class UserSignatureStream(Stream):
+    """A user has signed their own device with their user-signing key
+    """
+
+    UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+
+    NAME = "user_signature"
+    ROW_TYPE = UserSignatureStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_device_stream_token),
+            db_query_to_update_function(
+                store.get_all_user_signature_changes_for_remotes
+            ),
+        )