summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py8
-rw-r--r--synapse/storage/_base.py32
-rw-r--r--synapse/storage/account_data.py14
-rw-r--r--synapse/storage/appservice.py14
-rw-r--r--synapse/storage/background_updates.py20
-rw-r--r--synapse/storage/client_ips.py26
-rw-r--r--synapse/storage/deviceinbox.py10
-rw-r--r--synapse/storage/devices.py32
-rw-r--r--synapse/storage/directory.py10
-rw-r--r--synapse/storage/e2e_room_keys.py6
-rw-r--r--synapse/storage/end_to_end_keys.py8
-rw-r--r--synapse/storage/event_federation.py12
-rw-r--r--synapse/storage/event_push_actions.py18
-rw-r--r--synapse/storage/events.py330
-rw-r--r--synapse/storage/events_bg_updates.py6
-rw-r--r--synapse/storage/events_worker.py232
-rw-r--r--synapse/storage/filtering.py4
-rw-r--r--synapse/storage/group_server.py38
-rw-r--r--synapse/storage/monthly_active_users.py2
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/presence.py6
-rw-r--r--synapse/storage/profile.py12
-rw-r--r--synapse/storage/push_rule.py18
-rw-r--r--synapse/storage/pusher.py40
-rw-r--r--synapse/storage/receipts.py36
-rw-r--r--synapse/storage/registration.py74
-rw-r--r--synapse/storage/relations.py4
-rw-r--r--synapse/storage/room.py10
-rw-r--r--synapse/storage/roommember.py350
-rw-r--r--synapse/storage/schema/delta/56/current_state_events_membership.sql22
-rw-r--r--synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql24
-rw-r--r--synapse/storage/schema/delta/56/room_membership_idx.sql18
-rw-r--r--synapse/storage/search.py56
-rw-r--r--synapse/storage/signatures.py2
-rw-r--r--synapse/storage/state.py62
-rw-r--r--synapse/storage/stats.py36
-rw-r--r--synapse/storage/stream.py44
-rw-r--r--synapse/storage/tags.py12
-rw-r--r--synapse/storage/transactions.py24
-rw-r--r--synapse/storage/user_directory.py38
-rw-r--r--synapse/storage/user_erasure_store.py5
41 files changed, 1012 insertions, 705 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 6b0ca80087..e7f6ea7286 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -469,7 +469,7 @@ class DataStore(
         return self._simple_select_list(
             table="users",
             keyvalues={},
-            retcols=["name", "password_hash", "is_guest", "admin"],
+            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
             desc="get_users",
         )
 
@@ -494,11 +494,11 @@ class DataStore(
             orderby=order,
             start=start,
             limit=limit,
-            retcols=["name", "password_hash", "is_guest", "admin"],
+            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
         )
         count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
         retval = {"users": users, "total": count}
-        defer.returnValue(retval)
+        return retval
 
     def search_users(self, term):
         """Function to search users list for one or more users with
@@ -514,7 +514,7 @@ class DataStore(
             table="users",
             term=term,
             col="name",
-            retcols=["name", "password_hash", "is_guest", "admin"],
+            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
             desc="search_users",
         )
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2f940dbae6..489ce82fae 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -86,7 +86,21 @@ _CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
-    method."""
+    method.
+
+    Args:
+        txn: The database transcation object to wrap.
+        name (str): The name of this transactions for logging.
+        database_engine (Sqlite3Engine|PostgresEngine)
+        after_callbacks(list|None): A list that callbacks will be appended to
+            that have been added by `call_after` which should be run on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
+        exception_callbacks(list|None): A list that callbacks will be appended
+            to that have been added by `call_on_exception` which should be run
+            if transaction ends with an error. None indicates that no callbacks
+            should be allowed to be scheduled to run.
+    """
 
     __slots__ = [
         "txn",
@@ -97,7 +111,7 @@ class LoggingTransaction(object):
     ]
 
     def __init__(
-        self, txn, name, database_engine, after_callbacks, exception_callbacks
+        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
     ):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
@@ -499,7 +513,7 @@ class SQLBaseStore(object):
                 after_callback(*after_args, **after_kwargs)
             raise
 
-        defer.returnValue(result)
+        return result
 
     @defer.inlineCallbacks
     def runWithConnection(self, func, *args, **kwargs):
@@ -539,7 +553,7 @@ class SQLBaseStore(object):
         with PreserveLoggingContext():
             result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
 
-        defer.returnValue(result)
+        return result
 
     @staticmethod
     def cursor_to_dict(cursor):
@@ -601,8 +615,8 @@ class SQLBaseStore(object):
             # a cursor after we receive an error from the db.
             if not or_ignore:
                 raise
-            defer.returnValue(False)
-        defer.returnValue(True)
+            return False
+        return True
 
     @staticmethod
     def _simple_insert_txn(txn, table, values):
@@ -694,7 +708,7 @@ class SQLBaseStore(object):
                     insertion_values,
                     lock=lock,
                 )
-                defer.returnValue(result)
+                return result
             except self.database_engine.module.IntegrityError as e:
                 attempts += 1
                 if attempts >= 5:
@@ -1107,7 +1121,7 @@ class SQLBaseStore(object):
         results = []
 
         if not iterable:
-            defer.returnValue(results)
+            return results
 
         # iterables can not be sliced, so convert it to a list first
         it_list = list(iterable)
@@ -1128,7 +1142,7 @@ class SQLBaseStore(object):
 
             results.extend(rows)
 
-        defer.returnValue(results)
+        return results
 
     @classmethod
     def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 8394389073..9fa5b4f3d6 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -111,9 +111,9 @@ class AccountDataWorkerStore(SQLBaseStore):
         )
 
         if result:
-            defer.returnValue(json.loads(result))
+            return json.loads(result)
         else:
-            defer.returnValue(None)
+            return None
 
     @cached(num_args=2)
     def get_account_data_for_room(self, user_id, room_id):
@@ -264,11 +264,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             on_invalidate=cache_context.invalidate,
         )
         if not ignored_account_data:
-            defer.returnValue(False)
+            return False
 
-        defer.returnValue(
-            ignored_user_id in ignored_account_data.get("ignored_users", {})
-        )
+        return ignored_user_id in ignored_account_data.get("ignored_users", {})
 
 
 class AccountDataStore(AccountDataWorkerStore):
@@ -332,7 +330,7 @@ class AccountDataStore(AccountDataWorkerStore):
             )
 
         result = self._account_data_id_gen.get_current_token()
-        defer.returnValue(result)
+        return result
 
     @defer.inlineCallbacks
     def add_account_data_for_user(self, user_id, account_data_type, content):
@@ -373,7 +371,7 @@ class AccountDataStore(AccountDataWorkerStore):
             )
 
         result = self._account_data_id_gen.get_current_token()
-        defer.returnValue(result)
+        return result
 
     def _update_max_stream_id(self, next_id):
         """Update the max stream_id
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index eb329ebd8b..05d9c05c3f 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -145,7 +145,7 @@ class ApplicationServiceTransactionWorkerStore(
             for service in as_list:
                 if service.id == res["as_id"]:
                     services.append(service)
-        defer.returnValue(services)
+        return services
 
     @defer.inlineCallbacks
     def get_appservice_state(self, service):
@@ -164,9 +164,9 @@ class ApplicationServiceTransactionWorkerStore(
             desc="get_appservice_state",
         )
         if result:
-            defer.returnValue(result.get("state"))
+            return result.get("state")
             return
-        defer.returnValue(None)
+        return None
 
     def set_appservice_state(self, service, state):
         """Set the application service state.
@@ -298,15 +298,13 @@ class ApplicationServiceTransactionWorkerStore(
         )
 
         if not entry:
-            defer.returnValue(None)
+            return None
 
         event_ids = json.loads(entry["event_ids"])
 
         events = yield self.get_events_as_list(event_ids)
 
-        defer.returnValue(
-            AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
-        )
+        return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
 
     def _get_last_txn(self, txn, service_id):
         txn.execute(
@@ -360,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore(
 
         events = yield self.get_events_as_list(event_ids)
 
-        defer.returnValue((upper_bound, events))
+        return (upper_bound, events)
 
 
 class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 50f913a414..e5f0668f09 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -115,7 +115,7 @@ class BackgroundUpdateStore(SQLBaseStore):
                         " Unscheduling background update task."
                     )
                     self._all_done = True
-                    defer.returnValue(None)
+                    return None
 
     @defer.inlineCallbacks
     def has_completed_background_updates(self):
@@ -127,11 +127,11 @@ class BackgroundUpdateStore(SQLBaseStore):
         # if we've previously determined that there is nothing left to do, that
         # is easy
         if self._all_done:
-            defer.returnValue(True)
+            return True
 
         # obviously, if we have things in our queue, we're not done.
         if self._background_update_queue:
-            defer.returnValue(False)
+            return False
 
         # otherwise, check if there are updates to be run. This is important,
         # as we may be running on a worker which doesn't perform the bg updates
@@ -144,9 +144,9 @@ class BackgroundUpdateStore(SQLBaseStore):
         )
         if not updates:
             self._all_done = True
-            defer.returnValue(True)
+            return True
 
-        defer.returnValue(False)
+        return False
 
     @defer.inlineCallbacks
     def do_next_background_update(self, desired_duration_ms):
@@ -173,14 +173,14 @@ class BackgroundUpdateStore(SQLBaseStore):
 
         if not self._background_update_queue:
             # no work left to do
-            defer.returnValue(None)
+            return None
 
         # pop from the front, and add back to the back
         update_name = self._background_update_queue.pop(0)
         self._background_update_queue.append(update_name)
 
         res = yield self._do_background_update(update_name, desired_duration_ms)
-        defer.returnValue(res)
+        return res
 
     @defer.inlineCallbacks
     def _do_background_update(self, update_name, desired_duration_ms):
@@ -231,7 +231,7 @@ class BackgroundUpdateStore(SQLBaseStore):
 
         performance.update(items_updated, duration_ms)
 
-        defer.returnValue(len(self._background_update_performance))
+        return len(self._background_update_performance)
 
     def register_background_update_handler(self, update_name, update_handler):
         """Register a handler for doing a background update.
@@ -266,7 +266,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         @defer.inlineCallbacks
         def noop_update(progress, batch_size):
             yield self._end_background_update(update_name)
-            defer.returnValue(1)
+            return 1
 
         self.register_background_update_handler(update_name, noop_update)
 
@@ -370,7 +370,7 @@ class BackgroundUpdateStore(SQLBaseStore):
                 logger.info("Adding index %s to %s", index_name, table)
                 yield self.runWithConnection(runner)
             yield self._end_background_update(update_name)
-            defer.returnValue(1)
+            return 1
 
         self.register_background_update_handler(update_name, updater)
 
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index bda68de5be..6db8c54077 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -104,7 +104,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         yield self.runWithConnection(f)
         yield self._end_background_update("user_ips_drop_nonunique_index")
-        defer.returnValue(1)
+        return 1
 
     @defer.inlineCallbacks
     def _analyze_user_ip(self, progress, batch_size):
@@ -121,7 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         yield self._end_background_update("user_ips_analyze")
 
-        defer.returnValue(1)
+        return 1
 
     @defer.inlineCallbacks
     def _remove_user_ip_dupes(self, progress, batch_size):
@@ -291,7 +291,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         if last:
             yield self._end_background_update("user_ips_remove_dupes")
 
-        defer.returnValue(batch_size)
+        return batch_size
 
     @defer.inlineCallbacks
     def insert_client_ip(
@@ -401,7 +401,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                         "device_id": did,
                         "last_seen": last_seen,
                     }
-        defer.returnValue(ret)
+        return ret
 
     @classmethod
     def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
@@ -461,14 +461,12 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
             for row in rows
         )
-        defer.returnValue(
-            list(
-                {
-                    "access_token": access_token,
-                    "ip": ip,
-                    "user_agent": user_agent,
-                    "last_seen": last_seen,
-                }
-                for (access_token, ip), (user_agent, last_seen) in iteritems(results)
-            )
+        return list(
+            {
+                "access_token": access_token,
+                "ip": ip,
+                "user_agent": user_agent,
+                "last_seen": last_seen,
+            }
+            for (access_token, ip), (user_agent, last_seen) in iteritems(results)
         )
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 4ea0deea4f..79bb0ea46d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -92,7 +92,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 user_id, last_deleted_stream_id
             )
             if not has_changed:
-                defer.returnValue(0)
+                return 0
 
         def delete_messages_for_device_txn(txn):
             sql = (
@@ -115,7 +115,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             last_deleted_stream_id, up_to_stream_id
         )
 
-        defer.returnValue(count)
+        return count
 
     def get_new_device_msgs_for_remote(
         self, destination, last_stream_id, current_stream_id, limit
@@ -263,7 +263,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
                     destination, stream_id
                 )
 
-        defer.returnValue(self._device_inbox_id_gen.get_current_token())
+        return self._device_inbox_id_gen.get_current_token()
 
     @defer.inlineCallbacks
     def add_messages_from_remote_to_device_inbox(
@@ -312,7 +312,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
             for user_id in local_messages_by_user_then_device.keys():
                 self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
 
-        defer.returnValue(stream_id)
+        return stream_id
 
     def _add_messages_to_local_device_inbox_txn(
         self, txn, stream_id, messages_by_user_then_device
@@ -426,4 +426,4 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
 
         yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
 
-        defer.returnValue(1)
+        return 1
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d2b113a4e7..8f72d92895 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -71,7 +71,7 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="get_devices_by_user",
         )
 
-        defer.returnValue({d["device_id"]: d for d in devices})
+        return {d["device_id"]: d for d in devices}
 
     @defer.inlineCallbacks
     def get_devices_by_remote(self, destination, from_stream_id, limit):
@@ -88,7 +88,7 @@ class DeviceWorkerStore(SQLBaseStore):
             destination, int(from_stream_id)
         )
         if not has_changed:
-            defer.returnValue((now_stream_id, []))
+            return (now_stream_id, [])
 
         # We retrieve n+1 devices from the list of outbound pokes where n is
         # our outbound device update limit. We then check if the very last
@@ -111,7 +111,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
         # Return an empty list if there are no updates
         if not updates:
-            defer.returnValue((now_stream_id, []))
+            return (now_stream_id, [])
 
         # if we have exceeded the limit, we need to exclude any results with the
         # same stream_id as the last row.
@@ -147,13 +147,13 @@ class DeviceWorkerStore(SQLBaseStore):
         # skip that stream_id and return an empty list, and continue with the next
         # stream_id next time.
         if not query_map:
-            defer.returnValue((stream_id_cutoff, []))
+            return (stream_id_cutoff, [])
 
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
 
-        defer.returnValue((now_stream_id, results))
+        return (now_stream_id, results)
 
     def _get_devices_by_remote_txn(
         self, txn, destination, from_stream_id, now_stream_id, limit
@@ -232,7 +232,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
                 results.append(result)
 
-        defer.returnValue(results)
+        return results
 
     def _get_last_device_update_for_remote_user(
         self, destination, user_id, from_stream_id
@@ -330,7 +330,7 @@ class DeviceWorkerStore(SQLBaseStore):
             else:
                 results[user_id] = yield self._get_cached_devices_for_user(user_id)
 
-        defer.returnValue((user_ids_not_in_cache, results))
+        return (user_ids_not_in_cache, results)
 
     @cachedInlineCallbacks(num_args=2, tree=True)
     def _get_cached_user_device(self, user_id, device_id):
@@ -340,7 +340,7 @@ class DeviceWorkerStore(SQLBaseStore):
             retcol="content",
             desc="_get_cached_user_device",
         )
-        defer.returnValue(db_to_json(content))
+        return db_to_json(content)
 
     @cachedInlineCallbacks()
     def _get_cached_devices_for_user(self, user_id):
@@ -350,9 +350,9 @@ class DeviceWorkerStore(SQLBaseStore):
             retcols=("device_id", "content"),
             desc="_get_cached_devices_for_user",
         )
-        defer.returnValue(
-            {device["device_id"]: db_to_json(device["content"]) for device in devices}
-        )
+        return {
+            device["device_id"]: db_to_json(device["content"]) for device in devices
+        }
 
     def get_devices_with_keys_by_user(self, user_id):
         """Get all devices (with any device keys) for a user
@@ -482,7 +482,7 @@ class DeviceWorkerStore(SQLBaseStore):
         results = {user_id: None for user_id in user_ids}
         results.update({row["user_id"]: row["stream_id"] for row in rows})
 
-        defer.returnValue(results)
+        return results
 
 
 class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
@@ -543,7 +543,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         """
         key = (user_id, device_id)
         if self.device_id_exists_cache.get(key, None):
-            defer.returnValue(False)
+            return False
 
         try:
             inserted = yield self._simple_insert(
@@ -557,7 +557,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
                 or_ignore=True,
             )
             self.device_id_exists_cache.prefill(key, True)
-            defer.returnValue(inserted)
+            return inserted
         except Exception as e:
             logger.error(
                 "store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -780,7 +780,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
                 hosts,
                 stream_id,
             )
-        defer.returnValue(stream_id)
+        return stream_id
 
     def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
         now = self._clock.time_msec()
@@ -889,4 +889,4 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
 
         yield self.runWithConnection(f)
         yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
-        defer.returnValue(1)
+        return 1
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 201bbd430c..e966a73f3d 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -46,7 +46,7 @@ class DirectoryWorkerStore(SQLBaseStore):
         )
 
         if not room_id:
