summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/__init__.py17
-rw-r--r--synapse/storage/data_stores/main/account_data.py62
-rw-r--r--synapse/storage/data_stores/main/appservice.py4
-rw-r--r--synapse/storage/data_stores/main/cache.py102
-rw-r--r--synapse/storage/data_stores/main/events_worker.py35
-rw-r--r--synapse/storage/data_stores/main/group_server.py92
-rw-r--r--synapse/storage/data_stores/main/profile.py2
-rw-r--r--synapse/storage/data_stores/main/push_rule.py14
-rw-r--r--synapse/storage/data_stores/main/search.py96
-rw-r--r--synapse/storage/util/id_generators.py11
10 files changed, 333 insertions, 102 deletions
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 5df9dce79d..4b4763c701 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -24,7 +24,6 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
-    ChainedIdGenerator,
     IdGenerator,
     MultiWriterIdGenerator,
     StreamIdGenerator,
@@ -125,19 +124,6 @@ class DataStore(
         self._clock = hs.get_clock()
         self.database_engine = database.engine
 
-        self._stream_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            extra_tables=[("local_invites", "stream_id")],
-        )
-        self._backfill_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            step=-1,
-            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-        )
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
@@ -164,9 +150,6 @@ class DataStore(
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-        self._push_rules_stream_id_gen = ChainedIdGenerator(
-            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-        )
         self._pushers_id_gen = StreamIdGenerator(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b334..f9eef1b78e 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
 
 import abc
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
-    def get_all_updated_account_data(
-        self, last_global_id, last_room_id, current_id, limit
-    ):
-        """Get all the client account_data that has changed on the server
+    async def get_updated_global_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
+
         Args:
-            last_global_id(int): The position to fetch from for top level data
-            last_room_id(int): The position to fetch from for per room data
-            current_id(int): The position to fetch up to.
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
         Returns:
-            A deferred pair of lists of tuples of stream_id int, user_id string,
-            room_id string, and type string.
+            A list of tuples of stream_id int, user_id string,
+            and type string.
         """
-        if last_room_id == current_id and last_global_id == current_id:
-            return defer.succeed(([], []))
+        if last_id == current_id:
+            return []
 
-        def get_updated_account_data_txn(txn):
+        def get_updated_global_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_global_id, current_id, limit))
-            global_results = txn.fetchall()
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return await self.db.runInteraction(
+            "get_updated_global_account_data", get_updated_global_account_data_txn
+        )
+
+    async def get_updated_room_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
 
