summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9104.feature1
-rw-r--r--synapse/app/generic_worker.py15
-rw-r--r--synapse/config/workers.py18
-rw-r--r--synapse/handlers/account_data.py144
-rw-r--r--synapse/handlers/read_marker.py5
-rw-r--r--synapse/handlers/receipts.py27
-rw-r--r--synapse/handlers/room_member.py7
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/account_data.py187
-rw-r--r--synapse/replication/slave/storage/_base.py10
-rw-r--r--synapse/replication/slave/storage/account_data.py40
-rw-r--r--synapse/replication/slave/storage/receipts.py35
-rw-r--r--synapse/replication/tcp/handler.py19
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py22
-rw-r--r--synapse/rest/client/v2_alpha/tags.py11
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/databases/main/__init__.py10
-rw-r--r--synapse/storage/databases/main/account_data.py107
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py92
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/receipts.py108
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql20
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres32
-rw-r--r--synapse/storage/databases/main/tags.py10
-rw-r--r--synapse/storage/util/id_generators.py84
-rw-r--r--tests/storage/test_id_generators.py112
27 files changed, 855 insertions, 280 deletions
diff --git a/changelog.d/9104.feature b/changelog.d/9104.feature
new file mode 100644
index 0000000000..1c4f88bce9
--- /dev/null
+++ b/changelog.d/9104.feature
@@ -0,0 +1 @@
+Add experimental support for moving off receipts and account data persistence off master.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index cb202bda44..e60988fa4a 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import (
 )
 from synapse.rest.client.v1.push_rule import PushRuleRestServlet
 from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory
+from synapse.rest.client.v2_alpha import (
+    account_data,
+    groups,
+    read_marker,
+    receipts,
+    room_keys,
+    sync,
+    tags,
+    user_directory,
+)
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
 from synapse.rest.client.v2_alpha.account_data import (
@@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer):
                     room.register_deprecated_servlets(self, resource)
                     InitialSyncRestServlet(self).register(resource)
                     room_keys.register_servlets(self, resource)
+                    tags.register_servlets(self, resource)
+                    account_data.register_servlets(self, resource)
+                    receipts.register_servlets(self, resource)
+                    read_marker.register_servlets(self, resource)
 
                     SendToDeviceRestServlet(self).register(resource)
 
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 364583f48b..f10e33f7b8 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -56,6 +56,12 @@ class WriterLocations:
     to_device = attr.ib(
         default=["master"], type=List[str], converter=_instance_to_list_converter,
     )
+    account_data = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
+    receipts = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
 
 
 class WorkerConfig(Config):
@@ -127,7 +133,7 @@ class WorkerConfig(Config):
 
         # Check that the configured writers for events and typing also appears in
         # `instance_map`.
-        for stream in ("events", "typing", "to_device"):
+        for stream in ("events", "typing", "to_device", "account_data", "receipts"):
             instances = _instance_to_list_converter(getattr(self.writers, stream))
             for instance in instances:
                 if instance != "master" and instance not in self.instance_map:
@@ -141,6 +147,16 @@ class WorkerConfig(Config):
                 "Must only specify one instance to handle `to_device` messages."
             )
 
+        if len(self.writers.account_data) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `account_data` messages."
+            )
+
+        if len(self.writers.receipts) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `receipts` messages."
+            )
+
         self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
         # Whether this worker should run background tasks or not.
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 341135822e..b1a5df9638 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,14 +13,157 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import random
 from typing import TYPE_CHECKING, List, Tuple
 
+from synapse.replication.http.account_data import (
+    ReplicationAddTagRestServlet,
+    ReplicationRemoveTagRestServlet,
+    ReplicationRoomAccountDataRestServlet,
+    ReplicationUserAccountDataRestServlet,
+)
 from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
 