-            defer.returnValue(None)
+            return None
             return
 
         servers = yield self._simple_select_onecol(
@@ -57,10 +57,10 @@ class DirectoryWorkerStore(SQLBaseStore):
         )
 
         if not servers:
-            defer.returnValue(None)
+            return None
             return
 
-        defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
+        return RoomAliasMapping(room_id, room_alias.to_string(), servers)
 
     def get_room_alias_creator(self, room_alias):
         return self._simple_select_one_onecol(
@@ -125,7 +125,7 @@ class DirectoryStore(DirectoryWorkerStore):
             raise SynapseError(
                 409, "Room alias %s already exists" % room_alias.to_string()
             )
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
@@ -133,7 +133,7 @@ class DirectoryStore(DirectoryWorkerStore):
             "delete_room_alias", self._delete_room_alias_txn, room_alias
         )
 
-        defer.returnValue(room_id)
+        return room_id
 
     def _delete_room_alias_txn(self, txn, room_alias):
         txn.execute(
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index f40ef2ab64..99128f2df7 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -61,7 +61,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         row["session_data"] = json.loads(row["session_data"])
 
-        defer.returnValue(row)
+        return row
 
     @defer.inlineCallbacks
     def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
@@ -118,7 +118,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         try:
             version = int(version)
         except ValueError:
-            defer.returnValue({"rooms": {}})
+            return {"rooms": {}}
 
         keyvalues = {"user_id": user_id, "version": version}
         if room_id:
@@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 "session_data": json.loads(row["session_data"]),
             }
 
-        defer.returnValue(sessions)
+        return sessions
 
     @defer.inlineCallbacks
     def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2fabb9e2cb..1e07474e70 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -41,7 +41,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             dict containing "key_json", "device_display_name".
         """
         if not query_list:
-            defer.returnValue({})
+            return {}
 
         results = yield self.runInteraction(
             "get_e2e_device_keys",
@@ -55,7 +55,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             for device_id, device_info in iteritems(device_keys):
                 device_info["keys"] = db_to_json(device_info.pop("key_json"))
 
-        defer.returnValue(results)
+        return results
 
     def _get_e2e_device_keys_txn(
         self, txn, query_list, include_all_devices=False, include_deleted_devices=False
@@ -130,9 +130,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             desc="add_e2e_one_time_keys_check",
         )
 
-        defer.returnValue(
-            {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
-        )
+        return {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
 
     @defer.inlineCallbacks
     def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index cb4478342f..4f500d893e 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -131,9 +131,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )
 
         if not rows:
-            defer.returnValue(0)
+            return 0
         else:
-            defer.returnValue(max(row["depth"] for row in rows))
+            return max(row["depth"] for row in rows)
 
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
@@ -169,7 +169,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             # make sure that we don't completely ignore the older events.
             res = res[0:5] + random.sample(res[5:], 5)
 
-        defer.returnValue(res)
+        return res
 
     def get_latest_event_ids_and_hashes_in_room(self, room_id):
         """
@@ -411,7 +411,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             limit,
         )
         events = yield self.get_events_as_list(ids)
-        defer.returnValue(events)
+        return events
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
 
@@ -463,7 +463,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             desc="get_successor_events",
         )
 
-        defer.returnValue([row["event_id"] for row in rows])
+        return [row["event_id"] for row in rows]
 
 
 class EventFederationStore(EventFederationWorkerStore):
@@ -654,4 +654,4 @@ class EventFederationStore(EventFederationWorkerStore):
         if not result:
             yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
 
-        defer.returnValue(batch_size)
+        return batch_size
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index eca77069fd..22025effbc 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -79,8 +79,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             db_conn.cursor(),
             name="_find_stream_orderings_for_times_txn",
             database_engine=self.database_engine,
-            after_callbacks=[],
-            exception_callbacks=[],
         )
         self._find_stream_orderings_for_times_txn(cur)
         cur.close()
@@ -102,7 +100,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             user_id,
             last_read_event_id,
         )