+        Args:
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
+        Returns:
+            A list of tuples of stream_id int, user_id string,
+            room_id string and type string.
+        """
+        if last_id == current_id:
+            return []
+
+        def get_updated_room_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_room_id, current_id, limit))
-            room_results = txn.fetchall()
-            return global_results, room_results
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
 
-        return self.db.runInteraction(
-            "get_all_updated_account_data_txn", get_updated_account_data_txn
+        return await self.db.runInteraction(
+            "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
     def get_updated_account_data_for_user(self, user_id, stream_id):
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index efbc06c796..7a1fe8cdd2 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -30,12 +30,12 @@ logger = logging.getLogger(__name__)
 
 
 def _make_exclusive_regex(services_cache):
-    # We precompie a regex constructed from all the regexes that the AS's
+    # We precompile a regex constructed from all the regexes that the AS's
     # have registered for exclusive users.
     exclusive_user_regexes = [
         regex.pattern
         for service in services_cache
-        for regex in service.get_exlusive_user_regexes()
+        for regex in service.get_exclusive_user_regexes()
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 342a87a46b..eac5a4e55b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,8 +16,13 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, Optional
+from typing import Any, Iterable, Optional, Tuple
 
+from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+    EventsStreamCurrentStateRow,
+    EventsStreamEventRow,
+)
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
@@ -66,7 +71,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         )
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "caches":
+        if stream_name == "events":
+            for row in rows:
+                self._process_event_stream_row(token, row)
+        elif stream_name == "backfill":
+            for row in rows:
+                self._invalidate_caches_for_event(
+                    -token,
+                    row.event_id,
+                    row.room_id,
+                    row.type,
+                    row.state_key,
+                    row.redacts,
+                    row.relates_to,
+                    backfilled=True,
+                )
+        elif stream_name == "caches":
             if self._cache_id_gen:
                 self._cache_id_gen.advance(instance_name, token)
 
@@ -85,6 +105,84 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def _process_event_stream_row(self, token, row):
+        data = row.data
+
+        if row.type == EventsStreamEventRow.TypeId:
+            self._invalidate_caches_for_event(
+                token,
+                data.event_id,
+                data.room_id,
+                data.type,
+                data.state_key,
+                data.redacts,
+                data.relates_to,
+                backfilled=False,
+            )
+        elif row.type == EventsStreamCurrentStateRow.TypeId:
+            self._curr_state_delta_stream_cache.entity_has_changed(
+                row.data.room_id, token
+            )
+
+            if data.type == EventTypes.Member:
+                self.get_rooms_for_user_with_stream_ordering.invalidate(
+                    (data.state_key,)
+                )
+        else:
+            raise Exception("Unknown events stream row type %s" % (row.type,))
+
+    def _invalidate_caches_for_event(
+        self,
+        stream_ordering,
+        event_id,
+        room_id,
+        etype,
+        state_key,
+        redacts,
+        relates_to,
+        backfilled,
+    ):
+        self._invalidate_get_event_cache(event_id)
+
+        self.get_latest_event_ids_in_room.invalidate((room_id,))
+
+        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+
+        if not backfilled:
+            self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
+
+        if redacts:
+            self._invalidate_get_event_cache(redacts)
+
+        if etype == EventTypes.Member:
+            self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+            self.get_invited_rooms_for_local_user.invalidate((state_key,))
+
+        if relates_to:
+            self.get_relations_for_event.invalidate_many((relates_to,))
+            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+            self.get_applicable_edit.invalidate((relates_to,))
+
+    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+        """Invalidates the cache and adds it to the cache stream so slaves
+        will know to invalidate their caches.
+
+        This should only be used to invalidate caches where slaves won't
+        otherwise know from other replication streams that the cache should
+        be invalidated.
+        """
+        cache_func = getattr(self, cache_name, None)
+        if not cache_func:
+            return
+
+        cache_func.invalidate(keys)
+        await self.db.runInteraction(
+            "invalidate_cache_and_stream",
+            self._send_invalidation_to_replication,
+            cache_func.__name__,
+            keys,
+        )
+
     def _invalidate_cache_and_stream(self, txn, cache_func, keys):
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 970c31bd05..9130b74eb5 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -37,8 +37,10 @@ from synapse.events import make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import get_domain_from_id
 from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
 from synapse.util.iterutils import batch_iter
@@ -74,6 +76,31 @@ class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
+        if hs.config.worker_app is None:
+            # We are the process in charge of generating stream ids for events,
+            # so instantiate ID generators based on the database
+            self._stream_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                extra_tables=[("local_invites", "stream_id")],
+            )
+            self._backfill_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                step=-1,
+                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+            )
+        else:
+            # Another process is in charge of persisting events and generating
+            # stream IDs: rely on the replication streams to let us know which
+            # IDs we can process.
+            self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
+            self._backfill_id_gen = SlavedIdTracker(
+                db_conn, "events", "stream_ordering", step=-1
+            )
+
         self._get_event_cache = Cache(
             "*getEvent*",
             keylen=3,
@@ -85,6 +112,14 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "events":
+            self._stream_id_gen.advance(token)
+        elif stream_name == "backfill":
+            self._backfill_id_gen.advance(-token)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
     def get_received_ts(self, event_id):
         """Get received_ts (when it was persisted) for the event.
 
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 0963e6c250..fb1361f1c1 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_invited_users_in_group",
         )
 
