summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py29
-rw-r--r--synapse/storage/databases/main/purge_events.py1
-rw-r--r--synapse/storage/databases/main/pusher.py72
-rw-r--r--synapse/storage/databases/main/registration.py341
-rw-r--r--synapse/storage/databases/main/roommember.py16
-rw-r--r--synapse/storage/databases/main/session.py145
-rw-r--r--synapse/storage/databases/main/ui_auth.py43
-rw-r--r--synapse/storage/databases/main/user_directory.py2
-rw-r--r--synapse/storage/roommember.py44
-rw-r--r--synapse/storage/schema/__init__.py2
-rw-r--r--synapse/storage/schema/main/delta/62/02session_store.sql23
-rw-r--r--synapse/storage/schema/main/delta/63/01create_registration_tokens.sql23
-rw-r--r--synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql20
14 files changed, 730 insertions, 33 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 01b918e12e..00a644e8f7 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -63,6 +63,7 @@ from .relations import RelationsStore
 from .room import RoomStore
 from .roommember import RoomMemberStore
 from .search import SearchStore
+from .session import SessionStore
 from .signatures import SignatureStore
 from .state import StateStore
 from .stats import StatsStore
@@ -121,6 +122,7 @@ class DataStore(
     ServerMetricsStore,
     EventForwardExtremitiesStore,
     LockStore,
+    SessionStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self.hs = hs
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 375463e4e9..9501f00f3b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -520,16 +520,26 @@ class EventsWorkerStore(SQLBaseStore):
         # We now look up if we're already fetching some of the events in the DB,
         # if so we wait for those lookups to finish instead of pulling the same
         # events out of the DB multiple times.
-        already_fetching: Dict[str, defer.Deferred] = {}
+        #
+        # Note: we might get the same `ObservableDeferred` back for multiple
+        # events we're already fetching, so we deduplicate the deferreds to
+        # avoid extraneous work (if we don't do this we can end up in a n^2 mode
+        # when we wait on the same Deferred N times, then try and merge the
+        # same dict into itself N times).
+        already_fetching_ids: Set[str] = set()
+        already_fetching_deferreds: Set[
+            ObservableDeferred[Dict[str, _EventCacheEntry]]
+        ] = set()
 
         for event_id in missing_events_ids:
             deferred = self._current_event_fetches.get(event_id)
             if deferred is not None:
                 # We're already pulling the event out of the DB. Add the deferred
                 # to the collection of deferreds to wait on.
-                already_fetching[event_id] = deferred.observe()
+                already_fetching_ids.add(event_id)
+                already_fetching_deferreds.add(deferred)
 
-        missing_events_ids.difference_update(already_fetching)
+        missing_events_ids.difference_update(already_fetching_ids)
 
         if missing_events_ids:
             log_ctx = current_context()
@@ -569,18 +579,25 @@ class EventsWorkerStore(SQLBaseStore):
             with PreserveLoggingContext():
                 fetching_deferred.callback(missing_events)
 
-        if already_fetching:
+        if already_fetching_deferreds:
             # Wait for the other event requests to finish and add their results
             # to ours.
             results = await make_deferred_yieldable(
                 defer.gatherResults(
-                    already_fetching.values(),
+                    (d.observe() for d in already_fetching_deferreds),
                     consumeErrors=True,
                 )
             ).addErrback(unwrapFirstError)
 
             for result in results:
-                event_entry_map.update(result)
+                # We filter out events that we haven't asked for as we might get
+                # a *lot* of superfluous events back, and there is no point
+                # going through and inserting them all (which can take time).
+                event_entry_map.update(
+                    (event_id, entry)
+                    for event_id, entry in result.items()
+                    if event_id in already_fetching_ids
+                )
 
         if not allow_rejected:
             event_entry_map = {
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 664c65dac5..bccff5e5b9 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 self._invalidate_cache_and_stream(
                     txn, self.have_seen_event, (room_id, event_id)
                 )
+                self._invalidate_get_event_cache(event_id)
 
         logger.info("[purge] done")
 
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b48fe086d4..e47caa2125 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore):
             self._remove_stale_pushers,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "remove_deleted_email_pushers",
+            self._remove_deleted_email_pushers,
+        )
+
     def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
         """JSON-decode the data in the rows returned from the `pushers` table
 
@@ -388,6 +393,73 @@ class PusherWorkerStore(SQLBaseStore):
 
         return number_deleted
 
+    async def _remove_deleted_email_pushers(
+        self, progress: dict, batch_size: int
+    ) -> int:
+        """A background update that deletes all pushers for deleted email addresses.
+
+        In previous versions of synapse, when users deleted their email address, it didn't
+        also delete all the pushers for that email address. This background update removes
+        those to prevent unwanted emails. This should only need to be run once (when users
+        upgrade to v1.42.0
+
+        Args:
+            progress: dict used to store progress of this background update
+            batch_size: the maximum number of rows to retrieve in a single select query
+
+        Returns:
+            The number of deleted rows
+        """
+
+        last_pusher = progress.get("last_pusher", 0)
+
+        def _delete_pushers(txn) -> int:
+
+            sql = """
+                SELECT p.id, p.user_name, p.app_id, p.pushkey
+                FROM pushers AS p
+                    LEFT JOIN user_threepids AS t
+                        ON t.user_id = p.user_name
+                        AND t.medium = 'email'
+                        AND t.address = p.pushkey
+                WHERE t.user_id is NULL
+                    AND p.app_id = 'm.email'
+                    AND p.id > ?
+                ORDER BY p.id ASC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_pusher, batch_size))
+
+            last = None
+            num_deleted = 0
+            for row in txn:
+                last = row[0]
+                num_deleted += 1
+                self.db_pool.simple_delete_txn(
+                    txn,
+                    "pushers",
+                    {"user_name": row[1], "app_id": row[2], "pushkey": row[3]},
+                )
+
+            if last is not None:
+                self.db_pool.updates._background_update_progress_txn(
+                    txn, "remove_deleted_email_pushers", {"last_pusher": last}
+                )
+
+            return num_deleted
+
+        number_deleted = await self.db_pool.runInteraction(
+            "_remove_deleted_email_pushers", _delete_pushers
+        )
+
+        if number_deleted < batch_size:
+            await self.db_pool.updates._end_background_update(
+                "remove_deleted_email_pushers"
+            )
+
+        return number_deleted
+
 
 class PusherStore(PusherWorkerStore):
     def get_pushers_stream_token(self) -> int:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c67bea81c6..a6517962f6 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
         return user_id
 