-        defer.returnValue(ret)
+        return ret
 
     def _get_unread_counts_by_receipt_txn(
         self, txn, room_id, user_id, last_read_event_id
@@ -180,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             return [r[0] for r in txn]
 
         ret = yield self.runInteraction("get_push_action_users_in_range", f)
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def get_unread_push_actions_for_user_in_range_for_http(
@@ -281,7 +279,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         # Take only up to the limit. We have to stop at the limit because
         # one of the subqueries may have hit the limit.
-        defer.returnValue(notifs[:limit])
+        return notifs[:limit]
 
     @defer.inlineCallbacks
     def get_unread_push_actions_for_user_in_range_for_email(
@@ -382,7 +380,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         notifs.sort(key=lambda r: -(r["received_ts"] or 0))
 
         # Now return the first `limit`
-        defer.returnValue(notifs[:limit])
+        return notifs[:limit]
 
     def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
         """A fast check to see if there might be something to push for the
@@ -479,7 +477,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 keyvalues={"event_id": event_id},
                 desc="remove_push_actions_from_staging",
             )
-            defer.returnValue(res)
+            return res
         except Exception:
             # this method is called from an exception handler, so propagating
             # another exception here really isn't helpful - there's nothing
@@ -734,7 +732,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         push_actions = yield self.runInteraction("get_push_actions_for_user", f)
         for pa in push_actions:
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
-        defer.returnValue(push_actions)
+        return push_actions
 
     @defer.inlineCallbacks
     def get_time_of_last_push_action_before(self, stream_ordering):
@@ -751,7 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             return txn.fetchone()
 
         result = yield self.runInteraction("get_time_of_last_push_action_before", f)
-        defer.returnValue(result[0] if result else None)
+        return result[0] if result else None
 
     @defer.inlineCallbacks
     def get_latest_push_action_stream_ordering(self):
@@ -760,7 +758,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             return txn.fetchone()
 
         result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
-        defer.returnValue(result[0] or 0)
+        return result[0] or 0
 
     def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
         # Sad that we have to blow away the cache for the whole room here
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index b486ca50eb..ac876287fc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -223,7 +223,7 @@ def _retry_on_integrity_error(func):
         except self.database_engine.module.IntegrityError:
             logger.exception("IntegrityError, retrying.")
             res = yield func(self, *args, delete_existing=True, **kwargs)
-        defer.returnValue(res)
+        return res
 
     return f
 
@@ -309,7 +309,7 @@ class EventsStore(
 
         max_persisted_id = yield self._stream_id_gen.get_current_token()
 
-        defer.returnValue(max_persisted_id)
+        return max_persisted_id
 
     @defer.inlineCallbacks
     @log_function
@@ -334,7 +334,7 @@ class EventsStore(
         yield make_deferred_yieldable(deferred)
 
         max_persisted_id = yield self._stream_id_gen.get_current_token()
-        defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
+        return (event.internal_metadata.stream_ordering, max_persisted_id)
 
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
@@ -364,147 +364,161 @@ class EventsStore(
         if not events_and_contexts:
             return
 
-        if backfilled:
-            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
-                len(events_and_contexts)
-            )
-        else:
-            stream_ordering_manager = self._stream_id_gen.get_next_mult(
-                len(events_and_contexts)
-            )
-
-        with stream_ordering_manager as stream_orderings:
-            for (event, context), stream in zip(events_and_contexts, stream_orderings):
-                event.internal_metadata.stream_ordering = stream
-
-            chunks = [
-                events_and_contexts[x : x + 100]
-                for x in range(0, len(events_and_contexts), 100)
-            ]
-
-            for chunk in chunks:
-                # We can't easily parallelize these since different chunks
-                # might contain the same event. :(
+        chunks = [
+            events_and_contexts[x : x + 100]
+            for x in range(0, len(events_and_contexts), 100)
+        ]
 
-                # NB: Assumes that we are only persisting events for one room
-                # at a time.
+        for chunk in chunks:
+            # We can't easily parallelize these since different chunks
+            # might contain the same event. :(
 
-                # map room_id->list[event_ids] giving the new forward
-                # extremities in each room
-                new_forward_extremeties = {}
+            # NB: Assumes that we are only persisting events for one room
+            # at a time.
 
-                # map room_id->(type,state_key)->event_id tracking the full
-                # state in each room after adding these events.
-                # This is simply used to prefill the get_current_state_ids
-                # cache
-                current_state_for_room = {}
+            # map room_id->list[event_ids] giving the new forward
+            # extremities in each room
+            new_forward_extremeties = {}
 
-                # map room_id->(to_delete, to_insert) where to_delete is a list
-                # of type/state keys to remove from current state, and to_insert
-                # is a map (type,key)->event_id giving the state delta in each
-                # room
-                state_delta_for_room = {}
+            # map room_id->(type,state_key)->event_id tracking the full
+            # state in each room after adding these events.
+            # This is simply used to prefill the get_current_state_ids
+            # cache
+            current_state_for_room = {}
 
-                if not backfilled:
-                    with Measure(self._clock, "_calculate_state_and_extrem"):
-                        # Work out the new "current state" for each room.
-                        # We do this by working out what the new extremities are and then
-                        # calculating the state from that.
-                        events_by_room = {}
-                        for event, context in chunk:
-                            events_by_room.setdefault(event.room_id, []).append(
-                                (event, context)
-                            )
+            # map room_id->(to_delete, to_insert) where to_delete is a list
+            # of type/state keys to remove from current state, and to_insert
+            # is a map (type,key)->event_id giving the state delta in each
+            # room
+            state_delta_for_room = {}
 
-                        for room_id, ev_ctx_rm in iteritems(events_by_room):
-                            latest_event_ids = yield self.get_latest_event_ids_in_room(
-                                room_id
-                            )
-                            new_latest_event_ids = yield self._calculate_new_extremities(
-                                room_id, ev_ctx_rm, latest_event_ids
+            if not backfilled:
+                with Measure(self._clock, "_calculate_state_and_extrem"):
+                    # Work out the new "current state" for each room.
+                    # We do this by working out what the new extremities are and then
+                    # calculating the state from that.
+                    events_by_room = {}
+                    for event, context in chunk:
+                        events_by_room.setdefault(event.room_id, []).append(
+                            (event, context)
+                        )
+
+                    for room_id, ev_ctx_rm in iteritems(events_by_room):
+                        latest_event_ids = yield self.get_latest_event_ids_in_room(
+                            room_id
+                        )
+                        new_latest_event_ids = yield self._calculate_new_extremities(
+                            room_id, ev_ctx_rm, latest_event_ids
+                        )
+
+                        latest_event_ids = set(latest_event_ids)
+                        if new_latest_event_ids == latest_event_ids:
+                            # No change in extremities, so no change in state
+                            continue
+
+                        # there should always be at least one forward extremity.
+                        # (except during the initial persistence of the send_join
+                        # results, in which case there will be no existing
+                        # extremities, so we'll `continue` above and skip this bit.)
+                        assert new_latest_event_ids, "No forward extremities left!"
+
+                        new_forward_extremeties[room_id] = new_latest_event_ids
+
+                        len_1 = (
+                            len(latest_event_ids) == 1
+                            and len(new_latest_event_ids) == 1
+                        )
+                        if len_1:
+                            all_single_prev_not_state = all(
+                                len(event.prev_event_ids()) == 1
+                                and not event.is_state()
+                                for event, ctx in ev_ctx_rm
                             )
-
-                            latest_event_ids = set(latest_event_ids)
-                            if new_latest_event_ids == latest_event_ids:
-                                # No change in extremities, so no change in state
+                            # Don't bother calculating state if they're just
+                            # a long chain of single ancestor non-state events.
+                            if all_single_prev_not_state:
                                 continue
 
-                            # there should always be at least one forward extremity.
-                            # (except during the initial persistence of the send_join
-                            # results, in which case there will be no existing
-                            # extremities, so we'll `continue` above and skip this bit.)
-                            assert new_latest_event_ids, "No forward extremities left!"
-
-                            new_forward_extremeties[room_id] = new_latest_event_ids
-
-                            len_1 = (
-                                len(latest_event_ids) == 1
-                                and len(new_latest_event_ids) == 1
+                        state_delta_counter.inc()
+                        if len(new_latest_event_ids) == 1:
+                            state_delta_single_event_counter.inc()
+
+                            # This is a fairly handwavey check to see if we could
+                            # have guessed what the delta would have been when
+                            # processing one of these events.
+                            # What we're interested in is if the latest extremities
+                            # were the same when we created the event as they are
+                            # now. When this server creates a new event (as opposed
+                            # to receiving it over federation) it will use the
+                            # forward extremities as the prev_events, so we can
+                            # guess this by looking at the prev_events and checking
+                            # if they match the current forward extremities.
+                            for ev, _ in ev_ctx_rm:
+                                prev_event_ids = set(ev.prev_event_ids())
+                                if latest_event_ids == prev_event_ids:
+                                    state_delta_reuse_delta_counter.inc()
+                                    break
+
+                        logger.info("Calculating state delta for room %s", room_id)
+                        with Measure(
+                            self._clock, "persist_events.get_new_state_after_events"
+                        ):
+                            res = yield self._get_new_state_after_events(
+                                room_id,
+                                ev_ctx_rm,
+                                latest_event_ids,
+                                new_latest_event_ids,
                             )
-                            if len_1:
-                                all_single_prev_not_state = all(
-                                    len(event.prev_event_ids()) == 1
-                                    and not event.is_state()
-                                    for event, ctx in ev_ctx_rm
-                                )
-                                # Don't bother calculating state if they're just
-                                # a long chain of single ancestor non-state events.
-                                if all_single_prev_not_state:
-                                    continue
-
-                            state_delta_counter.inc()
-                            if len(new_latest_event_ids) == 1:
-                                state_delta_single_event_counter.inc()
-
-                                # This is a fairly handwavey check to see if we could
-                                # have guessed what the delta would have been when
-                                # processing one of these events.
-                                # What we're interested in is if the latest extremities
-                                # were the same when we created the event as they are
-                                # now. When this server creates a new event (as opposed
-                                # to receiving it over federation) it will use the
-                                # forward extremities as the prev_events, so we can
-                                # guess this by looking at the prev_events and checking
-                                # if they match the current forward extremities.
-                                for ev, _ in ev_ctx_rm:
-                                    prev_event_ids = set(ev.prev_event_ids())
-                                    if latest_event_ids == prev_event_ids:
-                                        state_delta_reuse_delta_counter.inc()
-                                        break
-
-                            logger.info("Calculating state delta for room %s", room_id)
+                            current_state, delta_ids = res
+
+                        # If either are not None then there has been a change,
+                        # and we need to work out the delta (or use that
+                        # given)
+                        if delta_ids is not None:
+                            # If there is a delta we know that we've
+                            # only added or replaced state, never
+                            # removed keys entirely.
+                            state_delta_for_room[room_id] = ([], delta_ids)
+                        elif current_state is not None:
                             with Measure(
-                                self._clock, "persist_events.get_new_state_after_events"
+                                self._clock, "persist_events.calculate_state_delta"
                             ):
-                                res = yield self._get_new_state_after_events(
-                                    room_id,
-                                    ev_ctx_rm,
-                                    latest_event_ids,
-                                    new_latest_event_ids,
+                                delta = yield self._calculate_state_delta(
+                                    room_id, current_state
                                 )
-                                current_state, delta_ids = res
-
-                            # If either are not None then there has been a change,
-                            # and we need to work out the delta (or use that
-                            # given)
-                            if delta_ids is not None:
-                                # If there is a delta we know that we've
-                                # only added or replaced state, never
-                                # removed keys entirely.
-                                state_delta_for_room[room_id] = ([], delta_ids)
-                            elif current_state is not None:
-                                with Measure(
-                                    self._clock, "persist_events.calculate_state_delta"
-                                ):
-                                    delta = yield self._calculate_state_delta(
-                                        room_id, current_state
-                                    )
-                                state_delta_for_room[room_id] = delta
-
-                            # If we have the current_state then lets prefill
-                            # the cache with it.
-                            if current_state is not None:
-                                current_state_for_room[room_id] = current_state
+                            state_delta_for_room[room_id] = delta
+
+                        # If we have the current_state then lets prefill
+                        # the cache with it.
+                        if current_state is not None:
+                            current_state_for_room[room_id] = current_state
+
+            # We want to calculate the stream orderings as late as possible, as
+            # we only notify after all events with a lesser stream ordering have
+            # been persisted. I.e. if we spend 10s inside the with block then
+            # that will delay all subsequent events from being notified about.
+            # Hence why we do it down here rather than wrapping the entire
+            # function.
+            #
+            # Its safe to do this after calculating the state deltas etc as we
+            # only need to protect the *persistence* of the events. This is to
+            # ensure that queries of the form "fetch events since X" don't
+            # return events and stream positions after events that are still in
+            # flight, as otherwise subsequent requests "fetch event since Y"
+            # will not return those events.
+            #
+            # Note: Multiple instances of this function cannot be in flight at
+            # the same time for the same room.
+            if backfilled:
+                stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+                    len(chunk)
+                )
+            else:
+                stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk))
+
+            with stream_ordering_manager as stream_orderings:
+                for (event, context), stream in zip(chunk, stream_orderings):
+                    event.internal_metadata.stream_ordering = stream
 
                 yield self.runInteraction(
                     "persist_events",
@@ -595,7 +609,7 @@ class EventsStore(
             stale = latest_event_ids & result
             stale_forward_extremities_counter.observe(len(stale))
 
-        defer.returnValue(result)
+        return result
 
     @defer.inlineCallbacks
     def _get_events_which_are_prevs(self, event_ids):
@@ -633,7 +647,7 @@ class EventsStore(
                 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
             )
 
-        defer.returnValue(results)
+        return results
 
     @defer.inlineCallbacks
     def _get_prevs_before_rejected(self, event_ids):
@@ -695,7 +709,7 @@ class EventsStore(
                 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
             )
 
-        defer.returnValue(existing_prevs)
+        return existing_prevs
 
     @defer.inlineCallbacks
     def _get_new_state_after_events(
@@ -796,7 +810,7 @@ class EventsStore(
         # If they old and new groups are the same then we don't need to do
         # anything.
         if old_state_groups == new_state_groups:
-            defer.returnValue((None, None))
+            return (None, None)
 
         if len(new_state_groups) == 1 and len(old_state_groups) == 1:
             # If we're going from one state group to another, lets check if
@@ -813,7 +827,7 @@ class EventsStore(
                 # the current state in memory then lets also return that,
                 # but it doesn't matter if we don't.
                 new_state = state_groups_map.get(new_state_group)
-                defer.returnValue((new_state, delta_ids))
+                return (new_state, delta_ids)
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
@@ -825,7 +839,7 @@ class EventsStore(
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
             # state is.
-            defer.returnValue((state_groups_map[new_state_groups.pop()], None))
+            return (state_groups_map[new_state_groups.pop()], None)
 
         # Ok, we need to defer to the state handler to resolve our state sets.
 
@@ -854,7 +868,7 @@ class EventsStore(
             state_res_store=StateResolutionStore(self),
         )
 
-        defer.returnValue((res.state, None))
+        return (res.state, None)
 
     @defer.inlineCallbacks
     def _calculate_state_delta(self, room_id, current_state):
@@ -877,7 +891,7 @@ class EventsStore(
             if ev_id != existing_state.get(key)
         }
 
-        defer.returnValue((to_delete, to_insert))
+        return (to_delete, to_insert)
 
     @log_function
     def _persist_events_txn(
@@ -918,8 +932,6 @@ class EventsStore(
         min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
-
         self._update_forward_extremities_txn(
             txn,
             new_forward_extremities=new_forward_extremeties,
@@ -993,6 +1005,10 @@ class EventsStore(
             backfilled=backfilled,
         )
 
+        # We call this last as it assumes we've inserted the events into
+        # room_memberships, where applicable.
+        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
     def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
         for room_id, current_state_tuple in iteritems(state_delta_by_room):
             to_delete, to_insert = current_state_tuple
@@ -1062,16 +1078,16 @@ class EventsStore(
                 ),
             )
 
-            self._simple_insert_many_txn(
-                txn,
-                table="current_state_events",
-                values=[
-                    {
-                        "event_id": ev_id,
-                        "room_id": room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                    }
+            # We include the membership in the current state table, hence we do
+            # a lookup when we insert. This assumes that all events have already
+            # been inserted into room_memberships.
+            txn.executemany(
+                """INSERT INTO current_state_events
+                    (room_id, type, state_key, event_id, membership)
+                VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+                """,
+                [
+                    (room_id, key[0], key[1], ev_id, ev_id)
                     for key, ev_id in iteritems(to_insert)
                 ],
             )
@@ -1562,7 +1578,7 @@ class EventsStore(
             return count
 
         ret = yield self.runInteraction("count_messages", _count_messages)
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def count_daily_sent_messages(self):
@@ -1583,7 +1599,7 @@ class EventsStore(
             return count
 
         ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def count_daily_active_rooms(self):
@@ -1598,7 +1614,7 @@ class EventsStore(
             return count
 
         ret = yield self.runInteraction("count_daily_active_rooms", _count)
-        defer.returnValue(ret)
+        return ret
 
     def get_current_backfill_token(self):
         """The current minimum token that backfilled events have reached"""
@@ -2181,7 +2197,7 @@ class EventsStore(
         """
         to_1, so_1 = yield self._get_event_ordering(event_id1)
         to_2, so_2 = yield self._get_event_ordering(event_id2)
-        defer.returnValue((to_1, so_1) > (to_2, so_2))
+        return (to_1, so_1) > (to_2, so_2)
 
     @cachedInlineCallbacks(max_entries=5000)
     def _get_event_ordering(self, event_id):
@@ -2195,9 +2211,7 @@ class EventsStore(
         if not res:
             raise SynapseError(404, "Could not find event %s" % (event_id,))
 
-        defer.returnValue(
-            (int(res["topological_ordering"]), int(res["stream_ordering"]))
-        )
+        return (int(res["topological_ordering"]), int(res["stream_ordering"]))
 
     def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
         def get_all_updated_current_state_deltas_txn(txn):
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py
index 1ce21d190c..6587f31e2b 100644
--- a/synapse/storage/events_bg_updates.py
+++ b/synapse/storage/events_bg_updates.py
@@ -135,7 +135,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
         if not result:
             yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
 
-        defer.returnValue(result)
+        return result
 
     @defer.inlineCallbacks
     def _background_reindex_origin_server_ts(self, progress, batch_size):
@@ -212,7 +212,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
         if not result:
             yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
 
-        defer.returnValue(result)
+        return result
 
     @defer.inlineCallbacks
     def _cleanup_extremities_bg_update(self, progress, batch_size):
@@ -396,4 +396,4 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
             )
 
-        defer.returnValue(num_handled)
+        return num_handled
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 06379281b6..c6fa7f82fd 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -29,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions
 from synapse.events import FrozenEvent, event_type_from_format_version  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.events.utils import prune_event
-from synapse.logging.context import (
-    LoggingContext,
-    PreserveLoggingContext,
-    make_deferred_yieldable,
-    run_in_background,
-)
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import get_domain_from_id
 from synapse.util import batch_iter
@@ -139,8 +134,11 @@ class EventsWorkerStore(SQLBaseStore):
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
-            Deferred : A FrozenEvent.
+            Deferred[EventBase|None]
         """
+        if not isinstance(event_id, str):
+            raise TypeError("Invalid event event_id %r" % (event_id,))
+
         events = yield self.get_events_as_list(
             [event_id],
             check_redacted=check_redacted,
@@ -157,7 +155,7 @@ class EventsWorkerStore(SQLBaseStore):
         if event is None and not allow_none:
             raise NotFoundError("Could not find event %s" % (event_id,))
 
-        defer.returnValue(event)
+        return event
 
     @defer.inlineCallbacks
     def get_events(
@@ -187,7 +185,7 @@ class EventsWorkerStore(SQLBaseStore):
             allow_rejected=allow_rejected,
         )
 
-        defer.returnValue({e.event_id: e for e in events})
+        return {e.event_id: e for e in events}
 
     @defer.inlineCallbacks
     def get_events_as_list(
@@ -217,7 +215,7 @@ class EventsWorkerStore(SQLBaseStore):
         """
 
         if not event_ids:
-            defer.returnValue([])
+            return []
 
         # there may be duplicates so we cast the list to a set
         event_entry_map = yield self._get_events_from_cache_or_db(
@@ -313,7 +311,7 @@ class EventsWorkerStore(SQLBaseStore):
                         event.unsigned["prev_content"] = prev.content
                         event.unsigned["prev_sender"] = prev.sender
 
-        defer.returnValue(events)
+        return events
 
     @defer.inlineCallbacks
     def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
@@ -339,13 +337,12 @@ class EventsWorkerStore(SQLBaseStore):
             log_ctx = LoggingContext.current_context()
             log_ctx.record_event_fetch(len(missing_events_ids))
 
-            # Note that _enqueue_events is also responsible for turning db rows
+            # Note that _get_events_from_db is also responsible for turning db rows
             # into FrozenEvents (via _get_event_from_row), which involves seeing if
             # the events have been redacted, and if so pulling the redaction event out
             # of the database to check it.
             #
-            # _enqueue_events is a bit of a rubbish name but naming is hard.
-            missing_events = yield self._enqueue_events(
+            missing_events = yield self._get_events_from_db(
                 missing_events_ids, allow_rejected=allow_rejected
             )
 
@@ -418,28 +415,28 @@ class EventsWorkerStore(SQLBaseStore):
                 The fetch requests. Each entry consists of a list of event
                 ids to be fetched, and a deferred to be completed once the
                 events have been fetched.
+
+                The deferreds are callbacked with a dictionary mapping from event id
+                to event row. Note that it may well contain additional events that
+                were not part of this request.
         """
         with Measure(self._clock, "_fetch_event_list"):
             try:
-                event_id_lists = list(zip(*event_list))[0]
-                event_ids = [item for sublist in event_id_lists for item in sublist]
+                events_to_fetch = set(
+                    event_id for events, _ in event_list for event_id in events
+                )
 
                 row_dict = self._new_transaction(
-                    conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
+                    conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
                 )
 
                 # We only want to resolve deferreds from the main thread
-                def fire(lst, res):
-                    for ids, d in lst:
-                        if not d.called:
-                            try:
-                                with PreserveLoggingContext():
-                                    d.callback([res[i] for i in ids if i in res])
-                            except Exception:
-                                logger.exception("Failed to callback")
+                def fire():
+                    for _, d in event_list:
+                        d.callback(row_dict)
 
                 with PreserveLoggingContext():
-                    self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
+                    self.hs.get_reactor().callFromThread(fire)
             except Exception as e:
                 logger.exception("do_fetch")
 
@@ -454,13 +451,98 @@ class EventsWorkerStore(SQLBaseStore):
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
     @defer.inlineCallbacks
-    def _enqueue_events(self, events, allow_rejected=False):
+    def _get_events_from_db(self, event_ids, allow_rejected=False):
+        """Fetch a bunch of events from the database.
+
+        Returned events will be added to the cache for future lookups.
+
+        Args:
+            event_ids (Iterable[str]): The event_ids of the events to fetch
+            allow_rejected (bool): Whether to include rejected events
+
+        Returns:
+            Deferred[Dict[str, _EventCacheEntry]]:
+                map from event id to result. May return extra events which
+                weren't asked for.
+        """
+        fetched_events = {}
+        events_to_fetch = event_ids
+
+        while events_to_fetch:
+            row_map = yield self._enqueue_events(events_to_fetch)
+
+            # we need to recursively fetch any redactions of those events
+            redaction_ids = set()
+            for event_id in events_to_fetch:
+                row = row_map.get(event_id)
+                fetched_events[event_id] = row
+                if row:
+                    redaction_ids.update(row["redactions"])
+
+            events_to_fetch = redaction_ids.difference(fetched_events.keys())
+            if events_to_fetch:
+                logger.debug("Also fetching redaction events %s", events_to_fetch)
+
+        # build a map from event_id to EventBase
+        event_map = {}
+        for event_id, row in fetched_events.items():
+            if not row:
+                continue
+            assert row["event_id"] == event_id
+
+            rejected_reason = row["rejected_reason"]
+
+            if not allow_rejected and rejected_reason:
+                continue
+
+            d = json.loads(row["json"])
+            internal_metadata = json.loads(row["internal_metadata"])
+
+            format_version = row["format_version"]
+            if format_version is None:
+                # This means that we stored the event before we had the concept
+                # of a event format version, so it must be a V1 event.
+                format_version = EventFormatVersions.V1
+
+            original_ev = event_type_from_format_version(format_version)(
+                event_dict=d,
+                internal_metadata_dict=internal_metadata,
+                rejected_reason=rejected_reason,
+            )
+
+            event_map[event_id] = original_ev
+
+        # finally, we can decide whether each one nededs redacting, and build
+        # the cache entries.
+        result_map = {}
+        for event_id, original_ev in event_map.items():
+            redactions = fetched_events[event_id]["redactions"]
+            redacted_event = self._maybe_redact_event_row(
+                original_ev, redactions, event_map
+            )
+
+            cache_entry = _EventCacheEntry(
+                event=original_ev, redacted_event=redacted_event
+            )
+
+            self._get_event_cache.prefill((event_id,), cache_entry)
+            result_map[event_id] = cache_entry
+
+        return result_map
+
+    @defer.inlineCallbacks
+    def _enqueue_events(self, events):
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
+
+        Args:
+            events (Iterable[str]): events to be fetched.
+
+        Returns:
+            Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+                May contain events that weren't requested.
         """
-        if not events:
-            defer.returnValue({})
 
         events_d = defer.Deferred()
         with self._event_fetch_lock:
@@ -479,32 +561,12 @@ class EventsWorkerStore(SQLBaseStore):
                 "fetch_events", self.runWithConnection, self._do_fetch
             )
 
-        logger.debug("Loading %d events", len(events))
+        logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
-            rows = yield events_d
-        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
-        if not allow_rejected:
-            rows[:] = [r for r in rows if r["rejected_reason"] is None]
-
-        res = yield make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self._get_event_from_row,
-                        row["internal_metadata"],
-                        row["json"],
-                        row["redactions"],
-                        rejected_reason=row["rejected_reason"],
-                        format_version=row["format_version"],
-                    )
-                    for row in rows
-                ],
-                consumeErrors=True,
-            )
-        )
+            row_map = yield events_d
+        logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
 