+class AccountDataHandler:
+    def __init__(self, hs: "HomeServer"):
+        self._store = hs.get_datastore()
+        self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
+
+        self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
+        self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+        self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
+        self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
+        self._account_data_writers = hs.config.worker.writers.account_data
+
+    async def add_account_data_to_room(
+        self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
+        """Add some account_data to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The maximum stream ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_account_data_to_room(
+                user_id, room_id, account_data_type, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+
+            return max_stream_id
+        else:
+            response = await self._room_data_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                account_data_type=account_data_type,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def add_account_data_for_user(
+        self, user_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
+        """Add some account_data to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The maximum stream ID.
+        """
+
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_account_data_for_user(
+                user_id, account_data_type, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._user_data_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                account_data_type=account_data_type,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def add_tag_to_room(
+        self, user_id: str, room_id: str, tag: str, content: JsonDict
+    ) -> int:
+        """Add a tag to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            tag: The tag name to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The next account data ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_tag_to_room(
+                user_id, room_id, tag, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._add_tag_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                tag=tag,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
+        """Remove a tag from a room for a user.
+
+        Returns:
+            The next account data ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.remove_tag_from_room(
+                user_id, room_id, tag
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._remove_tag_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                tag=tag,
+            )
+            return response["max_stream_id"]
+
+
 class AccountDataEventSource:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index a7550806e6..6bb2fd936b 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
         super().__init__(hs)
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
+        self.account_data_handler = hs.get_account_data_handler()
         self.read_marker_linearizer = Linearizer(name="read_marker")
-        self.notifier = hs.get_notifier()
 
     async def received_client_read_marker(
         self, room_id: str, user_id: str, event_id: str
@@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
 
             if should_update:
                 content = {"event_id": event_id}
-                max_id = await self.store.add_account_data_to_room(
+                await self.account_data_handler.add_account_data_to_room(
                     user_id, room_id, "m.fully_read", content
                 )
-                self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a9abdf42e0..cc21fc2284 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
         self.hs = hs
-        self.federation = hs.get_federation_sender()
-        hs.get_federation_registry().register_edu_handler(
-            "m.receipt", self._received_remote_receipt
-        )
+
+        # We only need to poke the federation sender explicitly if its on the
+        # same instance. Other federation sender instances will get notified by
+        # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+        # in the receipts stream.
+        self.federation_sender = None
+        if hs.should_send_federation():
+            self.federation_sender = hs.get_federation_sender()
+
+        # If we can handle the receipt EDUs we do so, otherwise we route them
+        # to the appropriate worker.
+        if hs.get_instance_name() in hs.config.worker.writers.receipts:
+            hs.get_federation_registry().register_edu_handler(
+                "m.receipt", self._received_remote_receipt
+            )
+        else:
+            hs.get_federation_registry().register_instances_for_edu(
+                "m.receipt", hs.config.worker.writers.receipts,
+            )
+
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
 
@@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
         if not is_new:
             return
 
-        await self.federation.send_read_receipt(receipt)
+        if self.federation_sender:
+            await self.federation_sender.send_read_receipt(receipt)
 
 
 class ReceiptEventSource:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index cb5a29bc7e..e001e418f9 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.registration_handler = hs.get_registration_handler()
         self.profile_handler = hs.get_profile_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self.account_data_handler = hs.get_account_data_handler()
 
         self.member_linearizer = Linearizer(name="member")
 
@@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     direct_rooms[key].append(new_room_id)
 
                     # Save back to user's m.direct account data
-                    await self.store.add_account_data_for_user(
+                    await self.account_data_handler.add_account_data_for_user(
                         user_id, AccountDataTypes.DIRECT, direct_rooms
                     )
                     break
@@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Copy each room tag to the new room
         for tag, tag_content in room_tags.items():
-            await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+            await self.account_data_handler.add_tag_to_room(
+                user_id, new_room_id, tag, tag_content
+            )
 
     async def update_membership(
         self,
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index a84a064c8d..dd527e807f 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -15,6 +15,7 @@
 
 from synapse.http.server import JsonResource
 from synapse.replication.http import (
+    account_data,
     devices,
     federation,
     login,
@@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
         presence.register_servlets(hs, self)
         membership.register_servlets(hs, self)
         streams.register_servlets(hs, self)
+        account_data.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
new file mode 100644
index 0000000000..52d32528ee
--- /dev/null
+++ b/synapse/replication/http/account_data.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
+    """Add user account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_user_account_data/:user_id/:type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_user_account_data"
+    PATH_ARGS = ("user_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_for_user(
+            user_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
+    """Add room account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_room_account_data"
+    PATH_ARGS = ("user_id", "room_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_to_room(
+            user_id, room_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationAddTagRestServlet(ReplicationEndpoint):
+    """Add tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_tag"
+    PATH_ARGS = ("user_id", "room_id", "tag")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_tag_to_room(
+            user_id, room_id, tag, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
+    """Remove tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
+
+        {}
+
+    """
+
+    NAME = "remove_tag"
+    PATH_ARGS = (
+        "user_id",
+        "room_id",
+        "tag",
+    )
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag):
+
+        return {}
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+def register_servlets(hs, http_server):
+    ReplicationUserAccountDataRestServlet(hs).register(http_server)
+    ReplicationRoomAccountDataRestServlet(hs).register(http_server)
+    ReplicationAddTagRestServlet(hs).register(http_server)
+    ReplicationRemoveTagRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d0089fe06c..693c9ab901 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )  # type: Optional[MultiWriterIdGenerator]
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 4268565fc8..21afe5f155 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,47 +15,9 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.storage.databases.main.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = SlavedIdTracker(
-            db_conn,
-            "account_data",
-            "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self):
-        return self._account_data_id_gen.get_current_token()
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        elif stream_name == AccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                if not row.room_id:
-                    self.get_global_account_data_by_type_for_user.invalidate(
-                        (row.data_type, row.user_id)
-                    )
-                self.get_account_data_for_user.invalidate((row.user_id,))
-                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
-                self.get_account_data_for_room_and_type.invalidate(
-                    (row.user_id, row.room_id, row.data_type)
-                )
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6195917376..3dfdd9961d 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,43 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 
 from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = SlavedIdTracker(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
-
-    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
-        self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
-        self.get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
-        )
-        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
-        self.get_receipts_for_room.invalidate((room_id, receipt_type))
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ReceiptsStream.NAME:
-            self._receipts_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.invalidate_caches_for_receipt(
-                    row.room_id, row.receipt_type, row.user_id
-                )
-                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
-
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1f89249475..317796d5e0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import (
 from synapse.replication.tcp.protocol import AbstractConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
+    AccountDataStream,
     BackfillStream,
     CachesStream,
     EventsStream,
     FederationStream,
+    ReceiptsStream,
     Stream,
+    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
 )
@@ -132,6 +135,22 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+                # Only add AccountDataStream and TagAccountDataStream as a source on the
+                # instance in charge of account_data persistence.
+                if hs.get_instance_name() in hs.config.worker.writers.account_data:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
+            if isinstance(stream, ReceiptsStream):
+                # Only add ReceiptsStream as a source on the instance in charge of
+                # receipts.
+                if hs.get_instance_name() in hs.config.worker.writers.receipts:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker_app is not None:
                 continue
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_account_data_for_user(
-            user_id, account_data_type, body
-        )
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_account_data_for_user(user_id, account_data_type, body)
 
         return 200, {}
 
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
                 " Use /rooms/!roomId:server.name/read_markers",
             )
 
-        max_id = await self.store.add_account_data_to_room(
+        await self.handler.add_account_data_to_room(
             user_id, room_id, account_data_type, body
         )
 
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
         return 200, {}
 
     async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
     def __init__(self, hs):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, tag):
         requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_tag_to_room(user_id, room_id, tag, body)
 
         return 200, {}
 
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
 
-        max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.remove_tag_from_room(user_id, room_id, tag)
 
         return 200, {}
 
diff --git a/synapse/server.py b/synapse/server.py
index d4c235cda5..9cdda83aa1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender
 from synapse.federation.transport.client import TransportLayerClient
 from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
 from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
+from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.acme import AcmeHandler
 from synapse.handlers.admin import AdminHandler
@@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_module_api(self) -> ModuleApi:
         return ModuleApi(self, self.get_auth_handler())
 
+    @cache_in_self
+    def get_account_data_handler(self) -> AccountDataHandler:
+        return AccountDataHandler(self)
+
     async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
         return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index c4de07a0a8..ae561a2da3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,9 +160,13 @@ class DataStore(
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index bad8260892..68896f34af 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,14 +14,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
 import logging
 from typing import Dict, List, Optional, Set, Tuple
 
 from synapse.api.constants import AccountDataTypes
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
+class AccountDataWorkerStore(SQLBaseStore):
     """This is an abstract base class where subclasses must implement
     `get_max_account_data_stream_id` which can be called in the initializer.
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
+        self._instance_name = hs.get_instance_name()
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_account_data = (
+                self._instance_name in hs.config.worker.writers.account_data
+            )
+
+            self._account_data_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="account_data",
+                instance_name=self._instance_name,
+                tables=[
+                    ("room_account_data", "instance_name", "stream_id"),
+                    ("room_tags_revisions", "instance_name", "stream_id"),
+                    ("account_data", "instance_name", "stream_id"),
+                ],
+                sequence_name="account_data_sequence",
+                writers=hs.config.worker.writers.account_data,
+            )
+        else:
+            self._can_write_to_account_data = True
+
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.events:
+                self._account_data_id_gen = StreamIdGenerator(
+                    db_conn,
+                    "room_account_data",
+                    "stream_id",
+                    extra_tables=[("room_tags_revisions", "stream_id")],
+                )
+            else:
+                self._account_data_id_gen = SlavedIdTracker(
+                    db_conn,
+                    "room_account_data",
+                    "stream_id",
+                    extra_tables=[("room_tags_revisions", "stream_id")],
+                )
+
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
@@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         super().__init__(database, db_conn, hs)
 
-    @abc.abstractmethod
-    def get_max_account_data_stream_id(self):
+    def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream ID for account data stream
 
         Returns:
             int
         """
-        raise NotImplementedError()
+        return self._account_data_id_gen.get_current_token()
 
     @cached()
     async def get_account_data_for_user(
@@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
             )
         )
 
-
-class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn,
-            "room_account_data",
-            "stream_id",
-            extra_tables=[("room_tags_revisions", "stream_id")],
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self) -> int:
-        """Get the current max stream id for the private user data stream
-
-        Returns:
-            The maximum stream ID.
-        """
-        return self._account_data_id_gen.get_current_token()
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+        elif stream_name == AccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                if not row.room_id:
+                    self.get_global_account_data_by_type_for_user.invalidate(
+                        (row.data_type, row.user_id)
+                    )
+                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
+                self.get_account_data_for_room_and_type.invalidate(
+                    (row.user_id, row.room_id, row.data_type)
+                )
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
+        assert self._can_write_to_account_data
+
         content_json = json_encoder.encode(content)
 
         async with self._account_data_id_gen.get_next() as next_id:
@@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
+        assert self._can_write_to_account_data
+
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "add_user_account_data",
@@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore):
         # Invalidate the cache for any ignored users which were added or removed.
         for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
             self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+
