summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-01-19 10:19:25 +0000
committerErik Johnston <erik@matrix.org>2021-01-19 10:19:25 +0000
commitbed4fa29fd03696acbfee8c6952ca8692ff6d722 (patch)
tree1448960716d841caeb779e720d81b5db0dd8986c
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentQuote pip install with brackets to avoid shell interpretation. (#9151) (diff)
downloadsynapse-bed4fa29fd03696acbfee8c6952ca8692ff6d722.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
-rw-r--r--changelog.d/9104.feature1
-rw-r--r--changelog.d/9127.feature1
-rw-r--r--changelog.d/9128.bugfix1
-rw-r--r--changelog.d/9144.misc1
-rw-r--r--changelog.d/9145.bugfix1
-rw-r--r--changelog.d/9146.misc1
-rw-r--r--changelog.d/9151.doc1
-rw-r--r--docs/postgres.md2
-rw-r--r--docs/workers.md2
-rw-r--r--synapse/app/generic_worker.py15
-rw-r--r--synapse/config/workers.py18
-rw-r--r--synapse/handlers/account_data.py144
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/oidc_handler.py2
-rw-r--r--synapse/handlers/read_marker.py5
-rw-r--r--synapse/handlers/receipts.py27
-rw-r--r--synapse/handlers/room_member.py7
-rw-r--r--synapse/http/client.py2
-rw-r--r--synapse/http/matrixfederationclient.py2
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/http/account_data.py187
-rw-r--r--synapse/replication/slave/storage/_base.py10
-rw-r--r--synapse/replication/slave/storage/account_data.py40
-rw-r--r--synapse/replication/slave/storage/receipts.py35
-rw-r--r--synapse/replication/tcp/handler.py19
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py22
-rw-r--r--synapse/rest/client/v2_alpha/tags.py11
-rw-r--r--synapse/rest/synapse/client/pick_idp.py4
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/databases/main/__init__.py10
-rw-r--r--synapse/storage/databases/main/account_data.py107
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py92
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/receipts.py108
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql20
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres32
-rw-r--r--synapse/storage/databases/main/tags.py10
-rw-r--r--synapse/storage/util/id_generators.py84
-rw-r--r--tests/rest/client/v1/test_login.py146
-rw-r--r--tests/rest/client/v1/utils.py62
-rw-r--r--tests/server.py2
-rw-r--r--tests/storage/test_id_generators.py112
-rw-r--r--tests/test_utils/html_parsers.py53
-rw-r--r--tox.ini3
46 files changed, 1056 insertions, 371 deletions
diff --git a/changelog.d/9104.feature b/changelog.d/9104.feature
new file mode 100644
index 0000000000..1c4f88bce9
--- /dev/null
+++ b/changelog.d/9104.feature
@@ -0,0 +1 @@
+Add experimental support for moving off receipts and account data persistence off master.
diff --git a/changelog.d/9127.feature b/changelog.d/9127.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9127.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/changelog.d/9128.bugfix b/changelog.d/9128.bugfix
new file mode 100644
index 0000000000..f87b9fb9aa
--- /dev/null
+++ b/changelog.d/9128.bugfix
@@ -0,0 +1 @@
+Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login.
diff --git a/changelog.d/9144.misc b/changelog.d/9144.misc
new file mode 100644
index 0000000000..38a506b170
--- /dev/null
+++ b/changelog.d/9144.misc
@@ -0,0 +1 @@
+Enforce that replication HTTP clients are called with keyword arguments only.
diff --git a/changelog.d/9145.bugfix b/changelog.d/9145.bugfix
new file mode 100644
index 0000000000..947cf1dc25
--- /dev/null
+++ b/changelog.d/9145.bugfix
@@ -0,0 +1 @@
+Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0.
diff --git a/changelog.d/9146.misc b/changelog.d/9146.misc
new file mode 100644
index 0000000000..7af29baa30
--- /dev/null
+++ b/changelog.d/9146.misc
@@ -0,0 +1 @@
+Fix the Python 3.5 + old dependencies build in CI.
diff --git a/changelog.d/9151.doc b/changelog.d/9151.doc
new file mode 100644
index 0000000000..7535748060
--- /dev/null
+++ b/changelog.d/9151.doc
@@ -0,0 +1 @@
+Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters.
diff --git a/docs/postgres.md b/docs/postgres.md
index c30cc1fd8c..680685d04e 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -18,7 +18,7 @@ connect to a postgres database.
     virtualenv](../INSTALL.md#installing-from-source), you can install
     the library with:
 
-        ~/synapse/env/bin/pip install matrix-synapse[postgres]
+        ~/synapse/env/bin/pip install "matrix-synapse[postgres]"
 
     (substituting the path to your virtualenv for `~/synapse/env`, if
     you used a different path). You will require the postgres
diff --git a/docs/workers.md b/docs/workers.md
index cc5090f224..d01683681f 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -59,7 +59,7 @@ The appropriate dependencies must also be installed for Synapse. If using a
 virtualenv, these can be installed with:
 
 ```sh
-pip install matrix-synapse[redis]
+pip install "matrix-synapse[redis]"
 ```
 
 Note that these dependencies are included when synapse is installed with `pip
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index cb202bda44..e60988fa4a 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import (
 )
 from synapse.rest.client.v1.push_rule import PushRuleRestServlet
 from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory
+from synapse.rest.client.v2_alpha import (
+    account_data,
+    groups,
+    read_marker,
+    receipts,
+    room_keys,
+    sync,
+    tags,
+    user_directory,
+)
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
 from synapse.rest.client.v2_alpha.account_data import (
@@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer):
                     room.register_deprecated_servlets(self, resource)
                     InitialSyncRestServlet(self).register(resource)
                     room_keys.register_servlets(self, resource)
+                    tags.register_servlets(self, resource)
+                    account_data.register_servlets(self, resource)
+                    receipts.register_servlets(self, resource)
+                    read_marker.register_servlets(self, resource)
 
                     SendToDeviceRestServlet(self).register(resource)
 
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 364583f48b..f10e33f7b8 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -56,6 +56,12 @@ class WriterLocations:
     to_device = attr.ib(
         default=["master"], type=List[str], converter=_instance_to_list_converter,
     )
+    account_data = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
+    receipts = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
 
 
 class WorkerConfig(Config):
@@ -127,7 +133,7 @@ class WorkerConfig(Config):
 
         # Check that the configured writers for events and typing also appears in
         # `instance_map`.
-        for stream in ("events", "typing", "to_device"):
+        for stream in ("events", "typing", "to_device", "account_data", "receipts"):
             instances = _instance_to_list_converter(getattr(self.writers, stream))
             for instance in instances:
                 if instance != "master" and instance not in self.instance_map:
@@ -141,6 +147,16 @@ class WorkerConfig(Config):
                 "Must only specify one instance to handle `to_device` messages."
             )
 
+        if len(self.writers.account_data) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `account_data` messages."
+            )
+
+        if len(self.writers.receipts) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `receipts` messages."
+            )
+
         self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
         # Whether this worker should run background tasks or not.
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 341135822e..b1a5df9638 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,14 +13,157 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import random
 from typing import TYPE_CHECKING, List, Tuple
 
+from synapse.replication.http.account_data import (
+    ReplicationAddTagRestServlet,
+    ReplicationRemoveTagRestServlet,
+    ReplicationRoomAccountDataRestServlet,
+    ReplicationUserAccountDataRestServlet,
+)
 from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
 
+class AccountDataHandler:
+    def __init__(self, hs: "HomeServer"):
+        self._store = hs.get_datastore()
+        self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
+
+        self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
+        self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+        self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
+        self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
+        self._account_data_writers = hs.config.worker.writers.account_data
+
+    async def add_account_data_to_room(
+        self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
+        """Add some account_data to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The maximum stream ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_account_data_to_room(
+                user_id, room_id, account_data_type, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+
+            return max_stream_id
+        else:
+            response = await self._room_data_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                account_data_type=account_data_type,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def add_account_data_for_user(
+        self, user_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
+        """Add some account_data to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The maximum stream ID.
+        """
+
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_account_data_for_user(
+                user_id, account_data_type, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._user_data_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                account_data_type=account_data_type,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def add_tag_to_room(
+        self, user_id: str, room_id: str, tag: str, content: JsonDict
+    ) -> int:
+        """Add a tag to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            tag: The tag name to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The next account data ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_tag_to_room(
+                user_id, room_id, tag, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._add_tag_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                tag=tag,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
+        """Remove a tag from a room for a user.
+
+        Returns:
+            The next account data ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.remove_tag_from_room(
+                user_id, room_id, tag
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._remove_tag_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                tag=tag,
+            )
+            return response["max_stream_id"]
+
+
 class AccountDataEventSource:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 18cd2b62f0..0e98db22b3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1504,8 +1504,8 @@ class AuthHandler(BaseHandler):
     @staticmethod
     def add_query_param_to_url(url: str, param_name: str, param: Any):
         url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({param_name: param})
+        query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+        query.append((param_name, param))
         url_parts[4] = urllib.parse.urlencode(query)
         return urllib.parse.urlunparse(url_parts)
 
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 5e5fda7b2f..ba686d74b2 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -85,7 +85,7 @@ class OidcHandler:
         self._token_generator = OidcSessionTokenGenerator(hs)
         self._providers = {
             p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
-        }
+        }  # type: Dict[str, OidcProvider]
 
     async def load_metadata(self) -> None:
         """Validate the config and load the metadata from the remote endpoint.
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index a7550806e6..6bb2fd936b 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
         super().__init__(hs)
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
+        self.account_data_handler = hs.get_account_data_handler()
         self.read_marker_linearizer = Linearizer(name="read_marker")
-        self.notifier = hs.get_notifier()
 
     async def received_client_read_marker(
         self, room_id: str, user_id: str, event_id: str
@@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
 
             if should_update:
                 content = {"event_id": event_id}
-                max_id = await self.store.add_account_data_to_room(
+                await self.account_data_handler.add_account_data_to_room(
                     user_id, room_id, "m.fully_read", content
                 )
-                self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a9abdf42e0..cc21fc2284 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
         self.hs = hs
-        self.federation = hs.get_federation_sender()
-        hs.get_federation_registry().register_edu_handler(
-            "m.receipt", self._received_remote_receipt
-        )
+
+        # We only need to poke the federation sender explicitly if its on the
+        # same instance. Other federation sender instances will get notified by
+        # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+        # in the receipts stream.
+        self.federation_sender = None
+        if hs.should_send_federation():
+            self.federation_sender = hs.get_federation_sender()
+
+        # If we can handle the receipt EDUs we do so, otherwise we route them
+        # to the appropriate worker.
+        if hs.get_instance_name() in hs.config.worker.writers.receipts:
+            hs.get_federation_registry().register_edu_handler(
+                "m.receipt", self._received_remote_receipt
+            )
+        else:
+            hs.get_federation_registry().register_instances_for_edu(
+                "m.receipt", hs.config.worker.writers.receipts,
+            )
+
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
 
@@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
         if not is_new:
             return
 
-        await self.federation.send_read_receipt(receipt)
+        if self.federation_sender:
+            await self.federation_sender.send_read_receipt(receipt)
 
 
 class ReceiptEventSource:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index bea028b2bf..1f64c38988 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.registration_handler = hs.get_registration_handler()
         self.profile_handler = hs.get_profile_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self.account_data_handler = hs.get_account_data_handler()
 
         self.member_linearizer = Linearizer(name="member")
         self.member_limiter = Linearizer(max_count=10, name="member_as_limiter")
@@ -254,7 +255,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     direct_rooms[key].append(new_room_id)
 
                     # Save back to user's m.direct account data
-                    await self.store.add_account_data_for_user(
+                    await self.account_data_handler.add_account_data_for_user(
                         user_id, AccountDataTypes.DIRECT, direct_rooms
                     )
                     break
@@ -264,7 +265,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Copy each room tag to the new room
         for tag, tag_content in room_tags.items():
-            await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+            await self.account_data_handler.add_tag_to_room(
+                user_id, new_room_id, tag, tag_content
+            )
 
     async def update_membership(
         self,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index df498c8645..37ccf5ab98 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -724,7 +724,7 @@ class SimpleHttpClient:
                 read_body_with_max_size(response, output_stream, max_size)
             )
         except BodyExceededMaxSize:
-            SynapseError(
+            raise SynapseError(
                 502,
                 "Requested file is too large > %r bytes" % (max_size,),
                 Codes.TOO_LARGE,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b7103d6541..19293bf673 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -996,7 +996,7 @@ class MatrixFederationHttpClient:
             logger.warning(
                 "{%s} [%s] %s", request.txn_id, request.destination, msg,
             )
-            SynapseError(502, msg, Codes.TOO_LARGE)
+            raise SynapseError(502, msg, Codes.TOO_LARGE)
         except Exception as e:
             logger.warning(
                 "{%s} [%s] Error reading response: %s",
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index a84a064c8d..dd527e807f 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -15,6 +15,7 @@
 
 from synapse.http.server import JsonResource
 from synapse.replication.http import (
+    account_data,
     devices,
     federation,
     login,
@@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
         presence.register_servlets(hs, self)
         membership.register_servlets(hs, self)
         streams.register_servlets(hs, self)
+        account_data.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 1492ac922c..288727a566 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -177,7 +177,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         @trace(opname="outgoing_replication_request")
         @outgoing_gauge.track_inprogress()
-        async def send_request(instance_name="master", **kwargs):
+        async def send_request(*, instance_name="master", **kwargs):
             if instance_name == local_instance_name:
                 raise Exception("Trying to send HTTP request to self")
             if instance_name == "master":
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
new file mode 100644
index 0000000000..52d32528ee
--- /dev/null
+++ b/synapse/replication/http/account_data.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
+    """Add user account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_user_account_data/:user_id/:type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_user_account_data"
+    PATH_ARGS = ("user_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_for_user(
+            user_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
+    """Add room account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_room_account_data"
+    PATH_ARGS = ("user_id", "room_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_to_room(
+            user_id, room_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationAddTagRestServlet(ReplicationEndpoint):
+    """Add tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_tag"
+    PATH_ARGS = ("user_id", "room_id", "tag")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_tag_to_room(
+            user_id, room_id, tag, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
+    """Remove tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
+
+        {}
+
+    """
+
+    NAME = "remove_tag"
+    PATH_ARGS = (
+        "user_id",
+        "room_id",
+        "tag",
+    )
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag):
+
+        return {}
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+def register_servlets(hs, http_server):
+    ReplicationUserAccountDataRestServlet(hs).register(http_server)
+    ReplicationRoomAccountDataRestServlet(hs).register(http_server)
+    ReplicationAddTagRestServlet(hs).register(http_server)
+    ReplicationRemoveTagRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d0089fe06c..693c9ab901 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )  # type: Optional[MultiWriterIdGenerator]
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 4268565fc8..21afe5f155 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,47 +15,9 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.storage.databases.main.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = SlavedIdTracker(
-            db_conn,
-            "account_data",
-            "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self):
-        return self._account_data_id_gen.get_current_token()
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        elif stream_name == AccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                if not row.room_id:
-                    self.get_global_account_data_by_type_for_user.invalidate(
-                        (row.data_type, row.user_id)
-                    )
-                self.get_account_data_for_user.invalidate((row.user_id,))
-                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
-                self.get_account_data_for_room_and_type.invalidate(
-                    (row.user_id, row.room_id, row.data_type)
-                )
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6195917376..3dfdd9961d 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,43 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 
 from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = SlavedIdTracker(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
-
-    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
-        self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
-        self.get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
-        )
-        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
-        self.get_receipts_for_room.invalidate((room_id, receipt_type))
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ReceiptsStream.NAME:
-            self._receipts_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.invalidate_caches_for_receipt(
-                    row.room_id, row.receipt_type, row.user_id
-                )
-                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
-
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1f89249475..317796d5e0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import (
 from synapse.replication.tcp.protocol import AbstractConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
+    AccountDataStream,
     BackfillStream,
     CachesStream,
     EventsStream,
     FederationStream,
+    ReceiptsStream,
     Stream,
+    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
 )
@@ -132,6 +135,22 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+                # Only add AccountDataStream and TagAccountDataStream as a source on the
+                # instance in charge of account_data persistence.
+                if hs.get_instance_name() in hs.config.worker.writers.account_data:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
+            if isinstance(stream, ReceiptsStream):
+                # Only add ReceiptsStream as a source on the instance in charge of
+                # receipts.
+                if hs.get_instance_name() in hs.config.worker.writers.receipts:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker_app is not None:
                 continue
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_account_data_for_user(
-            user_id, account_data_type, body
-        )
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_account_data_for_user(user_id, account_data_type, body)
 
         return 200, {}
 
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
                 " Use /rooms/!roomId:server.name/read_markers",
             )
 
-        max_id = await self.store.add_account_data_to_room(
+        await self.handler.add_account_data_to_room(
             user_id, room_id, account_data_type, body
         )
 
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
         return 200, {}
 
     async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
     def __init__(self, hs):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, tag):
         requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_tag_to_room(user_id, room_id, tag, body)
 
         return 200, {}
 
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
 
-        max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.remove_tag_from_room(user_id, room_id, tag)
 
         return 200, {}
 
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
index e5b720bbca..9550b82998 100644
--- a/synapse/rest/synapse/client/pick_idp.py
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource):
         self._server_name = hs.hostname
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
-        client_redirect_url = parse_string(request, "redirectUrl", required=True)
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding="utf-8"
+        )
         idp = parse_string(request, "idp", required=False)
 
         # if we need to pick an IdP, do so
diff --git a/synapse/server.py b/synapse/server.py
index d4c235cda5..9cdda83aa1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender
 from synapse.federation.transport.client import TransportLayerClient
 from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
 from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
+from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.acme import AcmeHandler
 from synapse.handlers.admin import AdminHandler
@@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_module_api(self) -> ModuleApi:
         return ModuleApi(self, self.get_auth_handler())
 
+    @cache_in_self
+    def get_account_data_handler(self) -> AccountDataHandler:
+        return AccountDataHandler(self)
+
     async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
         return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index c4de07a0a8..ae561a2da3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,9 +160,13 @@ class DataStore(
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index bad8260892..68896f34af 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,14 +14,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
 import logging
 from typing import Dict, List, Optional, Set, Tuple
 
 from synapse.api.constants import AccountDataTypes
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
+class AccountDataWorkerStore(SQLBaseStore):
     """This is an abstract base class where subclasses must implement
     `get_max_account_data_stream_id` which can be called in the initializer.
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
+        self._instance_name = hs.get_instance_name()
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_account_data = (
+                self._instance_name in hs.config.worker.writers.account_data
+            )
+
+            self._account_data_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="account_data",
+                instance_name=self._instance_name,
+                tables=[
+                    ("room_account_data", "instance_name", "stream_id"),
+                    ("room_tags_revisions", "instance_name", "stream_id"),
+                    ("account_data", "instance_name", "stream_id"),
+                ],
+                sequence_name="account_data_sequence",
+                writers=hs.config.worker.writers.account_data,
+            )
+        else:
+            self._can_write_to_account_data = True
+
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.events:
+                self._account_data_id_gen = StreamIdGenerator(
+                    db_conn,
+                    "room_account_data",
+                    "stream_id",
+                    extra_tables=[("room_tags_revisions", "stream_id")],
+                )
+            else:
+                self._account_data_id_gen = SlavedIdTracker(
+                    db_conn,
+                    "room_account_data",
+                    "stream_id",
+                    extra_tables=[("room_tags_revisions", "stream_id")],
+                )
+
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
@@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         super().__init__(database, db_conn, hs)
 
-    @abc.abstractmethod
-    def get_max_account_data_stream_id(self):
+    def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream ID for account data stream
 
         Returns:
             int
         """
-        raise NotImplementedError()
+        return self._account_data_id_gen.get_current_token()
 
     @cached()
     async def get_account_data_for_user(
@@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
             )
         )
 
-
-class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn,
-            "room_account_data",
-            "stream_id",
-            extra_tables=[("room_tags_revisions", "stream_id")],
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self) -> int:
-        """Get the current max stream id for the private user data stream
-
-        Returns:
-            The maximum stream ID.
-        """
-        return self._account_data_id_gen.get_current_token()
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+        elif stream_name == AccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                if not row.room_id:
+                    self.get_global_account_data_by_type_for_user.invalidate(
+                        (row.data_type, row.user_id)
+                    )
+                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
+                self.get_account_data_for_room_and_type.invalidate(
+                    (row.user_id, row.room_id, row.data_type)
+                )
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
+        assert self._can_write_to_account_data
+
         content_json = json_encoder.encode(content)
 
         async with self._account_data_id_gen.get_next() as next_id:
@@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
+        assert self._can_write_to_account_data
+
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "add_user_account_data",
@@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore):
         # Invalidate the cache for any ignored users which were added or removed.
         for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
             self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+
+
+class AccountDataStore(AccountDataWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 58d3f71e45..31f70ac5ef 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="to_device",
                 instance_name=self._instance_name,
-                table="device_inbox",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("device_inbox", "instance_name", "stream_id")],
                 sequence_name="device_inbox_sequence",
                 writers=hs.config.worker.writers.to_device,
             )
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e5c03cc609..1b657191a9 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             (rotate_to_stream_ordering,),
         )
 
+    def _remove_old_push_actions_before_txn(
+        self, txn, room_id, user_id, stream_ordering
+    ):
+        """
+        Purges old push actions for a user and room before a given
+        stream_ordering.
+
+        We however keep a months worth of highlighted notifications, so that
+        users can still get a list of recent highlights.
+
+        Args:
+            txn: The transcation
+            room_id: Room ID to delete from
+            user_id: user ID to delete for
+            stream_ordering: The lowest stream ordering which will
+                                  not be deleted.
+        """
+        txn.call_after(
+            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            (room_id, user_id),
+        )
+
+        # We need to join on the events table to get the received_ts for
+        # event_push_actions and sqlite won't let us use a join in a delete so
+        # we can't just delete where received_ts < x. Furthermore we can
+        # only identify event_push_actions by a tuple of room_id, event_id
+        # we we can't use a subquery.
+        # Instead, we look up the stream ordering for the last event in that
+        # room received before the threshold time and delete event_push_actions
+        # in the room with a stream_odering before that.
+        txn.execute(
+            "DELETE FROM event_push_actions "
+            " WHERE user_id = ? AND room_id = ? AND "
+            " stream_ordering <= ?"
+            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+        )
+
+        txn.execute(
+            """
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """,
+            (room_id, user_id, stream_ordering),
+        )
+
 
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         return push_actions
 
-    def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
-
-        Args:
-            txn: The transcation
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id, user_id),
-        )
-
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
-        )
-
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
-        )
-
 
 def _action_has_highlight(actions):
     for action in actions:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4732685f6e..71d823be72 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="events",
                 instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
+                tables=[("events", "instance_name", "stream_ordering")],
                 sequence_name="events_stream_seq",
                 writers=hs.config.worker.writers.events,
             )
@@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="backfill",
                 instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
+                tables=[("events", "instance_name", "stream_ordering")],
                 sequence_name="events_backfill_stream_seq",
                 positive=False,
                 writers=hs.config.worker.writers.events,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1e7949a323..e0e57f0578 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,15 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
 import logging
 from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet import defer
 
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
-    """This is an abstract base class where subclasses must implement
-    `get_max_receipt_stream_id` which can be called in the initializer.
-    """
-
+class ReceiptsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
+        self._instance_name = hs.get_instance_name()
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_receipts = (
+                self._instance_name in hs.config.worker.writers.receipts
+            )
+
+            self._receipts_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="account_data",
+                instance_name=self._instance_name,
+                tables=[("receipts_linearized", "instance_name", "stream_id")],
+                sequence_name="receipts_sequence",
+                writers=hs.config.worker.writers.receipts,
+            )
+        else:
+            self._can_write_to_receipts = True
+
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.events:
+                self._receipts_id_gen = StreamIdGenerator(
+                    db_conn, "receipts_linearized", "stream_id"
+                )
+            else:
+                self._receipts_id_gen = SlavedIdTracker(
+                    db_conn, "receipts_linearized", "stream_id"
+                )
+
         super().__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
-    @abc.abstractmethod
     def get_max_receipt_stream_id(self):
         """Get the current max stream ID for receipts stream
 
         Returns:
             int
         """
-        raise NotImplementedError()
+        return self._receipts_id_gen.get_current_token()
 
     @cached()
     async def get_users_with_read_receipts_in_room(self, room_id):
@@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
-
-class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = StreamIdGenerator(
-            db_conn, "receipts_linearized", "stream_id"
+    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+        self.get_receipts_for_user.invalidate((user_id, receipt_type))
+        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self.get_last_receipt_event_id_for_user.invalidate(
+            (user_id, room_id, receipt_type)
         )
+        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+        self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == ReceiptsStream.NAME:
+            self._receipts_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.invalidate_caches_for_receipt(
+                    row.room_id, row.receipt_type, row.user_id
+                )
+                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def insert_linearized_receipt_txn(
         self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
         """
+        assert self._can_write_to_receipts
+
         res = self.db_pool.simple_select_one_txn(
             txn,
             table="events",
@@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
                     )
                     return None
 
-        txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
-        txn.call_after(
-            self._invalidate_get_users_with_receipts_in_room,
-            room_id,
-            receipt_type,
-            user_id,
-        )
-        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
-        # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(
-            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
+            self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
         )
 
         txn.call_after(
             self._receipts_stream_cache.entity_has_changed, room_id, stream_id
         )
 
-        txn.call_after(
-            self.get_last_receipt_event_id_for_user.invalidate,
-            (user_id, room_id, receipt_type),
-        )
-
         self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_linearized",
@@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
         Automatically does conversion between linearized and graph
         representations.
         """
+        assert self._can_write_to_receipts
+
         if not event_ids:
             return None
 
@@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
     async def insert_graph_receipt(
         self, room_id, receipt_type, user_id, event_ids, data
     ):
+        assert self._can_write_to_receipts
+
         return await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
@@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
     def insert_graph_receipt_txn(
         self, txn, room_id, receipt_type, user_id, event_ids, data
     ):
+        assert self._can_write_to_receipts
+
         txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
         txn.call_after(
             self._invalidate_get_users_with_receipts_in_room,
@@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "data": json_encoder.encode(data),
             },
         )
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
new file mode 100644
index 0000000000..46abf8d562
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
+ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
+ALTER TABLE account_data ADD COLUMN instance_name TEXT;
+
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
new file mode 100644
index 0000000000..4a6e6c74f5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
@@ -0,0 +1,32 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
+
+-- We need to take the max across all the account_data tables as they share the
+-- ID generator
+SELECT setval('account_data_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
+    )
+));
+
+CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
+
+SELECT setval('receipts_sequence', (
+    SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
+));
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 74da9c49f2..50067eabfc 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
         )
         return {row["tag"]: db_to_json(row["content"]) for row in rows}
 
-
-class TagsStore(TagsWorkerStore):
     async def add_tag_to_room(
         self, user_id: str, room_id: str, tag: str, content: JsonDict
     ) -> int:
@@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
+        assert self._can_write_to_account_data
+
         content_json = json_encoder.encode(content)
 
         def add_tag_txn(txn, next_id):
@@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
+        assert self._can_write_to_account_data
 
         def remove_tag_txn(txn, next_id):
             sql = (
@@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore):
             room_id: The ID of the room.
             next_id: The the revision to advance to.
         """
+        assert self._can_write_to_account_data
 
         txn.call_after(
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore):
                 # which stream_id ends up in the table, as long as it is higher
                 # than the id that the client has.
                 pass
+
+
+class TagsStore(TagsWorkerStore):
+    pass
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 133c0e7a28..39a3ab1162 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -17,7 +17,7 @@ import logging
 import threading
 from collections import deque
 from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
-        stream_name: A name for the stream.
+        stream_name: A name for the stream, for use in the `stream_positions`
+            table. (Does not need to be the same as the replication stream name)
         instance_name: The name of this instance.
-        table: Database table associated with stream.
-        instance_column: Column that stores the row's writer's instance name
-        id_column: Column that stores the stream ID.
+        tables: List of tables associated with the stream. Tuple of table
+            name, column name that stores the writer's instance name, and
+            column name that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
         writers: A list of known writers to use to populate current positions
@@ -206,9 +207,7 @@ class MultiWriterIdGenerator:
         db: DatabasePool,
         stream_name: str,
         instance_name: str,
-        table: str,
-        instance_column: str,
-        id_column: str,
+        tables: List[Tuple[str, str, str]],
         sequence_name: str,
         writers: List[str],
         positive: bool = True,
@@ -260,15 +259,16 @@ class MultiWriterIdGenerator:
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
         # We check that the table and sequence haven't diverged.
-        self._sequence_gen.check_consistency(
-            db_conn, table=table, id_column=id_column, positive=positive
-        )
+        for table, _, id_column in tables:
+            self._sequence_gen.check_consistency(
+                db_conn, table=table, id_column=id_column, positive=positive
+            )
 
         # This goes and fills out the above state from the database.
-        self._load_current_ids(db_conn, table, instance_column, id_column)
+        self._load_current_ids(db_conn, tables)
 
     def _load_current_ids(
-        self, db_conn, table: str, instance_column: str, id_column: str
+        self, db_conn, tables: List[Tuple[str, str, str]],
     ):
         cur = db_conn.cursor(txn_name="_load_current_ids")
 
@@ -306,17 +306,22 @@ class MultiWriterIdGenerator:
             # We add a GREATEST here to ensure that the result is always
             # positive. (This can be a problem for e.g. backfill streams where
             # the server has never backfilled).
-            sql = """
-                SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
-                FROM %(table)s
-            """ % {
-                "id": id_column,
-                "table": table,
-                "agg": "MAX" if self._positive else "-MIN",
-            }
-            cur.execute(sql)
-            (stream_id,) = cur.fetchone()
-            self._persisted_upto_position = stream_id
+            max_stream_id = 1
+            for table, _, id_column in tables:
+                sql = """
+                    SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                    FROM %(table)s
+                """ % {
+                    "id": id_column,
+                    "table": table,
+                    "agg": "MAX" if self._positive else "-MIN",
+                }
+                cur.execute(sql)
+                (stream_id,) = cur.fetchone()
+
+                max_stream_id = max(max_stream_id, stream_id)
+
+            self._persisted_upto_position = max_stream_id
         else:
             # If we have a min_stream_id then we pull out everything greater
             # than it from the DB so that we can prefill
@@ -329,21 +334,28 @@ class MultiWriterIdGenerator:
             # stream positions table before restart (or the stream position
             # table otherwise got out of date).
 
-            sql = """
-                SELECT %(instance)s, %(id)s FROM %(table)s
-                WHERE ? %(cmp)s %(id)s
-            """ % {
-                "id": id_column,
-                "table": table,
-                "instance": instance_column,
-                "cmp": "<=" if self._positive else ">=",
-            }
-            cur.execute(sql, (min_stream_id * self._return_factor,))
-
             self._persisted_upto_position = min_stream_id
 
+            rows = []
+            for table, instance_column, id_column in tables:
+                sql = """
+                    SELECT %(instance)s, %(id)s FROM %(table)s
+                    WHERE ? %(cmp)s %(id)s
+                """ % {
+                    "id": id_column,
+                    "table": table,
+                    "instance": instance_column,
+                    "cmp": "<=" if self._positive else ">=",
+                }
+                cur.execute(sql, (min_stream_id * self._return_factor,))
+
+                rows.extend(cur)
+
+            # Sort so that we handle rows in order for each instance.
+            rows.sort()
+
             with self._lock:
-                for (instance, stream_id,) in cur:
+                for (instance, stream_id,) in rows:
                     stream_id = self._return_factor * stream_id
                     self._add_persisted_position(stream_id)
 
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 73a009efd1..2d25490374 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,9 +15,8 @@
 
 import time
 import urllib.parse
-from html.parser import HTMLParser
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from urllib.parse import parse_qs, urlencode, urlparse
+from typing import Any, Dict, Union
+from urllib.parse import urlencode
 
 from mock import Mock
 
@@ -38,6 +37,7 @@ from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
 try:
@@ -69,6 +69,12 @@ TEST_SAML_METADATA = """
 LOGIN_URL = b"/_matrix/client/r0/login"
 TEST_URL = b"/_matrix/client/r0/account/whoami"
 
+# a (valid) url with some annoying characters in.  %3D is =, %26 is &, %2B is +
+TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
+
+# the query params in TEST_CLIENT_REDIRECT_URL
+EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
@@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             },
         }
 
+        # default OIDC provider
         config["oidc_config"] = TEST_OIDC_CONFIG
 
+        # additional OIDC providers
+        config["oidc_providers"] = [
+            {
+                "idp_id": "idp1",
+                "idp_name": "IDP1",
+                "discover": False,
+                "issuer": "https://issuer1",
+                "client_id": "test-client-id",
+                "client_secret": "test-client-secret",
+                "scopes": ["profile"],
+                "authorization_endpoint": "https://issuer1/auth",
+                "token_endpoint": "https://issuer1/token",
+                "userinfo_endpoint": "https://issuer1/userinfo",
+                "user_mapping_provider": {
+                    "config": {"localpart_template": "{{ user.sub }}"}
+                },
+            }
+        ]
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
+        from synapse.rest.oidc import OIDCResource
+
         d = super().create_resource_dict()
         d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
+        d["/_synapse/oidc"] = OIDCResource(self.hs)
         return d
 
     def test_multi_sso_redirect(self):
         """/login/sso/redirect should redirect to an identity picker"""
-        client_redirect_url = "https://x?<abc>"
-
         # first hit the redirect url, which should redirect to our idp picker
         channel = self.make_request(
             "GET",
-            "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
+            "/_matrix/client/r0/login/sso/redirect?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
         )
         self.assertEqual(channel.code, 302, channel.result)
         uri = channel.headers.getRawHeaders("Location")[0]
@@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
-        class FormPageParser(HTMLParser):
-            def __init__(self):
-                super().__init__()
-
-                # the values of the hidden inputs: map from name to value
-                self.hiddens = {}  # type: Dict[str, Optional[str]]
-
-                # the values of the radio buttons
-                self.radios = []  # type: List[Optional[str]]
-
-            def handle_starttag(
-                self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
-            ) -> None:
-                attr_dict = dict(attrs)
-                if tag == "input":
-                    if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
-                        self.radios.append(attr_dict["value"])
-                    elif attr_dict["type"] == "hidden":
-                        input_name = attr_dict["name"]
-                        assert input_name
-                        self.hiddens[input_name] = attr_dict["value"]
-
-            def error(_, message):
-                self.fail(message)
-
-        p = FormPageParser()
+        p = TestHtmlParser()
         p.feed(channel.result["body"].decode("utf-8"))
         p.close()
 
-        self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
+        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"])
 
-        self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
+        self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
 
     def test_multi_sso_redirect_to_cas(self):
         """If CAS is chosen, should redirect to the CAS server"""
-        client_redirect_url = "https://x?<abc>"
 
         channel = self.make_request(
             "GET",
-            "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+            + "&idp=cas",
             shorthand=False,
         )
         self.assertEqual(channel.code, 302, channel.result)
@@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         service_uri = cas_uri_params["service"][0]
         _, service_uri_query = service_uri.split("?", 1)
         service_uri_params = urllib.parse.parse_qs(service_uri_query)
-        self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
+        self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
 
     def test_multi_sso_redirect_to_saml(self):
         """If SAML is chosen, should redirect to the SAML server"""
-        client_redirect_url = "https://x?<abc>"
-
         channel = self.make_request(
             "GET",
             "/_synapse/client/pick_idp?redirectUrl="
-            + client_redirect_url
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=saml",
         )
         self.assertEqual(channel.code, 302, channel.result)
@@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # the RelayState is used to carry the client redirect url
         saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
         relay_state_param = saml_uri_params["RelayState"][0]
-        self.assertEqual(relay_state_param, client_redirect_url)
+        self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
 
-    def test_multi_sso_redirect_to_oidc(self):
+    def test_login_via_oidc(self):
         """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
-        client_redirect_url = "https://x?<abc>"
 
+        # pick the default OIDC provider
         channel = self.make_request(
             "GET",
             "/_synapse/client/pick_idp?redirectUrl="
-            + client_redirect_url
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=oidc",
         )
         self.assertEqual(channel.code, 302, channel.result)
@@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
         self.assertEqual(
             self._get_value_from_macaroon(macaroon, "client_redirect_url"),
-            client_redirect_url,
+            TEST_CLIENT_REDIRECT_URL,
         )
 
+        channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+
+        # that should serve a confirmation page
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertTrue(
+            channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
+        )
+        p = TestHtmlParser()
+        p.feed(channel.text_body)
+        p.close()
+
+        # ... which should contain our redirect link
+        self.assertEqual(len(p.links), 1)
+        path, query = p.links[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        login_token = params[2][1]
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@user1:test")
+
     def test_multi_sso_redirect_to_unknown(self):
         """An unknown IdP should cause a 400"""
         channel = self.make_request(
@@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
 
         # whitelist this client URI so we redirect straight to it rather than
         # serving a confirmation page
-        config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
+        config["sso"] = {"client_whitelist": ["https://x"]}
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
@@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase):
 
     def test_username_picker(self):
         """Test the happy path of a username picker flow."""
-        client_redirect_url = "https://whitelisted.client"
 
         # do the start of the login flow
         channel = self.helper.auth_via_oidc(
-            {"sub": "tester", "displayname": "Jonny"}, client_redirect_url
+            {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
         )
 
         # that should redirect to the username picker
@@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         session = username_mapping_sessions[session_id]
         self.assertEqual(session.remote_user_id, "tester")
         self.assertEqual(session.display_name, "Jonny")
-        self.assertEqual(session.client_redirect_url, client_redirect_url)
+        self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
 
         # the expiry time should be about 15 minutes away
         expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
@@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
-        # ensure that the returned location starts with the requested redirect URL
-        self.assertEqual(
-            location_headers[0][: len(client_redirect_url)], client_redirect_url
+        # ensure that the returned location matches the requested redirect URL
+        path, query = location_headers[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
         )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
 
         # fish the login token out of the returned redirect uri
-        parts = urlparse(location_headers[0])
-        query = parse_qs(parts.query)
-        login_token = query["loginToken"][0]
+        login_token = params[2][1]
 
         # finally, submit the matrix login token to the login API, which gives us our
         # matrix access token, mxid, and device id.
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index c6647dbe08..b1333df82d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -20,8 +20,7 @@ import json
 import re
 import time
 import urllib.parse
-from html.parser import HTMLParser
-from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple
+from typing import Any, Dict, Mapping, MutableMapping, Optional
 
 from mock import patch
 
@@ -35,6 +34,7 @@ from synapse.types import JsonDict
 
 from tests.server import FakeChannel, FakeSite, make_request
 from tests.test_utils import FakeResponse
+from tests.test_utils.html_parsers import TestHtmlParser
 
 
 @attr.s
@@ -440,10 +440,36 @@ class RestHelper:
         # param that synapse passes to the IdP via query params, as well as the cookie
         # that synapse passes to the client.
 
-        oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1)
+        oauth_uri_path, _ = oauth_uri.split("?", 1)
         assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
             "unexpected SSO URI " + oauth_uri_path
         )
+        return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+
+    def complete_oidc_auth(
+        self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+    ) -> FakeChannel:
+        """Mock out an OIDC authentication flow
+
+        Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
+        initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
+        Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
+        sent back to the OIDC provider.
+
+        Requires the OIDC callback resource to be mounted at the normal place.
+
+        Args:
+            oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
+               from initiate_sso_login or initiate_sso_ui_auth).
+            cookies: the cookies set by synapse's redirect endpoint, which will be
+               sent back to the callback endpoint.
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+        """
+        _, oauth_uri_qs = oauth_uri.split("?", 1)
         params = urllib.parse.parse_qs(oauth_uri_qs)
         callback_uri = "%s?%s" % (
             urllib.parse.urlparse(params["redirect_uri"][0]).path,
@@ -456,9 +482,9 @@ class RestHelper:
         expected_requests = [
             # first we get a hit to the token endpoint, which we tell to return
             # a dummy OIDC access token
-            ("https://issuer.test/token", {"access_token": "TEST"}),
+            (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
             # and then one to the user_info endpoint, which returns our remote user id.
-            ("https://issuer.test/userinfo", user_info_dict),
+            (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
         ]
 
         async def mock_req(method: str, uri: str, data=None, headers=None):
@@ -542,25 +568,7 @@ class RestHelper:
         channel.extract_cookies(cookies)
 
         # parse the confirmation page to fish out the link.
-        class ConfirmationPageParser(HTMLParser):
-            def __init__(self):
-                super().__init__()
-
-                self.links = []  # type: List[str]
-
-            def handle_starttag(
-                self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
-            ) -> None:
-                attr_dict = dict(attrs)
-                if tag == "a":
-                    href = attr_dict["href"]
-                    if href:
-                        self.links.append(href)
-
-            def error(_, message):
-                raise AssertionError(message)
-
-        p = ConfirmationPageParser()
+        p = TestHtmlParser()
         p.feed(channel.text_body)
         p.close()
         assert len(p.links) == 1, "not exactly one link in confirmation page"
@@ -570,6 +578,8 @@ class RestHelper:
 
 # an 'oidc_config' suitable for login_via_oidc.
 TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
+TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
+TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
 TEST_OIDC_CONFIG = {
     "enabled": True,
     "discover": False,
@@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = {
     "client_secret": "test-client-secret",
     "scopes": ["profile"],
     "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
-    "token_endpoint": "https://issuer.test/token",
-    "userinfo_endpoint": "https://issuer.test/userinfo",
+    "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
+    "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
     "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
 }
diff --git a/tests/server.py b/tests/server.py
index 5a1b66270f..5a85d5fe7f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -74,7 +74,7 @@ class FakeChannel:
         return int(self.result["code"])
 
     @property
-    def headers(self):
+    def headers(self) -> Headers:
         if not self.result:
             raise Exception("No result yet.")
         h = Headers()
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index cc0612cf65..3e2fd4da01 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.db_pool,
                 stream_name="test_stream",
                 instance_name=instance_name,
-                table="foobar",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("foobar", "instance_name", "stream_id")],
                 sequence_name="foobar_seq",
                 writers=writers,
             )
@@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.db_pool,
                 stream_name="test_stream",
                 instance_name=instance_name,
-                table="foobar",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("foobar", "instance_name", "stream_id")],
                 sequence_name="foobar_seq",
                 writers=writers,
                 positive=False,
@@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
+
+
+class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db_pool = self.store.db_pool  # type: DatabasePool
+
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar1 (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+        txn.execute(
+            """
+            CREATE TABLE foobar2 (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(
+        self, instance_name="master", writers=["master"]
+    ) -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db_pool,
+                stream_name="test_stream",
+                instance_name=instance_name,
+                tables=[
+                    ("foobar1", "instance_name", "stream_id"),
+                    ("foobar2", "instance_name", "stream_id"),
+                ],
+                sequence_name="foobar_seq",
+                writers=writers,
+            )
+
+        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+    def _insert_rows(
+        self,
+        table: str,
+        instance_name: str,
+        number: int,
+        update_stream_table: bool = True,
+    ):
+        """Insert N rows as the given instance, inserting with stream IDs pulled
+        from the postgres sequence.
+        """
+
+        def _insert(txn):
+            for _ in range(number):
+                txn.execute(
+                    "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
+                    (instance_name,),
+                )
+                if update_stream_table:
+                    txn.execute(
+                        """
+                        INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
+                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+                        """,
+                        (instance_name,),
+                    )
+
+        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+    def test_load_existing_stream(self):
+        """Test creating ID gens with multiple tables that have rows from after
+        the position in `stream_positions` table.
+        """
+        self._insert_rows("foobar1", "first", 3)
+        self._insert_rows("foobar2", "second", 3)
+        self._insert_rows("foobar2", "second", 1, update_stream_table=False)
+
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+        # The first ID gen will notice that it can advance its token to 7 as it
+        # has no in progress writes...
+        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
+
+        # ... but the second ID gen doesn't know that.
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
new file mode 100644
index 0000000000..ad563eb3f0
--- /dev/null
+++ b/tests/test_utils/html_parsers.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from html.parser import HTMLParser
+from typing import Dict, Iterable, List, Optional, Tuple
+
+
+class TestHtmlParser(HTMLParser):
+    """A generic HTML page parser which extracts useful things from the HTML"""
+
+    def __init__(self):
+        super().__init__()
+
+        # a list of links found in the doc
+        self.links = []  # type: List[str]
+
+        # the values of any hidden <input>s: map from name to value
+        self.hiddens = {}  # type: Dict[str, Optional[str]]
+
+        # the values of any radio buttons: map from name to list of values
+        self.radios = {}  # type: Dict[str, List[Optional[str]]]
+
+    def handle_starttag(
+        self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+    ) -> None:
+        attr_dict = dict(attrs)
+        if tag == "a":
+            href = attr_dict["href"]
+            if href:
+                self.links.append(href)
+        elif tag == "input":
+            input_name = attr_dict.get("name")
+            if attr_dict["type"] == "radio":
+                assert input_name
+                self.radios.setdefault(input_name, []).append(attr_dict["value"])
+            elif attr_dict["type"] == "hidden":
+                assert input_name
+                self.hiddens[input_name] = attr_dict["value"]
+
+    def error(_, message):
+        raise AssertionError(message)
diff --git a/tox.ini b/tox.ini
index 297136fcc5..5210e7b860 100644
--- a/tox.ini
+++ b/tox.ini
@@ -103,6 +103,9 @@ usedevelop=true
 [testenv:py35-old]
 skip_install=True
 deps =
+    # Ensure a version of setuptools that supports Python 3.5 is installed.
+    setuptools < 51.0.0
+
     # Old automat version for Twisted
     Automat == 0.3.0