-        defer.returnValue({e.event.event_id: e for e in res if e})
+        return row_map
 
     def _fetch_event_rows(self, txn, event_ids):
         """Fetch event rows from the database
@@ -577,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_dict
 
-    @defer.inlineCallbacks
-    def _get_event_from_row(
-        self, internal_metadata, js, redactions, format_version, rejected_reason=None
-    ):
-        """Parse an event row which has been read from the database
-
-        Args:
-            internal_metadata (str): json-encoded internal_metadata column
-            js (str): json-encoded event body from event_json
-            redactions (list[str]): a list of the events which claim to have redacted
-                this event, from the redactions table
-            format_version: (str): the 'format_version' column
-            rejected_reason (str|None): the reason this event was rejected, if any
-
-        Returns:
-            _EventCacheEntry
-        """
-        with Measure(self._clock, "_get_event_from_row"):
-            d = json.loads(js)
-            internal_metadata = json.loads(internal_metadata)
-
-            if format_version is None:
-                # This means that we stored the event before we had the concept
-                # of a event format version, so it must be a V1 event.
-                format_version = EventFormatVersions.V1
-
-            original_ev = event_type_from_format_version(format_version)(
-                event_dict=d,
-                internal_metadata_dict=internal_metadata,
-                rejected_reason=rejected_reason,
-            )
-
-            redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)
-
-            cache_entry = _EventCacheEntry(
-                event=original_ev, redacted_event=redacted_event
-            )
-
-            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
-        defer.returnValue(cache_entry)
-
-    @defer.inlineCallbacks
-    def _maybe_redact_event_row(self, original_ev, redactions):
+    def _maybe_redact_event_row(self, original_ev, redactions, event_map):
         """Given an event object and a list of possible redacting event ids,
         determine whether to honour any of those redactions and if so return a redacted
         event.
@@ -628,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore):
         Args:
              original_ev (EventBase):
              redactions (iterable[str]): list of event ids of potential redaction events
+             event_map (dict[str, EventBase]): other events which have been fetched, in
+                 which we can look up the redaaction events. Map from event id to event.
 
         Returns:
             Deferred[EventBase|None]: if the event should be redacted, a pruned
@@ -637,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore):
             # we choose to ignore redactions of m.room.create events.
             return None
 
-        if original_ev.type == "m.room.redaction":
-            # ... and redaction events
-            return None
-
-        redaction_map = yield self._get_events_from_cache_or_db(redactions)
-
         for redaction_id in redactions:
-            redaction_entry = redaction_map.get(redaction_id)
-            if not redaction_entry:
+            redaction_event = event_map.get(redaction_id)
+            if not redaction_event or redaction_event.rejected_reason:
                 # we don't have the redaction event, or the redaction event was not
                 # authorized.
                 logger.debug(
@@ -655,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore):
                 )
                 continue
 
-            redaction_event = redaction_entry.event
             if redaction_event.room_id != original_ev.room_id:
                 logger.debug(
                     "%s was redacted by %s but redaction was in a different room!",
@@ -710,7 +724,7 @@ class EventsWorkerStore(SQLBaseStore):
             desc="have_events_in_timeline",
         )
 
-        defer.returnValue(set(r["event_id"] for r in rows))
+        return set(r["event_id"] for r in rows)
 
     @defer.inlineCallbacks
     def have_seen_events(self, event_ids):
@@ -736,7 +750,7 @@ class EventsWorkerStore(SQLBaseStore):
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
             yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
-        defer.returnValue(results)
+        return results
 
     def get_seen_events_with_rejections(self, event_ids):
         """Given a list of event ids, check if we rejected them.
@@ -847,4 +861,4 @@ class EventsWorkerStore(SQLBaseStore):
         # it.
         complexity_v1 = round(state_events / 500, 2)
 
-        defer.returnValue({"v1": complexity_v1})
+        return {"v1": complexity_v1}
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index b195dc66a0..23b48f6cea 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -15,8 +15,6 @@
 
 from canonicaljson import encode_canonical_json
 
-from twisted.internet import defer
-
 from synapse.api.errors import Codes, SynapseError
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
@@ -41,7 +39,7 @@ class FilteringStore(SQLBaseStore):
             desc="get_user_filter",
         )
 
-        defer.returnValue(db_to_json(def_json))
+        return db_to_json(def_json)
 
     def add_user_filter(self, user_localpart, user_filter):
         def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index 73e6fc6de2..15b01c6958 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -307,15 +307,13 @@ class GroupServerStore(SQLBaseStore):
             desc="get_group_categories",
         )
 
-        defer.returnValue(
-            {
-                row["category_id"]: {
-                    "is_public": row["is_public"],
-                    "profile": json.loads(row["profile"]),
-                }
-                for row in rows
+        return {
+            row["category_id"]: {
+                "is_public": row["is_public"],
+                "profile": json.loads(row["profile"]),
             }
-        )
+            for row in rows
+        }
 
     @defer.inlineCallbacks
     def get_group_category(self, group_id, category_id):
@@ -328,7 +326,7 @@ class GroupServerStore(SQLBaseStore):
 
         category["profile"] = json.loads(category["profile"])
 
-        defer.returnValue(category)
+        return category
 
     def upsert_group_category(self, group_id, category_id, profile, is_public):
         """Add/update room category for group
@@ -370,15 +368,13 @@ class GroupServerStore(SQLBaseStore):
             desc="get_group_roles",
         )
 