+
+class AccountDataStore(AccountDataWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 58d3f71e45..31f70ac5ef 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="to_device",
                 instance_name=self._instance_name,
-                table="device_inbox",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("device_inbox", "instance_name", "stream_id")],
                 sequence_name="device_inbox_sequence",
                 writers=hs.config.worker.writers.to_device,
             )
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e5c03cc609..1b657191a9 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             (rotate_to_stream_ordering,),
         )
 
+    def _remove_old_push_actions_before_txn(
+        self, txn, room_id, user_id, stream_ordering
+    ):
+        """
+        Purges old push actions for a user and room before a given
+        stream_ordering.
+
+        We however keep a months worth of highlighted notifications, so that
+        users can still get a list of recent highlights.
+
+        Args:
+            txn: The transcation
+            room_id: Room ID to delete from
+            user_id: user ID to delete for
+            stream_ordering: The lowest stream ordering which will
+                                  not be deleted.
+        """
+        txn.call_after(
+            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            (room_id, user_id),
+        )
+
+        # We need to join on the events table to get the received_ts for
+        # event_push_actions and sqlite won't let us use a join in a delete so
+        # we can't just delete where received_ts < x. Furthermore we can
+        # only identify event_push_actions by a tuple of room_id, event_id
+        # we we can't use a subquery.
+        # Instead, we look up the stream ordering for the last event in that
+        # room received before the threshold time and delete event_push_actions
+        # in the room with a stream_odering before that.
+        txn.execute(
+            "DELETE FROM event_push_actions "
+            " WHERE user_id = ? AND room_id = ? AND "
+            " stream_ordering <= ?"
+            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+        )
+
+        txn.execute(
+            """
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """,
+            (room_id, user_id, stream_ordering),
+        )
+
 
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         return push_actions
 
