summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/background_updates.py2
-rw-r--r--synapse/storage/data_stores/main/__init__.py34
-rw-r--r--synapse/storage/data_stores/main/appservice.py14
-rw-r--r--synapse/storage/data_stores/main/client_ips.py4
-rw-r--r--synapse/storage/data_stores/main/devices.py13
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py33
-rw-r--r--synapse/storage/data_stores/main/event_federation.py188
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py34
-rw-r--r--synapse/storage/data_stores/main/events.py18
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py2
-rw-r--r--synapse/storage/data_stores/main/events_worker.py90
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py32
-rw-r--r--synapse/storage/data_stores/main/push_rule.py8
-rw-r--r--synapse/storage/data_stores/main/pusher.py156
-rw-r--r--synapse/storage/data_stores/main/receipts.py4
-rw-r--r--synapse/storage/data_stores/main/room.py37
-rw-r--r--synapse/storage/data_stores/main/roommember.py4
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres39
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite23
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.md24
-rw-r--r--synapse/storage/data_stores/main/state.py8
-rw-r--r--synapse/storage/data_stores/main/state_deltas.py4
-rw-r--r--synapse/storage/data_stores/main/stream.py8
-rw-r--r--synapse/storage/data_stores/main/user_erasure_store.py4
-rw-r--r--synapse/storage/data_stores/state/store.py4
-rw-r--r--synapse/storage/database.py159
-rw-r--r--synapse/storage/engines/__init__.py28
-rw-r--r--synapse/storage/engines/_base.py87
-rw-r--r--synapse/storage/engines/postgres.py22
-rw-r--r--synapse/storage/engines/sqlite.py13
-rw-r--r--synapse/storage/persist_events.py8
-rw-r--r--synapse/storage/prepare_database.py21
-rw-r--r--synapse/storage/types.py65
34 files changed, 883 insertions, 309 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index da3b99f93d..13de5f1f62 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -56,7 +56,7 @@ class SQLBaseStore(metaclass=ABCMeta):
             members_changed (iterable[str]): The user_ids of members that have
                 changed
         """
-        for host in set(get_domain_from_id(u) for u in members_changed):
+        for host in {get_domain_from_id(u) for u in members_changed}:
             self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
             self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
 
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index bd547f35cf..eb1a7e5002 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -189,7 +189,7 @@ class BackgroundUpdater(object):
                 keyvalues=None,
                 retcols=("update_name", "depends_on"),
             )
-            in_flight = set(update["update_name"] for update in updates)
+            in_flight = {update["update_name"] for update in updates}
             for update in updates:
                 if update["depends_on"] not in in_flight:
                     self._background_update_queue.append(update["update_name"])
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 2700cca822..acca079f23 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -20,6 +20,7 @@ import logging
 import time
 
 from synapse.api.constants import PresenceState
+from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
@@ -117,16 +118,6 @@ class DataStore(
         self._clock = hs.get_clock()
         self.database_engine = database.engine
 
-        all_users_native = are_all_users_on_domain(
-            db_conn.cursor(), database.engine, hs.hostname
-        )
-        if not all_users_native:
-            raise Exception(
-                "Found users in database not native to %s!\n"
-                "You cannot changed a synapse server_name after it's been configured"
-                % (hs.hostname,)
-            )
-
         self._stream_id_gen = StreamIdGenerator(
             db_conn,
             "events",
@@ -567,13 +558,26 @@ class DataStore(
         )
 
 
-def are_all_users_on_domain(txn, database_engine, domain):
+def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+    """Called before upgrading an existing database to check that it is broadly sane
+    compared with the configuration.
+    """
+    domain = config.server_name
+
     sql = database_engine.convert_param_style(
         "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
     )
     pat = "%:" + domain
-    txn.execute(sql, (pat,))
-    num_not_matching = txn.fetchall()[0][0]
+    cur.execute(sql, (pat,))
+    num_not_matching = cur.fetchall()[0][0]
     if num_not_matching == 0:
-        return True
-    return False
+        return
+
+    raise Exception(
+        "Found users in database not native to %s!\n"
+        "You cannot changed a synapse server_name after it's been configured"
+        % (domain,)
+    )
+
+
+__all__ = ["DataStore", "check_database_before_upgrade"]
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index b2f39649fd..efbc06c796 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -135,7 +135,7 @@ class ApplicationServiceTransactionWorkerStore(
             may be empty.
         """
         results = yield self.db.simple_select_list(
-            "application_services_state", dict(state=state), ["as_id"]
+            "application_services_state", {"state": state}, ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
         as_list = self.get_app_services()
@@ -158,7 +158,7 @@ class ApplicationServiceTransactionWorkerStore(
         """
         result = yield self.db.simple_select_one(
             "application_services_state",
-            dict(as_id=service.id),
+            {"as_id": service.id},
             ["state"],
             allow_none=True,
             desc="get_appservice_state",
@@ -177,7 +177,7 @@ class ApplicationServiceTransactionWorkerStore(
             A Deferred which resolves when the state was set successfully.
         """
         return self.db.simple_upsert(
-            "application_services_state", dict(as_id=service.id), dict(state=state)
+            "application_services_state", {"as_id": service.id}, {"state": state}
         )
 
     def create_appservice_txn(self, service, events):
@@ -253,13 +253,15 @@ class ApplicationServiceTransactionWorkerStore(
             self.db.simple_upsert_txn(
                 txn,
                 "application_services_state",
-                dict(as_id=service.id),
-                dict(last_txn=txn_id),
+                {"as_id": service.id},
+                {"last_txn": txn_id},
             )
 
             # Delete txn
             self.db.simple_delete_txn(
-                txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
+                txn,
+                "application_services_txns",
+                {"txn_id": txn_id, "as_id": service.id},
             )
 
         return self.db.runInteraction(
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 13f4c9c72e..e1ccb27142 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -530,7 +530,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
             for row in rows
         )
-        return list(
+        return [
             {
                 "access_token": access_token,
                 "ip": ip,
@@ -538,7 +538,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 "last_seen": last_seen,
             }
             for (access_token, ip), (user_agent, last_seen) in iteritems(results)
-        )
+        ]
 
     @wrap_as_background_process("prune_old_user_ips")
     async def _prune_old_user_ips(self):
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index b7617efb80..d55733a4cd 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -137,7 +137,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
         # get the cross-signing keys of the users in the list, so that we can
         # determine which of the device changes were cross-signing keys
-        users = set(r[0] for r in updates)
+        users = {r[0] for r in updates}
         master_key_by_user = {}
         self_signing_key_by_user = {}
         for user in users:
@@ -446,7 +446,7 @@ class DeviceWorkerStore(SQLBaseStore):
             a set of user_ids and results_map is a mapping of
             user_id -> device_id -> device_info
         """
-        user_ids = set(user_id for user_id, _ in query_list)
+        user_ids = {user_id for user_id, _ in query_list}
         user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
 
         # We go and check if any of the users need to have their device lists
@@ -454,10 +454,9 @@ class DeviceWorkerStore(SQLBaseStore):
         users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
             user_ids
         )
-        user_ids_in_cache = (
-            set(user_id for user_id, stream_id in user_map.items() if stream_id)
-            - users_needing_resync
-        )
+        user_ids_in_cache = {
+            user_id for user_id, stream_id in user_map.items() if stream_id
+        } - users_needing_resync
         user_ids_not_in_cache = user_ids - user_ids_in_cache
 
         results = {}
@@ -604,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
             rows = yield self.db.execute(
                 "get_users_whose_signatures_changed", None, sql, user_id, from_key
             )
-            return set(user for row in rows for user in json.loads(row[0]))
+            return {user for row in rows for user in json.loads(row[0])}
         else:
             return set()
 
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index e551606f9d..001a53f9b4 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -680,11 +680,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 'user_signing' for a user-signing key
             key (dict): the key data
         """
-        # the cross-signing keys need to occupy the same namespace as devices,
-        # since signatures are identified by device ID.  So add an entry to the
-        # device table to make sure that we don't have a collision with device
-        # IDs
-
         # the 'key' dict will look something like:
         # {
         #   "user_id": "@alice:example.com",
@@ -701,16 +696,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         # The "keys" property must only have one entry, which will be the public
         # key, so we just grab the first value in there
         pubkey = next(iter(key["keys"].values()))
-        self.db.simple_insert_txn(
-            txn,
-            "devices",
-            values={
-                "user_id": user_id,
-                "device_id": pubkey,
-                "display_name": key_type + " signing key",
-                "hidden": True,
-            },
-        )
+
+        # The cross-signing keys need to occupy the same namespace as devices,
+        # since signatures are identified by device ID.  So add an entry to the
+        # device table to make sure that we don't have a collision with device
+        # IDs.
+        # We only need to do this for local users, since remote servers should be
+        # responsible for checking this for their own users.
+        if self.hs.is_mine_id(user_id):
+            self.db.simple_insert_txn(
+                txn,
+                "devices",
+                values={
+                    "user_id": user_id,
+                    "device_id": pubkey,
+                    "display_name": key_type + " signing key",
+                    "hidden": True,
+                },
+            )
 
         # and finally, store the key itself
         with self._cross_signing_id_gen.get_next() as stream_id:
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 60c67457b4..62d4e9f599 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,8 +14,8 @@
 # limitations under the License.
 import itertools
 import logging
+from typing import Dict, List, Optional, Set, Tuple
 
-from six.moves import range
 from six.moves.queue import Empty, PriorityQueue
 
 from twisted.internet import defer
@@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
 from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
@@ -46,21 +47,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             event_ids, include_given=include_given
         ).addCallback(self.get_events_as_list)
 
-    def get_auth_chain_ids(self, event_ids, include_given=False):
+    def get_auth_chain_ids(
+        self,
+        event_ids: List[str],
+        include_given: bool = False,
+        ignore_events: Optional[Set[str]] = None,
+    ):
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
-            event_ids (list): state events
-            include_given (bool): include the given events in result
+            event_ids: state events
+            include_given: include the given events in result
+            ignore_events: Set of events to exclude from the returned auth
+                chain. This is useful if the caller will just discard the
+                given events anyway, and saves us from figuring out their auth
+                chains if not required.
 
         Returns:
             list of event_ids
         """
         return self.db.runInteraction(
-            "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
+            "get_auth_chain_ids",
+            self._get_auth_chain_ids_txn,
+            event_ids,
+            include_given,
+            ignore_events,
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
+        if ignore_events is None:
+            ignore_events = set()
+
         if include_given:
             results = set(event_ids)
         else:
@@ -71,15 +88,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         front = set(event_ids)
         while front:
             new_front = set()
-            front_list = list(front)
-            chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
-            for chunk in chunks:
+            for chunk in batch_iter(front, 100):
                 clause, args = make_in_list_sql_clause(
                     txn.database_engine, "event_id", chunk
                 )
-                txn.execute(base_sql + clause, list(args))
-                new_front.update([r[0] for r in txn])
+                txn.execute(base_sql + clause, args)
+                new_front.update(r[0] for r in txn)
 
+            new_front -= ignore_events
             new_front -= results
 
             front = new_front
@@ -87,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
+    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+        """Given sets of state events figure out the auth chain difference (as
+        per state res v2 algorithm).
+
+        This equivalent to fetching the full auth chain for each set of state
+        and returning the events that don't appear in each and every auth
+        chain.
+
+        Returns:
+            Deferred[Set[str]]
+        """
+
+        return self.db.runInteraction(
+            "get_auth_chain_difference",
+            self._get_auth_chain_difference_txn,
+            state_sets,
+        )
+
+    def _get_auth_chain_difference_txn(
+        self, txn, state_sets: List[Set[str]]
+    ) -> Set[str]:
+
+        # Algorithm Description
+        # ~~~~~~~~~~~~~~~~~~~~~
+        #
+        # The idea here is to basically walk the auth graph of each state set in
+        # tandem, keeping track of which auth events are reachable by each state
+        # set. If we reach an auth event we've already visited (via a different
+        # state set) then we mark that auth event and all ancestors as reachable
+        # by the state set. This requires that we keep track of the auth chains
+        # in memory.
+        #
+        # Doing it in a such a way means that we can stop early if all auth
+        # events we're currently walking are reachable by all state sets.
+        #
+        # *Note*: We can't stop walking an event's auth chain if it is reachable
+        # by all state sets. This is because other auth chains we're walking
+        # might be reachable only via the original auth chain. For example,
+        # given the following auth chain:
+        #
+        #       A -> C -> D -> E
+        #           /         /
+        #       B -´---------´
+        #
+        # and state sets {A} and {B} then walking the auth chains of A and B
+        # would immediately show that C is reachable by both. However, if we
+        # stopped at C then we'd only reach E via the auth chain of B and so E
+        # would errornously get included in the returned difference.
+        #
+        # The other thing that we do is limit the number of auth chains we walk
+        # at once, due to practical limits (i.e. we can only query the database
+        # with a limited set of parameters). We pick the auth chains we walk
+        # each iteration based on their depth, in the hope that events with a
+        # lower depth are likely reachable by those with higher depths.
+        #
+        # We could use any ordering that we believe would give a rough
+        # topological ordering, e.g. origin server timestamp. If the ordering
+        # chosen is not topological then the algorithm still produces the right
+        # result, but perhaps a bit more inefficiently. This is why it is safe
+        # to use "depth" here.
+
+        initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+        # Dict from events in auth chains to which sets *cannot* reach them.
+        # I.e. if the set is empty then all sets can reach the event.
+        event_to_missing_sets = {
+            event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
+            for event_id in initial_events
+        }
+
+        # We need to get the depth of the initial events for sorting purposes.
+        sql = """
+            SELECT depth, event_id FROM events
+            WHERE %s
+            ORDER BY depth ASC
+        """
+        clause, args = make_in_list_sql_clause(
+            txn.database_engine, "event_id", initial_events
+        )
+        txn.execute(sql % (clause,), args)
+
+        # The sorted list of events whose auth chains we should walk.
+        search = txn.fetchall()  # type: List[Tuple[int, str]]
+
+        # Map from event to its auth events
+        event_to_auth_events = {}  # type: Dict[str, Set[str]]
+
+        base_sql = """
+            SELECT a.event_id, auth_id, depth
+            FROM event_auth AS a
+            INNER JOIN events AS e ON (e.event_id = a.auth_id)
+            WHERE
+        """
+
+        while search:
+            # Check whether all our current walks are reachable by all state
+            # sets. If so we can bail.
+            if all(not event_to_missing_sets[eid] for _, eid in search):
+                break
+
+            # Fetch the auth events and their depths of the N last events we're
+            # currently walking
+            search, chunk = search[:-100], search[-100:]
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
+            )
+            txn.execute(base_sql + clause, args)
+
+            for event_id, auth_event_id, auth_event_depth in txn:
+                event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
+
+                sets = event_to_missing_sets.get(auth_event_id)
+                if sets is None:
+                    # First time we're seeing this event, so we add it to the
+                    # queue of things to fetch.
+                    search.append((auth_event_depth, auth_event_id))
+
+                    # Assume that this event is unreachable from any of the
+                    # state sets until proven otherwise
+                    sets = event_to_missing_sets[auth_event_id] = set(
+                        range(len(state_sets))
+                    )
+                else:
+                    # We've previously seen this event, so look up its auth
+                    # events and recursively mark all ancestors as reachable
+                    # by the current event's state set.
+                    a_ids = event_to_auth_events.get(auth_event_id)
+                    while a_ids:
+                        new_aids = set()
+                        for a_id in a_ids:
+                            event_to_missing_sets[a_id].intersection_update(
+                                event_to_missing_sets[event_id]
+                            )
+
+                            b = event_to_auth_events.get(a_id)
+                            if b:
+                                new_aids.update(b)
+
+                        a_ids = new_aids
+
+                # Mark that the auth event is reachable by the approriate sets.
+                sets.intersection_update(event_to_missing_sets[event_id])
+
+            search.sort()
+
+        # Return all events where not all sets can reach them.
+        return {eid for eid, n in event_to_missing_sets.items() if n}
+
     def get_oldest_events_in_room(self, room_id):
         return self.db.runInteraction(
             "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
@@ -410,7 +574,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                     query, (room_id, event_id, False, limit - len(event_results))
                 )
 
-                new_results = set(t[0] for t in txn) - seen_events
+                new_results = {t[0] for t in txn} - seen_events
 
                 new_front |= new_results
                 seen_events |= new_results
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 9988a6d3fc..8eed590929 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -608,6 +608,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         return range_end
 
+    @defer.inlineCallbacks
+    def get_time_of_last_push_action_before(self, stream_ordering):
+        def f(txn):
+            sql = (
+                "SELECT e.received_ts"
+                " FROM event_push_actions AS ep"
+                " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
+                " WHERE ep.stream_ordering > ?"
+                " ORDER BY ep.stream_ordering ASC"
+                " LIMIT 1"
+            )
+            txn.execute(sql, (stream_ordering,))
+            return txn.fetchone()
+
+        result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
+        return result[0] if result else None
+
 
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -736,23 +753,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         return push_actions
 
     @defer.inlineCallbacks
-    def get_time_of_last_push_action_before(self, stream_ordering):
-        def f(txn):
-            sql = (
-                "SELECT e.received_ts"
-                " FROM event_push_actions AS ep"
-                " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
-                " WHERE ep.stream_ordering > ?"
-                " ORDER BY ep.stream_ordering ASC"
-                " LIMIT 1"
-            )
-            txn.execute(sql, (stream_ordering,))
-            return txn.fetchone()
-
-        result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
-        return result[0] if result else None
-
-    @defer.inlineCallbacks
     def get_latest_push_action_stream_ordering(self):
         def f(txn):
             txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index c9d0d68c3a..d593ef47b8 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -145,7 +145,7 @@ class EventsStore(
             return txn.fetchall()
 
         res = yield self.db.runInteraction("read_forward_extremities", fetch)
-        self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
+        self._current_forward_extremities_amount = c_counter([x[0] for x in res])
 
     @_retry_on_integrity_error
     @defer.inlineCallbacks
@@ -598,11 +598,11 @@ class EventsStore(
             # We find out which membership events we may have deleted
             # and which we have added, then we invlidate the caches for all
             # those users.
-            members_changed = set(
+            members_changed = {
                 state_key
                 for ev_type, state_key in itertools.chain(to_delete, to_insert)
                 if ev_type == EventTypes.Member
-            )
+            }
 
             for member in members_changed:
                 txn.call_after(
@@ -1168,7 +1168,11 @@ class EventsStore(
                 and original_event.internal_metadata.is_redacted()
             ):
                 # Redaction was allowed
-                pruned_json = encode_json(prune_event_dict(original_event.get_dict()))
+                pruned_json = encode_json(
+                    prune_event_dict(
+                        original_event.room_version, original_event.get_dict()
+                    )
+                )
             else:
                 # Redaction wasn't allowed
                 pruned_json = None
@@ -1615,7 +1619,7 @@ class EventsStore(
         """
         )
 
-        referenced_state_groups = set(sg for sg, in txn)
+        referenced_state_groups = {sg for sg, in txn}
         logger.info(
             "[purge] found %i referenced state groups", len(referenced_state_groups)
         )
@@ -1929,7 +1933,9 @@ class EventsStore(
                 return
 
             # Prune the event's dict then convert it to JSON.
-            pruned_json = encode_json(prune_event_dict(event.get_dict()))
+            pruned_json = encode_json(
+                prune_event_dict(event.room_version, event.get_dict())
+            )
 
             # Update the event_json table to replace the event's JSON with the pruned
             # JSON.
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 5177b71016..f54c8b1ee0 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -402,7 +402,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                     keyvalues={},
                     retcols=("room_id",),
                 )
-                room_ids = set(row["room_id"] for row in rows)
+                room_ids = {row["room_id"] for row in rows}
                 for room_id in room_ids:
                     txn.call_after(
                         self.get_latest_event_ids_in_room.invalidate, (room_id,)
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 7251e819f5..ca237c6f12 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -28,9 +28,12 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
 from synapse.api.errors import NotFoundError
-from synapse.api.room_versions import EventFormatVersions
-from synapse.events import FrozenEvent, event_type_from_format_version  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.api.room_versions import (
+    KNOWN_ROOM_VERSIONS,
+    EventFormatVersions,
+    RoomVersions,
+)
+from synapse.events import make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -494,9 +497,9 @@ class EventsWorkerStore(SQLBaseStore):
         """
         with Measure(self._clock, "_fetch_event_list"):
             try:
-                events_to_fetch = set(
+                events_to_fetch = {
                     event_id for events, _ in event_list for event_id in events
-                )
+                }
 
                 row_dict = self.db.new_transaction(
                     conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
@@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore):
                 # of a event format version, so it must be a V1 event.
                 format_version = EventFormatVersions.V1
 
-            original_ev = event_type_from_format_version(format_version)(
+            room_version_id = row["room_version_id"]
+
+            if not room_version_id:
+                # this should only happen for out-of-band membership events
+                if not internal_metadata.get("out_of_band_membership"):
+                    logger.warning(
+                        "Room %s for event %s is unknown", d["room_id"], event_id
+                    )
+                    continue
+
+                # take a wild stab at the room version based on the event format
+                if format_version == EventFormatVersions.V1:
+                    room_version = RoomVersions.V1
+                elif format_version == EventFormatVersions.V2:
+                    room_version = RoomVersions.V3
+                else:
+                    room_version = RoomVersions.V5
+            else:
+                room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+                if not room_version:
+                    logger.error(
+                        "Event %s in room %s has unknown room version %s",
+                        event_id,
+                        d["room_id"],
+                        room_version_id,
+                    )
+                    continue
+
+                if room_version.event_format != format_version:
+                    logger.error(
+                        "Event %s in room %s with version %s has wrong format: "
+                        "expected %s, was %s",
+                        event_id,
+                        d["room_id"],
+                        room_version_id,
+                        room_version.event_format,
+                        format_version,
+                    )
+                    continue
+
+            original_ev = make_event_from_dict(
                 event_dict=d,
+                room_version=room_version,
                 internal_metadata_dict=internal_metadata,
                 rejected_reason=rejected_reason,
             )
@@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore):
            of EventFormatVersions. 'None' means the event predates
            EventFormatVersions (so the event is format V1).
 
+         * room_version_id (str|None): The version of the room which contains the event.
+           Hopefully one of RoomVersions.
+
+           Due to historical reasons, there may be a few events in the database which
+           do not have an associated room; in this case None will be returned here.
+
          * rejected_reason (str|None): if the event was rejected, the reason
            why.
 
@@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore):
         """
         event_dict = {}
         for evs in batch_iter(event_ids, 200):
-            sql = (
-                "SELECT "
-                " e.event_id, "
-                " e.internal_metadata,"
-                " e.json,"
-                " e.format_version, "
-                " rej.reason "
-                " FROM event_json as e"
-                " LEFT JOIN rejections as rej USING (event_id)"
-                " WHERE "
-            )
+            sql = """\
+                SELECT
+                  e.event_id,
+                  e.internal_metadata,
+                  e.json,
+                  e.format_version,
+                  r.room_version,
+                  rej.reason
+                FROM event_json as e
+                  LEFT JOIN rooms r USING (room_id)
+                  LEFT JOIN rejections as rej USING (event_id)
+                WHERE """
 
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "e.event_id", evs
@@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore):
                     "internal_metadata": row[1],
                     "json": row[2],
                     "format_version": row[3],
-                    "rejected_reason": row[4],
+                    "room_version_id": row[4],
+                    "rejected_reason": row[5],
                     "redactions": [],
                 }
 