-        defer.returnValue(
-            {
-                row["role_id"]: {
-                    "is_public": row["is_public"],
-                    "profile": json.loads(row["profile"]),
-                }
-                for row in rows
+        return {
+            row["role_id"]: {
+                "is_public": row["is_public"],
+                "profile": json.loads(row["profile"]),
             }
-        )
+            for row in rows
+        }
 
     @defer.inlineCallbacks
     def get_group_role(self, group_id, role_id):
@@ -391,7 +387,7 @@ class GroupServerStore(SQLBaseStore):
 
         role["profile"] = json.loads(role["profile"])
 
-        defer.returnValue(role)
+        return role
 
     def upsert_group_role(self, group_id, role_id, profile, is_public):
         """Add/remove user role
@@ -960,7 +956,7 @@ class GroupServerStore(SQLBaseStore):
                 _register_user_group_membership_txn,
                 next_id,
             )
-        defer.returnValue(res)
+        return res
 
     @defer.inlineCallbacks
     def create_group(
@@ -1057,9 +1053,9 @@ class GroupServerStore(SQLBaseStore):
 
         now = int(self._clock.time_msec())
         if row and now < row["valid_until_ms"]:
-            defer.returnValue(json.loads(row["attestation_json"]))
+            return json.loads(row["attestation_json"])
 
-        defer.returnValue(None)
+        return None
 
     def get_joined_groups(self, user_id):
         return self._simple_select_onecol(
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 081564360f..752e9788a2 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -173,7 +173,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             )
             if user_id:
                 count = count + 1
-        defer.returnValue(count)
+        return count
 
     @defer.inlineCallbacks
     def upsert_monthly_active_user(self, user_id):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 7c4e1dc7ec..d20eacda59 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 55
+SCHEMA_VERSION = 56
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 42ec8c6bb8..1a0f2d5768 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -90,9 +90,7 @@ class PresenceStore(SQLBaseStore):
                 presence_states,
             )
 
-        defer.returnValue(
-            (stream_orderings[-1], self._presence_id_gen.get_current_token())
-        )
+        return (stream_orderings[-1], self._presence_id_gen.get_current_token())
 
     def _update_presence_txn(self, txn, stream_orderings, presence_states):
         for stream_id, state in zip(stream_orderings, presence_states):
@@ -180,7 +178,7 @@ class PresenceStore(SQLBaseStore):
         for row in rows:
             row["currently_active"] = bool(row["currently_active"])
 
-        defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows})
+        return {row["user_id"]: UserPresenceState(**row) for row in rows}
 
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 0ff392bdb4..8a5d8e9b18 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -34,15 +34,13 @@ class ProfileWorkerStore(SQLBaseStore):
         except StoreError as e:
             if e.code == 404:
                 # no match
-                defer.returnValue(ProfileInfo(None, None))
+                return ProfileInfo(None, None)
                 return
             else:
                 raise
 
-        defer.returnValue(
-            ProfileInfo(
-                avatar_url=profile["avatar_url"], display_name=profile["displayname"]
-            )
+        return ProfileInfo(
+            avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
     def get_profile_displayname(self, user_localpart):
@@ -168,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
         )
 
         if res:
-            defer.returnValue(True)
+            return True
 
         res = yield self._simple_select_one_onecol(
             table="group_invites",
@@ -179,4 +177,4 @@ class ProfileStore(ProfileWorkerStore):
         )
 
         if res:
-            defer.returnValue(True)
+            return True
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 98cec8c82b..a6517c4cf3 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -120,7 +120,7 @@ class PushRulesWorkerStore(
 
         rules = _load_rules(rows, enabled_map)
 
-        defer.returnValue(rules)
+        return rules
 
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_enabled_for_user(self, user_id):
@@ -130,9 +130,7 @@ class PushRulesWorkerStore(
             retcols=("user_name", "rule_id", "enabled"),
             desc="get_push_rules_enabled_for_user",
         )
-        defer.returnValue(
-            {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
-        )
+        return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
 
     def have_push_rules_changed_for_user(self, user_id, last_id):
         if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
@@ -160,7 +158,7 @@ class PushRulesWorkerStore(
     )
     def bulk_get_push_rules(self, user_ids):
         if not user_ids:
-            defer.returnValue({})
+            return {}
 
         results = {user_id: [] for user_id in user_ids}
 
@@ -182,7 +180,7 @@ class PushRulesWorkerStore(
         for user_id, rules in results.items():
             results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
 
-        defer.returnValue(results)
+        return results
 
     @defer.inlineCallbacks
     def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
@@ -253,7 +251,7 @@ class PushRulesWorkerStore(
         result = yield self._bulk_get_push_rules_for_room(
             event.room_id, state_group, current_state_ids, event=event
         )
-        defer.returnValue(result)
+        return result
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _bulk_get_push_rules_for_room(
@@ -312,7 +310,7 @@ class PushRulesWorkerStore(
 
         rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
 
-        defer.returnValue(rules_by_user)
+        return rules_by_user
 
     @cachedList(
         cached_method_name="get_push_rules_enabled_for_user",
@@ -322,7 +320,7 @@ class PushRulesWorkerStore(
     )
     def bulk_get_push_rules_enabled(self, user_ids):
         if not user_ids:
-            defer.returnValue({})
+            return {}
 
         results = {user_id: {} for user_id in user_ids}
 
@@ -336,7 +334,7 @@ class PushRulesWorkerStore(
         for row in rows:
             enabled = bool(row["enabled"])
             results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
-        defer.returnValue(results)
+        return results
 
 
 class PushRuleStore(PushRulesWorkerStore):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index cfe0a94330..b431d24b8a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -63,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
         ret = yield self._simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
-        defer.returnValue(ret is not None)
+        return ret is not None
 
     def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
         return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
@@ -95,7 +95,7 @@ class PusherWorkerStore(SQLBaseStore):
             ],
             desc="get_pushers_by",
         )
-        defer.returnValue(self._decode_pushers_rows(ret))
+        return self._decode_pushers_rows(ret)
 
     @defer.inlineCallbacks
     def get_all_pushers(self):
@@ -106,7 +106,7 @@ class PusherWorkerStore(SQLBaseStore):
             return self._decode_pushers_rows(rows)
 
         rows = yield self.runInteraction("get_all_pushers", get_pushers)
-        defer.returnValue(rows)
+        return rows
 
     def get_all_updated_pushers(self, last_id, current_id, limit):
         if last_id == current_id:
@@ -205,7 +205,7 @@ class PusherWorkerStore(SQLBaseStore):
         result = {user_id: False for user_id in user_ids}
         result.update({r["user_name"]: True for r in rows})
 
-        defer.returnValue(result)
+        return result
 
 
 class PusherStore(PusherWorkerStore):
@@ -308,22 +308,36 @@ class PusherStore(PusherWorkerStore):
     def update_pusher_last_stream_ordering_and_success(
         self, app_id, pushkey, user_id, last_stream_ordering, last_success
     ):
-        yield self._simple_update_one(
-            "pushers",
-            {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
-            {
+        """Update the last stream ordering position we've processed up to for
+        the given pusher.
+
+        Args:
+            app_id (str)
+            pushkey (str)
+            last_stream_ordering (int)
+            last_success (int)
+
+        Returns:
+            Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+        """
+        updated = yield self._simple_update(
+            table="pushers",
+            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+            updatevalues={
                 "last_stream_ordering": last_stream_ordering,
                 "last_success": last_success,
             },
             desc="update_pusher_last_stream_ordering_and_success",
         )
 
+        return bool(updated)
+
     @defer.inlineCallbacks
     def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
-        yield self._simple_update_one(
-            "pushers",
-            {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
-            {"failing_since": failing_since},
+        yield self._simple_update(
+            table="pushers",
+            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+            updatevalues={"failing_since": failing_since},
             desc="update_pusher_failing_since",
         )
 
@@ -343,7 +357,7 @@ class PusherStore(PusherWorkerStore):
                 "throttle_ms": row["throttle_ms"],
             }
 
-        defer.returnValue(params_by_room)
+        return params_by_room
 
     @defer.inlineCallbacks
     def set_throttle_params(self, pusher_id, room_id, params):
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index b477da12b1..6aa6d98ebb 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cachedInlineCallbacks()
     def get_users_with_read_receipts_in_room(self, room_id):
         receipts = yield self.get_receipts_for_room(room_id, "m.read")
-        defer.returnValue(set(r["user_id"] for r in receipts))
+        return set(r["user_id"] for r in receipts)
 
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
@@ -92,7 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             desc="get_receipts_for_user",
         )
 
-        defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
+        return {row["room_id"]: row["event_id"] for row in rows}
 
     @defer.inlineCallbacks
     def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
@@ -110,16 +110,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
             return txn.fetchall()
 
         rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
-        defer.returnValue(
-            {
-                row[0]: {
-                    "event_id": row[1],
-                    "topological_ordering": row[2],
-                    "stream_ordering": row[3],
-                }
-                for row in rows
+        return {
+            row[0]: {
+                "event_id": row[1],
+                "topological_ordering": row[2],
+                "stream_ordering": row[3],
             }
-        )
+            for row in rows
+        }
 
     @defer.inlineCallbacks
     def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -147,7 +145,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             room_ids, to_key, from_key=from_key
         )
 
-        defer.returnValue([ev for res in results.values() for ev in res])
+        return [ev for res in results.values() for ev in res]
 
     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """Get receipts for a single room for sending to clients.
@@ -197,7 +195,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
 
         if not rows:
-            defer.returnValue([])
+            return []
 
         content = {}
         for row in rows:
@@ -205,9 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 row["user_id"]
             ] = json.loads(row["data"])
 
-        defer.returnValue(
-            [{"type": "m.receipt", "room_id": room_id, "content": content}]
-        )
+        return [{"type": "m.receipt", "room_id": room_id, "content": content}]
 
     @cachedList(
         cached_method_name="_get_linearized_receipts_for_room",
@@ -217,7 +213,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         if not room_ids:
-            defer.returnValue({})
+            return {}
 
         def f(txn):
             if from_key:
@@ -264,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             room_id: [results[room_id]] if room_id in results else []
             for room_id in room_ids
         }
-        defer.returnValue(results)
+        return results
 
     def get_all_updated_receipts(self, last_id, current_id, limit=None):
         if last_id == current_id:
@@ -468,7 +464,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             )
 
         if event_ts is None:
-            defer.returnValue(None)
+            return None
 
         now = self._clock.time_msec()
         logger.debug(
@@ -482,7 +478,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
         max_persisted_id = self._receipts_id_gen.get_current_token()
 
-        defer.returnValue((stream_id, max_persisted_id))
+        return (stream_id, max_persisted_id)
 
     def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
         return self.runInteraction(
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 8b2c2a97ab..55e4e84d71 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -75,12 +75,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         info = yield self.get_user_by_id(user_id)
         if not info:
-            defer.returnValue(False)
+            return False
 
         now = self.clock.time_msec()
         trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
         is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
-        defer.returnValue(is_trial)
+        return is_trial
 
     @cached()
     def get_user_by_access_token(self, token):
@@ -115,7 +115,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             allow_none=True,
             desc="get_expiration_ts_for_user",
         )
-        defer.returnValue(res)
+        return res
 
     @defer.inlineCallbacks
     def set_account_validity_for_user(
@@ -190,7 +190,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_user_from_renewal_token",
         )
 
-        defer.returnValue(res)
+        return res
 
     @defer.inlineCallbacks
     def get_renewal_token_for_user(self, user_id):
@@ -209,7 +209,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_renewal_token_for_user",
         )
 
-        defer.returnValue(res)
+        return res
 
     @defer.inlineCallbacks
     def get_users_expiring_soon(self):
@@ -237,7 +237,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             self.config.account_validity.renew_at,
         )
 
-        defer.returnValue(res)
+        return res
 
     @defer.inlineCallbacks
     def set_renewal_mail_status(self, user_id, email_sent):
@@ -280,7 +280,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="is_server_admin",
         )
 