-    def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
-
-        Args:
-            txn: The transcation
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id, user_id),
-        )
-
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
-        )
-
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
-        )
-
 
 def _action_has_highlight(actions):
     for action in actions:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4732685f6e..71d823be72 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="events",
                 instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
+                tables=[("events", "instance_name", "stream_ordering")],
                 sequence_name="events_stream_seq",
                 writers=hs.config.worker.writers.events,
             )
@@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="backfill",
                 instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
+                tables=[("events", "instance_name", "stream_ordering")],
                 sequence_name="events_backfill_stream_seq",
                 positive=False,
                 writers=hs.config.worker.writers.events,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1e7949a323..e0e57f0578 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,15 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
 import logging
 from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet import defer
 
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
-    """This is an abstract base class where subclasses must implement
-    `get_max_receipt_stream_id` which can be called in the initializer.
-    """
-
+class ReceiptsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
+        self._instance_name = hs.get_instance_name()
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_receipts = (
+                self._instance_name in hs.config.worker.writers.receipts
+            )
+
+            self._receipts_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="account_data",
+                instance_name=self._instance_name,
+                tables=[("receipts_linearized", "instance_name", "stream_id")],
+                sequence_name="receipts_sequence",
+                writers=hs.config.worker.writers.receipts,
+            )
+        else:
+            self._can_write_to_receipts = True
+
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.events:
+                self._receipts_id_gen = StreamIdGenerator(
+                    db_conn, "receipts_linearized", "stream_id"
+                )
+            else:
+                self._receipts_id_gen = SlavedIdTracker(
+                    db_conn, "receipts_linearized", "stream_id"
+                )
+
         super().__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