@@ -804,7 +856,7 @@ class EventsWorkerStore(SQLBaseStore):
             desc="have_events_in_timeline",
         )
 
-        return set(r["event_id"] for r in rows)
+        return {r["event_id"] for r in rows}
 
     @defer.inlineCallbacks
     def have_seen_events(self, event_ids):
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 1507a14e09..925bc5691b 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
         def _count_users(txn):
             sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
-
             txn.execute(sql)
             (count,) = txn.fetchone()
             return count
 
         return self.db.runInteraction("count_users", _count_users)
 
+    @cached(num_args=0)
+    def get_monthly_active_count_by_service(self):
+        """Generates current count of monthly active users broken down by service.
+        A service is typically an appservice but also includes native matrix users.
+        Since the `monthly_active_users` table is populated from the `user_ips` table
+        `config.track_appservice_user_ips` must be set to `true` for this
+        method to return anything other than native matrix users.
+
+        Returns:
+            Deferred[dict]: dict that includes a mapping between app_service_id
+                and the number of occurrences.
+
+        """
+
+        def _count_users_by_service(txn):
+            sql = """
+                SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+                FROM monthly_active_users
+                LEFT JOIN users ON monthly_active_users.user_id=users.name
+                GROUP BY appservice_id;
+            """
+
+            txn.execute(sql)
+            result = txn.fetchall()
+            return dict(result)
+
+        return self.db.runInteraction("count_users_by_service", _count_users_by_service)
+
     @defer.inlineCallbacks
     def get_registered_reserved_users(self):
         """Of the reserved threepids defined in config, which are associated
@@ -292,6 +319,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
 
         self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
         self._invalidate_cache_and_stream(
+            txn, self.get_monthly_active_count_by_service, ()
+        )
+        self._invalidate_cache_and_stream(
             txn, self.user_last_seen_monthly_active, (user_id,)
         )
 
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index e2673ae073..62ac88d9f2 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -276,21 +276,21 @@ class PushRulesWorkerStore(
         # We ignore app service users for now. This is so that we don't fill
         # up the `get_if_users_have_pushers` cache with AS entries that we
         # know don't have pushers, nor even read receipts.
-        local_users_in_room = set(
+        local_users_in_room = {
             u
             for u in users_in_room
             if self.hs.is_mine_id(u)
             and not self.get_if_app_services_interested_in_user(u)
-        )
+        }
 
         # users in the room who have pushers need to get push rules run because
         # that's how their pushers work
         if_users_with_pushers = yield self.get_if_users_have_pushers(
             local_users_in_room, on_invalidate=cache_context.invalidate
         )
-        user_ids = set(
+        user_ids = {
             uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-        )
+        }
 
         users_with_receipts = yield self.get_users_with_read_receipts_in_room(
             room_id, on_invalidate=cache_context.invalidate
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 6b03233262..547b9d69cb 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -197,6 +197,84 @@ class PusherWorkerStore(SQLBaseStore):
 
         return result
 
+    @defer.inlineCallbacks
+    def update_pusher_last_stream_ordering(
+        self, app_id, pushkey, user_id, last_stream_ordering
+    ):
+        yield self.db.simple_update_one(
+            "pushers",
+            {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+            {"last_stream_ordering": last_stream_ordering},
+            desc="update_pusher_last_stream_ordering",
+        )
+
+    @defer.inlineCallbacks
+    def update_pusher_last_stream_ordering_and_success(
+        self, app_id, pushkey, user_id, last_stream_ordering, last_success
+    ):
+        """Update the last stream ordering position we've processed up to for
+        the given pusher.
+
+        Args:
+            app_id (str)
+            pushkey (str)
+            last_stream_ordering (int)
+            last_success (int)
+
+        Returns:
+            Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+        """
+        updated = yield self.db.simple_update(
+            table="pushers",
+            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+            updatevalues={
+                "last_stream_ordering": last_stream_ordering,
+                "last_success": last_success,
+            },
+            desc="update_pusher_last_stream_ordering_and_success",
+        )
+
+        return bool(updated)
+
+    @defer.inlineCallbacks
+    def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
+        yield self.db.simple_update(
+            table="pushers",
+            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+            updatevalues={"failing_since": failing_since},
+            desc="update_pusher_failing_since",
+        )
+
+    @defer.inlineCallbacks
+    def get_throttle_params_by_room(self, pusher_id):
+        res = yield self.db.simple_select_list(
+            "pusher_throttle",
+            {"pusher": pusher_id},
+            ["room_id", "last_sent_ts", "throttle_ms"],
+            desc="get_throttle_params_by_room",
+        )
+
+        params_by_room = {}
+        for row in res:
+            params_by_room[row["room_id"]] = {
+                "last_sent_ts": row["last_sent_ts"],
+                "throttle_ms": row["throttle_ms"],
+            }
+
+        return params_by_room
+
+    @defer.inlineCallbacks
+    def set_throttle_params(self, pusher_id, room_id, params):
+        # no need to lock because `pusher_throttle` has a primary key on
+        # (pusher, room_id) so simple_upsert will retry
+        yield self.db.simple_upsert(
+            "pusher_throttle",
+            {"pusher": pusher_id, "room_id": room_id},
+            params,
+            desc="set_throttle_params",
+            lock=False,
+        )
+
 
 class PusherStore(PusherWorkerStore):
     def get_pushers_stream_token(self):
@@ -282,81 +360,3 @@ class PusherStore(PusherWorkerStore):
 
         with self._pushers_id_gen.get_next() as stream_id:
             yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
-
-    @defer.inlineCallbacks
-    def update_pusher_last_stream_ordering(
-        self, app_id, pushkey, user_id, last_stream_ordering
-    ):
-        yield self.db.simple_update_one(
-            "pushers",
-            {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
-            {"last_stream_ordering": last_stream_ordering},
-            desc="update_pusher_last_stream_ordering",
-        )
-
-    @defer.inlineCallbacks
-    def update_pusher_last_stream_ordering_and_success(
-        self, app_id, pushkey, user_id, last_stream_ordering, last_success
-    ):
-        """Update the last stream ordering position we've processed up to for
-        the given pusher.
-
-        Args:
-            app_id (str)
-            pushkey (str)
-            last_stream_ordering (int)
-            last_success (int)
-
-        Returns:
-            Deferred[bool]: True if the pusher still exists; False if it has been deleted.
-        """
-        updated = yield self.db.simple_update(
-            table="pushers",
-            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
-            updatevalues={
-                "last_stream_ordering": last_stream_ordering,
-                "last_success": last_success,
-            },
-            desc="update_pusher_last_stream_ordering_and_success",
-        )
-
-        return bool(updated)
-
-    @defer.inlineCallbacks
-    def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
-        yield self.db.simple_update(
-            table="pushers",
-            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
-            updatevalues={"failing_since": failing_since},
-            desc="update_pusher_failing_since",
-        )
-
-    @defer.inlineCallbacks
-    def get_throttle_params_by_room(self, pusher_id):
-        res = yield self.db.simple_select_list(
-            "pusher_throttle",
-            {"pusher": pusher_id},
-            ["room_id", "last_sent_ts", "throttle_ms"],
-            desc="get_throttle_params_by_room",
-        )
-
-        params_by_room = {}
-        for row in res:
-            params_by_room[row["room_id"]] = {
-                "last_sent_ts": row["last_sent_ts"],
-                "throttle_ms": row["throttle_ms"],
-            }
-
-        return params_by_room
-
-    @defer.inlineCallbacks
-    def set_throttle_params(self, pusher_id, room_id, params):
-        # no need to lock because `pusher_throttle` has a primary key on
-        # (pusher, room_id) so simple_upsert will retry
-        yield self.db.simple_upsert(
-            "pusher_throttle",
-            {"pusher": pusher_id, "room_id": room_id},
-            params,
-            desc="set_throttle_params",
-            lock=False,
-        )
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 96e54d145e..0d932a0672 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cachedInlineCallbacks()
     def get_users_with_read_receipts_in_room(self, room_id):
         receipts = yield self.get_receipts_for_room(room_id, "m.read")
-        return set(r["user_id"] for r in receipts)
+        return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
@@ -283,7 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 args.append(limit)
             txn.execute(sql, args)
 
-            return list(r[0:5] + (json.loads(r[5]),) for r in txn)
+            return [r[0:5] + (json.loads(r[5]),) for r in txn]
 
         return self.db.runInteraction(
             "get_all_updated_receipts", get_all_updated_receipts_txn
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 9a17e336ba..e6c10c6316 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -954,6 +954,23 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
         self.config = hs.config
 
+    async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion):
+        """Ensure that the room is stored in the table
+
+        Called when we join a room over federation, and overwrites any room version
+        currently in the table.
+        """
+        await self.db.simple_upsert(
+            desc="upsert_room_on_join",
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            values={"room_version": room_version.identifier},
+            insertion_values={"is_public": False, "creator": ""},
+            # rooms has a unique constraint on room_id, so no need to lock when doing an
+            # emulated upsert.
+            lock=False,
+        )
+
     @defer.inlineCallbacks
     def store_room(
         self,
@@ -1003,6 +1020,26 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
 
+    async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+        """
+        When we receive an invite over federation, store the version of the room if we
+        don't already know the room version.
+        """
+        await self.db.simple_upsert(
+            desc="maybe_store_room_on_invite",
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            values={},
+            insertion_values={
+                "room_version": room_version.identifier,
+                "is_public": False,
+                "creator": "",
+            },
+            # rooms has a unique constraint on room_id, so no need to lock when doing an
+            # emulated upsert.
+            lock=False,
+        )
+
     @defer.inlineCallbacks
     def set_room_is_public(self, room_id, is_public):
         def set_room_is_public_txn(txn, next_id):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index d5ced05701..d5bd0cb5cf 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -465,7 +465,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
             txn.execute(sql % (clause,), args)
 
-            return set(row[0] for row in txn)
+            return {row[0] for row in txn}
 
         return await self.db.runInteraction(
             "get_users_server_still_shares_room_with",
@@ -826,7 +826,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 GROUP BY room_id, user_id;
             """
             txn.execute(sql, (user_id,))