-    def get_user_id_by_threepid_txn(self, txn, medium, address):
+    def get_user_id_by_threepid_txn(
+        self, txn, medium: str, address: str
+    ) -> Optional[str]:
         """Returns user id from threepid
 
         Args:
             txn (cursor):
-            medium (str): threepid medium e.g. email
-            address (str): threepid address e.g. me@example.com
+            medium: threepid medium e.g. email
+            address: threepid address e.g. me@example.com
 
         Returns:
-            str|None: user id or None if no user id/threepid mapping exists
+            user id, or None if no user id/threepid mapping exists
         """
         ret = self.db_pool.simple_select_one_txn(
             txn,
@@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             return ret["user_id"]
         return None
 
-    async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+    async def user_add_threepid(
+        self,
+        user_id: str,
+        medium: str,
+        address: str,
+        validated_at: int,
+        added_at: int,
+    ) -> None:
         await self.db_pool.simple_upsert(
             "user_threepids",
             {"medium": medium, "address": address},
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
         )
 
-    async def user_get_threepids(self, user_id):
+    async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
         return await self.db_pool.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
@@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "user_get_threepids",
         )
 
-    async def user_delete_threepid(self, user_id, medium, address) -> None:
+    async def user_delete_threepid(
+        self, user_id: str, medium: str, address: str
+    ) -> None:
         await self.db_pool.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
@@ -1157,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="update_access_token_last_validated",
         )
 