-    @abc.abstractmethod
     def get_max_receipt_stream_id(self):
         """Get the current max stream ID for receipts stream
 
         Returns:
             int
         """
-        raise NotImplementedError()
+        return self._receipts_id_gen.get_current_token()
 
     @cached()
     async def get_users_with_read_receipts_in_room(self, room_id):
@@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
-
-class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = StreamIdGenerator(
-            db_conn, "receipts_linearized", "stream_id"
+    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+        self.get_receipts_for_user.invalidate((user_id, receipt_type))
+        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self.get_last_receipt_event_id_for_user.invalidate(
+            (user_id, room_id, receipt_type)
         )
+        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+        self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == ReceiptsStream.NAME:
+            self._receipts_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.invalidate_caches_for_receipt(
+                    row.room_id, row.receipt_type, row.user_id
+                )
+                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def insert_linearized_receipt_txn(
         self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
         """
+        assert self._can_write_to_receipts
+
         res = self.db_pool.simple_select_one_txn(
             txn,
             table="events",
@@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
                     )
                     return None
 
-        txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
-        txn.call_after(
-            self._invalidate_get_users_with_receipts_in_room,
-            room_id,
-            receipt_type,
-            user_id,
-        )
-        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
-        # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(
-            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
+            self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
         )
 
         txn.call_after(
             self._receipts_stream_cache.entity_has_changed, room_id, stream_id
         )
 
-        txn.call_after(
-            self.get_last_receipt_event_id_for_user.invalidate,
-            (user_id, room_id, receipt_type),
-        )
-
         self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_linearized",
@@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
         Automatically does conversion between linearized and graph
         representations.
         """
+        assert self._can_write_to_receipts
+
         if not event_ids:
             return None
 
@@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
     async def insert_graph_receipt(
         self, room_id, receipt_type, user_id, event_ids, data
     ):
+        assert self._can_write_to_receipts
+
         return await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
@@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
     def insert_graph_receipt_txn(
         self, txn, room_id, receipt_type, user_id, event_ids, data
     ):
+        assert self._can_write_to_receipts
+
         txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
         txn.call_after(
             self._invalidate_get_users_with_receipts_in_room,
@@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "data": json_encoder.encode(data),
             },
         )
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
new file mode 100644
index 0000000000..46abf8d562
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
+ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
+ALTER TABLE account_data ADD COLUMN instance_name TEXT;
+
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
new file mode 100644
index 0000000000..4a6e6c74f5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
@@ -0,0 +1,32 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
+
+-- We need to take the max across all the account_data tables as they share the
+-- ID generator
+SELECT setval('account_data_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
+    )
+));
+
+CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
+
+SELECT setval('receipts_sequence', (
+    SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
+));
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 74da9c49f2..50067eabfc 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
         )
         return {row["tag"]: db_to_json(row["content"]) for row in rows}
 