-            return set(row[0] for row in txn if row[1] == 0)
+            return {row[0] for row in txn if row[1] == 0}
 
         return self.db.runInteraction(
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
new file mode 100644
index 0000000000..92aaadde0d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
@@ -0,0 +1,39 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- When we first added the room_version column to the rooms table, it was populated from
+-- the current_state_events table. However, there was an issue causing a background
+-- update to clean up the current_state_events table for rooms where the server is no
+-- longer participating, before that column could be populated. Therefore, some rooms had
+-- a NULL room_version.
+
+-- The rooms_version_column_2.sql.* delta files were introduced to make the populating
+-- synchronous instead of running it in a background update, which fixed this issue.
+-- However, all of the instances of Synapse installed or updated in the meantime got
+-- their rooms table corrupted with NULL room_versions.
+
+-- This query fishes out the room versions from the create event using the state_events
+-- table instead of the current_state_events one, as the former still have all of the
+-- create events.
+
+UPDATE rooms SET room_version=(
+    SELECT COALESCE(json::json->'content'->>'room_version','1')
+    FROM state_events se INNER JOIN event_json ej USING (event_id)
+    WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+    LIMIT 1
+) WHERE rooms.room_version IS NULL;
+
+-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using
+-- sqlite syntax for the json extraction.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
new file mode 100644
index 0000000000..e19dab97cb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
@@ -0,0 +1,23 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- see rooms_version_column_3.sql.postgres for details of what's going on here.
+
+UPDATE rooms SET room_version=(
+    SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1')
+    FROM state_events se INNER JOIN event_json ej USING (event_id)
+    WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+    LIMIT 1
+) WHERE rooms.room_version IS NULL;
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
index bbd3f18604..c00f287190 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.md
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -1,13 +1,21 @@
-# Building full schema dumps
+# Synapse Database Schemas
 
-These schemas need to be made from a database that has had all background updates run.
+These schemas are used as a basis to create brand new Synapse databases, on both
+SQLite3 and Postgres.
 
-To do so, use `scripts-dev/make_full_schema.sh`. This will produce
-`full.sql.postgres ` and `full.sql.sqlite` files.
+## Building full schema dumps
+
+If you want to recreate these schemas, they need to be made from a database that
+has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
+`full.sql.postgres ` and `full.sql.sqlite` files. 
 
 Ensure postgres is installed and your user has the ability to run bash commands
-such as `createdb`.
+such as `createdb`, then call
+
+    ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
 
-```
-./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
-```
+There are currently two folders with full-schema snapshots. `16` is a snapshot
+from 2015, for historical reference. The other contains the most recent full
+schema snapshot.
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 3d34103e67..3a3b9a8e72 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -321,7 +321,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="get_referenced_state_groups",
         )
 