+    async def registration_token_is_valid(self, token: str) -> bool:
+        """Checks if a token can be used to authenticate a registration.
+
+        Args:
+            token: The registration token to be checked
+        Returns:
+            True if the token is valid, False otherwise.
+        """
+        res = await self.db_pool.simple_select_one(
+            "registration_tokens",
+            keyvalues={"token": token},
+            retcols=["uses_allowed", "pending", "completed", "expiry_time"],
+            allow_none=True,
+        )
+
+        # Check if the token exists
+        if res is None:
+            return False
+
+        # Check if the token has expired
+        now = self._clock.time_msec()
+        if res["expiry_time"] and res["expiry_time"] < now:
+            return False
+
+        # Check if the token has been used up
+        if (
+            res["uses_allowed"]
+            and res["pending"] + res["completed"] >= res["uses_allowed"]
+        ):
+            return False
+
+        # Otherwise, the token is valid
+        return True
+
+    async def set_registration_token_pending(self, token: str) -> None:
+        """Increment the pending registrations counter for a token.
+
+        Args:
+            token: The registration token pending use
+        """
+
+        def _set_registration_token_pending_txn(txn):
+            pending = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+            self.db_pool.simple_update_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={"pending": pending + 1},
+            )
+
+        return await self.db_pool.runInteraction(
+            "set_registration_token_pending", _set_registration_token_pending_txn
+        )
+
+    async def use_registration_token(self, token: str) -> None:
+        """Complete a use of the given registration token.
+
+        The `pending` counter will be decremented, and the `completed`
+        counter will be incremented.
+
+        Args:
+            token: The registration token to be 'used'
+        """
+
+        def _use_registration_token_txn(txn):
+            # Normally, res is Optional[Dict[str, Any]].
+            # Override type because the return type is only optional if
+            # allow_none is True, and we don't want mypy throwing errors
+            # about None not being indexable.
+            res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )  # type: ignore
+
+            # Decrement pending and increment completed
+            self.db_pool.simple_update_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={
+                    "completed": res["completed"] + 1,
+                    "pending": res["pending"] - 1,
+                },
+            )
+
+        return await self.db_pool.runInteraction(
+            "use_registration_token", _use_registration_token_txn
+        )
+
+    async def get_registration_tokens(
+        self, valid: Optional[bool] = None
+    ) -> List[Dict[str, Any]]:
+        """List all registration tokens. Used by the admin API.
+
+        Args:
+            valid: If True, only valid tokens are returned.
+              If False, only invalid tokens are returned.
+              Default is None: return all tokens regardless of validity.
+
+        Returns:
+            A list of dicts, each containing details of a token.
+        """
+
+        def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+            if valid is None:
+                # Return all tokens regardless of validity
+                txn.execute("SELECT * FROM registration_tokens")
+
+            elif valid:
+                # Select valid tokens only
+                sql = (
+                    "SELECT * FROM registration_tokens WHERE "
+                    "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
+                    "AND (expiry_time > ? OR expiry_time IS NULL)"
+                )
+                txn.execute(sql, [now])
+
+            else:
+                # Select invalid tokens only
+                sql = (
+                    "SELECT * FROM registration_tokens WHERE "
+                    "uses_allowed <= pending + completed OR expiry_time <= ?"
+                )
+                txn.execute(sql, [now])
+
+            return self.db_pool.cursor_to_dict(txn)
+
+        return await self.db_pool.runInteraction(
+            "select_registration_tokens",
+            select_registration_tokens_txn,
+            self._clock.time_msec(),
+            valid,
+        )
+
+    async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
+        """Get info about the given registration token. Used by the admin API.
+
+        Args:
+            token: The token to retrieve information about.
+
+        Returns:
+            A dict, or None if token doesn't exist.
+        """
+        return await self.db_pool.simple_select_one(
+            "registration_tokens",
+            keyvalues={"token": token},
+            retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
+            allow_none=True,
+            desc="get_one_registration_token",
+        )
+
+    async def generate_registration_token(
+        self, length: int, chars: str
+    ) -> Optional[str]:
+        """Generate a random registration token. Used by the admin API.
+
+        Args:
+            length: The length of the token to generate.
+            chars: A string of the characters allowed in the generated token.
+
+        Returns:
+            The generated token.
+
+        Raises:
+            SynapseError if a unique registration token could still not be
+            generated after a few tries.
+        """
+        # Make a few attempts at generating a unique token of the required
+        # length before failing.
+        for _i in range(3):
+            # Generate token
+            token = "".join(random.choices(chars, k=length))
+
+            # Check if the token already exists
+            existing_token = await self.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="token",
+                allow_none=True,
+                desc="check_if_registration_token_exists",
+            )
+
+            if existing_token is None:
+                # The generated token doesn't exist yet, return it
+                return token
+
+        raise SynapseError(
+            500,
+            "Unable to generate a unique registration token. Try again with a greater length",
+            Codes.UNKNOWN,
+        )
+
+    async def create_registration_token(
+        self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
+    ) -> bool:
+        """Create a new registration token. Used by the admin API.
+
+        Args:
+            token: The token to create.
+            uses_allowed: The number of times the token can be used to complete
+              a registration before it becomes invalid. A value of None indicates
+              unlimited uses.
+            expiry_time: The latest time the token is valid. Given as the
+              number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
+              None indicates that the token does not expire.
+
+        Returns:
+            Whether the row was inserted or not.
+        """
+
+        def _create_registration_token_txn(txn):
+            row = self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["token"],
+                allow_none=True,
+            )
+
+            if row is not None:
+                # Token already exists
+                return False
+
+            self.db_pool.simple_insert_txn(
+                txn,
+                "registration_tokens",
+                values={
+                    "token": token,
+                    "uses_allowed": uses_allowed,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": expiry_time,
+                },
+            )
+
+            return True
+
+        return await self.db_pool.runInteraction(
+            "create_registration_token", _create_registration_token_txn
+        )
+
+    async def update_registration_token(
+        self, token: str, updatevalues: Dict[str, Optional[int]]
+    ) -> Optional[Dict[str, Any]]:
+        """Update a registration token. Used by the admin API.
+
+        Args:
+            token: The token to update.
+            updatevalues: A dict with the fields to update. E.g.:
+              `{"uses_allowed": 3}` to update just uses_allowed, or
+              `{"uses_allowed": 3, "expiry_time": None}` to update both.
+              This is passed straight to simple_update_one.
+
+        Returns:
+            A dict with all info about the token, or None if token doesn't exist.
+        """
+
+        def _update_registration_token_txn(txn):
+            try:
+                self.db_pool.simple_update_one_txn(
+                    txn,
+                    "registration_tokens",
+                    keyvalues={"token": token},
+                    updatevalues=updatevalues,
+                )
+            except StoreError:
+                # Update failed because token does not exist
+                return None
+
+            # Get all info about the token so it can be sent in the response
+            return self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=[
+                    "token",
+                    "uses_allowed",
+                    "pending",
+                    "completed",
+                    "expiry_time",
+                ],
+                allow_none=True,
+            )
+
+        return await self.db_pool.runInteraction(
+            "update_registration_token", _update_registration_token_txn
+        )
+
+    async def delete_registration_token(self, token: str) -> bool:
+        """Delete a registration token. Used by the admin API.
+
+        Args:
+            token: The token to delete.
+
+        Returns:
+            Whether the token was successfully deleted or not.
+        """
+        try:
+            await self.db_pool.simple_delete_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                desc="delete_registration_token",
+            )
+        except StoreError:
+            # Deletion failed because token does not exist
+            return False
+
+        return True
+
     @cached()
     async def mark_access_token_as_used(self, token_id: int) -> None:
         """
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e8157ba3d4..c58a4b8690 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached()
-    async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
+    async def get_invited_rooms_for_local_user(
+        self, user_id: str
+    ) -> List[RoomsForUser]:
         """Get all the rooms the *local* user is invited to.
 
         Args:
@@ -384,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
         sql = """
-            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
             FROM local_current_membership AS c
             INNER JOIN events AS e USING (room_id, event_id)
+            INNER JOIN rooms AS r USING (room_id)
             WHERE
                 user_id = ?
                 AND %s
@@ -395,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
         txn.execute(sql, (user_id, *args))
-        results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
+        results = [RoomsForUser(*r) for r in txn]
 
         return results
 
@@ -445,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns:
             Returns the rooms the user is in currently, along with the stream
-            ordering of the most recent join for that user and room.
+            ordering of the most recent join for that user and room, along with
+            the room version of the room.
         """
         return await self.db_pool.runInteraction(
             "get_rooms_for_user_with_stream_ordering",
@@ -522,7 +526,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             _get_users_server_still_shares_room_with_txn,
         )
 
-    async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
+    async def get_rooms_for_user(
+        self, user_id: str, on_invalidate=None
+    ) -> FrozenSet[str]:
         """Returns a set of room_ids the user is currently joined to.
 
         If a remote user only returns rooms this server is currently
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
new file mode 100644
index 0000000000..172f27d109
--- /dev/null
+++ b/synapse/storage/databases/main/session.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+from typing import TYPE_CHECKING
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+class SessionStore(SQLBaseStore):
+    """
+    A store for generic session data.
+
+    Each type of session should provide a unique type (to separate sessions).
+
+    Sessions are automatically removed when they expire.
+    """
+
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        # Create a background job for culling expired sessions.
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
+
+    async def create_session(
+        self, session_type: str, value: JsonDict, expiry_ms: int
+    ) -> str:
+        """
+        Creates a new pagination session for the room hierarchy endpoint.
+
+        Args:
+            session_type: The type for this session.
+            value: The value to store.
+            expiry_ms: How long before an item is evicted from the cache
+                in milliseconds. Default is 0, indicating items never get
+                evicted based on time.
+
+        Returns:
+            The newly created session ID.
+
+        Raises:
+            StoreError if a unique session ID cannot be generated.
+        """
+        # autogen a session ID and try to create it. We may clash, so just
+        # try a few times till one goes through, giving up eventually.
+        attempts = 0
+        while attempts < 5:
+            session_id = stringutils.random_string(24)
+
+            try:
+                await self.db_pool.simple_insert(
+                    table="sessions",
+                    values={
+                        "session_id": session_id,
+                        "session_type": session_type,
+                        "value": json_encoder.encode(value),
+                        "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
+                    },
+                    desc="create_session",
+                )
+
+                return session_id
+            except self.db_pool.engine.module.IntegrityError:
+                attempts += 1
+        raise StoreError(500, "Couldn't generate a session ID.")
+
+    async def get_session(self, session_type: str, session_id: str) -> JsonDict:
+        """
+        Retrieve data stored with create_session
+
+        Args:
+            session_type: The type for this session.
+            session_id: The session ID returned from create_session.
+
+        Raises:
+            StoreError if the session cannot be found.
+        """
+
+        def _get_session(
+            txn: LoggingTransaction, session_type: str, session_id: str, ts: int
+        ) -> JsonDict:
+            # This includes the expiry time since items are only periodically
+            # deleted, not upon expiry.
+            select_sql = """
+            SELECT value FROM sessions WHERE
+            session_type = ? AND session_id = ? AND expiry_time_ms > ?
+            """
+            txn.execute(select_sql, [session_type, session_id, ts])
+            row = txn.fetchone()
+
+            if not row:
+                raise StoreError(404, "No session")
+
+            return db_to_json(row[0])
+
+        return await self.db_pool.runInteraction(
+            "get_session",
+            _get_session,
+            session_type,
+            session_id,
+            self._clock.time_msec(),
+        )
+
+    @wrap_as_background_process("delete_expired_sessions")
+    async def _delete_expired_sessions(self) -> None:
+        """Remove sessions with expiry dates that have passed."""
+
+        def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
+            sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
+            txn.execute(sql, (ts,))
+
+        await self.db_pool.runInteraction(
+            "delete_expired_sessions",
+            _delete_expired_sessions_txn,
+            self._clock.time_msec(),
+        )
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad..4d6bbc94c7 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 
 import attr
 
+from synapse.api.constants import LoginType
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
             keyvalues={},
         )
 
+        # If a registration token was used, decrement the pending counter
+        # before deleting the session.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="ui_auth_sessions_credentials",
+            column="session_id",
+            iterable=session_ids,
+            keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+            retcols=["result"],
+        )
+
+        # Get the tokens used and how much pending needs to be decremented by.
+        token_counts: Dict[str, int] = {}
+        for r in rows:
+            # If registration was successfully completed, the result of the
+            # registration token stage for that session will be True.
+            # If a token was used to authenticate, but registration was
+            # never completed, the result will be the token used.
+            token = db_to_json(r["result"])
+            if isinstance(token, str):
+                token_counts[token] = token_counts.get(token, 0) + 1
+
+        # Update the `pending` counters.
+        if len(token_counts) > 0:
+            token_rows = self.db_pool.simple_select_many_txn(
+                txn,
+                table="registration_tokens",
+                column="token",
+                iterable=list(token_counts.keys()),
+                keyvalues={},
+                retcols=["token", "pending"],
+            )
+            for token_row in token_rows:
+                token = token_row["token"]
+                new_pending = token_row["pending"] - token_counts[token]
+                self.db_pool.simple_update_one_txn(
+                    txn,
+                    table="registration_tokens",
+                    keyvalues={"token": token},
+                    updatevalues={"pending": new_pending},
+                )
+
         # Delete the corresponding completed credentials.
         self.db_pool.simple_delete_many_txn(
             txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 9d28d69ac7..65dde67ae9 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         return False
 
     async def update_profile_in_user_dir(
-        self, user_id: str, display_name: str, avatar_url: str
+        self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
     ) -> None:
         """
         Update or add a user's profile in the user directory.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c34fbf21bc..2500381b7b 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -14,25 +14,41 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