-        defer.returnValue(res if res else False)
+        return res if res else False
 
     def _query_for_auth(self, txn, token):
         sql = (
@@ -311,7 +311,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         res = yield self.runInteraction(
             "is_support_user", self.is_support_user_txn, user_id
         )
-        defer.returnValue(res)
+        return res
 
     def is_support_user_txn(self, txn, user_id):
         res = self._simple_select_one_onecol_txn(
@@ -349,7 +349,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             return 0
 
         ret = yield self.runInteraction("count_users", _count_users)
-        defer.returnValue(ret)
+        return ret
 
     def count_daily_user_type(self):
         """
@@ -395,7 +395,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             return count
 
         ret = yield self.runInteraction("count_users", _count_users)
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def find_next_generated_user_id_localpart(self):
@@ -425,7 +425,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 if i not in found:
                     return i
 
-        defer.returnValue(
+        return (
             (
                 yield self.runInteraction(
                     "find_next_generated_user_id", _find_next_generated_user_id
@@ -447,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         user_id = yield self.runInteraction(
             "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
         )
-        defer.returnValue(user_id)
+        return user_id
 
     def get_user_id_by_threepid_txn(self, txn, medium, address):
         """Returns user id from threepid
@@ -487,7 +487,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             ["medium", "address", "validated_at", "added_at"],
             "user_get_threepids",
         )
-        defer.returnValue(ret)
+        return ret
 
     def user_delete_threepid(self, user_id, medium, address):
         return self._simple_delete(
@@ -569,6 +569,27 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_id_servers_user_bound",
         )
 
+    @cachedInlineCallbacks()
+    def get_user_deactivated_status(self, user_id):
+        """Retrieve the value for the `deactivated` property for the provided user.
+
+        Args:
+            user_id (str): The ID of the user to retrieve the status for.
+
+        Returns:
+            defer.Deferred(bool): The requested value.
+        """
+
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="deactivated",
+            desc="get_user_deactivated_status",
+        )
+
+        # Convert the integer into a boolean.
+        return res == 1
+
 
 class RegistrationStore(
     RegistrationWorkerStore, background_updates.BackgroundUpdateStore
@@ -677,7 +698,7 @@ class RegistrationStore(
         if end:
             yield self._end_background_update("users_set_deactivated_flag")
 
-        defer.returnValue(batch_size)
+        return batch_size
 
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -957,7 +978,7 @@ class RegistrationStore(
             desc="is_guest",
         )
 
-        defer.returnValue(res if res else False)
+        return res if res else False
 
     def add_user_pending_deactivation(self, user_id):
         """
@@ -1024,7 +1045,7 @@ class RegistrationStore(
 
         yield self._end_background_update("user_threepids_grandfather")
 
-        defer.returnValue(1)
+        return 1
 
     def get_threepid_validation_session(
         self, medium, client_secret, address=None, sid=None, validated=True
@@ -1317,24 +1338,3 @@ class RegistrationStore(
             user_id,
             deactivated,
         )
-
-    @cachedInlineCallbacks()
-    def get_user_deactivated_status(self, user_id):
-        """Retrieve the value for the `deactivated` property for the provided user.
-
-        Args:
-            user_id (str): The ID of the user to retrieve the status for.
-
-        Returns:
-            defer.Deferred(bool): The requested value.
-        """
-
-        res = yield self._simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user_id},
-            retcol="deactivated",
-            desc="get_user_deactivated_status",
-        )
-
-        # Convert the integer into a boolean.
-        defer.returnValue(res == 1)
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 9954bc094f..fcb5f2f23a 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,8 +17,6 @@ import logging
 
 import attr
 
-from twisted.internet import defer
-
 from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
@@ -363,7 +361,7 @@ class RelationsWorkerStore(SQLBaseStore):
             return
 
         edit_event = yield self.get_event(edit_id, allow_none=True)
-        defer.returnValue(edit_event)
+        return edit_event
 
     def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
         """Check if a user has already annotated an event with the same key
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index fe9d79d792..bc606292b8 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -193,14 +193,12 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
         if row:
-            defer.returnValue(
-                RatelimitOverride(
-                    messages_per_second=row["messages_per_second"],
-                    burst_count=row["burst_count"],
-                )
+            return RatelimitOverride(
+                messages_per_second=row["messages_per_second"],
+                burst_count=row["burst_count"],
             )
         else:
-            defer.returnValue(None)
+            return None
 
 
 class RoomStore(RoomWorkerStore, SearchStore):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 32cfd010a5..eecb276465 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -24,6 +24,8 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import LoggingTransaction
 from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import get_domain_from_id
 from synapse.util.async_helpers import Linearizer
@@ -53,9 +55,51 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
 MemberSummary = namedtuple("MemberSummary", ("members", "count"))
 
 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
+    def __init__(self, db_conn, hs):
+        super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+
+        # Is the current_state_events.membership up to date? Or is the
+        # background update still running?
+        self._current_state_events_membership_up_to_date = False
+
+        txn = LoggingTransaction(
+            db_conn.cursor(),
+            name="_check_safe_current_state_events_membership_updated",
+            database_engine=self.database_engine,
+        )
+        self._check_safe_current_state_events_membership_updated_txn(txn)
+        txn.close()
+
+    def _check_safe_current_state_events_membership_updated_txn(self, txn):
+        """Checks if it is safe to assume the new current_state_events
+        membership column is up to date
+        """
+
+        pending_update = self._simple_select_one_txn(
+            txn,
+            table="background_updates",
+            keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
+            retcols=["update_name"],
+            allow_none=True,
+        )
+
+        self._current_state_events_membership_up_to_date = not pending_update
+
+        # If the update is still running, reschedule to run.
+        if pending_update:
+            self._clock.call_later(
+                15.0,
+                run_as_background_process,
+                "_check_safe_current_state_events_membership_updated",
+                self.runInteraction,
+                "_check_safe_current_state_events_membership_updated",
+                self._check_safe_current_state_events_membership_updated_txn,
+            )
+
     @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
     def get_hosts_in_room(self, room_id, cache_context):
         """Returns the set of all hosts currently in the room
@@ -64,19 +108,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             room_id, on_invalidate=cache_context.invalidate
         )
         hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
-        defer.returnValue(hosts)
+        return hosts
 
     @cached(max_entries=100000, iterable=True)
     def get_users_in_room(self, room_id):
         def f(txn):
-            sql = (
-                "SELECT m.user_id FROM room_memberships as m"
-                " INNER JOIN current_state_events as c"
-                " ON m.event_id = c.event_id "
-                " AND m.room_id = c.room_id "
-                " AND m.user_id = c.state_key"
-                " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
-            )
+            # If we can assume current_state_events.membership is up to date
+            # then we can avoid a join, which is a Very Good Thing given how
+            # frequently this function gets called.
+            if self._current_state_events_membership_up_to_date:
+                sql = """
+                    SELECT state_key FROM current_state_events
+                    WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+                """
+            else:
+                sql = """
+                    SELECT state_key FROM room_memberships as m
+                    INNER JOIN current_state_events as c
+                    ON m.event_id = c.event_id
+                    AND m.room_id = c.room_id
+                    AND m.user_id = c.state_key
+                    WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+                """
 
             txn.execute(sql, (room_id, Membership.JOIN))
             return [to_ascii(r[0]) for r in txn]
@@ -98,15 +151,29 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # first get counts.
             # We do this all in one transaction to keep the cache small.
             # FIXME: get rid of this when we have room_stats
-            sql = """
-                SELECT count(*), m.membership FROM room_memberships as m
-                 INNER JOIN current_state_events as c
-                 ON m.event_id = c.event_id
-                 AND m.room_id = c.room_id
-                 AND m.user_id = c.state_key
-                 WHERE c.type = 'm.room.member' AND c.room_id = ?
-                 GROUP BY m.membership
-            """
+
+            # If we can assume current_state_events.membership is up to date
+            # then we can avoid a join, which is a Very Good Thing given how
+            # frequently this function gets called.
+            if self._current_state_events_membership_up_to_date:
+                # Note, rejected events will have a null membership field, so
+                # we we manually filter them out.
+                sql = """
+                    SELECT count(*), membership FROM current_state_events
+                    WHERE type = 'm.room.member' AND room_id = ?
+                        AND membership IS NOT NULL
+                    GROUP BY membership
+                """
+            else:
+                sql = """
+                    SELECT count(*), m.membership FROM room_memberships as m
+                    INNER JOIN current_state_events as c
+                    ON m.event_id = c.event_id
+                    AND m.room_id = c.room_id
+                    AND m.user_id = c.state_key
+                    WHERE c.type = 'm.room.member' AND c.room_id = ?
+                    GROUP BY m.membership
+                """
 
             txn.execute(sql, (room_id,))
             res = {}
@@ -115,19 +182,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
             # we order by membership and then fairly arbitrarily by event_id so
             # heroes are consistent
-            sql = """
-                SELECT m.user_id, m.membership, m.event_id
-                FROM room_memberships as m
-                 INNER JOIN current_state_events as c
-                 ON m.event_id = c.event_id
-                 AND m.room_id = c.room_id
-                 AND m.user_id = c.state_key
-                 WHERE c.type = 'm.room.member' AND c.room_id = ?
-                 ORDER BY
-                    CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
-                    m.event_id ASC
-                 LIMIT ?
-            """
+            if self._current_state_events_membership_up_to_date:
+                # Note, rejected events will have a null membership field, so
+                # we we manually filter them out.
+                sql = """
+                    SELECT state_key, membership, event_id
+                    FROM current_state_events
+                    WHERE type = 'm.room.member' AND room_id = ?
+                        AND membership IS NOT NULL
+                    ORDER BY
+                        CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+                        event_id ASC
+                    LIMIT ?
+                """
+            else:
+                sql = """
+                    SELECT c.state_key, m.membership, c.event_id
+                    FROM room_memberships as m
+                    INNER JOIN current_state_events as c USING (room_id, event_id)
+                    WHERE c.type = 'm.room.member' AND c.room_id = ?
+                    ORDER BY
+                        CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+                        c.event_id ASC
+                    LIMIT ?
+                """
 
             # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
             txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
@@ -189,31 +267,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         invites = yield self.get_invited_rooms_for_user(user_id)
         for invite in invites:
             if invite.room_id == room_id:
-                defer.returnValue(invite)
-        defer.returnValue(None)
+                return invite
+        return None
 
+    @defer.inlineCallbacks
     def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
         """ Get all the rooms for this user where the membership for this user
         matches one in the membership list.
 
+        Filters out forgotten rooms.
+
         Args:
             user_id (str): The user ID.
             membership_list (list): A list of synapse.api.constants.Membership
             values which the user must be in.
+
         Returns:
-            A list of dictionary objects, with room_id, membership and sender
-            defined.
+            Deferred[list[RoomsForUser]]
         """
         if not membership_list:
             return defer.succeed(None)
 
-        return self.runInteraction(
+        rooms = yield self.runInteraction(
             "get_rooms_for_user_where_membership_is",
             self._get_rooms_for_user_where_membership_is_txn,
             user_id,
             membership_list,
         )
 
+        # Now we filter out forgotten rooms
+        forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+        return [room for room in rooms if room.room_id not in forgotten_rooms]
+
     def _get_rooms_for_user_where_membership_is_txn(
         self, txn, user_id, membership_list
     ):
@@ -223,26 +308,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         results = []
         if membership_list:
-            where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
-                " OR ".join(["membership = ?" for _ in membership_list]),
-            )
-
-            args = [user_id]
-            args.extend(membership_list)
+            if self._current_state_events_membership_up_to_date:
+                sql = """
+                    SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+                    FROM current_state_events AS c
+                    INNER JOIN events AS e USING (room_id, event_id)
+                    WHERE
+                        c.type = 'm.room.member'
+                        AND state_key = ?
+                        AND c.membership IN (%s)
+                """ % (
+                    ",".join("?" * len(membership_list))
+                )
+            else:
+                sql = """
+                    SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
+                    FROM current_state_events AS c
+                    INNER JOIN room_memberships AS m USING (room_id, event_id)
+                    INNER JOIN events AS e USING (room_id, event_id)
+                    WHERE
+                        c.type = 'm.room.member'
+                        AND state_key = ?
+                        AND m.membership IN (%s)
+                """ % (
+                    ",".join("?" * len(membership_list))
+                )
 
-            sql = (
-                "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
-                " FROM current_state_events as c"
-                " INNER JOIN room_memberships as m"
-                " ON m.event_id = c.event_id"
-                " INNER JOIN events as e"
-                " ON e.event_id = c.event_id"
-                " AND m.room_id = c.room_id"
-                " AND m.user_id = c.state_key"
-                " WHERE c.type = 'm.room.member' AND %s"
-            ) % (where_clause,)
-
-            txn.execute(sql, args)
+            txn.execute(sql, (user_id, *membership_list))
             results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
 
         if do_invite:
@@ -283,11 +375,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         rooms = yield self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN]
         )
-        defer.returnValue(
-            frozenset(
-                GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
-                for r in rooms
-            )
+        return frozenset(
+            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+            for r in rooms
         )
 
     @defer.inlineCallbacks
@@ -297,7 +387,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         rooms = yield self.get_rooms_for_user_with_stream_ordering(
             user_id, on_invalidate=on_invalidate
         )
-        defer.returnValue(frozenset(r.room_id for r in rooms))
+        return frozenset(r.room_id for r in rooms)
 
     @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
     def get_users_who_share_room_with_user(self, user_id, cache_context):
@@ -314,7 +404,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
             user_who_share_room.update(user_ids)
 
-        defer.returnValue(user_who_share_room)
+        return user_who_share_room
 
     @defer.inlineCallbacks
     def get_joined_users_from_context(self, event, context):
@@ -330,7 +420,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         result = yield self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
-        defer.returnValue(result)
+        return result
 
     def get_joined_users_from_state(self, room_id, state_entry):
         state_group = state_entry.state_group
@@ -444,7 +534,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                         avatar_url=to_ascii(event.content.get("avatar_url", None)),
                     )
 
-        defer.returnValue(users_in_room)
+        return users_in_room
 
     @cachedInlineCallbacks(max_entries=10000)
     def is_host_joined(self, room_id, host):
@@ -453,8 +543,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         sql = """
             SELECT state_key FROM current_state_events AS c
-            INNER JOIN room_memberships USING (event_id)
-            WHERE membership = 'join'
+            INNER JOIN room_memberships AS m USING (event_id)
+            WHERE m.membership = 'join'
                 AND type = 'm.room.member'
                 AND c.room_id = ?
                 AND state_key LIKE ?
@@ -469,14 +559,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
 
         if not rows:
-            defer.returnValue(False)
+            return False
 
         user_id = rows[0][0]
         if get_domain_from_id(user_id) != host:
             # This can only happen if the host name has something funky in it
             raise Exception("Invalid host name")
 
-        defer.returnValue(True)
+        return True
 
     @cachedInlineCallbacks()
     def was_host_joined(self, room_id, host):
@@ -509,14 +599,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
 
         if not rows:
-            defer.returnValue(False)
+            return False
 
         user_id = rows[0][0]
         if get_domain_from_id(user_id) != host:
             # This can only happen if the host name has something funky in it
             raise Exception("Invalid host name")
 
-        defer.returnValue(True)
+        return True
 
     def get_joined_hosts(self, room_id, state_entry):
         state_group = state_entry.state_group
@@ -543,7 +633,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         cache = self._get_joined_hosts_cache(room_id)
         joined_hosts = yield cache.get_destinations(state_entry)
 
-        defer.returnValue(joined_hosts)
+        return joined_hosts
 
     @cached(max_entries=10000)
     def _get_joined_hosts_cache(self, room_id):
@@ -573,7 +663,45 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             return rows[0][0]
 
         count = yield self.runInteraction("did_forget_membership", f)
-        defer.returnValue(count == 0)
+        return count == 0
+
+    @cached()
+    def get_forgotten_rooms_for_user(self, user_id):
+        """Gets all rooms the user has forgotten.
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[set[str]]
+        """
+
+        def _get_forgotten_rooms_for_user_txn(txn):
+            # This is a slightly convoluted query that first looks up all rooms
+            # that the user has forgotten in the past, then rechecks that list
+            # to see if any have subsequently been updated. This is done so that
+            # we can use a partial index on `forgotten = 1` on the assumption
+            # that few users will actually forget many rooms.
+            #
+            # Note that a room is considered "forgotten" if *all* membership
+            # events for that user and room have the forgotten field set (as
+            # when a user forgets a room we update all rows for that user and
+            # room, not just the current one).
+            sql = """
+                SELECT room_id, (
+                    SELECT count(*) FROM room_memberships
+                    WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
+                ) AS count
+                FROM room_memberships AS m
+                WHERE user_id = ? AND forgotten = 1
+                GROUP BY room_id, user_id;
+            """
+            txn.execute(sql, (user_id,))
+            return set(row[0] for row in txn if row[1] == 0)
+
+        return self.runInteraction(
+            "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
+        )
 
     @defer.inlineCallbacks
     def get_rooms_user_has_been_in(self, user_id):
@@ -602,6 +730,17 @@ class RoomMemberStore(RoomMemberWorkerStore):
         self.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
         )
+        self.register_background_update_handler(
+            _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+            self._background_current_state_membership,
+        )
+        self.register_background_index_update(
+            "room_membership_forgotten_idx",
+            index_name="room_memberships_user_room_forgotten",
+            table="room_memberships",
+            columns=["user_id", "room_id"],
+            where_clause="forgotten = 1",
+        )
 
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
@@ -703,6 +842,9 @@ class RoomMemberStore(RoomMemberWorkerStore):
             txn.execute(sql, (user_id, room_id))
 
             self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+            self._invalidate_cache_and_stream(
+                txn, self.get_forgotten_rooms_for_user, (user_id,)
+            )
 
         return self.runInteraction("forget_membership", f)
 
@@ -779,7 +921,65 @@ class RoomMemberStore(RoomMemberWorkerStore):
         if not result:
             yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
 
-        defer.returnValue(result)
+        return result
+
+    @defer.inlineCallbacks
+    def _background_current_state_membership(self, progress, batch_size):
+        """Update the new membership column on current_state_events.
+
+        This works by iterating over all rooms in alphebetical order.
+        """
+
+        def _background_current_state_membership_txn(txn, last_processed_room):
+            processed = 0
+            while processed < batch_size:
+                txn.execute(
+                    """
+                        SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
+                    """,
+                    (last_processed_room,),
+                )
+                row = txn.fetchone()
+                if not row or not row[0]:
+                    return processed, True
+
+                next_room, = row
+
+                sql = """
+                    UPDATE current_state_events
+                    SET membership = (
+                        SELECT membership FROM room_memberships
+                        WHERE event_id = current_state_events.event_id
+                    )
+                    WHERE room_id = ?
+                """
+                txn.execute(sql, (next_room,))
+                processed += txn.rowcount
+
+                last_processed_room = next_room
+
+            self._background_update_progress_txn(
+                txn,
+                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+                {"last_processed_room": last_processed_room},
+            )
+
+            return processed, False
+
+        # If we haven't got a last processed room then just use the empty
+        # string, which will compare before all room IDs correctly.
+        last_processed_room = progress.get("last_processed_room", "")
+
+        row_count, finished = yield self.runInteraction(
+            "_background_current_state_membership_update",
+            _background_current_state_membership_txn,
+            last_processed_room,
+        )
+
+        if finished:
+            yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
+
+        return row_count
 
 
 class _JoinedHostsCache(object):
@@ -807,7 +1007,7 @@ class _JoinedHostsCache(object):
             state_entry(synapse.state._StateCacheEntry)
         """
         if state_entry.state_group == self.state_group:
-            defer.returnValue(frozenset(self.hosts_to_joined_users))
+            return frozenset(self.hosts_to_joined_users)
 
         with (yield self.linearizer.queue(())):
             if state_entry.state_group == self.state_group:
@@ -844,7 +1044,7 @@ class _JoinedHostsCache(object):
             else:
                 self.state_group = object()
             self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
-        defer.returnValue(frozenset(self.hosts_to_joined_users))
+        return frozenset(self.hosts_to_joined_users)
 
     def __len__(self):
         return self._len
diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/schema/delta/56/current_state_events_membership.sql
new file mode 100644
index 0000000000..473018676f
--- /dev/null
+++ b/synapse/storage/schema/delta/56/current_state_events_membership.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We add membership to current state so that we don't need to join against
+-- room_memberships, which can be surprisingly costly (we do such queries
+-- very frequently).
+-- This will be null for non-membership events and the content.membership key
+-- for membership events. (Will also be null for membership events until the
+-- background update job has finished).
+ALTER TABLE current_state_events ADD membership TEXT;
diff --git a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql
new file mode 100644
index 0000000000..3133d42d4a
--- /dev/null
+++ b/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We add membership to current state so that we don't need to join against
+-- room_memberships, which can be surprisingly costly (we do such queries
+-- very frequently).
+-- This will be null for non-membership events and the content.membership key
+-- for membership events. (Will also be null for membership events until the
+-- background update job has finished).
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('current_state_events_membership', '{}');
diff --git a/synapse/storage/schema/delta/56/room_membership_idx.sql b/synapse/storage/schema/delta/56/room_membership_idx.sql
new file mode 100644
index 0000000000..92ab1f5e65
--- /dev/null
+++ b/synapse/storage/schema/delta/56/room_membership_idx.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Adds an index on room_memberships for fetching all forgotten rooms for a user
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('room_membership_forgotten_idx', '{}');
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index f3b1cec933..df87ab6a6d 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -166,7 +166,7 @@ class SearchStore(BackgroundUpdateStore):
         if not result:
             yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
 
-        defer.returnValue(result)
+        return result
 
     @defer.inlineCallbacks
     def _background_reindex_gin_search(self, progress, batch_size):
@@ -209,7 +209,7 @@ class SearchStore(BackgroundUpdateStore):
             yield self.runWithConnection(create_index)
 
         yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
-        defer.returnValue(1)
+        return 1
 
     @defer.inlineCallbacks
     def _background_reindex_search_order(self, progress, batch_size):
@@ -287,7 +287,7 @@ class SearchStore(BackgroundUpdateStore):
         if not finished:
             yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
 
-        defer.returnValue(num_rows)
+        return num_rows
 
     def store_event_search_txn(self, txn, event, key, value):
         """Add event to the search table
@@ -454,17 +454,15 @@ class SearchStore(BackgroundUpdateStore):
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
 
-        defer.returnValue(
-            {
-                "results": [
-                    {"event": event_map[r["event_id"]], "rank": r["rank"]}
-                    for r in results
-                    if r["event_id"] in event_map
-                ],
-                "highlights": highlights,
-                "count": count,
-            }
-        )
+        return {
+            "results": [
+                {"event": event_map[r["event_id"]], "rank": r["rank"]}
+                for r in results
+                if r["event_id"] in event_map
+            ],
+            "highlights": highlights,
+            "count": count,
+        }
 
     @defer.inlineCallbacks
     def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
@@ -599,22 +597,20 @@ class SearchStore(BackgroundUpdateStore):
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
 
-        defer.returnValue(
-            {
-                "results": [
-                    {
-                        "event": event_map[r["event_id"]],
-                        "rank": r["rank"],
-                        "pagination_token": "%s,%s"
-                        % (r["origin_server_ts"], r["stream_ordering"]),
-                    }
-                    for r in results
-                    if r["event_id"] in event_map
-                ],
-                "highlights": highlights,
-                "count": count,
-            }
-        )
+        return {
+            "results": [
+                {
+                    "event": event_map[r["event_id"]],
+                    "rank": r["rank"],
+                    "pagination_token": "%s,%s"
+                    % (r["origin_server_ts"], r["stream_ordering"]),
+                }
+                for r in results
+                if r["event_id"] in event_map
+            ],
+            "highlights": highlights,
+            "count": count,
+        }
 
     def _find_highlights_in_postgres(self, search_query, events):
         """Given a list of events and a search term, return a list of words
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 6bd81e84ad..fb83218f90 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -59,7 +59,7 @@ class SignatureWorkerStore(SQLBaseStore):
             for e_id, h in hashes.items()
         }
 
-        defer.returnValue(list(hashes.items()))
+        return list(hashes.items())
 
     def _get_event_reference_hashes_txn(self, txn, event_id):
         """Get all the hashes for a given PDU.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0bfe1b4550..1980a87108 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -422,7 +422,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         # Retrieve the room's create event
         create_event = yield self.get_create_event_for_room(room_id)
-        defer.returnValue(create_event.content.get("room_version", "1"))
+        return create_event.content.get("room_version", "1")
 
     @defer.inlineCallbacks
     def get_room_predecessor(self, room_id):
@@ -442,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         create_event = yield self.get_create_event_for_room(room_id)
 
         # Return predecessor if present
-        defer.returnValue(create_event.content.get("predecessor", None))
+        return create_event.content.get("predecessor", None)
 
     @defer.inlineCallbacks
     def get_create_event_for_room(self, room_id):
@@ -466,7 +466,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         # Retrieve the room's create event and return
         create_event = yield self.get_event(create_id)
-        defer.returnValue(create_event)
+        return create_event
 
     @cached(max_entries=100000, iterable=True)
     def get_current_state_ids(self, room_id):
@@ -510,6 +510,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             event ID.
         """
 
+        where_clause, where_args = state_filter.make_sql_filter_clause()
+
+        if not where_clause:
+            # We delegate to the cached version
+            return self.get_current_state_ids(room_id)
+
         def _get_filtered_current_state_ids_txn(txn):
             results = {}
             sql = """
@@ -517,8 +523,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 WHERE room_id = ?
             """
 
-            where_clause, where_args = state_filter.make_sql_filter_clause()
-
             if where_clause:
                 sql += " AND (%s)" % (where_clause,)
 
@@ -559,7 +563,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         if not event:
             return
 
-        defer.returnValue(event.content.get("canonical_alias"))
+        return event.content.get("canonical_alias")
 
     @cached(max_entries=10000, iterable=True)
     def get_state_group_delta(self, state_group):
@@ -609,14 +613,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
-            defer.returnValue({})
+            return {}
 
         event_to_groups = yield self._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups)
 
-        defer.returnValue(group_to_state)
+        return group_to_state
 
     @defer.inlineCallbacks
     def get_state_ids_for_group(self, state_group):
@@ -630,7 +634,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         group_to_state = yield self._get_state_for_groups((state_group,))
 
-        defer.returnValue(group_to_state[state_group])
+        return group_to_state[state_group]
 
     @defer.inlineCallbacks
     def get_state_groups(self, room_id, event_ids):
@@ -641,7 +645,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 dict of state_group_id -> list of state events.
         """
         if not event_ids:
-            defer.returnValue({})
+            return {}
 
         group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
 
@@ -654,16 +658,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             get_prev_content=False,
         )
 
-        defer.returnValue(
-            {
-                group: [
-                    state_event_map[v]
-                    for v in itervalues(event_id_map)
-                    if v in state_event_map
-                ]
-                for group, event_id_map in iteritems(group_to_ids)
-            }
-        )
+        return {
+            group: [
+                state_event_map[v]
+                for v in itervalues(event_id_map)
+                if v in state_event_map
+            ]
+            for group, event_id_map in iteritems(group_to_ids)
+        }
 
     @defer.inlineCallbacks
     def _get_state_groups_from_groups(self, groups, state_filter):
@@ -690,7 +692,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             )
             results.update(res)
 