-        return set(row["state_group"] for row in rows)
+        return {row["state_group"] for row in rows}
 
 
 class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
@@ -367,7 +367,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             """
 
             txn.execute(sql, (last_room_id, batch_size))
-            room_ids = list(row[0] for row in txn)
+            room_ids = [row[0] for row in txn]
             if not room_ids:
                 return True, set()
 
@@ -384,7 +384,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
 
             txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
 
-            joined_room_ids = set(row[0] for row in txn)
+            joined_room_ids = {row[0] for row in txn}
 
             left_rooms = set(room_ids) - joined_room_ids
 
@@ -404,7 +404,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
                 retcols=("state_key",),
             )
 
-            potentially_left_users = set(row["state_key"] for row in rows)
+            potentially_left_users = {row["state_key"] for row in rows}
 
             # Now lets actually delete the rooms from the DB.
             self.db.simple_delete_many_txn(
diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 12c982cb26..725e12507f 100644
--- a/synapse/storage/data_stores/main/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -15,6 +15,8 @@
 
 import logging
 
+from twisted.internet import defer
+
 from synapse.storage._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
@@ -56,7 +58,7 @@ class StateDeltasStore(SQLBaseStore):
             # if the CSDs haven't changed between prev_stream_id and now, we
             # know for certain that they haven't changed between prev_stream_id and
             # max_stream_id.
-            return max_stream_id, []
+            return defer.succeed((max_stream_id, []))
 
         def get_current_state_deltas_txn(txn):
             # First we calculate the max stream id that will give us less than
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 056b25b13a..ada5cce6c2 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -346,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             from_key (str): The room_key portion of a StreamToken
         """
         from_key = RoomStreamToken.parse_stream_token(from_key).stream