+from typing import List, Optional, Tuple
+
+import attr
+
+from synapse.types import PersistedEventPosition
 
 logger = logging.getLogger(__name__)
 
 
-RoomsForUser = namedtuple(
-    "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class RoomsForUser:
+    room_id: str
+    sender: str
+    membership: str
+    event_id: str
+    stream_ordering: int
+    room_version_id: str
+
+
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class GetRoomsForUserWithStreamOrdering:
+    room_id: str
+    event_pos: PersistedEventPosition
 
-GetRoomsForUserWithStreamOrdering = namedtuple(
-    "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
-)
 
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class ProfileInfo:
+    avatar_url: Optional[str]
+    display_name: Optional[str]
 
-# We store this using a namedtuple so that we save about 3x space over using a
-# dict.
-ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
 
-# "members" points to a truncated list of (user_id, event_id) tuples for users of
-# a given membership type, suitable for use in calculating heroes for a room.
-# "count" points to the total numberr of users of a given membership type.
-MemberSummary = namedtuple("MemberSummary", ("members", "count"))
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class MemberSummary:
+    # A truncated list of (user_id, event_id) tuples for users of a given
+    # membership type, suitable for use in calculating heroes for a room.
+    members: List[Tuple[str, str]]
+    # The total number of users of a given membership type.
+    count: int
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index a5bc0ee8a5..af9cc69949 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+# When updating these values, please leave a short summary of the changes below.
+
 SCHEMA_VERSION = 63
 """Represents the expectations made by the codebase about the database schema
 
diff --git a/synapse/storage/schema/main/delta/62/02session_store.sql b/synapse/storage/schema/main/delta/62/02session_store.sql
new file mode 100644
index 0000000000..535fb34c10
--- /dev/null
+++ b/synapse/storage/schema/main/delta/62/02session_store.sql
@@ -0,0 +1,23 @@
+/*
+ * Copyright 2021 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS sessions(
+    session_type TEXT NOT NULL,  -- The unique key for this type of session.
+    session_id TEXT NOT NULL,  -- The session ID passed to the client.
+    value TEXT NOT NULL, -- A JSON dictionary to persist.
+    expiry_time_ms BIGINT NOT NULL,  -- The time this session will expire (epoch time in milliseconds).
+    UNIQUE (session_type, session_id)
+);
diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
new file mode 100644
index 0000000000..ee6cf958f4
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
@@ -0,0 +1,23 @@
+/* Copyright 2021 Callum Brown
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS registration_tokens(
+    token TEXT NOT NULL,  -- The token that can be used for authentication.
+    uses_allowed INT,  -- The total number of times this token can be used. NULL if no limit.
+    pending INT NOT NULL, -- The number of in progress registrations using this token.
+    completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
+    expiry_time BIGINT,  -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
+    UNIQUE (token)
+);
diff --git a/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
new file mode 100644
index 0000000000..611c4b95cf
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We may not have deleted all pushers for emails that are no longer linked
+-- to an account, so we set up a background job to delete them.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6302, 'remove_deleted_email_pushers', '{}');