-
-class TagsStore(TagsWorkerStore):
     async def add_tag_to_room(
         self, user_id: str, room_id: str, tag: str, content: JsonDict
     ) -> int:
@@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
+        assert self._can_write_to_account_data
+
         content_json = json_encoder.encode(content)
 
         def add_tag_txn(txn, next_id):
@@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
+        assert self._can_write_to_account_data
 
         def remove_tag_txn(txn, next_id):
             sql = (
@@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore):
             room_id: The ID of the room.
             next_id: The the revision to advance to.
         """
+        assert self._can_write_to_account_data
 
         txn.call_after(
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore):
                 # which stream_id ends up in the table, as long as it is higher
                 # than the id that the client has.
                 pass
+
+
+class TagsStore(TagsWorkerStore):
+    pass
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 133c0e7a28..39a3ab1162 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -17,7 +17,7 @@ import logging
 import threading
 from collections import deque
 from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
-        stream_name: A name for the stream.
+        stream_name: A name for the stream, for use in the `stream_positions`
+            table. (Does not need to be the same as the replication stream name)
         instance_name: The name of this instance.
-        table: Database table associated with stream.
-        instance_column: Column that stores the row's writer's instance name
-        id_column: Column that stores the stream ID.
+        tables: List of tables associated with the stream. Tuple of table
+            name, column name that stores the writer's instance name, and
+            column name that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
         writers: A list of known writers to use to populate current positions
@@ -206,9 +207,7 @@ class MultiWriterIdGenerator:
         db: DatabasePool,
         stream_name: str,
         instance_name: str,
-        table: str,
-        instance_column: str,
-        id_column: str,
+        tables: List[Tuple[str, str, str]],
         sequence_name: str,
         writers: List[str],
         positive: bool = True,
@@ -260,15 +259,16 @@ class MultiWriterIdGenerator:
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
         # We check that the table and sequence haven't diverged.
-        self._sequence_gen.check_consistency(
-            db_conn, table=table, id_column=id_column, positive=positive
-        )
+        for table, _, id_column in tables:
+            self._sequence_gen.check_consistency(
+                db_conn, table=table, id_column=id_column, positive=positive
+            )
 
         # This goes and fills out the above state from the database.
-        self._load_current_ids(db_conn, table, instance_column, id_column)
+        self._load_current_ids(db_conn, tables)
 
     def _load_current_ids(
-        self, db_conn, table: str, instance_column: str, id_column: str
+        self, db_conn, tables: List[Tuple[str, str, str]],
     ):
         cur = db_conn.cursor(txn_name="_load_current_ids")
 
@@ -306,17 +306,22 @@ class MultiWriterIdGenerator:
             # We add a GREATEST here to ensure that the result is always
             # positive. (This can be a problem for e.g. backfill streams where
             # the server has never backfilled).
-            sql = """
-                SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
-                FROM %(table)s
-            """ % {
-                "id": id_column,
-                "table": table,
-                "agg": "MAX" if self._positive else "-MIN",
-            }
-            cur.execute(sql)
-            (stream_id,) = cur.fetchone()
-            self._persisted_upto_position = stream_id
+            max_stream_id = 1
+            for table, _, id_column in tables:
+                sql = """
+                    SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                    FROM %(table)s
+                """ % {
+                    "id": id_column,
+                    "table": table,
+                    "agg": "MAX" if self._positive else "-MIN",
+                }
+                cur.execute(sql)
+                (stream_id,) = cur.fetchone()
+
+                max_stream_id = max(max_stream_id, stream_id)
+
+            self._persisted_upto_position = max_stream_id
         else:
             # If we have a min_stream_id then we pull out everything greater
             # than it from the DB so that we can prefill
@@ -329,21 +334,28 @@ class MultiWriterIdGenerator:
             # stream positions table before restart (or the stream position
             # table otherwise got out of date).
 
-            sql = """
-                SELECT %(instance)s, %(id)s FROM %(table)s
-                WHERE ? %(cmp)s %(id)s
-            """ % {
-                "id": id_column,
-                "table": table,
-                "instance": instance_column,
-                "cmp": "<=" if self._positive else ">=",
-            }
-            cur.execute(sql, (min_stream_id * self._return_factor,))
-
             self._persisted_upto_position = min_stream_id
 
+            rows = []
+            for table, instance_column, id_column in tables:
+                sql = """
+                    SELECT %(instance)s, %(id)s FROM %(table)s
+                    WHERE ? %(cmp)s %(id)s
+                """ % {
+                    "id": id_column,
+                    "table": table,
+                    "instance": instance_column,
+                    "cmp": "<=" if self._positive else ">=",
+                }
+                cur.execute(sql, (min_stream_id * self._return_factor,))
+
+                rows.extend(cur)
+
+            # Sort so that we handle rows in order for each instance.
+            rows.sort()
+
             with self._lock:
-                for (instance, stream_id,) in cur:
+                for (instance, stream_id,) in rows:
                     stream_id = self._return_factor * stream_id
                     self._add_persisted_position(stream_id)
 
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index cc0612cf65..3e2fd4da01 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.db_pool,
                 stream_name="test_stream",
                 instance_name=instance_name,
-                table="foobar",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("foobar", "instance_name", "stream_id")],
                 sequence_name="foobar_seq",
                 writers=writers,
             )
@@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.db_pool,
                 stream_name="test_stream",
                 instance_name=instance_name,
-                table="foobar",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("foobar", "instance_name", "stream_id")],
                 sequence_name="foobar_seq",
                 writers=writers,
                 positive=False,
@@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
+
+
+class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db_pool = self.store.db_pool  # type: DatabasePool
+
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar1 (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+        txn.execute(
+            """
+            CREATE TABLE foobar2 (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(
+        self, instance_name="master", writers=["master"]
+    ) -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db_pool,
+                stream_name="test_stream",
+                instance_name=instance_name,
+                tables=[
+                    ("foobar1", "instance_name", "stream_id"),
+                    ("foobar2", "instance_name", "stream_id"),
+                ],
+                sequence_name="foobar_seq",
+                writers=writers,
+            )
+
+        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+    def _insert_rows(
+        self,
+        table: str,
+        instance_name: str,
+        number: int,
+        update_stream_table: bool = True,
+    ):
+        """Insert N rows as the given instance, inserting with stream IDs pulled
+        from the postgres sequence.
+        """
+
+        def _insert(txn):
+            for _ in range(number):
+                txn.execute(
+                    "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
+                    (instance_name,),
+                )
+                if update_stream_table:
+                    txn.execute(
+                        """
+                        INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
+                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+                        """,
+                        (instance_name,),
+                    )
+
+        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+    def test_load_existing_stream(self):
+        """Test creating ID gens with multiple tables that have rows from after
+        the position in `stream_positions` table.
+        """
+        self._insert_rows("foobar1", "first", 3)
+        self._insert_rows("foobar2", "second", 3)
+        self._insert_rows("foobar2", "second", 1, update_stream_table=False)
+
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+        # The first ID gen will notice that it can advance its token to 7 as it
+        # has no in progress writes...
+        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
+
+        # ... but the second ID gen doesn't know that.
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)