-        return set(
+        return {
             room_id
             for room_id in room_ids
             if self._events_stream_cache.has_entity_changed(room_id, from_key)
-        )
+        }
 
     @defer.inlineCallbacks
     def get_room_events_stream_for_room(
@@ -679,11 +679,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         events_before = yield self.get_events_as_list(
-            [e for e in results["before"]["event_ids"]], get_prev_content=True
+            list(results["before"]["event_ids"]), get_prev_content=True
         )
 
         events_after = yield self.get_events_as_list(
-            [e for e in results["after"]["event_ids"]], get_prev_content=True
+            list(results["after"]["event_ids"]), get_prev_content=True
         )
 
         return {
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index af8025bc17..ec6b8a4ffd 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -63,9 +63,9 @@ class UserErasureWorkerStore(SQLBaseStore):
             retcols=("user_id",),
             desc="are_users_erased",
         )
-        erased_users = set(row["user_id"] for row in rows)
+        erased_users = {row["user_id"] for row in rows}
 
-        res = dict((u, u in erased_users) for u in user_ids)
+        res = {u: u in erased_users for u in user_ids}
         return res
 
 
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index c4ee9b7ccb..57a5267663 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -520,11 +520,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             retcols=("state_group",),
         )
 
-        remaining_state_groups = set(
+        remaining_state_groups = {
             row["state_group"]
             for row in rows
             if row["state_group"] not in state_groups_to_delete
-        )
+        }
 
         logger.info(
             "[purge] de-delta-ing %i remaining state groups",
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3eeb2f7c04..e61595336c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -15,9 +15,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import sys
 import time
-from typing import Iterable, Tuple
+from time import monotonic as monotonic_time
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
 
 from six import iteritems, iterkeys, itervalues
 from six.moves import intern, range
@@ -29,27 +29,21 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
-from synapse.logging.context import LoggingContext, make_deferred_yieldable
+from synapse.logging.context import (
+    LoggingContext,
+    LoggingContextOrSentinel,
+    make_deferred_yieldable,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Connection, Cursor
 from synapse.util.stringutils import exception_to_unicode
 
-# import a function which will return a monotonic time, in seconds
-try:
-    # on python 3, use time.monotonic, since time.clock can go backwards
-    from time import monotonic as monotonic_time
-except ImportError:
-    # ... but python 2 doesn't have it
-    from time import clock as monotonic_time
-
 logger = logging.getLogger(__name__)
 
-try:
-    MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
-    # python 3 does not have a maximum int value
-    MAX_TXN_ID = 2 ** 63 - 1
+# python 3 does not have a maximum int value
+MAX_TXN_ID = 2 ** 63 - 1
 
 sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -77,7 +71,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 
 
 def make_pool(
-    reactor, db_config: DatabaseConnectionConfig, engine
+    reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
 ) -> adbapi.ConnectionPool:
     """Get the connection pool for the database.
     """
@@ -90,7 +84,9 @@ def make_pool(
     )
 
 
-def make_conn(db_config: DatabaseConnectionConfig, engine):
+def make_conn(
+    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> Connection:
     """Make a new connection to the database and return it.
 
     Returns:
@@ -107,20 +103,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
     return db_conn
 
 
-class LoggingTransaction(object):
+# The type of entry which goes on our after_callbacks and exception_callbacks lists.
+#
+# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
+# that mypy sees the type but the runtime python doesn't.
+_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+
+
+class LoggingTransaction:
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method.
 
     Args:
         txn: The database transcation object to wrap.
-        name (str): The name of this transactions for logging.
-        database_engine (Sqlite3Engine|PostgresEngine)
-        after_callbacks(list|None): A list that callbacks will be appended to
+        name: The name of this transactions for logging.
+        database_engine
+        after_callbacks: A list that callbacks will be appended to
             that have been added by `call_after` which should be run on
             successful completion of the transaction. None indicates that no
             callbacks should be allowed to be scheduled to run.
-        exception_callbacks(list|None): A list that callbacks will be appended
+        exception_callbacks: A list that callbacks will be appended
             to that have been added by `call_on_exception` which should be run
             if transaction ends with an error. None indicates that no callbacks
             should be allowed to be scheduled to run.
@@ -135,46 +138,67 @@ class LoggingTransaction(object):
     ]
 
     def __init__(
-        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
+        self,
+        txn: Cursor,
+        name: str,
+        database_engine: BaseDatabaseEngine,
+        after_callbacks: Optional[List[_CallbackListEntry]] = None,
+        exception_callbacks: Optional[List[_CallbackListEntry]] = None,
     ):
-        object.__setattr__(self, "txn", txn)
-        object.__setattr__(self, "name", name)
-        object.__setattr__(self, "database_engine", database_engine)
-        object.__setattr__(self, "after_callbacks", after_callbacks)
-        object.__setattr__(self, "exception_callbacks", exception_callbacks)
+        self.txn = txn
+        self.name = name
+        self.database_engine = database_engine
+        self.after_callbacks = after_callbacks
+        self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback, *args, **kwargs):
+    def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
         """
+        # if self.after_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.after_callbacks is not None
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_on_exception(self, callback, *args, **kwargs):
+    def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+        # if self.exception_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.exception_callbacks is not None
         self.exception_callbacks.append((callback, args, kwargs))
 
-    def __getattr__(self, name):
-        return getattr(self.txn, name)
+    def fetchall(self) -> List[Tuple]:
+        return self.txn.fetchall()
 
-    def __setattr__(self, name, value):
-        setattr(self.txn, name, value)
+    def fetchone(self) -> Tuple:
+        return self.txn.fetchone()
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Tuple]:
         return self.txn.__iter__()
 
+    @property
+    def rowcount(self) -> int:
+        return self.txn.rowcount
+
+    @property
+    def description(self) -> Any:
+        return self.txn.description
+
     def execute_batch(self, sql, args):
         if isinstance(self.database_engine, PostgresEngine):
-            from psycopg2.extras import execute_batch
+            from psycopg2.extras import execute_batch  # type: ignore
 
             self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
         else:
             for val in args:
                 self.execute(sql, val)
 
-    def execute(self, sql, *args):
+    def execute(self, sql: str, *args: Any):
         self._do_execute(self.txn.execute, sql, *args)
 
-    def executemany(self, sql, *args):
+    def executemany(self, sql: str, *args: Any):
         self._do_execute(self.txn.executemany, sql, *args)
 
     def _make_sql_one_line(self, sql):
@@ -207,6 +231,9 @@ class LoggingTransaction(object):
             sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
             sql_query_timer.labels(sql.split()[0]).observe(secs)
 
+    def close(self):
+        self.txn.close()
+
 
 class PerformanceCounters(object):
     def __init__(self):
@@ -251,7 +278,9 @@ class Database(object):
 
     _TXN_ID = 0
 
-    def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
+    def __init__(
+        self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    ):
         self.hs = hs
         self._clock = hs.get_clock()
         self._database_config = database_config
@@ -259,9 +288,9 @@ class Database(object):
 
         self.updates = BackgroundUpdater(hs, self)
 
-        self._previous_txn_total_time = 0
-        self._current_txn_total_time = 0
-        self._previous_loop_ts = 0
+        self._previous_txn_total_time = 0.0
+        self._current_txn_total_time = 0.0
+        self._previous_loop_ts = 0.0
 
         # TODO(paul): These can eventually be removed once the metrics code
         #   is running in mainline, and we have some nice monitoring frontends
@@ -463,23 +492,23 @@ class Database(object):
             sql_txn_timer.labels(desc).observe(duration)
 
     @defer.inlineCallbacks
-    def runInteraction(self, desc, func, *args, **kwargs):
+    def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
         """Starts a transaction on the database and runs a given function
 
         Arguments:
-            desc (str): description of the transaction, for logging and metrics
-            func (func): callback function, which will be called with a
+            desc: description of the transaction, for logging and metrics
+            func: callback function, which will be called with a
                 database transaction (twisted.enterprise.adbapi.Transaction) as
                 its first argument, followed by `args` and `kwargs`.
 
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
+            args: positional args to pass to `func`
+            kwargs: named args to pass to `func`
 
         Returns:
             Deferred: The result of func
         """
-        after_callbacks = []
-        exception_callbacks = []
+        after_callbacks = []  # type: List[_CallbackListEntry]
+        exception_callbacks = []  # type: List[_CallbackListEntry]
 
         if LoggingContext.current_context() == LoggingContext.sentinel:
             logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -505,20 +534,22 @@ class Database(object):
         return result
 
     @defer.inlineCallbacks
-    def runWithConnection(self, func, *args, **kwargs):
+    def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
         Arguments:
-            func (func): callback function, which will be called with a
+            func: callback function, which will be called with a
                 database connection (twisted.enterprise.adbapi.Connection) as
                 its first argument, followed by `args` and `kwargs`.
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
+            args: positional args to pass to `func`
+            kwargs: named args to pass to `func`
 
         Returns:
             Deferred: The result of func
         """
-        parent_context = LoggingContext.current_context()
+        parent_context = (
+            LoggingContext.current_context()
+        )  # type: Optional[LoggingContextOrSentinel]
         if parent_context == LoggingContext.sentinel:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
@@ -554,8 +585,8 @@ class Database(object):
         Returns:
             A list of dicts where the key is the column header.
         """
-        col_headers = list(intern(str(column[0])) for column in cursor.description)
-        results = list(dict(zip(col_headers, row)) for row in cursor)
+        col_headers = [intern(str(column[0])) for column in cursor.description]
+        results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
     def execute(self, desc, decoder, query, *args):
@@ -800,7 +831,7 @@ class Database(object):
                 return False
 
         # We didn't find any existing rows, so insert a new one
-        allvalues = {}
+        allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
         allvalues.update(values)
         allvalues.update(insertion_values)
@@ -829,7 +860,7 @@ class Database(object):
         Returns:
             None
         """
-        allvalues = {}
+        allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
         allvalues.update(insertion_values)
 
@@ -916,7 +947,7 @@ class Database(object):
         Returns:
             None
         """
-        allnames = []
+        allnames = []  # type: List[str]
         allnames.extend(key_names)
         allnames.extend(value_names)
 
@@ -1100,7 +1131,7 @@ class Database(object):
             keyvalues : dict of column names and values to select the rows with
             retcols : list of strings giving the names of the columns to return
         """
-        results = []
+        results = []  # type: List[Dict[str, Any]]
 
         if not iterable:
             return results
@@ -1439,7 +1470,7 @@ class Database(object):
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
 
         where_clause = "WHERE " if filters or keyvalues else ""
-        arg_list = []
+        arg_list = []  # type: List[Any]
         if filters:
             where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
             arg_list += list(filters.values())
@@ -1504,7 +1535,7 @@ class Database(object):
 
 def make_in_list_sql_clause(
     database_engine, column: str, iterable: Iterable
-) -> Tuple[str, Iterable]:
+) -> Tuple[str, list]:
     """Returns an SQL clause that checks the given column is in the iterable.
 
     On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import importlib
 import platform
 
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 from .postgres import PostgresEngine
 from .sqlite import Sqlite3Engine
 
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
 
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
     name = database_config["name"]
-    engine_class = SUPPORTED_MODULE.get(name, None)
 
-    if engine_class:
+    if name == "sqlite3":
+        import sqlite3
+
+        return Sqlite3Engine(sqlite3, database_config)
+
+    if name == "psycopg2":
         # pypy requires psycopg2cffi rather than psycopg2
-        if name == "psycopg2" and platform.python_implementation() == "PyPy":
-            name = "psycopg2cffi"
-        module = importlib.import_module(name)
-        return engine_class(module, database_config)
+        if platform.python_implementation() == "PyPy":
+            import psycopg2cffi as psycopg2  # type: ignore
+        else:
+            import psycopg2  # type: ignore
+
+        return PostgresEngine(psycopg2, database_config)
 
     raise RuntimeError("Unsupported database engine '%s'" % (name,))
 
 
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
 
 
 class IncorrectDatabaseSetup(RuntimeError):
     pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+    def __init__(self, module, database_config: dict):
+        self.module = module
+
+    @property
+    @abc.abstractmethod
+    def single_threaded(self) -> bool:
+        ...
+
+    @property
+    @abc.abstractmethod
+    def can_native_upsert(self) -> bool:
+        """
+        Do we support native UPSERTs?
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def supports_tuple_comparison(self) -> bool:
+        """
+        Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def supports_using_any_list(self) -> bool:
+        """
+        Do we support using `a = ANY(?)` and passing a list
+        """
+        ...
+
+    @abc.abstractmethod
+    def check_database(
+        self, db_conn: ConnectionType, allow_outdated_version: bool = False
+    ) -> None:
+        ...
+
+    @abc.abstractmethod
+    def check_new_database(self, txn) -> None:
+        """Gets called when setting up a brand new database. This allows us to
+        apply stricter checks on new databases versus existing database.
+        """
+        ...
+
+    @abc.abstractmethod
+    def convert_param_style(self, sql: str) -> str:
+        ...
+
+    @abc.abstractmethod
+    def on_new_connection(self, db_conn: ConnectionType) -> None:
+        ...
+
+    @abc.abstractmethod
+    def is_deadlock(self, error: Exception) -> bool:
+        ...
+
+    @abc.abstractmethod
+    def is_connection_closed(self, conn: ConnectionType) -> bool:
+        ...
+
+    @abc.abstractmethod
+    def lock_table(self, txn, table: str) -> None:
+        ...
+
+    @abc.abstractmethod
+    def get_next_state_group_id(self, txn) -> int:
+        """Returns an int that can be used as a new state_group ID
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def server_version(self) -> str:
+        """Gets a string giving the server version. For example: '3.22.0'
+        """
+        ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a077345960..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,16 +15,14 @@
 
 import logging
 
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 
 logger = logging.getLogger(__name__)
 
 
-class PostgresEngine(object):
-    single_threaded = False
-
+class PostgresEngine(BaseDatabaseEngine):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
         self.module.extensions.register_type(self.module.extensions.UNICODE)
 
         # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@@ -36,6 +34,10 @@ class PostgresEngine(object):
         self.synchronous_commit = database_config.get("synchronous_commit", True)
         self._version = None  # unknown as yet
 
+    @property
+    def single_threaded(self) -> bool:
+        return False
+
     def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
@@ -53,7 +55,7 @@ class PostgresEngine(object):
             if rows and rows[0][0] != "UTF8":
                 raise IncorrectDatabaseSetup(
                     "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
-                    "See docs/postgres.rst for more information." % (rows[0][0],)
+                    "See docs/postgres.md for more information." % (rows[0][0],)
                 )
 
             txn.execute(
@@ -62,12 +64,16 @@ class PostgresEngine(object):
             collation, ctype = txn.fetchone()
             if collation != "C":
                 logger.warning(
-                    "Database has incorrect collation of %r. Should be 'C'", collation
+                    "Database has incorrect collation of %r. Should be 'C'\n"
+                    "See docs/postgres.md for more information.",
+                    collation,
                 )
 
             if ctype != "C":
                 logger.warning(
-                    "Database has incorrect ctype of %r. Should be 'C'", ctype
+                    "Database has incorrect ctype of %r. Should be 'C'\n"
+                    "See docs/postgres.md for more information.",
+                    ctype,
                 )
 
     def check_new_database(self, txn):
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 641e490697..2bfeefd54e 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,16 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import sqlite3
 import struct
 import threading
 
+from synapse.storage.engines import BaseDatabaseEngine
 
-class Sqlite3Engine(object):
-    single_threaded = True
 
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
 
         database = database_config.get("args", {}).get("database")
         self._is_in_memory = database in (None, ":memory:",)
@@ -32,6 +32,10 @@ class Sqlite3Engine(object):
         self._current_state_group_id_lock = threading.Lock()
 
     @property
+    def single_threaded(self) -> bool:
+        return True
+
+    @property
     def can_native_upsert(self):
         """
         Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -68,7 +72,6 @@ class Sqlite3Engine(object):
         return sql
 
     def on_new_connection(self, db_conn):
-
         # We need to import here to avoid an import loop.
         from synapse.storage.prepare_database import prepare_database
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index b950550f23..0f9ac1cf09 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -602,14 +602,14 @@ class EventsPersistenceStorage(object):
             event_id_to_state_group.update(event_to_groups)
 
         # State groups of old_latest_event_ids
-        old_state_groups = set(
+        old_state_groups = {
             event_id_to_state_group[evid] for evid in old_latest_event_ids
-        )
+        }
 
         # State groups of new_latest_event_ids
-        new_state_groups = set(
+        new_state_groups = {
             event_id_to_state_group[evid] for evid in new_latest_event_ids
-        )
+        }
 
         # If they old and new groups are the same then we don't need to do
         # anything.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c285ef52a0..6cb7d4b922 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -278,13 +278,17 @@ def _upgrade_existing_database(
             the current_version wasn't generated by applying those delta files.
         database_engine (DatabaseEngine)
         config (synapse.config.homeserver.HomeServerConfig|None):
-            application config, or None if we are connecting to an existing
-            database which we expect to be configured already
+            None if we are initialising a blank database, otherwise the application
+            config
         data_stores (list[str]): The names of the data stores to instantiate
             on the given database.
         is_empty (bool): Is this a blank database? I.e. do we need to run the
             upgrade portions of the delta scripts.
     """
+    if is_empty:
+        assert not applied_delta_files
+    else:
+        assert config
 
     if current_version > SCHEMA_VERSION:
         raise ValueError(
@@ -292,6 +296,13 @@ def _upgrade_existing_database(
             + "new for the server to understand"
         )
 
+    # some of the deltas assume that config.server_name is set correctly, so now
+    # is a good time to run the sanity check.
+    if not is_empty and "main" in data_stores:
+        from synapse.storage.data_stores.main import check_database_before_upgrade
+
+        check_database_before_upgrade(cur, database_engine, config)
+
     start_ver = current_version
     if not upgraded:
         start_ver += 1
@@ -345,9 +356,9 @@ def _upgrade_existing_database(
                     "Could not open delta dir for version %d: %s" % (v, directory)
                 )
 
-        duplicates = set(
+        duplicates = {
             file_name for file_name, count in file_name_counter.items() if count > 1
-        )
+        }
         if duplicates:
             # We don't support using the same file name in the same delta version.
             raise PrepareDatabaseException(
@@ -454,7 +465,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
         ),
         (modname,),
     )
-    applied_deltas = set(d for d, in cur)
+    applied_deltas = {d for d, in cur}
     for (name, stream) in names_and_streams:
         if name in applied_deltas:
             continue
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
new file mode 100644
index 0000000000..daff81c5ee
--- /dev/null
+++ b/synapse/storage/types.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Iterable, Iterator, List, Tuple
+
+from typing_extensions import Protocol
+
+
+"""
+Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
+"""
+
+
+class Cursor(Protocol):
+    def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+        ...
+
+    def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+        ...
+
+    def fetchall(self) -> List[Tuple]:
+        ...
+
+    def fetchone(self) -> Tuple:
+        ...
+
+    @property
+    def description(self) -> Any:
+        return None
+
+    @property
+    def rowcount(self) -> int:
+        return 0
+
+    def __iter__(self) -> Iterator[Tuple]:
+        ...
+
+    def close(self) -> None:
+        ...
+
+
+class Connection(Protocol):
+    def cursor(self) -> Cursor:
+        ...
+
+    def close(self) -> None:
+        ...
+
+    def commit(self) -> None:
+        ...
+
+    def rollback(self, *args, **kwargs) -> None:
+        ...