summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7384.bugfix1
-rw-r--r--changelog.d/7457.feature1
-rw-r--r--changelog.d/7465.bugfix1
-rw-r--r--changelog.d/7491.misc1
-rw-r--r--changelog.d/7505.misc1
-rw-r--r--changelog.d/7511.bugfix1
-rw-r--r--changelog.d/7513.misc1
-rw-r--r--changelog.d/7514.doc1
-rw-r--r--changelog.d/7516.misc1
-rw-r--r--changelog.d/7518.misc1
-rw-r--r--docs/reverse_proxy.md146
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/appservice/__init__.py2
-rw-r--r--synapse/event_auth.py78
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/oidc_handler.py76
-rw-r--r--synapse/handlers/room_member.py284
-rw-r--r--synapse/handlers/room_member_worker.py28
-rw-r--r--synapse/replication/slave/storage/events.py89
-rw-r--r--synapse/replication/slave/storage/push_rule.py8
-rw-r--r--synapse/replication/tcp/streams/_base.py68
-rw-r--r--synapse/rest/client/v1/login.py31
-rw-r--r--synapse/rest/client/v2_alpha/auth.py19
-rw-r--r--synapse/storage/data_stores/main/__init__.py17
-rw-r--r--synapse/storage/data_stores/main/account_data.py62
-rw-r--r--synapse/storage/data_stores/main/appservice.py4
-rw-r--r--synapse/storage/data_stores/main/cache.py102
-rw-r--r--synapse/storage/data_stores/main/events_worker.py35
-rw-r--r--synapse/storage/data_stores/main/group_server.py92
-rw-r--r--synapse/storage/data_stores/main/profile.py2
-rw-r--r--synapse/storage/data_stores/main/push_rule.py14
-rw-r--r--synapse/storage/data_stores/main/search.py96
-rw-r--r--synapse/storage/util/id_generators.py11
-rw-r--r--tests/handlers/test_oidc.py15
-rw-r--r--tests/replication/slave/storage/test_events.py16
-rw-r--r--tests/replication/tcp/streams/test_account_data.py117
-rw-r--r--tests/replication/tcp/test_commands.py4
-rw-r--r--tox.ini5
38 files changed, 931 insertions, 506 deletions
diff --git a/changelog.d/7384.bugfix b/changelog.d/7384.bugfix
new file mode 100644
index 0000000000..f49c600173
--- /dev/null
+++ b/changelog.d/7384.bugfix
@@ -0,0 +1 @@
+Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
diff --git a/changelog.d/7457.feature b/changelog.d/7457.feature
new file mode 100644
index 0000000000..7ad767bf71
--- /dev/null
+++ b/changelog.d/7457.feature
@@ -0,0 +1 @@
+Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs).
diff --git a/changelog.d/7465.bugfix b/changelog.d/7465.bugfix
new file mode 100644
index 0000000000..1cbe50caa5
--- /dev/null
+++ b/changelog.d/7465.bugfix
@@ -0,0 +1 @@
+Prevent rooms with 0 members or with invalid version strings from breaking group queries.
\ No newline at end of file
diff --git a/changelog.d/7491.misc b/changelog.d/7491.misc
new file mode 100644
index 0000000000..50eb226db7
--- /dev/null
+++ b/changelog.d/7491.misc
@@ -0,0 +1 @@
+Move event stream handling out of slave store.
diff --git a/changelog.d/7505.misc b/changelog.d/7505.misc
new file mode 100644
index 0000000000..26114a3744
--- /dev/null
+++ b/changelog.d/7505.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.event_auth`.
diff --git a/changelog.d/7511.bugfix b/changelog.d/7511.bugfix
new file mode 100644
index 0000000000..cf8bc69c6f
--- /dev/null
+++ b/changelog.d/7511.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug that broke the update remote profile background process.
diff --git a/changelog.d/7513.misc b/changelog.d/7513.misc
new file mode 100644
index 0000000000..2ea7373e29
--- /dev/null
+++ b/changelog.d/7513.misc
@@ -0,0 +1 @@
+Add type hints to room member handler.
diff --git a/changelog.d/7514.doc b/changelog.d/7514.doc
new file mode 100644
index 0000000000..981168c7e8
--- /dev/null
+++ b/changelog.d/7514.doc
@@ -0,0 +1 @@
+Improve the formatting of `reverse_proxy.md`.
diff --git a/changelog.d/7516.misc b/changelog.d/7516.misc
new file mode 100644
index 0000000000..94b0fd49b2
--- /dev/null
+++ b/changelog.d/7516.misc
@@ -0,0 +1 @@
+Add a worker store for search insertion, required for moving event persistence off master.
diff --git a/changelog.d/7518.misc b/changelog.d/7518.misc
new file mode 100644
index 0000000000..f6e143fe1c
--- /dev/null
+++ b/changelog.d/7518.misc
@@ -0,0 +1 @@
+Fix typing annotations in `tests.replication`.
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 82bd5d1cdf..cbb8269568 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -34,97 +34,107 @@ the reverse proxy and the homeserver.
 
 ### nginx
 
-        server {
-            listen 443 ssl;
-            listen [::]:443 ssl;
-            server_name matrix.example.com;
-
-            location /_matrix {
-                proxy_pass http://localhost:8008;
-                proxy_set_header X-Forwarded-For $remote_addr;
-                # Nginx by default only allows file uploads up to 1M in size
-                # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
-                client_max_body_size 10M;
-            }
-        }
-
-        server {
-            listen 8448 ssl default_server;
-            listen [::]:8448 ssl default_server;
-            server_name example.com;
-
-            location / {
-                proxy_pass http://localhost:8008;
-                proxy_set_header X-Forwarded-For $remote_addr;
-            }
-        }
-
-> **NOTE**: Do not add a `/` after the port in `proxy_pass`, otherwise nginx will
+```
+server {
+    listen 443 ssl;
+    listen [::]:443 ssl;
+    server_name matrix.example.com;
+
+    location /_matrix {
+        proxy_pass http://localhost:8008;
+        proxy_set_header X-Forwarded-For $remote_addr;
+        # Nginx by default only allows file uploads up to 1M in size
+        # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
+        client_max_body_size 10M;
+    }
+}
+
+server {
+    listen 8448 ssl default_server;
+    listen [::]:8448 ssl default_server;
+    server_name example.com;
+
+    location / {
+        proxy_pass http://localhost:8008;
+        proxy_set_header X-Forwarded-For $remote_addr;
+    }
+}
+```
+
+**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will
 canonicalise/normalise the URI.
 
 ### Caddy 1
 
-        matrix.example.com {
-          proxy /_matrix http://localhost:8008 {
-            transparent
-          }
-        }
+```
+matrix.example.com {
+  proxy /_matrix http://localhost:8008 {
+    transparent
+  }
+}
 