-    def get_rooms_in_group(self, group_id, include_private=False):
+    def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+        """Retrieve the rooms that belong to a given group. Does not return rooms that
+        lack members.
+
+        Args:
+            group_id: The ID of the group to query for rooms
+            include_private: Whether to return private rooms in results
+
+        Returns:
+            Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
+            form of:
+
+            {
+              "room_id": "!a_room_id:example.com",  # The ID of the room
+              "is_public": False                    # Whether this is a public room or not
+            }
+        """
         # TODO: Pagination
 
-        keyvalues = {"group_id": group_id}
-        if not include_private:
-            keyvalues["is_public"] = True
+        def _get_rooms_in_group_txn(txn):
+            sql = """
+            SELECT room_id, is_public FROM group_rooms
+                WHERE group_id = ?
+                AND room_id IN (
+                    SELECT group_rooms.room_id FROM group_rooms
+                    LEFT JOIN room_stats_current ON
+                        group_rooms.room_id = room_stats_current.room_id
+                        AND joined_members > 0
+                        AND local_users_in_room > 0
+                    LEFT JOIN rooms ON
+                        group_rooms.room_id = rooms.room_id
+                        AND (room_version <> '') = ?
+                )
+            """
+            args = [group_id, False]
 
-        return self.db.simple_select_list(
-            table="group_rooms",
-            keyvalues=keyvalues,
-            retcols=("room_id", "is_public"),
-            desc="get_rooms_in_group",
-        )
+            if not include_private:
+                sql += " AND is_public = ?"
+                args += [True]
+
+            txn.execute(sql, args)
+
+            return [
+                {"room_id": room_id, "is_public": is_public}
+                for room_id, is_public in txn
+            ]
 
-    def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+        return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+
+    def get_rooms_for_summary_by_category(
+        self, group_id: str, include_private: bool = False,
+    ):
         """Get the rooms and categories that should be included in a summary request
 
-        Returns ([rooms], [categories])
+        Args:
+            group_id: The ID of the group to query the summary for
+            include_private: Whether to return private rooms in results
+
+        Returns:
+            Deferred[Tuple[List, Dict]]: A tuple containing:
+
+                * A list of dictionaries with the keys:
+                    * "room_id": str, the room ID
+                    * "is_public": bool, whether the room is public
+                    * "category_id": str|None, the category ID if set, else None
+                    * "order": int, the sort order of rooms
+
+                * A dictionary with the key:
+                    * category_id (str): a dictionary with the keys:
+                        * "is_public": bool, whether the category is public
+                        * "profile": str, the category profile
+                        * "order": int, the sort order of rooms in this category
         """
 
         def _get_rooms_for_summary_txn(txn):
@@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
                 SELECT room_id, is_public, category_id, room_order
                 FROM group_summary_rooms
                 WHERE group_id = ?
+                AND room_id IN (
+                    SELECT group_rooms.room_id FROM group_rooms
+                    LEFT JOIN room_stats_current ON
+                        group_rooms.room_id = room_stats_current.room_id
+                        AND joined_members > 0
+                        AND local_users_in_room > 0
+                    LEFT JOIN rooms ON
+                        group_rooms.room_id = rooms.room_id
+                        AND (room_version <> '') = ?
+                )
             """
 
             if not include_private:
                 sql += " AND is_public = ?"
-                txn.execute(sql, (group_id, True))
+                txn.execute(sql, (group_id, False, True))
             else:
-                txn.execute(sql, (group_id,))
+                txn.execute(sql, (group_id, False))
 
             rooms = [
                 {
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index 2b52cf9c1a..bfc9369f0b 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore):
         return self.db.simple_update(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
-            values={
+            updatevalues={
                 "displayname": displayname,
                 "avatar_url": avatar_url,
                 "last_check": self._clock.time_msec(),
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index b3faafa0a4..ef8f40959f 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -16,19 +16,23 @@
 
 import abc
 import logging
+from typing import Union
 
 from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.push.baserules import list_with_base_rules
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
 from synapse.storage.database import Database
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.storage.util.id_generators import ChainedIdGenerator
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -64,6 +68,7 @@ class PushRulesWorkerStore(
     ReceiptsWorkerStore,
     PusherWorkerStore,
     RoomMemberWorkerStore,
+    EventsWorkerStore,
     SQLBaseStore,
 ):
     """This is an abstract base class where subclasses must implement