-        defer.returnValue(results)
+        return results
 
     def _get_state_groups_from_groups_txn(
         self, txn, groups, state_filter=StateFilter.all()
@@ -825,7 +827,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             for event_id, group in iteritems(event_to_groups)
         }
 
-        defer.returnValue({event: event_to_state[event] for event in event_ids})
+        return {event: event_to_state[event] for event in event_ids}
 
     @defer.inlineCallbacks
     def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -851,7 +853,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             for event_id, group in iteritems(event_to_groups)
         }
 
-        defer.returnValue({event: event_to_state[event] for event in event_ids})
+        return {event: event_to_state[event] for event in event_ids}
 
     @defer.inlineCallbacks
     def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -867,7 +869,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             A deferred dict from (type, state_key) -> state_event
         """
         state_map = yield self.get_state_for_events([event_id], state_filter)
-        defer.returnValue(state_map[event_id])
+        return state_map[event_id]
 
     @defer.inlineCallbacks
     def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -883,7 +885,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             A deferred dict from (type, state_key) -> state_event
         """
         state_map = yield self.get_state_ids_for_events([event_id], state_filter)
-        defer.returnValue(state_map[event_id])
+        return state_map[event_id]
 
     @cached(max_entries=50000)
     def _get_state_group_for_event(self, event_id):
@@ -913,7 +915,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="_get_state_group_for_events",
         )
 
-        defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
+        return {row["event_id"]: row["state_group"] for row in rows}
 
     def _get_state_for_group_using_cache(self, cache, group, state_filter):
         """Checks if group is in cache. See `_get_state_for_groups`
@@ -993,7 +995,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         incomplete_groups = incomplete_groups_m | incomplete_groups_nm
 
         if not incomplete_groups:
-            defer.returnValue(state)
+            return state
 
         cache_sequence_nm = self._state_group_cache.sequence
         cache_sequence_m = self._state_group_members_cache.sequence
@@ -1020,7 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             # everything we need from the database anyway.
             state[group] = state_filter.filter_state(group_state_dict)
 
-        defer.returnValue(state)
+        return state
 
     def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
         """Gets the state at each of a list of state groups, optionally
@@ -1494,7 +1496,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                 self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
             )
 
-        defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
+        return result * BATCH_SIZE_SCALE_FACTOR
 
     @defer.inlineCallbacks
     def _background_index_state(self, progress, batch_size):
@@ -1524,4 +1526,4 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
 
         yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
 
-        defer.returnValue(1)
+        return 1
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
index 1cec84ee2e..e13efed417 100644
--- a/synapse/storage/stats.py
+++ b/synapse/storage/stats.py
@@ -66,7 +66,7 @@ class StatsStore(StateDeltasStore):
 
         if not self.stats_enabled:
             yield self._end_background_update("populate_stats_createtables")
-            defer.returnValue(1)
+            return 1
 
         # Get all the rooms that we want to process.
         def _make_staging_area(txn):
@@ -120,7 +120,7 @@ class StatsStore(StateDeltasStore):
         self.get_earliest_token_for_room_stats.invalidate_all()
 
         yield self._end_background_update("populate_stats_createtables")
-        defer.returnValue(1)
+        return 1
 
     @defer.inlineCallbacks
     def _populate_stats_cleanup(self, progress, batch_size):