-        example.com:8448 {
-          proxy / http://localhost:8008 {
-            transparent
-          }
-        }
+example.com:8448 {
+  proxy / http://localhost:8008 {
+    transparent
+  }
+}
+```
 
 ### Caddy 2
 
-        matrix.example.com {
-          reverse_proxy /_matrix/* http://localhost:8008
-        }
+```
+matrix.example.com {
+  reverse_proxy /_matrix/* http://localhost:8008
+}
 
-        example.com:8448 {
-          reverse_proxy http://localhost:8008
-        }
+example.com:8448 {
+  reverse_proxy http://localhost:8008
+}
+```
 
 ### Apache
 
-        <VirtualHost *:443>
-            SSLEngine on
-            ServerName matrix.example.com;
+```
+<VirtualHost *:443>
+    SSLEngine on
+    ServerName matrix.example.com;
 
-            AllowEncodedSlashes NoDecode
-            ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
-            ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
-        </VirtualHost>
+    AllowEncodedSlashes NoDecode
+    ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
+    ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
+</VirtualHost>
 
-        <VirtualHost *:8448>
-            SSLEngine on
-            ServerName example.com;
+<VirtualHost *:8448>
+    SSLEngine on
+    ServerName example.com;
 
-            AllowEncodedSlashes NoDecode
-            ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
-            ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
-        </VirtualHost>
+    AllowEncodedSlashes NoDecode
+    ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
+    ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
+</VirtualHost>
+```
 
-> **NOTE**: ensure the  `nocanon` options are included.
+**NOTE**: ensure the  `nocanon` options are included.
 
 ### HAProxy
 
-        frontend https
-          bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
+```
+frontend https
+  bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
 
-          # Matrix client traffic
-          acl matrix-host hdr(host) -i matrix.example.com
-          acl matrix-path path_beg /_matrix
+  # Matrix client traffic
+  acl matrix-host hdr(host) -i matrix.example.com
+  acl matrix-path path_beg /_matrix
 
-          use_backend matrix if matrix-host matrix-path
+  use_backend matrix if matrix-host matrix-path
 
-        frontend matrix-federation
-          bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
-          default_backend matrix
+frontend matrix-federation
+  bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
+  default_backend matrix
 
-        backend matrix
-          server matrix 127.0.0.1:8008
+backend matrix
+  server matrix 127.0.0.1:8008
+```
 
 ## Homeserver Configuration
 
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 2e3add7ac5..ab801108ca 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -122,6 +122,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
 )
 from synapse.storage.data_stores.main.presence import UserPresenceState
+from synapse.storage.data_stores.main.search import SearchWorkerStore
 from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
 from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
 from synapse.types import ReadReceipt
@@ -451,6 +452,7 @@ class GenericWorkerSlavedStore(
     SlavedFilteringStore,
     MonthlyActiveUsersWorkerStore,
     MediaRepositoryStore,
+    SearchWorkerStore,
     BaseSlavedStore,
 ):
     def __init__(self, database, db_conn, hs):
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index aea3985a5f..1b13e84425 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -270,7 +270,7 @@ class ApplicationService(object):
     def is_exclusive_room(self, room_id):
         return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
 
-    def get_exlusive_user_regexes(self):
+    def get_exclusive_user_regexes(self):
         """Get the list of regexes used to determine if a user is exclusively
         registered by the AS
         """
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 5a5b568a95..c582355146 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Set, Tuple
+from typing import List, Optional, Set, Tuple
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -29,18 +29,19 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersion,
 )
-from synapse.types import UserID, get_domain_from_id
+from synapse.events import EventBase
+from synapse.types import StateMap, UserID, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
 
 def check(
     room_version_obj: RoomVersion,
-    event,
-    auth_events,
-    do_sig_check=True,
-    do_size_check=True,
-):
+    event: EventBase,
+    auth_events: StateMap[EventBase],
+    do_sig_check: bool = True,
+    do_size_check: bool = True,
+) -> None:
     """ Checks if this event is correctly authed.
 
     Args:
@@ -189,7 +190,7 @@ def check(
     logger.debug("Allowing! %s", event)
 
 
-def _check_size_limits(event):
+def _check_size_limits(event: EventBase) -> None:
     def too_big(field):
         raise EventSizeError("%s too large" % (field,))
 
@@ -207,13 +208,18 @@ def _check_size_limits(event):
         too_big("event")
 
 
-def _can_federate(event, auth_events):
+def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
     creation_event = auth_events.get((EventTypes.Create, ""))
+    # There should always be a creation event, but if not don't federate.
+    if not creation_event:
+        return False
 
     return creation_event.content.get("m.federate", True) is True
 
 
-def _is_membership_change_allowed(event, auth_events):
+def _is_membership_change_allowed(
+    event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
     membership = event.content["membership"]
 
     # Check if this is the room creator joining:
@@ -339,21 +345,25 @@ def _is_membership_change_allowed(event, auth_events):
         raise AuthError(500, "Unknown membership %s" % membership)
 
 
-def _check_event_sender_in_room(event, auth_events):
+def _check_event_sender_in_room(
+    event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
     key = (EventTypes.Member, event.user_id)
     member_event = auth_events.get(key)
 
-    return _check_joined_room(member_event, event.user_id, event.room_id)
+    _check_joined_room(member_event, event.user_id, event.room_id)
 
 
-def _check_joined_room(member, user_id, room_id):
+def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None:
     if not member or member.membership != Membership.JOIN:
         raise AuthError(
             403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
         )
 
 
-def get_send_level(etype, state_key, power_levels_event):
+def get_send_level(
+    etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase]
+) -> int:
     """Get the power level required to send an event of a given type
 
     The federation spec [1] refers to this as "Required Power Level".
@@ -361,13 +371,13 @@ def get_send_level(etype, state_key, power_levels_event):
     https://matrix.org/docs/spec/server_server/unstable.html#definitions
 
     Args:
-        etype (str): type of event
-        state_key (str|None): state_key of state event, or None if it is not
+        etype: type of event
+        state_key: state_key of state event, or None if it is not
             a state event.
-        power_levels_event (synapse.events.EventBase|None): power levels event
+        power_levels_event: power levels event
             in force at this point in the room
     Returns:
-        int: power level required to send this event.
+        power level required to send this event.
     """
 
     if power_levels_event:
@@ -388,7 +398,7 @@ def get_send_level(etype, state_key, power_levels_event):
     return int(send_level)
 
 
-def _can_send_event(event, auth_events):
+def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
     power_levels_event = _get_power_level_event(auth_events)
 
     send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
@@ -410,7 +420,9 @@ def _can_send_event(event, auth_events):
     return True
 
 
-def check_redaction(room_version_obj: RoomVersion, event, auth_events):
+def check_redaction(
+    room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+) -> bool:
     """Check whether the event sender is allowed to redact the target event.
 
     Returns:
@@ -442,7 +454,9 @@ def check_redaction(room_version_obj: RoomVersion, event, auth_events):
     raise AuthError(403, "You don't have permission to redact events")
 
 
-def _check_power_levels(room_version_obj, event, auth_events):
+def _check_power_levels(
+    room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+) -> None:
     user_list = event.content.get("users", {})
     # Validate users
     for k, v in user_list.items():
@@ -473,7 +487,7 @@ def _check_power_levels(room_version_obj, event, auth_events):
         ("redact", None),
         ("kick", None),
         ("invite", None),
-    ]
+    ]  # type: List[Tuple[str, Optional[str]]]
 
     old_list = current_state.content.get("users", {})
     for user in set(list(old_list) + list(user_list)):
@@ -503,12 +517,12 @@ def _check_power_levels(room_version_obj, event, auth_events):
             new_loc = new_loc.get(dir, {})
 
         if level_to_check in old_loc:
-            old_level = int(old_loc[level_to_check])
+            old_level = int(old_loc[level_to_check])  # type: Optional[int]
         else:
             old_level = None
 
         if level_to_check in new_loc:
-            new_level = int(new_loc[level_to_check])
+            new_level = int(new_loc[level_to_check])  # type: Optional[int]
         else:
             new_level = None
 
@@ -534,21 +548,21 @@ def _check_power_levels(room_version_obj, event, auth_events):
             )
 
 
-def _get_power_level_event(auth_events):
+def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
     return auth_events.get((EventTypes.PowerLevels, ""))
 
 
-def get_user_power_level(user_id, auth_events):
+def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
     """Get a user's power level
 
     Args:
-        user_id (str): user's id to look up in power_levels
-        auth_events (dict[(str, str), synapse.events.EventBase]):
+        user_id: user's id to look up in power_levels
+        auth_events:
             state in force at this point in the room (or rather, a subset of
             it including at least the create event and power levels event.
 
     Returns:
-        int: the user's power level in this room.
+        the user's power level in this room.
     """
     power_level_event = _get_power_level_event(auth_events)
     if power_level_event:
@@ -574,7 +588,7 @@ def get_user_power_level(user_id, auth_events):
             return 0
 
 
-def _get_named_level(auth_events, name, default):
+def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
     power_level_event = _get_power_level_event(auth_events)
 
     if not power_level_event:
@@ -587,7 +601,7 @@ def _get_named_level(auth_events, name, default):
         return default
 
 
-def _verify_third_party_invite(event, auth_events):
+def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]):
     """
     Validates that the invite event is authorized by a previous third-party invite.
 
@@ -662,7 +676,7 @@ def get_public_keys(invite_event):
     return public_keys
 
 
-def auth_types_for_event(event) -> Set[Tuple[str, str]]:
+def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
     """Given an event, return a list of (EventType, StateKey) that may be
     needed to auth the event. The returned list may be a superset of what
     would actually be required depending on the full state of the room.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 524281d2f1..75b39e878c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -80,7 +80,9 @@ class AuthHandler(BaseHandler):
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
-        self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled
+        self._sso_enabled = (
+            hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
+        )
 
         # we keep this as a list despite the O(N^2) implication so that we can
         # keep PASSWORD first and avoid confusing clients which pick the first
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 178f263439..4ba8c7fda5 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -311,7 +311,7 @@ class OidcHandler:
         ``ClientAuth`` to authenticate with the client with its ID and secret.
 
         Args:
-            code: The autorization code we got from the callback.
+            code: The authorization code we got from the callback.
 
         Returns:
             A dict containing various tokens.
@@ -497,11 +497,14 @@ class OidcHandler:
         return UserInfo(claims)
 
     async def handle_redirect_request(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> None:
+        self,
+        request: SynapseRequest,
+        client_redirect_url: bytes,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
         """Handle an incoming request to /login/sso/redirect
 
-        It redirects the browser to the authorization endpoint with a few
+        It returns a redirect to the authorization endpoint with a few
         parameters:
 
           - ``client_id``: the client ID set in ``oidc_config.client_id``
@@ -511,24 +514,32 @@ class OidcHandler:
           - ``state``: a random string
           - ``nonce``: a random string
 
-        In addition to redirecting the client, we are setting a cookie with
+        In addition generating a redirect URL, we are setting a cookie with
         a signed macaroon token containing the state, the nonce and the
         client_redirect_url params. Those are then checked when the client
         comes back from the provider.
 
-
         Args:
             request: the incoming request from the browser.
                 We'll respond to it with a redirect and a cookie.
             client_redirect_url: the URL that we should redirect the client to
                 when everything is done
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
+
+        Returns:
+            The redirect URL to the authorization endpoint.
+
         """
 
         state = generate_token()
         nonce = generate_token()
 
         cookie = self._generate_oidc_session_token(
-            state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
+            state=state,
+            nonce=nonce,
+            client_redirect_url=client_redirect_url.decode(),
+            ui_auth_session_id=ui_auth_session_id,
         )
         request.addCookie(
             SESSION_COOKIE_NAME,
@@ -541,7 +552,7 @@ class OidcHandler:
 
         metadata = await self.load_metadata()
         authorization_endpoint = metadata.get("authorization_endpoint")
-        uri = prepare_grant_uri(
+        return prepare_grant_uri(
             authorization_endpoint,
             client_id=self._client_auth.client_id,
             response_type="code",
@@ -550,8 +561,6 @@ class OidcHandler:
             state=state,
             nonce=nonce,
         )
-        request.redirect(uri)
-        finish_request(request)
 
     async def handle_oidc_callback(self, request: SynapseRequest) -> None:
         """Handle an incoming request to /_synapse/oidc/callback
@@ -625,7 +634,11 @@ class OidcHandler:
 
         # Deserialize the session token and verify it.
         try:
-            nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
+            (
+                nonce,
+                client_redirect_url,
+                ui_auth_session_id,
+            ) = self._verify_oidc_session_token(session, state)
         except MacaroonDeserializationException as e:
             logger.exception("Invalid session")
             self._render_error(request, "invalid_session", str(e))
@@ -678,15 +691,21 @@ class OidcHandler:
             return
 
         # and finally complete the login
-        await self._auth_handler.complete_sso_login(
-            user_id, request, client_redirect_url
-        )
+        if ui_auth_session_id:
+            await self._auth_handler.complete_sso_ui_auth(
+                user_id, ui_auth_session_id, request
+            )
+        else:
+            await self._auth_handler.complete_sso_login(
+                user_id, request, client_redirect_url
+            )
 
     def _generate_oidc_session_token(
         self,
         state: str,
         nonce: str,
         client_redirect_url: str,
+        ui_auth_session_id: Optional[str],
         duration_in_ms: int = (60 * 60 * 1000),
     ) -> str:
         """Generates a signed token storing data about an OIDC session.
@@ -702,6 +721,8 @@ class OidcHandler:
             nonce: The ``nonce`` parameter passed to the OIDC provider.
             client_redirect_url: The URL the client gave when it initiated the
                 flow.
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
             duration_in_ms: An optional duration for the token in milliseconds.
                 Defaults to an hour.
 
@@ -718,12 +739,19 @@ class OidcHandler:
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (client_redirect_url,)
         )
+        if ui_auth_session_id:
+            macaroon.add_first_party_caveat(
+                "ui_auth_session_id = %s" % (ui_auth_session_id,)
+            )
         now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
         return macaroon.serialize()
 
-    def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
+    def _verify_oidc_session_token(
+        self, session: str, state: str
+    ) -> Tuple[str, str, Optional[str]]:
         """Verifies and extract an OIDC session token.
 
         This verifies that a given session token was issued by this homeserver
@@ -734,7 +762,7 @@ class OidcHandler:
             state: The state the OIDC provider gave back
 
         Returns:
-            The nonce and the client_redirect_url for this session
+            The nonce, client_redirect_url, and ui_auth_session_id for this session
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -744,17 +772,27 @@ class OidcHandler:
         v.satisfy_exact("state = %s" % (state,))
         v.satisfy_general(lambda c: c.startswith("nonce = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+        # to always satisfy this.
+        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
         v.satisfy_general(self._verify_expiry)
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-        # Extract the `nonce` and `client_redirect_url` from the token
+        # Extract the `nonce`, `client_redirect_url`, and maybe the
+        # `ui_auth_session_id` from the token.
         nonce = self._get_value_from_macaroon(macaroon, "nonce")
         client_redirect_url = self._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
+        try:
+            ui_auth_session_id = self._get_value_from_macaroon(
+                macaroon, "ui_auth_session_id"
+            )  # type: Optional[str]
+        except ValueError:
+            ui_auth_session_id = None
 
-        return nonce, client_redirect_url
+        return nonce, client_redirect_url, ui_auth_session_id
 
     def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
         """Extracts a caveat value from a macaroon token.
@@ -773,7 +811,7 @@ class OidcHandler:
         for caveat in macaroon.caveats:
             if caveat.caveat_id.startswith(prefix):
                 return caveat.caveat_id[len(prefix) :]
-        raise Exception("No %s caveat in macaroon" % (key,))
+        raise ValueError("No %s caveat in macaroon" % (key,))
 
     def _verify_expiry(self, caveat: str) -> bool:
         prefix = "time < "
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4ddeba4c97..e51e1c32fe 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,13 +17,16 @@
 
 import abc
 import logging
+from typing import Dict, Iterable, List, Optional, Tuple, Union
 
 from six.moves import http_client
 
 from synapse import types
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.types import Collection, RoomID, UserID
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room, user_left_room
 
@@ -74,84 +77,84 @@ class RoomMemberHandler(object):
         self.base_handler = BaseHandler(hs)
 
     @abc.abstractmethod
-    async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+    async def _remote_join(
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        user: UserID,
+        content: dict,
+    ) -> Optional[dict]:
         """Try and join a room that this server is not in
 
         Args:
-            requester (Requester)
-            remote_room_hosts (list[str]): List of servers that can be used
-                to join via.
-            room_id (str): Room that we are trying to join
-            user (UserID): User who is trying to join
-            content (dict): A dict that should be used as the content of the
-                join event.
-
-        Returns:
-            Deferred
+            requester
+            remote_room_hosts: List of servers that can be used to join via.
+            room_id: Room that we are trying to join
+            user: User who is trying to join
+            content: A dict that should be used as the content of the join event.
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
     async def _remote_reject_invite(
-        self, requester, remote_room_hosts, room_id, target, content
-    ):
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        target: UserID,
+        content: dict,
+    ) -> dict:
         """Attempt to reject an invite for a room this server is not in. If we
         fail to do so we locally mark the invite as rejected.
 
         Args:
-            requester (Requester)
-            remote_room_hosts (list[str]): List of servers to use to try and
-                reject invite
-            room_id (str)
-            target (UserID): The user rejecting the invite
-            content (dict): The content for the rejection event
+            requester
+            remote_room_hosts: List of servers to use to try and reject invite
+            room_id
+            target: The user rejecting the invite
+            content: The content for the rejection event
 
         Returns:
-            Deferred[dict]: A dictionary to be returned to the client, may
+            A dictionary to be returned to the client, may
             include event_id etc, or nothing if we locally rejected
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
-    async def _user_joined_room(self, target, room_id):
+    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has joined the
         room.
 
         Args:
-            target (UserID)
-            room_id (str)
-
-        Returns:
-            None
+            target
+            room_id
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
-    async def _user_left_room(self, target, room_id):
+    async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has left the
         room.
 
         Args:
-            target (UserID)
-            room_id (str)
-
-        Returns:
-            None
+            target
+            room_id
         """
         raise NotImplementedError()
 
     async def _local_membership_update(
         self,
-        requester,
-        target,
-        room_id,
-        membership,
+        requester: Requester,
+        target: UserID,
+        room_id: str,
+        membership: str,
         prev_event_ids: Collection[str],
-        txn_id=None,
-        ratelimit=True,
-        content=None,
-        require_consent=True,
-    ):
+        txn_id: Optional[str] = None,
+        ratelimit: bool = True,
+        content: Optional[dict] = None,
+        require_consent: bool = True,
+    ) -> EventBase:
         user_id = target.to_string()
 
         if content is None:
@@ -214,16 +217,13 @@ class RoomMemberHandler(object):
 
     async def copy_room_tags_and_direct_to_room(
         self, old_room_id, new_room_id, user_id
-    ):
+    ) -> None:
         """Copies the tags and direct room state from one room to another.
 
         Args:
-            old_room_id (str)
-            new_room_id (str)
-            user_id (str)
-
-        Returns:
-            Deferred[None]
+            old_room_id: The room ID of the old room.
+            new_room_id: The room ID of the new room.
+            user_id: The user's ID.
         """
         # Retrieve user account data for predecessor room
         user_account_data, _ = await self.store.get_account_data_for_user(user_id)
@@ -253,17 +253,17 @@ class RoomMemberHandler(object):
 
     async def update_membership(
         self,
-        requester,
-        target,
-        room_id,
-        action,
-        txn_id=None,
-        remote_room_hosts=None,
-        third_party_signed=None,
-        ratelimit=True,
-        content=None,
-        require_consent=True,
-    ):
+        requester: Requester,
+        target: UserID,
+        room_id: str,
+        action: str,
+        txn_id: Optional[str] = None,
+        remote_room_hosts: Optional[List[str]] = None,
+        third_party_signed: Optional[dict] = None,
+        ratelimit: bool = True,
+        content: Optional[dict] = None,
+        require_consent: bool = True,
+    ) -> Union[EventBase, Optional[dict]]:
         key = (room_id,)
 
         with (await self.member_linearizer.queue(key)):
@@ -284,17 +284,17 @@ class RoomMemberHandler(object):
 
     async def _update_membership(
         self,
-        requester,
-        target,
-        room_id,
-        action,
-        txn_id=None,
-        remote_room_hosts=None,
-        third_party_signed=None,
-        ratelimit=True,
-        content=None,
-        require_consent=True,
-    ):
+        requester: Requester,
+        target: UserID,
+        room_id: str,
+        action: str,
+        txn_id: Optional[str] = None,
+        remote_room_hosts: Optional[List[str]] = None,
+        third_party_signed: Optional[dict] = None,
+        ratelimit: bool = True,
+        content: Optional[dict] = None,
+        require_consent: bool = True,
+    ) -> Union[EventBase, Optional[dict]]:
         content_specified = bool(content)
         if content is None:
             content = {}
@@ -468,12 +468,11 @@ class RoomMemberHandler(object):
                 else:
                     # send the rejection to the inviter's HS.
                     remote_room_hosts = remote_room_hosts + [inviter.domain]
-                    res = await self._remote_reject_invite(
+                    return await self._remote_reject_invite(
                         requester, remote_room_hosts, room_id, target, content,
                     )
-                    return res
 
-        res = await self._local_membership_update(
+        return await self._local_membership_update(
             requester=requester,
             target=target,
             room_id=room_id,
@@ -484,9 +483,10 @@ class RoomMemberHandler(object):
             content=content,
             require_consent=require_consent,
         )
-        return res
 
-    async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+    async def transfer_room_state_on_room_upgrade(
+        self, old_room_id: str, room_id: str
+    ) -> None:
         """Upon our server becoming aware of an upgraded room, either by upgrading a room
         ourselves or joining one, we can transfer over information from the previous room.
 
@@ -494,12 +494,8 @@ class RoomMemberHandler(object):
         well as migrating the room directory state.
 
         Args:
-            old_room_id (str): The ID of the old room
-
-            room_id (str): The ID of the new room
-
-        Returns:
-            Deferred
+            old_room_id: The ID of the old room
+            room_id: The ID of the new room
         """
         logger.info("Transferring room state from %s to %s", old_room_id, room_id)
 
@@ -526,17 +522,16 @@ class RoomMemberHandler(object):
             # Remove the old room from those groups
             await self.store.remove_room_from_group(group_id, old_room_id)
 
-    async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+    async def copy_user_state_on_room_upgrade(
+        self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
+    ) -> None:
         """Copy user-specific information when they join a new room when that new room is the
         result of a room upgrade
 
         Args:
-            old_room_id (str): The ID of upgraded room
-            new_room_id (str): The ID of the new room
-            user_ids (Iterable[str]): User IDs to copy state for
-
-        Returns:
-            Deferred
+            old_room_id: The ID of upgraded room
+            new_room_id: The ID of the new room
+            user_ids: User IDs to copy state for
         """
 
         logger.debug(
@@ -566,17 +561,23 @@ class RoomMemberHandler(object):
                 )
                 continue
 
-    async def send_membership_event(self, requester, event, context, ratelimit=True):
+    async def send_membership_event(
+        self,
+        requester: Requester,
+        event: EventBase,
+        context: EventContext,
+        ratelimit: bool = True,
+    ):
         """
         Change the membership status of a user in a room.
 
         Args:
-            requester (Requester): The local user who requested the membership
+            requester: The local user who requested the membership
                 event. If None, certain checks, like whether this homeserver can
                 act as the sender, will be skipped.
-            event (SynapseEvent): The membership event.
+            event: The membership event.
             context: The context of the event.
-            ratelimit (bool): Whether to rate limit this request.
+            ratelimit: Whether to rate limit this request.
         Raises:
             SynapseError if there was a problem changing the membership.
         """
@@ -636,7 +637,9 @@ class RoomMemberHandler(object):
                 if prev_member_event.membership == Membership.JOIN:
                     await self._user_left_room(target_user, room_id)
 
-    async def _can_guest_join(self, current_state_ids):
+    async def _can_guest_join(
+        self, current_state_ids: Dict[Tuple[str, str], str]
+    ) -> bool:
         """
         Returns whether a guest can join a room based on its current state.
         """
@@ -653,12 +656,14 @@ class RoomMemberHandler(object):
             and guest_access.content["guest_access"] == "can_join"
         )
 
-    async def lookup_room_alias(self, room_alias):
+    async def lookup_room_alias(
+        self, room_alias: RoomAlias
+    ) -> Tuple[RoomID, List[str]]:
         """
         Get the room ID associated with a room alias.
 
         Args:
-            room_alias (RoomAlias): The alias to look up.
+            room_alias: The alias to look up.
         Returns:
             A tuple of:
                 The room ID as a RoomID object.
@@ -682,24 +687,25 @@ class RoomMemberHandler(object):
 
         return RoomID.from_string(room_id), servers
 
-    async def _get_inviter(self, user_id, room_id):
+    async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]:
         invite = await self.store.get_invite_for_local_user_in_room(
             user_id=user_id, room_id=room_id
         )
         if invite:
             return UserID.from_string(invite.sender)
+        return None
 
     async def do_3pid_invite(
         self,
-        room_id,
-        inviter,
-        medium,
-        address,
-        id_server,
-        requester,
-        txn_id,
-        id_access_token=None,
-    ):
+        room_id: str,
+        inviter: UserID,
+        medium: str,
+        address: str,
+        id_server: str,
+        requester: Requester,
+        txn_id: Optional[str],
+        id_access_token: Optional[str] = None,
+    ) -> None:
         if self.config.block_non_admin_invites:
             is_requester_admin = await self.auth.is_server_admin(requester.user)
             if not is_requester_admin:
@@ -748,15 +754,15 @@ class RoomMemberHandler(object):
 
     async def _make_and_store_3pid_invite(
         self,
-        requester,
-        id_server,
-        medium,
-        address,
-        room_id,
-        user,
-        txn_id,
-        id_access_token=None,
-    ):
+        requester: Requester,
+        id_server: str,
+        medium: str,
+        address: str,
+        room_id: str,
+        user: UserID,
+        txn_id: Optional[str],
+        id_access_token: Optional[str] = None,
+    ) -> None:
         room_state = await self.state_handler.get_current_state(room_id)
 
         inviter_display_name = ""
@@ -830,7 +836,9 @@ class RoomMemberHandler(object):
             txn_id=txn_id,
         )
 
-    async def _is_host_in_room(self, current_state_ids):
+    async def _is_host_in_room(
+        self, current_state_ids: Dict[Tuple[str, str], str]
+    ) -> bool:
         # Have we just created the room, and is this about to be the very
         # first member event?
         create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -852,7 +860,7 @@ class RoomMemberHandler(object):
 
         return False
 
-    async def _is_server_notice_room(self, room_id):
+    async def _is_server_notice_room(self, room_id: str) -> bool:
         if self._server_notices_mxid is None:
             return False
         user_ids = await self.store.get_users_in_room(room_id)
@@ -867,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         self.distributor.declare("user_joined_room")
         self.distributor.declare("user_left_room")
 
-    async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
+    async def _is_remote_room_too_complex(
+        self, room_id: str, remote_room_hosts: List[str]
+    ) -> Optional[bool]:
         """
         Check if complexity of a remote room is too great.
 
         Args:
-            room_id (str)
-            remote_room_hosts (list[str])
+            room_id
+            remote_room_hosts
 
         Returns: bool of whether the complexity is too great, or None
             if unable to be fetched
@@ -887,21 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             return complexity["v1"] > max_complexity
         return None
 
-    async def _is_local_room_too_complex(self, room_id):
+    async def _is_local_room_too_complex(self, room_id: str) -> bool:
         """
         Check if the complexity of a local room is too great.
 
         Args:
-            room_id (str)
-
-        Returns: bool
+            room_id: The room ID to check for complexity.
         """
         max_complexity = self.hs.config.limit_remote_rooms.complexity
         complexity = await self.store.get_room_complexity(room_id)
 
         return complexity["v1"] > max_complexity
 
-    async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+    async def _remote_join(
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        user: UserID,
+        content: dict,
+    ) -> None:
         """Implements RoomMemberHandler._remote_join
         """
         # filter ourselves out of remote_room_hosts: do_invite_join ignores it
@@ -961,8 +976,13 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             )
 
     async def _remote_reject_invite(
-        self, requester, remote_room_hosts, room_id, target, content
-    ):
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        target: UserID,
+        content: dict,
+    ) -> dict:
         """Implements RoomMemberHandler._remote_reject_invite
         """
         fed_handler = self.federation_handler
@@ -983,17 +1003,17 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             await self.store.locally_reject_invite(target.to_string(), room_id)
             return {}
 
-    async def _user_joined_room(self, target, room_id):
+    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_joined_room
         """
-        return user_joined_room(self.distributor, target, room_id)
+        user_joined_room(self.distributor, target, room_id)
 
-    async def _user_left_room(self, target, room_id):
+    async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
-        return user_left_room(self.distributor, target, room_id)
+        user_left_room(self.distributor, target, room_id)
 
-    async def forget(self, user, room_id):
+    async def forget(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
 
         member = await self.state_handler.get_current_state(
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 0fc54349ab..5c776cc0be 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import List, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.handlers.room_member import RoomMemberHandler
@@ -22,6 +23,7 @@ from synapse.replication.http.membership import (
     ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
     ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
 )
+from synapse.types import Requester, UserID
 
 logger = logging.getLogger(__name__)
 
@@ -34,7 +36,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         self._remote_reject_client = ReplRejectInvite.make_client(hs)
         self._notify_change_client = ReplJoinedLeft.make_client(hs)
 
-    async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+    async def _remote_join(
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        user: UserID,
+        content: dict,
+    ) -> Optional[dict]:
         """Implements RoomMemberHandler._remote_join
         """
         if len(remote_room_hosts) == 0:
@@ -53,8 +62,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         return ret
 
     async def _remote_reject_invite(
-        self, requester, remote_room_hosts, room_id, target, content
-    ):
+        self,
+        requester: Requester,
+        remote_room_hosts: List[str],
+        room_id: str,
+        target: UserID,
+        content: dict,
+    ) -> dict:
         """Implements RoomMemberHandler._remote_reject_invite
         """
         return await self._remote_reject_client(
@@ -65,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
             content=content,
         )
 
-    async def _user_joined_room(self, target, room_id):
+    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_joined_room
         """
-        return await self._notify_change_client(
+        await self._notify_change_client(
             user_id=target.to_string(), room_id=room_id, change="joined"
         )
 
-    async def _user_left_room(self, target, room_id):
+    async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
-        return await self._notify_change_client(
+        await self._notify_change_client(
             user_id=target.to_string(), room_id=room_id, change="left"
         )
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b313720a4b..1a1a50a24f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,11 +15,6 @@
 # limitations under the License.
 import logging
 
-from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import (
-    EventsStreamCurrentStateRow,
-    EventsStreamEventRow,
-)
 from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
 from synapse.storage.data_stores.main.event_push_actions import (
     EventPushActionsWorkerStore,
@@ -35,7 +30,6 @@ from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
 
 logger = logging.getLogger(__name__)
 
@@ -62,11 +56,6 @@ class SlavedEventStore(
     BaseSlavedStore,
 ):
     def __init__(self, database: Database, db_conn, hs):
-        self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
-        self._backfill_id_gen = SlavedIdTracker(
-            db_conn, "events", "stream_ordering", step=-1
-        )
-
         super(SlavedEventStore, self).__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
@@ -92,81 +81,3 @@ class SlavedEventStore(
 
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "events":
-            self._stream_id_gen.advance(token)
-            for row in rows:
-                self._process_event_stream_row(token, row)
-        elif stream_name == "backfill":
-            self._backfill_id_gen.advance(-token)
-            for row in rows:
-                self.invalidate_caches_for_event(
-                    -token,
-                    row.event_id,
-                    row.room_id,
-                    row.type,
-                    row.state_key,
-                    row.redacts,
-                    row.relates_to,
-                    backfilled=True,
-                )
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
-
-    def _process_event_stream_row(self, token, row):
-        data = row.data
-
-        if row.type == EventsStreamEventRow.TypeId:
-            self.invalidate_caches_for_event(
-                token,
-                data.event_id,
-                data.room_id,
-                data.type,
-                data.state_key,
-                data.redacts,
-                data.relates_to,
-                backfilled=False,
-            )
-        elif row.type == EventsStreamCurrentStateRow.TypeId:
-            self._curr_state_delta_stream_cache.entity_has_changed(
-                row.data.room_id, token
-            )
-
-            if data.type == EventTypes.Member:
-                self.get_rooms_for_user_with_stream_ordering.invalidate(
-                    (data.state_key,)
-                )
-        else:
-            raise Exception("Unknown events stream row type %s" % (row.type,))
-
-    def invalidate_caches_for_event(
-        self,
-        stream_ordering,
-        event_id,
-        room_id,
-        etype,
-        state_key,
-        redacts,
-        relates_to,
-        backfilled,
-    ):
-        self._invalidate_get_event_cache(event_id)
-
-        self.get_latest_event_ids_in_room.invalidate((room_id,))
-
-        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
-
-        if not backfilled:
-            self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
-
-        if redacts:
-            self._invalidate_get_event_cache(redacts)
-
-        if etype == EventTypes.Member:
-            self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
-            self.get_invited_rooms_for_local_user.invalidate((state_key,))
-
-        if relates_to:
-            self.get_relations_for_event.invalidate_many((relates_to,))
-            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
-            self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 5d5816d7eb..6adb19463a 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,19 +15,11 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
-from synapse.storage.database import Database
 
-from ._slaved_id_tracker import SlavedIdTracker
 from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def __init__(self, database: Database, db_conn, hs):
-        self._push_rules_stream_id_gen = SlavedIdTracker(
-            db_conn, "push_rules_stream", "stream_id"
-        )
-        super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
-
     def get_push_rules_stream_token(self):
         return (
             self._push_rules_stream_id_gen.get_current_token(),
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b48a6a3e91..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import heapq
 import logging
 from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
+if TYPE_CHECKING:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 # the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
 # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
 # just a row from a database query, though this is dependent on the stream in question.
 #
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
 
 # The type returned by the update_function of a stream, as well as get_updates(),
 # get_updates_since, etc.
@@ -533,32 +546,63 @@ class AccountDataStream(Stream):
     """
 
     AccountDataStreamRow = namedtuple(
-        "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
+        "AccountDataStream",
+        ("user_id", "room_id", "data_type"),  # str  # Optional[str]  # str
     )
 
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(self.store.get_max_account_data_stream_id),
-            db_query_to_update_function(self._update_function),
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, to_token: int, limit: int
+    ) -> StreamUpdateResult:
+        limited = False
+        global_results = await self.store.get_updated_global_account_data(
+            from_token, to_token, limit
         )
 
-    async def _update_function(self, from_token, to_token, limit):
-        global_results, room_results = await self.store.get_all_updated_account_data(
-            from_token, from_token, to_token, limit
+        # if the global results hit the limit, we'll need to limit the room results to
+        # the same stream token.
+        if len(global_results) >= limit:
+            to_token = global_results[-1][0]
+            limited = True
+
+        room_results = await self.store.get_updated_room_account_data(
+            from_token, to_token, limit
         )
 
-        results = list(room_results)
-        results.extend(
-            (stream_id, user_id, None, account_data_type)
+        # likewise, if the room results hit the limit, limit the global results to
+        # the same stream token.
+        if len(room_results) >= limit:
+            to_token = room_results[-1][0]
+            limited = True
+
+        # convert the global results to the right format, and limit them to the to_token
+        # at the same time
+        global_rows = (
+            (stream_id, (user_id, None, account_data_type))
             for stream_id, user_id, account_data_type in global_results
+            if stream_id <= to_token
+        )
+
+        # we know that the room_results are already limited to `to_token` so no need
+        # for a check on `stream_id` here.
+        room_rows = (
+            (stream_id, (user_id, room_id, account_data_type))
+            for stream_id, user_id, room_id, account_data_type in room_results
         )
 
-        return results
+        # we need to return a sorted list, so merge them together.
+        updates = list(heapq.merge(room_rows, global_rows))
+        return updates, to_token, limited
 
 
 class GroupServerStream(Stream):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index de7eca21f8..d89b2e5532 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
 
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
-    def on_GET(self, request: SynapseRequest):
+    async def on_GET(self, request: SynapseRequest):
         args = request.args
         if b"redirectUrl" not in args:
             return 400, "Redirect URL not specified for SSO auth"
         client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = self.get_sso_url(client_redirect_url)
+        sso_url = await self.get_sso_url(request, client_redirect_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         """Get the URL to redirect to, to perform SSO auth
 
         Args:
+            request: The client request to redirect.
             client_redirect_url: the URL that we should redirect the
                 client to when everything is done
 
@@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._cas_handler = hs.get_cas_handler()
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         return self._cas_handler.get_redirect_url(
             {"redirectUrl": client_redirect_url}
         ).encode("ascii")
@@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._saml_handler = hs.get_saml_handler()
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         return self._saml_handler.handle_redirect_request(client_redirect_url)
 
 
-class OIDCRedirectServlet(RestServlet):
+class OIDCRedirectServlet(BaseSSORedirectServlet):
     """Implementation for /login/sso/redirect for the OIDC login flow."""
 
     PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet):
     def __init__(self, hs):
         self._oidc_handler = hs.get_oidc_handler()
 
-    async def on_GET(self, request):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
+        return await self._oidc_handler.handle_redirect_request(
+            request, client_redirect_url
+        )
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 24dd3d3e96..7bca1326d5 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
 
         # SSO configuration.
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
         self._cas_enabled = hs.config.cas_enabled
         if self._cas_enabled:
             self._cas_handler = hs.get_cas_handler()
             self._cas_server_url = hs.config.cas_server_url
             self._cas_service_url = hs.config.cas_service_url
+        self._saml_enabled = hs.config.saml2_enabled
+        if self._saml_enabled:
+            self._saml_handler = hs.get_saml_handler()
+        self._oidc_enabled = hs.config.oidc_enabled
+        if self._oidc_enabled:
+            self._oidc_handler = hs.get_oidc_handler()
+            self._cas_server_url = hs.config.cas_server_url
+            self._cas_service_url = hs.config.cas_service_url
 
     async def on_GET(self, request, stagetype):
         session = parse_string(request, "session")
@@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet):
                 )
 
             elif self._saml_enabled:
-                client_redirect_url = ""
+                client_redirect_url = b""
                 sso_redirect_url = self._saml_handler.handle_redirect_request(
                     client_redirect_url, session
                 )
 
+            elif self._oidc_enabled:
+                client_redirect_url = b""
+                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
+                    request, client_redirect_url, session
+                )
+
             else:
                 raise SynapseError(400, "Homeserver not configured for SSO.")
 
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 5df9dce79d..4b4763c701 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -24,7 +24,6 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
-    ChainedIdGenerator,
     IdGenerator,
     MultiWriterIdGenerator,
     StreamIdGenerator,
@@ -125,19 +124,6 @@ class DataStore(
         self._clock = hs.get_clock()
         self.database_engine = database.engine
 
-        self._stream_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            extra_tables=[("local_invites", "stream_id")],
-        )
-        self._backfill_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            step=-1,
-            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-        )
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
@@ -164,9 +150,6 @@ class DataStore(
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-        self._push_rules_stream_id_gen = ChainedIdGenerator(
-            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-        )
         self._pushers_id_gen = StreamIdGenerator(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b334..f9eef1b78e 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
 
 import abc
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
-    def get_all_updated_account_data(
-        self, last_global_id, last_room_id, current_id, limit
-    ):
-        """Get all the client account_data that has changed on the server
+    async def get_updated_global_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
+
         Args:
-            last_global_id(int): The position to fetch from for top level data
-            last_room_id(int): The position to fetch from for per room data
-            current_id(int): The position to fetch up to.
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
         Returns:
-            A deferred pair of lists of tuples of stream_id int, user_id string,
-            room_id string, and type string.
+            A list of tuples of stream_id int, user_id string,
+            and type string.
         """
-        if last_room_id == current_id and last_global_id == current_id:
-            return defer.succeed(([], []))
+        if last_id == current_id:
+            return []
 
-        def get_updated_account_data_txn(txn):
+        def get_updated_global_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_global_id, current_id, limit))
-            global_results = txn.fetchall()
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return await self.db.runInteraction(
+            "get_updated_global_account_data", get_updated_global_account_data_txn
+        )
+
+    async def get_updated_room_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
 
+        Args:
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
+        Returns:
+            A list of tuples of stream_id int, user_id string,
+            room_id string and type string.
+        """
+        if last_id == current_id:
+            return []
+
+        def get_updated_room_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_room_id, current_id, limit))
-            room_results = txn.fetchall()
-            return global_results, room_results
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
 
-        return self.db.runInteraction(
-            "get_all_updated_account_data_txn", get_updated_account_data_txn
+        return await self.db.runInteraction(
+            "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
     def get_updated_account_data_for_user(self, user_id, stream_id):
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index efbc06c796..7a1fe8cdd2 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -30,12 +30,12 @@ logger = logging.getLogger(__name__)
 
 
 def _make_exclusive_regex(services_cache):
-    # We precompie a regex constructed from all the regexes that the AS's
+    # We precompile a regex constructed from all the regexes that the AS's
     # have registered for exclusive users.
     exclusive_user_regexes = [
         regex.pattern
         for service in services_cache
-        for regex in service.get_exlusive_user_regexes()
+        for regex in service.get_exclusive_user_regexes()
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 342a87a46b..eac5a4e55b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,8 +16,13 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, Optional
+from typing import Any, Iterable, Optional, Tuple
 
+from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+    EventsStreamCurrentStateRow,
+    EventsStreamEventRow,
+)
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
@@ -66,7 +71,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         )
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "caches":
+        if stream_name == "events":
+            for row in rows:
+                self._process_event_stream_row(token, row)
+        elif stream_name == "backfill":
+            for row in rows:
+                self._invalidate_caches_for_event(
+                    -token,
+                    row.event_id,
+                    row.room_id,
+                    row.type,
+                    row.state_key,
+                    row.redacts,
+                    row.relates_to,
+                    backfilled=True,
+                )
+        elif stream_name == "caches":
             if self._cache_id_gen:
                 self._cache_id_gen.advance(instance_name, token)
 
@@ -85,6 +105,84 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def _process_event_stream_row(self, token, row):
+        data = row.data
+
+        if row.type == EventsStreamEventRow.TypeId:
+            self._invalidate_caches_for_event(
+                token,
+                data.event_id,
+                data.room_id,
+                data.type,
+                data.state_key,
+                data.redacts,
+                data.relates_to,
+                backfilled=False,
+            )
+        elif row.type == EventsStreamCurrentStateRow.TypeId:
+            self._curr_state_delta_stream_cache.entity_has_changed(
+                row.data.room_id, token
+            )
+
+            if data.type == EventTypes.Member:
+                self.get_rooms_for_user_with_stream_ordering.invalidate(
+                    (data.state_key,)
+                )
+        else:
+            raise Exception("Unknown events stream row type %s" % (row.type,))
+
+    def _invalidate_caches_for_event(
+        self,
+        stream_ordering,
+        event_id,
+        room_id,
+        etype,
+        state_key,
+        redacts,
+        relates_to,
+        backfilled,
+    ):
+        self._invalidate_get_event_cache(event_id)
+
+        self.get_latest_event_ids_in_room.invalidate((room_id,))
+
+        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+
+        if not backfilled:
+            self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
+
+        if redacts:
+            self._invalidate_get_event_cache(redacts)
+
+        if etype == EventTypes.Member:
+            self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+            self.get_invited_rooms_for_local_user.invalidate((state_key,))
+
+        if relates_to:
+            self.get_relations_for_event.invalidate_many((relates_to,))
+            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+            self.get_applicable_edit.invalidate((relates_to,))
+
+    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+        """Invalidates the cache and adds it to the cache stream so slaves
+        will know to invalidate their caches.
+
+        This should only be used to invalidate caches where slaves won't
+        otherwise know from other replication streams that the cache should
+        be invalidated.
+        """
+        cache_func = getattr(self, cache_name, None)
+        if not cache_func:
+            return
+
+        cache_func.invalidate(keys)
+        await self.db.runInteraction(
+            "invalidate_cache_and_stream",
+            self._send_invalidation_to_replication,
+            cache_func.__name__,
+            keys,
+        )
+
     def _invalidate_cache_and_stream(self, txn, cache_func, keys):
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 970c31bd05..9130b74eb5 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -37,8 +37,10 @@ from synapse.events import make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import get_domain_from_id
 from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
 from synapse.util.iterutils import batch_iter
@@ -74,6 +76,31 @@ class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
+        if hs.config.worker_app is None:
+            # We are the process in charge of generating stream ids for events,
+            # so instantiate ID generators based on the database
+            self._stream_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                extra_tables=[("local_invites", "stream_id")],
+            )
+            self._backfill_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                step=-1,
+                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+            )
+        else:
+            # Another process is in charge of persisting events and generating
+            # stream IDs: rely on the replication streams to let us know which
+            # IDs we can process.
+            self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
+            self._backfill_id_gen = SlavedIdTracker(
+                db_conn, "events", "stream_ordering", step=-1
+            )
+
         self._get_event_cache = Cache(
             "*getEvent*",
             keylen=3,
@@ -85,6 +112,14 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "events":
+            self._stream_id_gen.advance(token)
+        elif stream_name == "backfill":
+            self._backfill_id_gen.advance(-token)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
     def get_received_ts(self, event_id):
         """Get received_ts (when it was persisted) for the event.
 
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 0963e6c250..fb1361f1c1 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_invited_users_in_group",
         )
 
-    def get_rooms_in_group(self, group_id, include_private=False):
+    def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+        """Retrieve the rooms that belong to a given group. Does not return rooms that
+        lack members.
+
+        Args:
+            group_id: The ID of the group to query for rooms
+            include_private: Whether to return private rooms in results
+
+        Returns:
+            Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
+            form of:
+
+            {
+              "room_id": "!a_room_id:example.com",  # The ID of the room
+              "is_public": False                    # Whether this is a public room or not
+            }
+        """
         # TODO: Pagination
 
-        keyvalues = {"group_id": group_id}
-        if not include_private:
-            keyvalues["is_public"] = True
+        def _get_rooms_in_group_txn(txn):
+            sql = """
+            SELECT room_id, is_public FROM group_rooms
+                WHERE group_id = ?
+                AND room_id IN (
+                    SELECT group_rooms.room_id FROM group_rooms
+                    LEFT JOIN room_stats_current ON
+                        group_rooms.room_id = room_stats_current.room_id
+                        AND joined_members > 0
+                        AND local_users_in_room > 0
+                    LEFT JOIN rooms ON
+                        group_rooms.room_id = rooms.room_id
+                        AND (room_version <> '') = ?
+                )
+            """
+            args = [group_id, False]
 
-        return self.db.simple_select_list(
-            table="group_rooms",
-            keyvalues=keyvalues,
-            retcols=("room_id", "is_public"),
-            desc="get_rooms_in_group",
-        )
+            if not include_private:
+                sql += " AND is_public = ?"
+                args += [True]
+
+            txn.execute(sql, args)
+
+            return [
+                {"room_id": room_id, "is_public": is_public}
+                for room_id, is_public in txn
+            ]
 
-    def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+        return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+
+    def get_rooms_for_summary_by_category(
+        self, group_id: str, include_private: bool = False,
+    ):
         """Get the rooms and categories that should be included in a summary request
 
-        Returns ([rooms], [categories])
+        Args:
+            group_id: The ID of the group to query the summary for
+            include_private: Whether to return private rooms in results
+
+        Returns:
+            Deferred[Tuple[List, Dict]]: A tuple containing:
+
+                * A list of dictionaries with the keys:
+                    * "room_id": str, the room ID
+                    * "is_public": bool, whether the room is public
+                    * "category_id": str|None, the category ID if set, else None
+                    * "order": int, the sort order of rooms
+
+                * A dictionary with the key:
+                    * category_id (str): a dictionary with the keys:
+                        * "is_public": bool, whether the category is public
+                        * "profile": str, the category profile
+                        * "order": int, the sort order of rooms in this category
         """
 
         def _get_rooms_for_summary_txn(txn):
@@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
                 SELECT room_id, is_public, category_id, room_order
                 FROM group_summary_rooms
                 WHERE group_id = ?
+                AND room_id IN (
+                    SELECT group_rooms.room_id FROM group_rooms
+                    LEFT JOIN room_stats_current ON
+                        group_rooms.room_id = room_stats_current.room_id
+                        AND joined_members > 0
+                        AND local_users_in_room > 0
+                    LEFT JOIN rooms ON
+                        group_rooms.room_id = rooms.room_id
+                        AND (room_version <> '') = ?
+                )
             """
 
             if not include_private:
                 sql += " AND is_public = ?"
-                txn.execute(sql, (group_id, True))
+                txn.execute(sql, (group_id, False, True))
             else:
-                txn.execute(sql, (group_id,))
+                txn.execute(sql, (group_id, False))
 
             rooms = [
                 {
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index 2b52cf9c1a..bfc9369f0b 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore):
         return self.db.simple_update(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
-            values={
+            updatevalues={
                 "displayname": displayname,
                 "avatar_url": avatar_url,
                 "last_check": self._clock.time_msec(),
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index b3faafa0a4..ef8f40959f 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -16,19 +16,23 @@
 
 import abc
 import logging
+from typing import Union
 
 from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.push.baserules import list_with_base_rules
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
 from synapse.storage.database import Database
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.storage.util.id_generators import ChainedIdGenerator
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -64,6 +68,7 @@ class PushRulesWorkerStore(
     ReceiptsWorkerStore,
     PusherWorkerStore,
     RoomMemberWorkerStore,
+    EventsWorkerStore,
     SQLBaseStore,
 ):
     """This is an abstract base class where subclasses must implement
@@ -77,6 +82,15 @@ class PushRulesWorkerStore(
     def __init__(self, database: Database, db_conn, hs):
         super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
+        if hs.config.worker.worker_app is None:
+            self._push_rules_stream_id_gen = ChainedIdGenerator(
+                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+        else:
+            self._push_rules_stream_id_gen = SlavedIdTracker(
+                db_conn, "push_rules_stream", "stream_id"
+            )
+
         push_rules_prefill, push_rules_id = self.db.get_cache_dict(
             db_conn,
             "push_rules_stream",
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index ee75b92344..13f49d8060 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -37,7 +37,55 @@ SearchEntry = namedtuple(
 )
 
 
-class SearchBackgroundUpdateStore(SQLBaseStore):
+class SearchWorkerStore(SQLBaseStore):
+    def store_search_entries_txn(self, txn, entries):
+        """Add entries to the search table
+
+        Args:
+            txn (cursor):
+            entries (iterable[SearchEntry]):
+                entries to be added to the table
+        """
+        if not self.hs.config.enable_search:
+            return
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+
+            args = (
+                (
+                    entry.event_id,
+                    entry.room_id,
+                    entry.key,
+                    entry.value,
+                    entry.stream_ordering,
+                    entry.origin_server_ts,
+                )
+                for entry in entries
+            )
+
+            txn.executemany(sql, args)
+
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "INSERT INTO event_search (event_id, room_id, key, value)"
+                " VALUES (?,?,?,?)"
+            )
+            args = (
+                (entry.event_id, entry.room_id, entry.key, entry.value)
+                for entry in entries
+            )
+
+            txn.executemany(sql, args)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+
+class SearchBackgroundUpdateStore(SearchWorkerStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
 
         return num_rows
 
-    def store_search_entries_txn(self, txn, entries):
-        """Add entries to the search table
-
-        Args:
-            txn (cursor):
-            entries (iterable[SearchEntry]):
-                entries to be added to the table
-        """
-        if not self.hs.config.enable_search:
-            return
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
-            )
-
-            args = (
-                (
-                    entry.event_id,
-                    entry.room_id,
-                    entry.key,
-                    entry.value,
-                    entry.stream_ordering,
-                    entry.origin_server_ts,
-                )
-                for entry in entries
-            )
-
-            txn.executemany(sql, args)
-
-        elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            args = (
-                (entry.event_id, entry.room_id, entry.key, entry.value)
-                for entry in entries
-            )
-
-            txn.executemany(sql, args)
-        else:
-            # This should be unreachable.
-            raise Exception("Unrecognized database engine")
-
 
 class SearchStore(SearchBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 86d04ea9ac..f89ce0bed2 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -166,6 +166,7 @@ class ChainedIdGenerator(object):
 
     def __init__(self, chained_generator, db_conn, table, column):
         self.chained_generator = chained_generator
+        self._table = table
         self._lock = threading.Lock()
         self._current_max = _load_current_id(db_conn, table, column)
         self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
@@ -204,6 +205,16 @@ class ChainedIdGenerator(object):
 
             return self._current_max, self.chained_generator.get_current_token()
 
+    def advance(self, token: int):
+        """Stub implementation for advancing the token when receiving updates
+        over replication; raises an exception as this instance should be the
+        only source of updates.
+        """
+
+        raise Exception(
+            "Attempted to advance token on source for table %r", self._table
+        )
+
 
 class MultiWriterIdGenerator:
     """An ID generator that tracks a stream that can have multiple writers.
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 61963aa90d..1bb25ab684 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -292,11 +292,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @defer.inlineCallbacks
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
-        req = Mock(spec=["addCookie", "redirect", "finish"])
-        yield defer.ensureDeferred(
+        req = Mock(spec=["addCookie"])
+        url = yield defer.ensureDeferred(
             self.handler.handle_redirect_request(req, b"http://client/redirect")
         )
-        url = req.redirect.call_args[0][0]
         url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
 
@@ -382,7 +381,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         nonce = "nonce"
         client_redirect_url = "http://client/redirect"
         session = self.handler._generate_oidc_session_token(
-            state=state, nonce=nonce, client_redirect_url=client_redirect_url,
+            state=state,
+            nonce=nonce,
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=None,
         )
         request.getCookie.return_value = session
 
@@ -472,7 +474,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # Mismatching session
         session = self.handler._generate_oidc_session_token(
-            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
+            state="state",
+            nonce="nonce",
+            client_redirect_url="http://client/redirect",
+            ui_auth_session_id=None,
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0fee8a71c4..1a88c7fb80 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -17,11 +17,12 @@ from canonicaljson import encode_canonical_json
 
 from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
-from synapse.events.snapshot import EventContext
 from synapse.handlers.room import RoomEventSource
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.storage.roommember import RoomsForUser
 
+from tests.server import FakeTransport
+
 from ._base import BaseSlavedStoreTestCase
 
 USER_ID = "@feeling:test"
@@ -240,6 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         # limit the replication rate
         repl_transport = self._server_transport
+        assert isinstance(repl_transport, FakeTransport)
         repl_transport.autoflush = False
 
         # build the join and message events and persist them in the same batch.
@@ -322,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         type="m.room.message",
         key=None,
         internal={},
-        state=None,
         depth=None,
         prev_events=[],
         auth_events=[],
@@ -362,15 +363,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
 
         self.event_id += 1
-
-        if state is not None:
-            state_ids = {key: e.event_id for key, e in state.items()}
-            context = EventContext.with_state(
-                state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
-            )
-        else:
-            state_handler = self.hs.get_state_handler()
-            context = self.get_success(state_handler.compute_event_context(event))
+        state_handler = self.hs.get_state_handler()
+        context = self.get_success(state_handler.compute_event_context(event))
 
         self.master_store.add_push_actions_to_staging(
             event.event_id, {user_id: actions for user_id, actions in push_actions}
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
new file mode 100644
index 0000000000..6a5116dd2a
--- /dev/null
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.tcp.streams._base import (
+    _STREAM_UPDATE_TARGET_ROW_COUNT,
+    AccountDataStream,
+)
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class AccountDataStreamTestCase(BaseStreamTestCase):
+    def test_update_function_room_account_data_limit(self):
+        """Test replication with many room account data updates
+        """
+        store = self.hs.get_datastore()
+
+        # generate lots of account data updates
+        updates = []
+        for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+            update = "m.test_type.%i" % (i,)
+            self.get_success(
+                store.add_account_data_to_room("test_user", "test_room", update, {})
+            )
+            updates.append(update)
+
+        # also one global update
+        self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
+
+        # tell the notifier to catch up to avoid duplicate rows.
+        # workaround for https://github.com/matrix-org/synapse/issues/7360
+        # FIXME remove this when the above is fixed
+        self.replicate()
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual([], self.test_handler.received_rdata_rows)
+
+        # now reconnect to pull the updates
+        self.reconnect()
+        self.replicate()
+
+        # we should have received all the expected rows in the right order
+        received_rows = self.test_handler.received_rdata_rows
+
+        for t in updates:
+            (stream_name, token, row) = received_rows.pop(0)
+            self.assertEqual(stream_name, AccountDataStream.NAME)
+            self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+            self.assertEqual(row.data_type, t)
+            self.assertEqual(row.room_id, "test_room")
+
+        (stream_name, token, row) = received_rows.pop(0)
+        self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+        self.assertEqual(row.data_type, "m.global")
+        self.assertIsNone(row.room_id)
+
+        self.assertEqual([], received_rows)
+
+    def test_update_function_global_account_data_limit(self):
+        """Test replication with many global account data updates
+        """
+        store = self.hs.get_datastore()
+
+        # generate lots of account data updates
+        updates = []
+        for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+            update = "m.test_type.%i" % (i,)
+            self.get_success(store.add_account_data_for_user("test_user", update, {}))
+            updates.append(update)
+
+        # also one per-room update
+        self.get_success(
+            store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
+        )
+
+        # tell the notifier to catch up to avoid duplicate rows.
+        # workaround for https://github.com/matrix-org/synapse/issues/7360
+        # FIXME remove this when the above is fixed
+        self.replicate()
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual([], self.test_handler.received_rdata_rows)
+
+        # now reconnect to pull the updates
+        self.reconnect()
+        self.replicate()
+
+        # we should have received all the expected rows in the right order
+        received_rows = self.test_handler.received_rdata_rows
+
+        for t in updates:
+            (stream_name, token, row) = received_rows.pop(0)
+            self.assertEqual(stream_name, AccountDataStream.NAME)
+            self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+            self.assertEqual(row.data_type, t)
+            self.assertIsNone(row.room_id)
+
+        (stream_name, token, row) = received_rows.pop(0)
+        self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+        self.assertEqual(row.data_type, "m.per_room")
+        self.assertEqual(row.room_id, "test_room")
+
+        self.assertEqual([], received_rows)
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index 7ddfd0a733..60c10a441a 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -30,7 +30,7 @@ class ParseCommandTestCase(TestCase):
     def test_parse_rdata(self):
         line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
         cmd = parse_command_from_line(line)
-        self.assertIsInstance(cmd, RdataCommand)
+        assert isinstance(cmd, RdataCommand)
         self.assertEqual(cmd.stream_name, "events")
         self.assertEqual(cmd.instance_name, "master")
         self.assertEqual(cmd.token, 6287863)
@@ -38,7 +38,7 @@ class ParseCommandTestCase(TestCase):
     def test_parse_rdata_batch(self):
         line = 'RDATA presence master batch ["@foo:example.com", "online"]'
         cmd = parse_command_from_line(line)
-        self.assertIsInstance(cmd, RdataCommand)
+        assert isinstance(cmd, RdataCommand)
         self.assertEqual(cmd.stream_name, "presence")
         self.assertEqual(cmd.instance_name, "master")
         self.assertIsNone(cmd.token)
diff --git a/tox.ini b/tox.ini
index 203c648008..3bb4d45e2a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -180,6 +180,7 @@ commands = mypy \
             synapse/api \
             synapse/appservice \
             synapse/config \
+            synapse/event_auth.py \
             synapse/events/spamcheck.py \
             synapse/federation \
             synapse/handlers/auth.py \
@@ -187,6 +188,8 @@ commands = mypy \
             synapse/handlers/directory.py \
             synapse/handlers/oidc_handler.py \
             synapse/handlers/presence.py \
+            synapse/handlers/room_member.py \
+            synapse/handlers/room_member_worker.py \
             synapse/handlers/saml_handler.py \
             synapse/handlers/sync.py \
             synapse/handlers/ui_auth \
@@ -204,7 +207,7 @@ commands = mypy \
             synapse/storage/util \
             synapse/streams \
             synapse/util/caches/stream_change_cache.py \
-            tests/replication/tcp/streams \
+            tests/replication \
             tests/test_utils \
             tests/rest/client/v2_alpha/test_auth.py \
             tests/util/test_stream_change_cache.py