@@ -77,6 +82,15 @@ class PushRulesWorkerStore(
     def __init__(self, database: Database, db_conn, hs):
         super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
+        if hs.config.worker.worker_app is None:
+            self._push_rules_stream_id_gen = ChainedIdGenerator(
+                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+        else:
+            self._push_rules_stream_id_gen = SlavedIdTracker(
+                db_conn, "push_rules_stream", "stream_id"
+            )
+
         push_rules_prefill, push_rules_id = self.db.get_cache_dict(
             db_conn,
             "push_rules_stream",
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index ee75b92344..13f49d8060 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -37,7 +37,55 @@ SearchEntry = namedtuple(
 )
 
 
-class SearchBackgroundUpdateStore(SQLBaseStore):
+class SearchWorkerStore(SQLBaseStore):
+    def store_search_entries_txn(self, txn, entries):
+        """Add entries to the search table
+
+        Args:
+            txn (cursor):
+            entries (iterable[SearchEntry]):
+                entries to be added to the table
+        """
+        if not self.hs.config.enable_search:
+            return
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+
+            args = (
+                (
+                    entry.event_id,
+                    entry.room_id,
+                    entry.key,
+                    entry.value,
+                    entry.stream_ordering,
+                    entry.origin_server_ts,
+                )
+                for entry in entries
+            )
+
+            txn.executemany(sql, args)
+
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "INSERT INTO event_search (event_id, room_id, key, value)"
+                " VALUES (?,?,?,?)"
+            )
+            args = (
+                (entry.event_id, entry.room_id, entry.key, entry.value)
+                for entry in entries
+            )
+
+            txn.executemany(sql, args)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+
+class SearchBackgroundUpdateStore(SearchWorkerStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
 
         return num_rows
 
-    def store_search_entries_txn(self, txn, entries):
-        """Add entries to the search table
-
-        Args:
-            txn (cursor):
-            entries (iterable[SearchEntry]):
-                entries to be added to the table
-        """
-        if not self.hs.config.enable_search:
-            return
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
-            )
-
-            args = (
-                (
-                    entry.event_id,
-                    entry.room_id,
-                    entry.key,
-                    entry.value,
-                    entry.stream_ordering,
-                    entry.origin_server_ts,
-                )
-                for entry in entries
-            )
-
-            txn.executemany(sql, args)
-
-        elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            args = (
-                (entry.event_id, entry.room_id, entry.key, entry.value)
-                for entry in entries
-            )
-
-            txn.executemany(sql, args)
-        else:
-            # This should be unreachable.
-            raise Exception("Unrecognized database engine")
-
 
 class SearchStore(SearchBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 86d04ea9ac..f89ce0bed2 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -166,6 +166,7 @@ class ChainedIdGenerator(object):
 
     def __init__(self, chained_generator, db_conn, table, column):
         self.chained_generator = chained_generator
+        self._table = table
         self._lock = threading.Lock()
         self._current_max = _load_current_id(db_conn, table, column)
         self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
@@ -204,6 +205,16 @@ class ChainedIdGenerator(object):
 
             return self._current_max, self.chained_generator.get_current_token()
 
+    def advance(self, token: int):
+        """Stub implementation for advancing the token when receiving updates
+        over replication; raises an exception as this instance should be the
+        only source of updates.
+        """
+
+        raise Exception(
+            "Attempted to advance token on source for table %r", self._table
+        )
+
 
 class MultiWriterIdGenerator:
     """An ID generator that tracks a stream that can have multiple writers.