@@ -129,7 +129,7 @@ class StatsStore(StateDeltasStore):
         """
         if not self.stats_enabled:
             yield self._end_background_update("populate_stats_cleanup")
-            defer.returnValue(1)
+            return 1
 
         position = yield self._simple_select_one_onecol(
             TEMP_TABLE + "_position", None, "position"
@@ -143,14 +143,14 @@ class StatsStore(StateDeltasStore):
         yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
 
         yield self._end_background_update("populate_stats_cleanup")
-        defer.returnValue(1)
+        return 1
 
     @defer.inlineCallbacks
     def _populate_stats_process_rooms(self, progress, batch_size):
 
         if not self.stats_enabled:
             yield self._end_background_update("populate_stats_process_rooms")
-            defer.returnValue(1)
+            return 1
 
         # If we don't have progress filed, delete everything.
         if not progress:
@@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore):
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
             yield self._end_background_update("populate_stats_process_rooms")
-            defer.returnValue(1)
+            return 1
 
         logger.info(
             "Processing the next %d rooms of %d remaining",
@@ -211,16 +211,18 @@ class StatsStore(StateDeltasStore):
             avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
             canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
 
+            event_ids = [
+                join_rules_id,
+                history_visibility_id,
+                encryption_id,
+                name_id,
+                topic_id,
+                avatar_id,
+                canonical_alias_id,
+            ]
+
             state_events = yield self.get_events(
-                [
-                    join_rules_id,
-                    history_visibility_id,
-                    encryption_id,
-                    name_id,
-                    topic_id,
-                    avatar_id,
-                    canonical_alias_id,
-                ]
+                [ev for ev in event_ids if ev is not None]
             )
 
             def _get_or_none(event_id, arg):
@@ -303,9 +305,9 @@ class StatsStore(StateDeltasStore):
 
             if processed_event_count > batch_size:
                 # Don't process any more rooms, we've hit our batch size.
-                defer.returnValue(processed_event_count)
+                return processed_event_count
 
-        defer.returnValue(processed_event_count)
+        return processed_event_count
 
     def delete_all_stats(self):
         """
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index a0465484df..856c2ee8d8 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -300,7 +300,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         if not room_ids:
-            defer.returnValue({})
+            return {}
 
         results = {}
         room_ids = list(room_ids)
@@ -323,7 +323,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             )
             results.update(dict(zip(rm_ids, res)))
 
-        defer.returnValue(results)
+        return results
 
     def get_rooms_that_changed(self, room_ids, from_key):
         """Given a list of rooms and a token, return rooms where there may have
@@ -364,7 +364,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             the chunk of events returned.
         """
         if from_key == to_key:
-            defer.returnValue(([], from_key))
+            return ([], from_key)
 
         from_id = RoomStreamToken.parse_stream_token(from_key).stream
         to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -374,7 +374,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         if not has_changed:
-            defer.returnValue(([], from_key))
+            return ([], from_key)
 
         def f(txn):
             sql = (
@@ -407,7 +407,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             # get.
             key = from_key
 
-        defer.returnValue((ret, key))
+        return (ret, key)
 
     @defer.inlineCallbacks
     def get_membership_changes_for_user(self, user_id, from_key, to_key):
@@ -415,14 +415,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
         if from_key == to_key:
-            defer.returnValue([])
+            return []
 
         if from_id:
             has_changed = self._membership_stream_cache.has_entity_changed(
                 user_id, int(from_id)
             )
             if not has_changed:
-                defer.returnValue([])
+                return []
 
         def f(txn):
             sql = (
@@ -447,7 +447,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         self._set_before_and_after(ret, rows, topo_order=False)
 
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def get_recent_events_for_room(self, room_id, limit, end_token):
@@ -477,7 +477,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         self._set_before_and_after(events, rows)
 
-        defer.returnValue((events, token))
+        return (events, token)
 
     @defer.inlineCallbacks
     def get_recent_event_ids_for_room(self, room_id, limit, end_token):
@@ -496,7 +496,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         # Allow a zero limit here, and no-op.
         if limit == 0:
-            defer.returnValue(([], end_token))
+            return ([], end_token)
 
         end_token = RoomStreamToken.parse(end_token)
 
@@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         # We want to return the results in ascending order.
         rows.reverse()
 
-        defer.returnValue((rows, token))
+        return (rows, token)
 
     def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
         """Gets details of the first event in a room at or after a stream ordering
@@ -549,12 +549,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         token = yield self.get_room_max_stream_ordering()
         if room_id is None:
-            defer.returnValue("s%d" % (token,))
+            return "s%d" % (token,)
         else:
             topo = yield self.runInteraction(
                 "_get_max_topological_txn", self._get_max_topological_txn, room_id
             )
-            defer.returnValue("t%d-%d" % (topo, token))
+            return "t%d-%d" % (topo, token)
 
     def get_stream_token_for_event(self, event_id):
         """The stream token for an event
@@ -674,14 +674,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             [e for e in results["after"]["event_ids"]], get_prev_content=True
         )
 
-        defer.returnValue(
-            {
-                "events_before": events_before,
-                "events_after": events_after,
-                "start": results["before"]["token"],
-                "end": results["after"]["token"],
-            }
-        )
+        return {
+            "events_before": events_before,
+            "events_after": events_after,
+            "start": results["before"]["token"],
+            "end": results["after"]["token"],
+        }
 
     def _get_events_around_txn(
         self, txn, room_id, event_id, before_limit, after_limit, event_filter
@@ -785,7 +783,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         events = yield self.get_events_as_list(event_ids)
 
-        defer.returnValue((upper_bound, events))
+        return (upper_bound, events)
 
     def get_federation_out_pos(self, typ):
         return self._simple_select_one_onecol(
@@ -939,7 +937,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         self._set_before_and_after(events, rows)
 
-        defer.returnValue((events, token))
+        return (events, token)
 
 
 class StreamStore(StreamWorkerStore):
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index e88f8ea35f..20dd6bd53d 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -66,7 +66,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             room_id string, tag string and content string.
         """
         if last_id == current_id:
-            defer.returnValue([])
+            return []
 
         def get_all_updated_tags_txn(txn):
             sql = (
@@ -107,7 +107,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             )
             results.extend(tags)
 
-        defer.returnValue(results)
+        return results
 
     @defer.inlineCallbacks
     def get_updated_tags(self, user_id, stream_id):
@@ -135,7 +135,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             user_id, int(stream_id)
         )
         if not changed:
-            defer.returnValue({})
+            return {}
 
         room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
 
@@ -145,7 +145,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             for room_id in room_ids:
                 results[room_id] = tags_by_room.get(room_id, {})
 
-        defer.returnValue(results)
+        return results
 
     def get_tags_for_room(self, user_id, room_id):
         """Get all the tags for the given room
@@ -194,7 +194,7 @@ class TagsStore(TagsWorkerStore):
         self.get_tags_for_user.invalidate((user_id,))
 
         result = self._account_data_id_gen.get_current_token()
-        defer.returnValue(result)
+        return result
 
     @defer.inlineCallbacks
     def remove_tag_from_room(self, user_id, room_id, tag):
@@ -217,7 +217,7 @@ class TagsStore(TagsWorkerStore):
         self.get_tags_for_user.invalidate((user_id,))
 
         result = self._account_data_id_gen.get_current_token()
-        defer.returnValue(result)
+        return result
 
     def _update_revision_txn(self, txn, user_id, room_id, next_id):
         """Update the latest revision of the tags for the given user and room.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index fd18619178..b3c3bf55bc 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -147,7 +147,7 @@ class TransactionStore(SQLBaseStore):
 
         result = self._destination_retry_cache.get(destination, SENTINEL)
         if result is not SENTINEL:
-            defer.returnValue(result)
+            return result
 
         result = yield self.runInteraction(
             "get_destination_retry_timings",
@@ -158,7 +158,7 @@ class TransactionStore(SQLBaseStore):
         # We don't hugely care about race conditions between getting and
         # invalidating the cache, since we time out fairly quickly anyway.
         self._destination_retry_cache[destination] = result
-        defer.returnValue(result)
+        return result
 
     def _get_destination_retry_timings(self, txn, destination):
         result = self._simple_select_one_txn(
@@ -196,6 +196,26 @@ class TransactionStore(SQLBaseStore):
     def _set_destination_retry_timings(
         self, txn, destination, retry_last_ts, retry_interval
     ):
+
+        if self.database_engine.can_native_upsert:
+            # Upsert retry time interval if retry_interval is zero (i.e. we're
+            # resetting it) or greater than the existing retry interval.
+
+            sql = """
+                INSERT INTO destinations (destination, retry_last_ts, retry_interval)
+                    VALUES (?, ?, ?)
+                ON CONFLICT (destination) DO UPDATE SET
+                        retry_last_ts = EXCLUDED.retry_last_ts,
+                        retry_interval = EXCLUDED.retry_interval
+                    WHERE
+                        EXCLUDED.retry_interval = 0
+                        OR destinations.retry_interval < EXCLUDED.retry_interval
+            """
+
+            txn.execute(sql, (destination, retry_last_ts, retry_interval))
+
+            return
+
         self.database_engine.lock_table(txn, "destinations")
 
         # We need to be careful here as the data may have changed from under us
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 83466e25d9..b5188d9bee 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -109,7 +109,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
 
         yield self._end_background_update("populate_user_directory_createtables")
-        defer.returnValue(1)
+        return 1
 
     @defer.inlineCallbacks
     def _populate_user_directory_cleanup(self, progress, batch_size):
@@ -131,7 +131,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         )
 
         yield self._end_background_update("populate_user_directory_cleanup")
-        defer.returnValue(1)
+        return 1
 
     @defer.inlineCallbacks
     def _populate_user_directory_process_rooms(self, progress, batch_size):
@@ -177,7 +177,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
             yield self._end_background_update("populate_user_directory_process_rooms")
-            defer.returnValue(1)
+            return 1
 
         logger.info(
             "Processing the next %d rooms of %d remaining"
@@ -257,9 +257,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
 
             if processed_event_count > batch_size:
                 # Don't process any more rooms, we've hit our batch size.
-                defer.returnValue(processed_event_count)
+                return processed_event_count
 
-        defer.returnValue(processed_event_count)
+        return processed_event_count
 
     @defer.inlineCallbacks
     def _populate_user_directory_process_users(self, progress, batch_size):
@@ -268,7 +268,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         """
         if not self.hs.config.user_directory_search_all_users:
             yield self._end_background_update("populate_user_directory_process_users")
-            defer.returnValue(1)
+            return 1
 
         def _get_next_batch(txn):
             sql = "SELECT user_id FROM %s LIMIT %s" % (
@@ -298,7 +298,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         # No more users -- complete the transaction.
         if not users_to_work_on:
             yield self._end_background_update("populate_user_directory_process_users")
-            defer.returnValue(1)
+            return 1
 
         logger.info(
             "Processing the next %d users of %d remaining"
@@ -322,7 +322,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
                 progress,
             )
 
-        defer.returnValue(len(users_to_work_on))
+        return len(users_to_work_on)
 
     @defer.inlineCallbacks
     def is_room_world_readable_or_publicly_joinable(self, room_id):
@@ -344,16 +344,16 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
             join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
             if join_rule_ev:
                 if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
-                    defer.returnValue(True)
+                    return True
 
         hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
         if hist_vis_id:
             hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
             if hist_vis_ev:
                 if hist_vis_ev.content.get("history_visibility") == "world_readable":
-                    defer.returnValue(True)
+                    return True
 
-        defer.returnValue(False)
+        return False
 
     def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
         """
@@ -499,7 +499,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         user_ids = set(user_ids_share_pub)
         user_ids.update(user_ids_share_priv)
 
-        defer.returnValue(user_ids)
+        return user_ids
 
     def add_users_who_share_private_room(self, room_id, user_id_tuples):
         """Insert entries into the users_who_share_private_rooms table. The first
@@ -609,7 +609,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
 
         users = set(pub_rows)
         users.update(rows)
-        defer.returnValue(list(users))
+        return list(users)
 
     @defer.inlineCallbacks
     def get_rooms_in_common_for_users(self, user_id, other_user_id):
@@ -618,15 +618,15 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         sql = """
             SELECT room_id FROM (
                 SELECT c.room_id FROM current_state_events AS c
-                INNER JOIN room_memberships USING (event_id)
+                INNER JOIN room_memberships AS m USING (event_id)
                 WHERE type = 'm.room.member'
-                    AND membership = 'join'
+                    AND m.membership = 'join'
                     AND state_key = ?
             ) AS f1 INNER JOIN (
                 SELECT c.room_id FROM current_state_events AS c
-                INNER JOIN room_memberships USING (event_id)
+                INNER JOIN room_memberships AS m USING (event_id)
                 WHERE type = 'm.room.member'
-                    AND membership = 'join'
+                    AND m.membership = 'join'
                     AND state_key = ?
             ) f2 USING (room_id)
         """
@@ -635,7 +635,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
             "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
         )
 
-        defer.returnValue([room_id for room_id, in rows])
+        return [room_id for room_id, in rows]
 
     def delete_all_from_user_dir(self):
         """Delete the entire user directory
@@ -782,7 +782,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
 
         limited = len(results) > limit
 
-        defer.returnValue({"limited": limited, "results": results})
+        return {"limited": limited, "results": results}
 
 
 def _parse_query_sqlite(search_term):
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
index 1815fdc0dd..05cabc2282 100644
--- a/synapse/storage/user_erasure_store.py
+++ b/synapse/storage/user_erasure_store.py
@@ -12,9 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import operator
 
-from twisted.internet import defer
+import operator
 
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
@@ -67,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
 
         erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
         res = dict((u, u in erased_users) for u in user_ids)
-        defer.returnValue(res)
+        return res
 
 
 class UserErasureStore(UserErasureWorkerStore):