summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-10-30 12:08:09 +0000
committerErik Johnston <erik@matrix.org>2020-10-30 12:08:09 +0000
commit1ff3bc332ac5bd7a3c9ac4fcbcbf1692111f33d1 (patch)
tree9abed3402979f5105af3cb7ef739e8646e13f148 /synapse
parentMerge branch 'develop' into matrix-org-hotfixes (diff)
parentImplement and use an @lru_cache decorator (#8595) (diff)
downloadsynapse-1ff3bc332ac5bd7a3c9ac4fcbcbf1692111f33d1.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py113
-rw-r--r--synapse/appservice/__init__.py14
-rw-r--r--synapse/config/logger.py96
-rw-r--r--synapse/events/__init__.py4
-rw-r--r--synapse/federation/transport/server.py2
-rw-r--r--synapse/handlers/appservice.py75
-rw-r--r--synapse/handlers/auth.py31
-rw-r--r--synapse/handlers/message.py40
-rw-r--r--synapse/handlers/presence.py12
-rw-r--r--synapse/handlers/register.py7
-rw-r--r--synapse/handlers/room.py29
-rw-r--r--synapse/handlers/room_member.py8
-rw-r--r--synapse/http/client.py2
-rw-r--r--synapse/http/server.py4
-rw-r--r--synapse/http/site.py30
-rw-r--r--synapse/logging/__init__.py20
-rw-r--r--synapse/logging/_remote.py122
-rw-r--r--synapse/logging/_structured.py329
-rw-r--r--synapse/logging/_terse_json.py192
-rw-r--r--synapse/logging/filter.py33
-rw-r--r--synapse/notifier.py68
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py37
-rw-r--r--synapse/replication/http/membership.py6
-rw-r--r--synapse/replication/http/send_event.py3
-rw-r--r--synapse/replication/tcp/client.py20
-rw-r--r--synapse/replication/tcp/streams/events.py21
-rw-r--r--synapse/rest/admin/__init__.py4
-rw-r--r--synapse/rest/admin/users.py52
-rw-r--r--synapse/rest/media/v1/media_repository.py165
-rw-r--r--synapse/rest/media/v1/media_storage.py30
-rw-r--r--synapse/storage/database.py4
-rw-r--r--synapse/storage/databases/main/appservice.py98
-rw-r--r--synapse/storage/databases/main/censor_events.py6
-rw-r--r--synapse/storage/databases/main/events.py10
-rw-r--r--synapse/storage/databases/main/events_worker.py62
-rw-r--r--synapse/storage/databases/main/media_repository.py27
-rw-r--r--synapse/storage/databases/main/registration.py48
-rw-r--r--synapse/storage/databases/main/schema/delta/58/22puppet_token.sql17
-rw-r--r--synapse/types.py33
-rw-r--r--synapse/util/__init__.py24
-rw-r--r--synapse/util/caches/descriptors.py235
-rw-r--r--synapse/util/frozenutils.py22
-rw-r--r--synapse/util/retryutils.py2
43 files changed, 1205 insertions, 952 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 526cb58c5f..bfcaf68b2a 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.logging import opentracing as opentracing
+from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import StateMap, UserID
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
@@ -190,10 +191,6 @@ class Auth:
 
             user_id, app_service = await self._get_appservice_user_id(request)
             if user_id:
-                request.authenticated_entity = user_id
-                opentracing.set_tag("authenticated_entity", user_id)
-                opentracing.set_tag("appservice_id", app_service.id)
-
                 if ip_addr and self._track_appservice_user_ips:
                     await self.store.insert_client_ip(
                         user_id=user_id,
@@ -203,31 +200,38 @@ class Auth:
                         device_id="dummy-device",  # stubbed
                     )
 
-                return synapse.types.create_requester(user_id, app_service=app_service)
+                requester = synapse.types.create_requester(
+                    user_id, app_service=app_service
+                )
+
+                request.requester = user_id
+                opentracing.set_tag("authenticated_entity", user_id)
+                opentracing.set_tag("user_id", user_id)
+                opentracing.set_tag("appservice_id", app_service.id)
+
+                return requester
 
             user_info = await self.get_user_by_access_token(
                 access_token, rights, allow_expired=allow_expired
             )
-            user = user_info["user"]
-            token_id = user_info["token_id"]
-            is_guest = user_info["is_guest"]
-            shadow_banned = user_info["shadow_banned"]
+            token_id = user_info.token_id
+            is_guest = user_info.is_guest
+            shadow_banned = user_info.shadow_banned
 
             # Deny the request if the user account has expired.
             if self._account_validity.enabled and not allow_expired:
-                user_id = user.to_string()
-                if await self.store.is_account_expired(user_id, self.clock.time_msec()):
+                if await self.store.is_account_expired(
+                    user_info.user_id, self.clock.time_msec()
+                ):
                     raise AuthError(
                         403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
                     )
 
-            # device_id may not be present if get_user_by_access_token has been
-            # stubbed out.
-            device_id = user_info.get("device_id")
+            device_id = user_info.device_id
 
-            if user and access_token and ip_addr:
+            if access_token and ip_addr:
                 await self.store.insert_client_ip(
-                    user_id=user.to_string(),
+                    user_id=user_info.token_owner,
                     access_token=access_token,
                     ip=ip_addr,
                     user_agent=user_agent,
@@ -241,19 +245,23 @@ class Auth:
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
-            request.authenticated_entity = user.to_string()
-            opentracing.set_tag("authenticated_entity", user.to_string())
-            if device_id:
-                opentracing.set_tag("device_id", device_id)
-
-            return synapse.types.create_requester(
-                user,
+            requester = synapse.types.create_requester(
+                user_info.user_id,
                 token_id,
                 is_guest,
                 shadow_banned,
                 device_id,
                 app_service=app_service,
+                authenticated_entity=user_info.token_owner,
             )
+
+            request.requester = requester
+            opentracing.set_tag("authenticated_entity", user_info.token_owner)
+            opentracing.set_tag("user_id", user_info.user_id)
+            if device_id:
+                opentracing.set_tag("device_id", device_id)
+
+            return requester
         except KeyError:
             raise MissingClientTokenError()
 
@@ -284,7 +292,7 @@ class Auth:
 
     async def get_user_by_access_token(
         self, token: str, rights: str = "access", allow_expired: bool = False,
-    ) -> dict:
+    ) -> TokenLookupResult:
         """ Validate access token and get user_id from it
 
         Args:
@@ -293,13 +301,7 @@ class Auth:
                 allow this
             allow_expired: If False, raises an InvalidClientTokenError
                 if the token is expired
-        Returns:
-            dict that includes:
-               `user` (UserID)
-               `is_guest` (bool)
-               `shadow_banned` (bool)
-               `token_id` (int|None): access token id. May be None if guest
-               `device_id` (str|None): device corresponding to access token
+
         Raises:
             InvalidClientTokenError if a user by that token exists, but the token is
                 expired
@@ -309,9 +311,9 @@ class Auth:
 
         if rights == "access":
             # first look in the database
-            r = await self._look_up_user_by_access_token(token)
+            r = await self.store.get_user_by_access_token(token)
             if r:
-                valid_until_ms = r["valid_until_ms"]
+                valid_until_ms = r.valid_until_ms
                 if (
                     not allow_expired
                     and valid_until_ms is not None
@@ -328,7 +330,6 @@ class Auth:
         # otherwise it needs to be a valid macaroon
         try:
             user_id, guest = self._parse_and_validate_macaroon(token, rights)
-            user = UserID.from_string(user_id)
 
             if rights == "access":
                 if not guest:
@@ -354,23 +355,17 @@ class Auth:
                     raise InvalidClientTokenError(
                         "Guest access token used for regular user"
                     )
-                ret = {
-                    "user": user,
-                    "is_guest": True,
-                    "shadow_banned": False,
-                    "token_id": None,
+
+                ret = TokenLookupResult(
+                    user_id=user_id,
+                    is_guest=True,
                     # all guests get the same device id
-                    "device_id": GUEST_DEVICE_ID,
-                }
+                    device_id=GUEST_DEVICE_ID,
+                )
             elif rights == "delete_pusher":
                 # We don't store these tokens in the database
-                ret = {
-                    "user": user,
-                    "is_guest": False,
-                    "shadow_banned": False,
-                    "token_id": None,
-                    "device_id": None,
-                }
+
+                ret = TokenLookupResult(user_id=user_id, is_guest=False)
             else:
                 raise RuntimeError("Unknown rights setting %s", rights)
             return ret
@@ -479,31 +474,15 @@ class Auth:
         now = self.hs.get_clock().time_msec()
         return now < expiry
 
-    async def _look_up_user_by_access_token(self, token):
-        ret = await self.store.get_user_by_access_token(token)
-        if not ret:
-            return None
-
-        # we use ret.get() below because *lots* of unit tests stub out
-        # get_user_by_access_token in a way where it only returns a couple of
-        # the fields.
-        user_info = {
-            "user": UserID.from_string(ret.get("name")),
-            "token_id": ret.get("token_id", None),
-            "is_guest": False,
-            "shadow_banned": ret.get("shadow_banned"),
-            "device_id": ret.get("device_id"),
-            "valid_until_ms": ret.get("valid_until_ms"),
-        }
-        return user_info
-
     def get_appservice_by_req(self, request):
         token = self.get_access_token_from_request(request)
         service = self.store.get_app_service_by_token(token)
         if not service:
             logger.warning("Unrecognised appservice access token.")
             raise InvalidClientTokenError()
-        request.authenticated_entity = service.sender
+        request.requester = synapse.types.create_requester(
+            service.sender, app_service=service
+        )
         return service
 
     async def is_server_admin(self, user: UserID) -> bool:
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 3862d9c08f..3944780a42 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Match, Optional
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import _CacheContext, cached
 
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationServiceApi
@@ -52,11 +52,11 @@ class ApplicationService:
         self,
         token,
         hostname,
+        id,
+        sender,
         url=None,
         namespaces=None,
         hs_token=None,
-        sender=None,
-        id=None,
         protocols=None,
         rate_limited=True,
         ip_range_whitelist=None,
@@ -164,9 +164,9 @@ class ApplicationService:
         does_match = await self.matches_user_in_member_list(event.room_id, store)
         return does_match
 
-    @cached(num_args=1)
+    @cached(num_args=1, cache_context=True)
     async def matches_user_in_member_list(
-        self, room_id: str, store: "DataStore"
+        self, room_id: str, store: "DataStore", cache_context: _CacheContext,
     ) -> bool:
         """Check if this service is interested a room based upon it's membership
 
@@ -177,7 +177,9 @@ class ApplicationService:
         Returns:
             True if this service would like to know about this room.
         """
-        member_list = await store.get_users_in_room(room_id)
+        member_list = await store.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
 
         # check joined member events
         for user_id in member_list:
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 6b7be28aee..d4e887a3e0 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -23,7 +23,6 @@ from string import Template
 import yaml
 
 from twisted.logger import (
-    ILogObserver,
     LogBeginner,
     STDLibLogObserver,
     eventAsText,
@@ -32,11 +31,9 @@ from twisted.logger import (
 
 import synapse
 from synapse.app import _base as appbase
-from synapse.logging._structured import (
-    reload_structured_logging,
-    setup_structured_logging,
-)
+from synapse.logging._structured import setup_structured_logging
 from synapse.logging.context import LoggingContextFilter
+from synapse.logging.filter import MetadataFilter
 from synapse.util.versionstring import get_version_string
 
 from ._base import Config, ConfigError
@@ -48,7 +45,11 @@ DEFAULT_LOG_CONFIG = Template(
 # This is a YAML file containing a standard Python logging configuration
 # dictionary. See [1] for details on the valid settings.
 #
+# Synapse also supports structured logging for machine readable logs which can
+# be ingested by ELK stacks. See [2] for details.
+#
 # [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
+# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md
 
 version: 1
 
@@ -176,11 +177,11 @@ class LoggingConfig(Config):
                 log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
 
 
-def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
+def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
     """
-    Set up Python stdlib logging.
+    Set up Python standard library logging.
     """
-    if log_config is None:
+    if log_config_path is None:
         log_format = (
             "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
             " - %(message)s"
@@ -196,7 +197,8 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
         handler.setFormatter(formatter)
         logger.addHandler(handler)
     else:
-        logging.config.dictConfig(log_config)
+        # Load the logging configuration.
+        _load_logging_config(log_config_path)
 
     # We add a log record factory that runs all messages through the
     # LoggingContextFilter so that we get the context *at the time we log*
@@ -204,12 +206,14 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
     # filter options, but care must when using e.g. MemoryHandler to buffer
     # writes.
 
-    log_filter = LoggingContextFilter(request="")
+    log_context_filter = LoggingContextFilter(request="")
+    log_metadata_filter = MetadataFilter({"server_name": config.server_name})
     old_factory = logging.getLogRecordFactory()
 
     def factory(*args, **kwargs):
         record = old_factory(*args, **kwargs)
-        log_filter.filter(record)
+        log_context_filter.filter(record)
+        log_metadata_filter.filter(record)
         return record
 
     logging.setLogRecordFactory(factory)
@@ -255,21 +259,40 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
     if not config.no_redirect_stdio:
         print("Redirected stdout/stderr to logs")
 
-    return observer
-
 
-def _reload_stdlib_logging(*args, log_config=None):
-    logger = logging.getLogger("")
+def _load_logging_config(log_config_path: str) -> None:
+    """
+    Configure logging from a log config path.
+    """
+    with open(log_config_path, "rb") as f:
+        log_config = yaml.safe_load(f.read())
 
     if not log_config:
-        logger.warning("Reloaded a blank config?")
+        logging.warning("Loaded a blank logging config?")
+
+    # If the old structured logging configuration is being used, convert it to
+    # the new style configuration.
+    if "structured" in log_config and log_config.get("structured"):
+        log_config = setup_structured_logging(log_config)
 
     logging.config.dictConfig(log_config)
 
 
+def _reload_logging_config(log_config_path):
+    """
+    Reload the log configuration from the file and apply it.
+    """
+    # If no log config path was given, it cannot be reloaded.
+    if log_config_path is None:
+        return
+
+    _load_logging_config(log_config_path)
+    logging.info("Reloaded log config from %s due to SIGHUP", log_config_path)
+
+
 def setup_logging(
     hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
-) -> ILogObserver:
+) -> None:
     """
     Set up the logging subsystem.
 
@@ -282,41 +305,18 @@ def setup_logging(
 
         logBeginner: The Twisted logBeginner to use.
 
-    Returns:
-        The "root" Twisted Logger observer, suitable for sending logs to from a
-        Logger instance.
     """
-    log_config = config.worker_log_config if use_worker_options else config.log_config
-
-    def read_config(*args, callback=None):
-        if log_config is None:
-            return None
-
-        with open(log_config, "rb") as f:
-            log_config_body = yaml.safe_load(f.read())
-
-        if callback:
-            callback(log_config=log_config_body)
-            logging.info("Reloaded log config from %s due to SIGHUP", log_config)
-
-        return log_config_body
+    log_config_path = (
+        config.worker_log_config if use_worker_options else config.log_config
+    )
 
-    log_config_body = read_config()
+    # Perform one-time logging configuration.
+    _setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
+    # Add a SIGHUP handler to reload the logging configuration, if one is available.
+    appbase.register_sighup(_reload_logging_config, log_config_path)
 
-    if log_config_body and log_config_body.get("structured") is True:
-        logger = setup_structured_logging(
-            hs, config, log_config_body, logBeginner=logBeginner
-        )
-        appbase.register_sighup(read_config, callback=reload_structured_logging)
-    else:
-        logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner)
-        appbase.register_sighup(read_config, callback=_reload_stdlib_logging)
-
-    # make sure that the first thing we log is a thing we can grep backwards
-    # for
+    # Log immediately so we can grep backwards.
     logging.warning("***** STARTING SERVER *****")
     logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
     logging.info("Server hostname: %s", config.server_name)
     logging.info("Instance name: %s", hs.get_instance_name())
-
-    return logger
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e203206865..8028663fa8 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -368,7 +368,7 @@ class FrozenEvent(EventBase):
         return self.__repr__()
 
     def __repr__(self):
-        return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
+        return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % (
             self.get("event_id", None),
             self.get("type", None),
             self.get("state_key", None),
@@ -451,7 +451,7 @@ class FrozenEventV2(EventBase):
         return self.__repr__()
 
     def __repr__(self):
-        return "<%s event_id='%s', type='%s', state_key='%s'>" % (
+        return "<%s event_id=%r, type=%r, state_key=%r>" % (
             self.__class__.__name__,
             self.event_id,
             self.get("type", None),
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 3a6b95631e..a0933fae88 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -154,7 +154,7 @@ class Authenticator:
         )
 
         logger.debug("Request from %s", origin)
-        request.authenticated_entity = origin
+        request.requester = origin
 
         # If we get a valid signed request from the other side, its probably
         # alive
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3ed29a2c16..9fc8444228 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,9 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
 
 from prometheus_client import Counter
 
@@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
-from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
 
 
 class ApplicationServicesHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
         self.appservice_api = hs.get_application_service_api()
@@ -247,7 +250,9 @@ class ApplicationServicesHandler:
                         service, "presence", new_token
                     )
 
-    async def _handle_typing(self, service: ApplicationService, new_token: int):
+    async def _handle_typing(
+        self, service: ApplicationService, new_token: int
+    ) -> List[JsonDict]:
         typing_source = self.event_sources.sources["typing"]
         # Get the typing events from just before current
         typing, _ = await typing_source.get_new_events_as(
@@ -259,7 +264,7 @@ class ApplicationServicesHandler:
         )
         return typing
 
-    async def _handle_receipts(self, service: ApplicationService):
+    async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "read_receipt"
         )
@@ -271,7 +276,7 @@ class ApplicationServicesHandler:
 
     async def _handle_presence(
         self, service: ApplicationService, users: Collection[Union[str, UserID]]
-    ):
+    ) -> List[JsonDict]:
         events = []  # type: List[JsonDict]
         presence_source = self.event_sources.sources["presence"]
         from_key = await self.store.get_type_stream_id_for_appservice(
@@ -301,11 +306,11 @@ class ApplicationServicesHandler:
 
         return events
 
-    async def query_user_exists(self, user_id):
+    async def query_user_exists(self, user_id: str) -> bool:
         """Check if any application service knows this user_id exists.
 
         Args:
-            user_id(str): The user to query if they exist on any AS.
+            user_id: The user to query if they exist on any AS.
         Returns:
             True if this user exists on at least one application service.
         """
@@ -316,11 +321,13 @@ class ApplicationServicesHandler:
                 return True
         return False
 
-    async def query_room_alias_exists(self, room_alias):
+    async def query_room_alias_exists(
+        self, room_alias: RoomAlias
+    ) -> Optional[RoomAliasMapping]:
         """Check if an application service knows this room alias exists.
 
         Args:
-            room_alias(RoomAlias): The room alias to query.
+            room_alias: The room alias to query.
         Returns:
             namedtuple: with keys "room_id" and "servers" or None if no
             association can be found.
@@ -336,10 +343,13 @@ class ApplicationServicesHandler:
             )
             if is_known_alias:
                 # the alias exists now so don't query more ASes.
-                result = await self.store.get_association_from_room_alias(room_alias)
-                return result
+                return await self.store.get_association_from_room_alias(room_alias)
+
+        return None
 
-    async def query_3pe(self, kind, protocol, fields):
+    async def query_3pe(
+        self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
+    ) -> List[JsonDict]:
         services = self._get_services_for_3pn(protocol)
 
         results = await make_deferred_yieldable(
@@ -361,7 +371,9 @@ class ApplicationServicesHandler:
 
         return ret
 
-    async def get_3pe_protocols(self, only_protocol=None):
+    async def get_3pe_protocols(
+        self, only_protocol: Optional[str] = None
+    ) -> Dict[str, JsonDict]:
         services = self.store.get_app_services()
         protocols = {}  # type: Dict[str, List[JsonDict]]
 
@@ -379,7 +391,7 @@ class ApplicationServicesHandler:
                 if info is not None:
                     protocols[p].append(info)
 
-        def _merge_instances(infos):
+        def _merge_instances(infos: List[JsonDict]) -> JsonDict:
             if not infos:
                 return {}
 
@@ -394,19 +406,17 @@ class ApplicationServicesHandler:
 
             return combined
 
-        for p in protocols.keys():
-            protocols[p] = _merge_instances(protocols[p])
+        return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
 
-        return protocols
-
-    async def _get_services_for_event(self, event):
+    async def _get_services_for_event(
+        self, event: EventBase
+    ) -> List[ApplicationService]:
         """Retrieve a list of application services interested in this event.
 
         Args:
-            event(Event): The event to check. Can be None if alias_list is not.
+            event: The event to check. Can be None if alias_list is not.
         Returns:
-            list<ApplicationService>: A list of services interested in this
-            event based on the service regex.
+            A list of services interested in this event based on the service regex.
         """
         services = self.store.get_app_services()
 
@@ -420,17 +430,15 @@ class ApplicationServicesHandler:
 
         return interested_list
 
-    def _get_services_for_user(self, user_id):
+    def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
-        return interested_list
+        return [s for s in services if (s.is_interested_in_user(user_id))]
 
-    def _get_services_for_3pn(self, protocol):
+    def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
-        return interested_list
+        return [s for s in services if s.is_interested_in_protocol(protocol)]
 
-    async def _is_unknown_user(self, user_id):
+    async def _is_unknown_user(self, user_id: str) -> bool:
         if not self.is_mine_id(user_id):
             # we don't know if they are unknown or not since it isn't one of our
             # users. We can't poke ASes.
@@ -445,9 +453,8 @@ class ApplicationServicesHandler:
         service_list = [s for s in services if s.sender == user_id]
         return len(service_list) == 0
 
-    async def _check_user_exists(self, user_id):
+    async def _check_user_exists(self, user_id: str) -> bool:
         unknown_user = await self._is_unknown_user(user_id)
         if unknown_user:
-            exists = await self.query_user_exists(user_id)
-            return exists
+            return await self.query_user_exists(user_id)
         return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index dd14ab69d7..ff103cbb92 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,10 +18,20 @@ import logging
 import time
 import unicodedata
 import urllib.parse
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import attr
-import bcrypt  # type: ignore[import]
+import bcrypt
 import pymacaroons
 
 from synapse.api.constants import LoginType
@@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
@@ -982,17 +991,17 @@ class AuthHandler(BaseHandler):
                 # This might return an awaitable, if it does block the log out
                 # until it completes.
                 result = provider.on_logged_out(
-                    user_id=str(user_info["user"]),
-                    device_id=user_info["device_id"],
+                    user_id=user_info.user_id,
+                    device_id=user_info.device_id,
                     access_token=access_token,
                 )
                 if inspect.isawaitable(result):
                     await result
 
         # delete pushers associated with this access token
-        if user_info["token_id"] is not None:
+        if user_info.token_id is not None:
             await self.hs.get_pusherpool().remove_pushers_by_access_token(
-                str(user_info["user"]), (user_info["token_id"],)
+                user_info.user_id, (user_info.token_id,)
             )
 
     async def delete_access_tokens_for_user(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9c0096bae5..31f91e0a1a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,9 +50,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
-from synapse.util import json_decoder
+from synapse.util import json_decoder, json_encoder
 from synapse.util.async_helpers import Linearizer
-from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_events_for_client
 
@@ -928,7 +927,7 @@ class EventCreationHandler:
 
         # Ensure that we can round trip before trying to persist in db
         try:
-            dump = frozendict_json_encoder.encode(event.content)
+            dump = json_encoder.encode(event.content)
             json_decoder.decode(dump)
         except Exception:
             logger.exception("Failed to encode content: %r", event.content)
@@ -1100,34 +1099,13 @@ class EventCreationHandler:
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.INVITE:
-
-                def is_inviter_member_event(e):
-                    return e.type == EventTypes.Member and e.sender == event.sender
-
-                current_state_ids = await context.get_current_state_ids()
-
-                # We know this event is not an outlier, so this must be
-                # non-None.
-                assert current_state_ids is not None
-
-                state_to_include_ids = [
-                    e_id
-                    for k, e_id in current_state_ids.items()
-                    if k[0] in self.room_invite_state_types
-                    or k == (EventTypes.Member, event.sender)
-                ]
-
-                state_to_include = await self.store.get_events(state_to_include_ids)
-
-                event.unsigned["invite_room_state"] = [
-                    {
-                        "type": e.type,
-                        "state_key": e.state_key,
-                        "content": e.content,
-                        "sender": e.sender,
-                    }
-                    for e in state_to_include.values()
-                ]
+                event.unsigned[
+                    "invite_room_state"
+                ] = await self.store.get_stripped_room_state_from_event_context(
+                    context,
+                    self.room_invite_state_types,
+                    membership_user_id=event.sender,
+                )
 
                 invitee = UserID.from_string(event.state_key)
                 if not self.hs.is_mine(invitee):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 49a00eed9c..8e014c9bb5 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -48,7 +48,7 @@ from synapse.util.wheel_timer import WheelTimer
 
 MYPY = False
 if MYPY:
-    import synapse.server
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -101,7 +101,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
 class BasePresenceHandler(abc.ABC):
     """Parts of the PresenceHandler that are shared between workers and master"""
 
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
@@ -199,7 +199,7 @@ class BasePresenceHandler(abc.ABC):
 
 
 class PresenceHandler(BasePresenceHandler):
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
@@ -1011,7 +1011,7 @@ def format_user_presence_state(state, now, include_user_id=True):
 
 
 class PresenceEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         # We can't call get_presence_handler here because there's a cycle:
         #
         #   Presence -> Notifier -> PresenceEventSource -> Presence
@@ -1071,12 +1071,14 @@ class PresenceEventSource:
 
             users_interested_in = await self._get_interested_in(user, explicit_room_id)
 
-            user_ids_changed = set()
+            user_ids_changed = set()  # type: Collection[str]
             changed = None
             if from_key:
                 changed = stream_change_cache.get_all_entities_changed(from_key)
 
             if changed is not None and len(changed) < 500:
+                assert isinstance(user_ids_changed, set)
+
                 # For small deltas, its quicker to get all changes and then
                 # work out if we share a room or they're in our presence list
                 get_updates_counter.labels("stream").inc()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a6f1d21674..ed1ff62599 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
                     400, "User ID already taken.", errcode=Codes.USER_IN_USE
                 )
             user_data = await self.auth.get_user_by_access_token(guest_access_token)
-            if not user_data["is_guest"] or user_data["user"].localpart != localpart:
+            if (
+                not user_data.is_guest
+                or UserID.from_string(user_data.user_id).localpart != localpart
+            ):
                 raise AuthError(
                     403,
                     "Cannot register taken user ID without valid guest "
@@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
             # up when the access token is saved, but that's quite an
             # invasive change I'd rather do separately.
             user_tuple = await self.store.get_user_by_access_token(token)
-            token_id = user_tuple["token_id"]
+            token_id = user_tuple.token_id
 
             await self.pusher_pool.add_pusher(
                 user_id=user_id,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index c5b1f1f1e1..e73031475f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -771,22 +771,29 @@ class RoomCreationHandler(BaseHandler):
                 ratelimit=False,
             )
 
-        for invitee in invite_list:
+        # we avoid dropping the lock between invites, as otherwise joins can
+        # start coming in and making the createRoom slow.
+        #
+        # we also don't need to check the requester's shadow-ban here, as we
+        # have already done so above (and potentially emptied invite_list).
+        with (await self.room_member_handler.member_linearizer.queue((room_id,))):
             content = {}
             is_direct = config.get("is_direct", None)
             if is_direct:
                 content["is_direct"] = is_direct
 
-            # Note that update_membership with an action of "invite" can raise a
-            # ShadowBanError, but this was handled above by emptying invite_list.
-            _, last_stream_id = await self.room_member_handler.update_membership(
-                requester,
-                UserID.from_string(invitee),
-                room_id,
-                "invite",
-                ratelimit=False,
-                content=content,
-            )
+            for invitee in invite_list:
+                (
+                    _,
+                    last_stream_id,
+                ) = await self.room_member_handler.update_membership_locked(
+                    requester,
+                    UserID.from_string(invitee),
+                    room_id,
+                    "invite",
+                    ratelimit=False,
+                    content=content,
+                )
 
         for invite_3pid in invite_3pid_list:
             id_server = invite_3pid["id_server"]
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0268288600..7e5e53a56f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -327,7 +327,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     # haproxy would have timed the request out anyway...
                     raise SynapseError(504, "took to long to process")
 
-                result = await self._update_membership(
+                result = await self.update_membership_locked(
                     requester,
                     target,
                     room_id,
@@ -342,7 +342,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         return result
 
-    async def _update_membership(
+    async def update_membership_locked(
         self,
         requester: Requester,
         target: UserID,
@@ -355,6 +355,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         content: Optional[dict] = None,
         require_consent: bool = True,
     ) -> Tuple[str, int]:
+        """Helper for update_membership.
+
+        Assumes that the membership linearizer is already held for the room.
+        """
         content_specified = bool(content)
         if content is None:
             content = {}
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8324632cb6..f409368802 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -359,7 +359,7 @@ class SimpleHttpClient:
                     agent=self.agent,
                     data=body_producer,
                     headers=headers,
-                    **self._extra_treq_args
+                    **self._extra_treq_args,
                 )  # type: defer.Deferred
 
                 # we use our own timeout mechanism rather than treq's as a workaround
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 65dbd339ac..c0919f8cb7 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request
 from twisted.web.static import File, NoRangeStaticProducer
 from twisted.web.util import redirectTo
 
-import synapse.events
-import synapse.metrics
 from synapse.api.errors import (
     CodeMessageException,
     Codes,
@@ -620,7 +618,7 @@ def respond_with_json(
     if pretty_print:
         encoder = iterencode_pretty_printed_json
     else:
-        if canonical_json or synapse.events.USE_FROZEN_DICTS:
+        if canonical_json:
             encoder = iterencode_canonical_json
         else:
             encoder = _encode_json_bytes
diff --git a/synapse/http/site.py b/synapse/http/site.py
index ddb1770b09..5f0581dc3f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Optional
+from typing import Optional, Union
 
 from twisted.python.failure import Failure
 from twisted.web.server import Request, Site
@@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
 from synapse.http import redact_uri
 from synapse.http.request_metrics import RequestMetrics, requests_counter
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.types import Requester
 
 logger = logging.getLogger(__name__)
 
@@ -54,9 +55,12 @@ class SynapseRequest(Request):
         Request.__init__(self, channel, *args, **kw)
         self.site = channel.site
         self._channel = channel  # this is used by the tests
-        self.authenticated_entity = None
         self.start_time = 0.0
 
+        # The requester, if authenticated. For federation requests this is the
+        # server name, for client requests this is the Requester object.
+        self.requester = None  # type: Optional[Union[Requester, str]]
+
         # we can't yet create the logcontext, as we don't know the method.
         self.logcontext = None  # type: Optional[LoggingContext]
 
@@ -271,11 +275,23 @@ class SynapseRequest(Request):
         # to the client (nb may be negative)
         response_send_time = self.finish_time - self._processing_finished_time
 
-        # need to decode as it could be raw utf-8 bytes
-        # from a IDN servname in an auth header
-        authenticated_entity = self.authenticated_entity
-        if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
-            authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+        # Convert the requester into a string that we can log
+        authenticated_entity = None
+        if isinstance(self.requester, str):
+            authenticated_entity = self.requester
+        elif isinstance(self.requester, Requester):
+            authenticated_entity = self.requester.authenticated_entity
+
+            # If this is a request where the target user doesn't match the user who
+            # authenticated (e.g. and admin is puppetting a user) then we log both.
+            if self.requester.user.to_string() != authenticated_entity:
+                authenticated_entity = "{},{}".format(
+                    authenticated_entity, self.requester.user.to_string(),
+                )
+        elif self.requester is not None:
+            # This shouldn't happen, but we log it so we don't lose information
+            # and can see that we're doing something wrong.
+            authenticated_entity = repr(self.requester)  # type: ignore[unreachable]
 
         # ...or could be raw utf-8 bytes in the User-Agent header.
         # N.B. if you don't do this, the logger explodes cryptically
diff --git a/synapse/logging/__init__.py b/synapse/logging/__init__.py
index e69de29bb2..b28b7b2ef7 100644
--- a/synapse/logging/__init__.py
+++ b/synapse/logging/__init__.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# These are imported to allow for nicer logging configuration files.
+from synapse.logging._remote import RemoteHandler
+from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
+
+__all__ = ["RemoteHandler", "JsonFormatter", "TerseJsonFormatter"]
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 0caf325916..fb937b3f28 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import sys
 import traceback
 from collections import deque
@@ -21,10 +22,11 @@ from math import floor
 from typing import Callable, Optional
 
 import attr
+from typing_extensions import Deque
 from zope.interface import implementer
 
 from twisted.application.internet import ClientService
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
 from twisted.internet.endpoints import (
     HostnameEndpoint,
     TCP4ClientEndpoint,
@@ -32,7 +34,9 @@ from twisted.internet.endpoints import (
 )
 from twisted.internet.interfaces import IPushProducer, ITransport
 from twisted.internet.protocol import Factory, Protocol
-from twisted.logger import ILogObserver, Logger, LogLevel
+from twisted.python.failure import Failure
+
+logger = logging.getLogger(__name__)
 
 
 @attr.s
@@ -45,11 +49,11 @@ class LogProducer:
     Args:
         buffer: Log buffer to read logs from.
         transport: Transport to write to.
-        format_event: A callable to format the log entry to a string.
+        format: A callable to format the log record to a string.
     """
 
     transport = attr.ib(type=ITransport)
-    format_event = attr.ib(type=Callable[[dict], str])
+    _format = attr.ib(type=Callable[[logging.LogRecord], str])
     _buffer = attr.ib(type=deque)
     _paused = attr.ib(default=False, type=bool, init=False)
 
@@ -61,16 +65,19 @@ class LogProducer:
         self._buffer = deque()
 
     def resumeProducing(self):
+        # If we're already producing, nothing to do.
         self._paused = False
 
+        # Loop until paused.
         while self._paused is False and (self._buffer and self.transport.connected):
             try:
-                # Request the next event and format it.
-                event = self._buffer.popleft()
-                msg = self.format_event(event)
+                # Request the next record and format it.
+                record = self._buffer.popleft()
+                msg = self._format(record)
 
                 # Send it as a new line over the transport.
                 self.transport.write(msg.encode("utf8"))
+                self.transport.write(b"\n")
             except Exception:
                 # Something has gone wrong writing to the transport -- log it
                 # and break out of the while.
@@ -78,76 +85,85 @@ class LogProducer:
                 break
 
 
-@attr.s
-@implementer(ILogObserver)
-class TCPLogObserver:
+class RemoteHandler(logging.Handler):
     """
-    An IObserver that writes JSON logs to a TCP target.
+    An logging handler that writes logs to a TCP target.
 
     Args:
-        hs (HomeServer): The homeserver that is being logged for.
         host: The host of the logging target.
         port: The logging target's port.
-        format_event: A callable to format the log entry to a string.
         maximum_buffer: The maximum buffer size.
     """
 
-    hs = attr.ib()
-    host = attr.ib(type=str)
-    port = attr.ib(type=int)
-    format_event = attr.ib(type=Callable[[dict], str])
-    maximum_buffer = attr.ib(type=int)
-    _buffer = attr.ib(default=attr.Factory(deque), type=deque)
-    _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
-    _logger = attr.ib(default=attr.Factory(Logger))
-    _producer = attr.ib(default=None, type=Optional[LogProducer])
-
-    def start(self) -> None:
+    def __init__(
+        self,
+        host: str,
+        port: int,
+        maximum_buffer: int = 1000,
+        level=logging.NOTSET,
+        _reactor=None,
+    ):
+        super().__init__(level=level)
+        self.host = host
+        self.port = port
+        self.maximum_buffer = maximum_buffer
+
+        self._buffer = deque()  # type: Deque[logging.LogRecord]
+        self._connection_waiter = None  # type: Optional[Deferred]
+        self._producer = None  # type: Optional[LogProducer]
 
         # Connect without DNS lookups if it's a direct IP.
+        if _reactor is None:
+            from twisted.internet import reactor
+
+            _reactor = reactor
+
         try:
             ip = ip_address(self.host)
             if isinstance(ip, IPv4Address):
-                endpoint = TCP4ClientEndpoint(
-                    self.hs.get_reactor(), self.host, self.port
-                )
+                endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
             elif isinstance(ip, IPv6Address):
-                endpoint = TCP6ClientEndpoint(
-                    self.hs.get_reactor(), self.host, self.port
-                )
+                endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
             else:
                 raise ValueError("Unknown IP address provided: %s" % (self.host,))
         except ValueError:
-            endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
+            endpoint = HostnameEndpoint(_reactor, self.host, self.port)
 
         factory = Factory.forProtocol(Protocol)
-        self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
+        self._service = ClientService(endpoint, factory, clock=_reactor)
         self._service.startService()
+        self._stopping = False
         self._connect()
 
-    def stop(self):
+    def close(self):
+        self._stopping = True
         self._service.stopService()
 
     def _connect(self) -> None:
         """
         Triggers an attempt to connect then write to the remote if not already writing.
         """
+        # Do not attempt to open multiple connections.
         if self._connection_waiter:
             return
 
         self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
 
-        @self._connection_waiter.addErrback
-        def fail(r):
-            r.printTraceback(file=sys.__stderr__)
+        def fail(failure: Failure) -> None:
+            # If the Deferred was cancelled (e.g. during shutdown) do not try to
+            # reconnect (this will cause an infinite loop of errors).
+            if failure.check(CancelledError) and self._stopping:
+                return
+
+            # For a different error, print the traceback and re-connect.
+            failure.printTraceback(file=sys.__stderr__)
             self._connection_waiter = None
             self._connect()
 
-        @self._connection_waiter.addCallback
-        def writer(r):
+        def writer(result: Protocol) -> None:
             # We have a connection. If we already have a producer, and its
             # transport is the same, just trigger a resumeProducing.
-            if self._producer and r.transport is self._producer.transport:
+            if self._producer and result.transport is self._producer.transport:
                 self._producer.resumeProducing()
                 self._connection_waiter = None
                 return
@@ -158,29 +174,29 @@ class TCPLogObserver:
 
             # Make a new producer and start it.
             self._producer = LogProducer(
-                buffer=self._buffer,
-                transport=r.transport,
-                format_event=self.format_event,
+                buffer=self._buffer, transport=result.transport, format=self.format,
             )
-            r.transport.registerProducer(self._producer, True)
+            result.transport.registerProducer(self._producer, True)
             self._producer.resumeProducing()
             self._connection_waiter = None
 
+        self._connection_waiter.addCallbacks(writer, fail)
+
     def _handle_pressure(self) -> None:
         """
-        Handle backpressure by shedding events.
+        Handle backpressure by shedding records.
 
         The buffer will, in this order, until the buffer is below the maximum:
-            - Shed DEBUG events
-            - Shed INFO events
-            - Shed the middle 50% of the events.
+            - Shed DEBUG records.
+            - Shed INFO records.
+            - Shed the middle 50% of the records.
         """
         if len(self._buffer) <= self.maximum_buffer:
             return
 
         # Strip out DEBUGs
         self._buffer = deque(
-            filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
+            filter(lambda record: record.levelno > logging.DEBUG, self._buffer)
         )
 
         if len(self._buffer) <= self.maximum_buffer:
@@ -188,7 +204,7 @@ class TCPLogObserver:
 
         # Strip out INFOs
         self._buffer = deque(
-            filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
+            filter(lambda record: record.levelno > logging.INFO, self._buffer)
         )
 
         if len(self._buffer) <= self.maximum_buffer:
@@ -209,17 +225,17 @@ class TCPLogObserver:
 
         self._buffer.extend(reversed(end_buffer))
 
-    def __call__(self, event: dict) -> None:
-        self._buffer.append(event)
+    def emit(self, record: logging.LogRecord) -> None:
+        self._buffer.append(record)
 
         # Handle backpressure, if it exists.
         try:
             self._handle_pressure()
         except Exception:
-            # If handling backpressure fails,clear the buffer and log the
+            # If handling backpressure fails, clear the buffer and log the
             # exception.
             self._buffer.clear()
-            self._logger.failure("Failed clearing backpressure")
+            logger.warning("Failed clearing backpressure")
 
         # Try and write immediately.
         self._connect()
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 0fc2ea609e..14d9c104c2 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -12,138 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import logging
 import os.path
-import sys
-import typing
-import warnings
-from typing import List
+from typing import Any, Dict, Generator, Optional, Tuple
 
-import attr
-from constantly import NamedConstant, Names, ValueConstant, Values
-from zope.interface import implementer
-
-from twisted.logger import (
-    FileLogObserver,
-    FilteringLogObserver,
-    ILogObserver,
-    LogBeginner,
-    Logger,
-    LogLevel,
-    LogLevelFilterPredicate,
-    LogPublisher,
-    eventAsText,
-    jsonFileLogObserver,
-)
+from constantly import NamedConstant, Names
 
 from synapse.config._base import ConfigError
-from synapse.logging._terse_json import (
-    TerseJSONToConsoleLogObserver,
-    TerseJSONToTCPLogObserver,
-)
-from synapse.logging.context import current_context
-
-
-def stdlib_log_level_to_twisted(level: str) -> LogLevel:
-    """
-    Convert a stdlib log level to Twisted's log level.
-    """
-    lvl = level.lower().replace("warning", "warn")
-    return LogLevel.levelWithName(lvl)
-
-
-@attr.s
-@implementer(ILogObserver)
-class LogContextObserver:
-    """
-    An ILogObserver which adds Synapse-specific log context information.
-
-    Attributes:
-        observer (ILogObserver): The target parent observer.
-    """
-
-    observer = attr.ib()
-
-    def __call__(self, event: dict) -> None:
-        """
-        Consume a log event and emit it to the parent observer after filtering
-        and adding log context information.
-
-        Args:
-            event (dict)
-        """
-        # Filter out some useless events that Twisted outputs
-        if "log_text" in event:
-            if event["log_text"].startswith("DNSDatagramProtocol starting on "):
-                return
-
-            if event["log_text"].startswith("(UDP Port "):
-                return
-
-            if event["log_text"].startswith("Timing out client") or event[
-                "log_format"
-            ].startswith("Timing out client"):
-                return
-
-        context = current_context()
-
-        # Copy the context information to the log event.
-        context.copy_to_twisted_log_entry(event)
-
-        self.observer(event)
-
-
-class PythonStdlibToTwistedLogger(logging.Handler):
-    """
-    Transform a Python stdlib log message into a Twisted one.
-    """
-
-    def __init__(self, observer, *args, **kwargs):
-        """
-        Args:
-            observer (ILogObserver): A Twisted logging observer.
-            *args, **kwargs: Args/kwargs to be passed to logging.Handler.
-        """
-        self.observer = observer
-        super().__init__(*args, **kwargs)
-
-    def emit(self, record: logging.LogRecord) -> None:
-        """
-        Emit a record to Twisted's observer.
-
-        Args:
-            record (logging.LogRecord)
-        """
-
-        self.observer(
-            {
-                "log_time": record.created,
-                "log_text": record.getMessage(),
-                "log_format": "{log_text}",
-                "log_namespace": record.name,
-                "log_level": stdlib_log_level_to_twisted(record.levelname),
-            }
-        )
-
-
-def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver:
-    """
-    A log observer that formats events like the traditional log formatter and
-    sends them to `outFile`.
-
-    Args:
-        outFile (file object): The file object to write to.
-    """
-
-    def formatEvent(_event: dict) -> str:
-        event = dict(_event)
-        event["log_level"] = event["log_level"].name.upper()
-        event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
-            event.get("log_format", "{log_text}") or "{log_text}"
-        )
-        return eventAsText(event, includeSystem=False) + "\n"
-
-    return FileLogObserver(outFile, formatEvent)
 
 
 class DrainType(Names):
@@ -155,30 +29,12 @@ class DrainType(Names):
     NETWORK_JSON_TERSE = NamedConstant()
 
 
-class OutputPipeType(Values):
-    stdout = ValueConstant(sys.__stdout__)
-    stderr = ValueConstant(sys.__stderr__)
-
-
-@attr.s
-class DrainConfiguration:
-    name = attr.ib()
-    type = attr.ib()
-    location = attr.ib()
-    options = attr.ib(default=None)
-
-
-@attr.s
-class NetworkJSONTerseOptions:
-    maximum_buffer = attr.ib(type=int)
-
-
-DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
+DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
 
 
 def parse_drain_configs(
     drains: dict,
-) -> typing.Generator[DrainConfiguration, None, None]:
+) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
     """
     Parse the drain configurations.
 
@@ -186,11 +42,12 @@ def parse_drain_configs(
         drains (dict): A list of drain configurations.
 
     Yields:
-        DrainConfiguration instances.
+        dict instances representing a logging handler.
 
     Raises:
         ConfigError: If any of the drain configuration items are invalid.
     """
+
     for name, config in drains.items():
         if "type" not in config:
             raise ConfigError("Logging drains require a 'type' key.")
@@ -202,6 +59,18 @@ def parse_drain_configs(
                 "%s is not a known logging drain type." % (config["type"],)
             )
 
+        # Either use the default formatter or the tersejson one.
+        if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
+            formatter = "json"  # type: Optional[str]
+        elif logging_type in (
+            DrainType.CONSOLE_JSON_TERSE,
+            DrainType.NETWORK_JSON_TERSE,
+        ):
+            formatter = "tersejson"
+        else:
+            # A formatter of None implies using the default formatter.
+            formatter = None
+
         if logging_type in [
             DrainType.CONSOLE,
             DrainType.CONSOLE_JSON,
@@ -217,9 +86,11 @@ def parse_drain_configs(
                     % (logging_type,)
                 )
 
-            pipe = OutputPipeType.lookupByName(location).value
-
-            yield DrainConfiguration(name=name, type=logging_type, location=pipe)
+            yield name, {
+                "class": "logging.StreamHandler",
+                "formatter": formatter,
+                "stream": "ext://sys." + location,
+            }
 
         elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
             if "location" not in config:
@@ -233,18 +104,25 @@ def parse_drain_configs(
                     "File paths need to be absolute, '%s' is a relative path"
                     % (location,)
                 )
-            yield DrainConfiguration(name=name, type=logging_type, location=location)
+
+            yield name, {
+                "class": "logging.FileHandler",
+                "formatter": formatter,
+                "filename": location,
+            }
 
         elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
             host = config.get("host")
             port = config.get("port")
             maximum_buffer = config.get("maximum_buffer", 1000)
-            yield DrainConfiguration(
-                name=name,
-                type=logging_type,
-                location=(host, port),
-                options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
-            )
+
+            yield name, {
+                "class": "synapse.logging.RemoteHandler",
+                "formatter": formatter,
+                "host": host,
+                "port": port,
+                "maximum_buffer": maximum_buffer,
+            }
 
         else:
             raise ConfigError(
@@ -253,126 +131,29 @@ def parse_drain_configs(
             )
 
 
-class StoppableLogPublisher(LogPublisher):
+def setup_structured_logging(log_config: dict,) -> dict:
     """
-    A log publisher that can tell its observers to shut down any external
-    communications.
-    """
-
-    def stop(self):
-        for obs in self._observers:
-            if hasattr(obs, "stop"):
-                obs.stop()
-
-
-def setup_structured_logging(
-    hs,
-    config,
-    log_config: dict,
-    logBeginner: LogBeginner,
-    redirect_stdlib_logging: bool = True,
-) -> LogPublisher:
-    """
-    Set up Twisted's structured logging system.
-
-    Args:
-        hs: The homeserver to use.
-        config (HomeserverConfig): The configuration of the Synapse homeserver.
-        log_config (dict): The log configuration to use.
+    Convert a legacy structured logging configuration (from Synapse < v1.23.0)
+    to one compatible with the new standard library handlers.
     """
-    if config.no_redirect_stdio:
-        raise ConfigError(
-            "no_redirect_stdio cannot be defined using structured logging."
-        )
-
-    logger = Logger()
-
     if "drains" not in log_config:
         raise ConfigError("The logging configuration requires a list of drains.")
 
-    observers = []  # type: List[ILogObserver]
-
-    for observer in parse_drain_configs(log_config["drains"]):
-        # Pipe drains
-        if observer.type == DrainType.CONSOLE:
-            logger.debug(
-                "Starting up the {name} console logger drain", name=observer.name
-            )
-            observers.append(SynapseFileLogObserver(observer.location))
-        elif observer.type == DrainType.CONSOLE_JSON:
-            logger.debug(
-                "Starting up the {name} JSON console logger drain", name=observer.name
-            )
-            observers.append(jsonFileLogObserver(observer.location))
-        elif observer.type == DrainType.CONSOLE_JSON_TERSE:
-            logger.debug(
-                "Starting up the {name} terse JSON console logger drain",
-                name=observer.name,
-            )
-            observers.append(
-                TerseJSONToConsoleLogObserver(observer.location, metadata={})
-            )
-
-        # File drains
-        elif observer.type == DrainType.FILE:
-            logger.debug("Starting up the {name} file logger drain", name=observer.name)
-            log_file = open(observer.location, "at", buffering=1, encoding="utf8")
-            observers.append(SynapseFileLogObserver(log_file))
-        elif observer.type == DrainType.FILE_JSON:
-            logger.debug(
-                "Starting up the {name} JSON file logger drain", name=observer.name
-            )
-            log_file = open(observer.location, "at", buffering=1, encoding="utf8")
-            observers.append(jsonFileLogObserver(log_file))
-
-        elif observer.type == DrainType.NETWORK_JSON_TERSE:
-            metadata = {"server_name": hs.config.server_name}
-            log_observer = TerseJSONToTCPLogObserver(
-                hs=hs,
-                host=observer.location[0],
-                port=observer.location[1],
-                metadata=metadata,
-                maximum_buffer=observer.options.maximum_buffer,
-            )
-            log_observer.start()
-            observers.append(log_observer)
-        else:
-            # We should never get here, but, just in case, throw an error.
-            raise ConfigError("%s drain type cannot be configured" % (observer.type,))
-
-    publisher = StoppableLogPublisher(*observers)
-    log_filter = LogLevelFilterPredicate()
-
-    for namespace, namespace_config in log_config.get(
-        "loggers", DEFAULT_LOGGERS
-    ).items():
-        # Set the log level for twisted.logger.Logger namespaces
-        log_filter.setLogLevelForNamespace(
-            namespace,
-            stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
-        )
-
-        # Also set the log levels for the stdlib logger namespaces, to prevent
-        # them getting to PythonStdlibToTwistedLogger and having to be formatted
-        if "level" in namespace_config:
-            logging.getLogger(namespace).setLevel(namespace_config.get("level"))
-
-    f = FilteringLogObserver(publisher, [log_filter])
-    lco = LogContextObserver(f)
-
-    if redirect_stdlib_logging:
-        stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
-        stdliblogger = logging.getLogger()
-        stdliblogger.addHandler(stuff_into_twisted)
-
-    # Always redirect standard I/O, otherwise other logging outputs might miss
-    # it.
-    logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
+    new_config = {
+        "version": 1,
+        "formatters": {
+            "json": {"class": "synapse.logging.JsonFormatter"},
+            "tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
+        },
+        "handlers": {},
+        "loggers": log_config.get("loggers", DEFAULT_LOGGERS),
+        "root": {"handlers": []},
+    }
 
-    return publisher
+    for handler_name, handler in parse_drain_configs(log_config["drains"]):
+        new_config["handlers"][handler_name] = handler
 
+        # Add each handler to the root logger.
+        new_config["root"]["handlers"].append(handler_name)
 
-def reload_structured_logging(*args, log_config=None) -> None:
-    warnings.warn(
-        "Currently the structured logging system can not be reloaded, doing nothing"
-    )
+    return new_config
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 9b46956ca9..2fbf5549a1 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -16,141 +16,65 @@
 """
 Log formatters that output terse JSON.
 """
-
 import json
-from typing import IO
-
-from twisted.logger import FileLogObserver
-
-from synapse.logging._remote import TCPLogObserver
+import logging
 
 _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
 
-
-def flatten_event(event: dict, metadata: dict, include_time: bool = False):
-    """
-    Flatten a Twisted logging event to an dictionary capable of being sent
-    as a log event to a logging aggregation system.
-
-    The format is vastly simplified and is not designed to be a "human readable
-    string" in the sense that traditional logs are. Instead, the structure is
-    optimised for searchability and filtering, with human-understandable log
-    keys.
-
-    Args:
-        event (dict): The Twisted logging event we are flattening.
-        metadata (dict): Additional data to include with each log message. This
-            can be information like the server name. Since the target log
-            consumer does not know who we are other than by host IP, this
-            allows us to forward through static information.
-        include_time (bool): Should we include the `time` key? If False, the
-            event time is stripped from the event.
-    """
-    new_event = {}
-
-    # If it's a failure, make the new event's log_failure be the traceback text.
-    if "log_failure" in event:
-        new_event["log_failure"] = event["log_failure"].getTraceback()
-
-    # If it's a warning, copy over a string representation of the warning.
-    if "warning" in event:
-        new_event["warning"] = str(event["warning"])
-
-    # Stdlib logging events have "log_text" as their human-readable portion,
-    # Twisted ones have "log_format". For now, include the log_format, so that
-    # context only given in the log format (e.g. what is being logged) is
-    # available.
-    if "log_text" in event:
-        new_event["log"] = event["log_text"]
-    else:
-        new_event["log"] = event["log_format"]
-
-    # We want to include the timestamp when forwarding over the network, but
-    # exclude it when we are writing to stdout. This is because the log ingester
-    # (e.g. logstash, fluentd) can add its own timestamp.
-    if include_time:
-        new_event["time"] = round(event["log_time"], 2)
-
-    # Convert the log level to a textual representation.
-    new_event["level"] = event["log_level"].name.upper()
-
-    # Ignore these keys, and do not transfer them over to the new log object.
-    # They are either useless (isError), transferred manually above (log_time,
-    # log_level, etc), or contain Python objects which are not useful for output
-    # (log_logger, log_source).
-    keys_to_delete = [
-        "isError",
-        "log_failure",
-        "log_format",
-        "log_level",
-        "log_logger",
-        "log_source",
-        "log_system",
-        "log_time",
-        "log_text",
-        "observer",
-        "warning",
-    ]
-
-    # If it's from the Twisted legacy logger (twisted.python.log), it adds some
-    # more keys we want to purge.
-    if event.get("log_namespace") == "log_legacy":
-        keys_to_delete.extend(["message", "system", "time"])
-
-    # Rather than modify the dictionary in place, construct a new one with only
-    # the content we want. The original event should be considered 'frozen'.
-    for key in event.keys():
-
-        if key in keys_to_delete:
-            continue
-
-        if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
-            # If it's a plain type, include it as is.
-            new_event[key] = event[key]
-        else:
-            # If it's not one of those basic types, write out a string
-            # representation. This should probably be a warning in development,
-            # so that we are sure we are only outputting useful data.
-            new_event[key] = str(event[key])
-
-    # Add the metadata information to the event (e.g. the server_name).
-    new_event.update(metadata)
-
-    return new_event
-
-
-def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver:
-    """
-    A log observer that formats events to a flattened JSON representation.
-
-    Args:
-        outFile: The file object to write to.
-        metadata: Metadata to be added to each log object.
-    """
-
-    def formatEvent(_event: dict) -> str:
-        flattened = flatten_event(_event, metadata)
-        return _encoder.encode(flattened) + "\n"
-
-    return FileLogObserver(outFile, formatEvent)
-
-
-def TerseJSONToTCPLogObserver(
-    hs, host: str, port: int, metadata: dict, maximum_buffer: int
-) -> FileLogObserver:
-    """
-    A log observer that formats events to a flattened JSON representation.
-
-    Args:
-        hs (HomeServer): The homeserver that is being logged for.
-        host: The host of the logging target.
-        port: The logging target's port.
-        metadata: Metadata to be added to each log object.
-        maximum_buffer: The maximum buffer size.
-    """
-
-    def formatEvent(_event: dict) -> str:
-        flattened = flatten_event(_event, metadata, include_time=True)
-        return _encoder.encode(flattened) + "\n"
-
-    return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
+# The properties of a standard LogRecord.
+_LOG_RECORD_ATTRIBUTES = {
+    "args",
+    "asctime",
+    "created",
+    "exc_info",
+    # exc_text isn't a public attribute, but is used to cache the result of formatException.
+    "exc_text",
+    "filename",
+    "funcName",
+    "levelname",
+    "levelno",
+    "lineno",
+    "message",
+    "module",
+    "msecs",
+    "msg",
+    "name",
+    "pathname",
+    "process",
+    "processName",
+    "relativeCreated",
+    "stack_info",
+    "thread",
+    "threadName",
+}
+
+
+class JsonFormatter(logging.Formatter):
+    def format(self, record: logging.LogRecord) -> str:
+        event = {
+            "log": record.getMessage(),
+            "namespace": record.name,
+            "level": record.levelname,
+        }
+
+        return self._format(record, event)
+
+    def _format(self, record: logging.LogRecord, event: dict) -> str:
+        # Add any extra attributes to the event.
+        for key, value in record.__dict__.items():
+            if key not in _LOG_RECORD_ATTRIBUTES:
+                event[key] = value
+
+        return _encoder.encode(event)
+
+
+class TerseJsonFormatter(JsonFormatter):
+    def format(self, record: logging.LogRecord) -> str:
+        event = {
+            "log": record.getMessage(),
+            "namespace": record.name,
+            "level": record.levelname,
+            "time": round(record.created, 2),
+        }
+
+        return self._format(record, event)
diff --git a/synapse/logging/filter.py b/synapse/logging/filter.py
new file mode 100644
index 0000000000..1baf8dd679
--- /dev/null
+++ b/synapse/logging/filter.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from typing_extensions import Literal
+
+
+class MetadataFilter(logging.Filter):
+    """Logging filter that adds constant values to each record.
+
+    Args:
+        metadata: Key-value pairs to add to each record.
+    """
+
+    def __init__(self, metadata: dict):
+        self._metadata = metadata
+
+    def filter(self, record: logging.LogRecord) -> Literal[True]:
+        for key, value in self._metadata.items():
+            setattr(record, key, value)
+        return True
diff --git a/synapse/notifier.py b/synapse/notifier.py
index eb56b26f21..a17352ef46 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -28,6 +28,7 @@ from typing import (
     Union,
 )
 
+import attr
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -173,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
         return bool(self.events)
 
 
+@attr.s(slots=True, frozen=True)
+class _PendingRoomEventEntry:
+    event_pos = attr.ib(type=PersistedEventPosition)
+    extra_users = attr.ib(type=Collection[UserID])
+
+    room_id = attr.ib(type=str)
+    type = attr.ib(type=str)
+    state_key = attr.ib(type=Optional[str])
+    membership = attr.ib(type=Optional[str])
+
+
 class Notifier:
     """ This class is responsible for notifying any listeners when there are
     new events available for it.
@@ -190,9 +202,7 @@ class Notifier:
         self.storage = hs.get_storage()
         self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastore()
-        self.pending_new_room_events = (
-            []
-        )  # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
+        self.pending_new_room_events = []  # type: List[_PendingRoomEventEntry]
 
         # Called when there are new things to stream over replication
         self.replication_callbacks = []  # type: List[Callable[[], None]]
@@ -255,7 +265,29 @@ class Notifier:
         max_room_stream_token: RoomStreamToken,
         extra_users: Collection[UserID] = [],
     ):
-        """ Used by handlers to inform the notifier something has happened
+        """Unwraps event and calls `on_new_room_event_args`.
+        """
+        self.on_new_room_event_args(
+            event_pos=event_pos,
+            room_id=event.room_id,
+            event_type=event.type,
+            state_key=event.get("state_key"),
+            membership=event.content.get("membership"),
+            max_room_stream_token=max_room_stream_token,
+            extra_users=extra_users,
+        )
+
+    def on_new_room_event_args(
+        self,
+        room_id: str,
+        event_type: str,
+        state_key: Optional[str],
+        membership: Optional[str],
+        event_pos: PersistedEventPosition,
+        max_room_stream_token: RoomStreamToken,
+        extra_users: Collection[UserID] = [],
+    ):
+        """Used by handlers to inform the notifier something has happened
         in the room, room event wise.
 
         This triggers the notifier to wake up any listeners that are
@@ -266,7 +298,16 @@ class Notifier:
         until all previous events have been persisted before notifying
         the client streams.
         """
-        self.pending_new_room_events.append((event_pos, event, extra_users))
+        self.pending_new_room_events.append(
+            _PendingRoomEventEntry(
+                event_pos=event_pos,
+                extra_users=extra_users,
+                room_id=room_id,
+                type=event_type,
+                state_key=state_key,
+                membership=membership,
+            )
+        )
         self._notify_pending_new_room_events(max_room_stream_token)
 
         self.notify_replication()
@@ -284,18 +325,19 @@ class Notifier:
         users = set()  # type: Set[UserID]
         rooms = set()  # type: Set[str]
 
-        for event_pos, event, extra_users in pending:
-            if event_pos.persisted_after(max_room_stream_token):
-                self.pending_new_room_events.append((event_pos, event, extra_users))
+        for entry in pending:
+            if entry.event_pos.persisted_after(max_room_stream_token):
+                self.pending_new_room_events.append(entry)
             else:
                 if (
-                    event.type == EventTypes.Member
-                    and event.membership == Membership.JOIN
+                    entry.type == EventTypes.Member
+                    and entry.membership == Membership.JOIN
+                    and entry.state_key
                 ):
-                    self._user_joined_room(event.state_key, event.room_id)
+                    self._user_joined_room(entry.state_key, entry.room_id)
 
-                users.update(extra_users)
-                rooms.add(event.room_id)
+                users.update(entry.extra_users)
+                rooms.add(entry.room_id)
 
         if users or rooms:
             self.on_new_event(
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9b5478b53..82a72dc34f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,8 +15,8 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
 
+import attr
 from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, Membership, RelationTypes
@@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import register_cache
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import lru_cache
+from synapse.util.caches.lrucache import LruCache
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
@@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
             dict of user_id -> push_rules
         """
         room_id = event.room_id
-        rules_for_room = await self._get_rules_for_room(room_id)
+        rules_for_room = self._get_rules_for_room(room_id)
 
         rules_by_user = await rules_for_room.get_rules(event, context)
 
@@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
 
         return rules_by_user
 
-    @cached()
+    @lru_cache()
     def _get_rules_for_room(self, room_id):
         """Get the current RulesForRoom object for the given room id
 
@@ -275,12 +276,14 @@ class RulesForRoom:
     the entire cache for the room.
     """
 
-    def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
+    def __init__(
+        self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+    ):
         """
         Args:
             hs (HomeServer)
             room_id (str)
-            rules_for_room_cache(Cache): The cache object that caches these
+            rules_for_room_cache: The cache object that caches these
                 RoomsForUser objects.
             room_push_rule_cache_metrics (CacheMetric)
         """
@@ -489,13 +492,21 @@ class RulesForRoom:
             self.state_group = state_group
 
 
-class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
-    # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
-    # which namedtuple does for us (i.e. two _CacheContext are the same if
-    # their caches and keys match). This is important in particular to
-    # dedupe when we add callbacks to lru cache nodes, otherwise the number
-    # of callbacks would grow.
+@attr.attrs(slots=True, frozen=True)
+class _Invalidation:
+    # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
+    # which means that it it is stored on the bulk_get_push_rules cache entry. In order
+    # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
+    # we need to ensure that two _Invalidation objects are "equal" if they refer to the
+    # same `cache` and `room_id`.
+    #
+    # attrs provides suitable __hash__ and __eq__ methods, provided we remember to
+    # set `frozen=True`.
+
+    cache = attr.ib(type=LruCache)
+    room_id = attr.ib(type=str)
+
     def __call__(self):
-        rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
+        rules = self.cache.get(self.room_id, None, update_metrics=False)
         if rules:
             rules.invalidate_all()
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index e7cc74a5d2..f0c37eaf5e 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
 
         requester = Requester.deserialize(self.store, content["requester"])
 
-        if requester.user:
-            request.authenticated_entity = requester.user.to_string()
+        request.requester = requester
 
         logger.info("remote_join: %s into room: %s", user_id, room_id)
 
@@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
 
         requester = Requester.deserialize(self.store, content["requester"])
 
-        if requester.user:
-            request.authenticated_entity = requester.user.to_string()
+        request.requester = requester
 
         # hopefully we're now on the master, so this won't recurse!
         event_id, stream_id = await self.member_handler.remote_reject_invite(
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index fc129dbaa7..8fa104c8d3 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             ratelimit = content["ratelimit"]
             extra_users = [UserID.from_string(u) for u in content["extra_users"]]
 
-        if requester.user:
-            request.authenticated_entity = requester.user.to_string()
+        request.requester = requester
 
         logger.info(
             "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e27ee216f0..2618eb1e53 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -141,21 +141,25 @@ class ReplicationDataHandler:
                 if row.type != EventsStreamEventRow.TypeId:
                     continue
                 assert isinstance(row, EventsStreamRow)
+                assert isinstance(row.data, EventsStreamEventRow)
 
-                event = await self.store.get_event(
-                    row.data.event_id, allow_rejected=True
-                )
-                if event.rejected_reason:
+                if row.data.rejected:
                     continue
 
                 extra_users = ()  # type: Tuple[UserID, ...]
-                if event.type == EventTypes.Member:
-                    extra_users = (UserID.from_string(event.state_key),)
+                if row.data.type == EventTypes.Member and row.data.state_key:
+                    extra_users = (UserID.from_string(row.data.state_key),)
 
                 max_token = self.store.get_room_max_token()
                 event_pos = PersistedEventPosition(instance_name, token)
-                self.notifier.on_new_room_event(
-                    event, event_pos, max_token, extra_users
+                self.notifier.on_new_room_event_args(
+                    event_pos=event_pos,
+                    max_room_stream_token=max_token,
+                    extra_users=extra_users,
+                    room_id=row.data.room_id,
+                    event_type=row.data.type,
+                    state_key=row.data.state_key,
+                    membership=row.data.membership,
                 )
 
         # Notify any waiting deferreds. The list is ordered by position so we
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 82e9e0d64e..86a62b71eb 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,12 +15,15 @@
 # limitations under the License.
 import heapq
 from collections.abc import Iterable
-from typing import List, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
 
 import attr
 
 from ._base import Stream, StreamUpdateResult, Token
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 """Handling of the 'events' replication stream
 
 This stream contains rows of various types. Each row therefore contains a 'type'
@@ -81,12 +84,14 @@ class BaseEventsStreamRow:
 class EventsStreamEventRow(BaseEventsStreamRow):
     TypeId = "ev"
 
-    event_id = attr.ib()  # str
-    room_id = attr.ib()  # str
-    type = attr.ib()  # str
-    state_key = attr.ib()  # str, optional
-    redacts = attr.ib()  # str, optional
-    relates_to = attr.ib()  # str, optional
+    event_id = attr.ib(type=str)
+    room_id = attr.ib(type=str)
+    type = attr.ib(type=str)
+    state_key = attr.ib(type=Optional[str])
+    redacts = attr.ib(type=Optional[str])
+    relates_to = attr.ib(type=Optional[str])
+    membership = attr.ib(type=Optional[str])
+    rejected = attr.ib(type=bool)
 
 
 @attr.s(slots=True, frozen=True)
@@ -113,7 +118,7 @@ class EventsStream(Stream):
 
     NAME = "events"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index a79996cae1..fa7e9e4043 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
 from synapse.rest.admin.users import (
     AccountValidityRenewServlet,
     DeactivateAccountRestServlet,
+    PushersRestServlet,
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
     UserAdminServlet,
@@ -226,8 +227,9 @@ def register_servlets(hs, http_server):
     DeviceRestServlet(hs).register(http_server)
     DevicesRestServlet(hs).register(http_server)
     DeleteDevicesRestServlet(hs).register(http_server)
-    EventReportsRestServlet(hs).register(http_server)
     EventReportDetailRestServlet(hs).register(http_server)
+    EventReportsRestServlet(hs).register(http_server)
+    PushersRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 933bb45346..b337311a37 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -39,6 +39,17 @@ from synapse.types import JsonDict, UserID
 
 logger = logging.getLogger(__name__)
 
+_GET_PUSHERS_ALLOWED_KEYS = {
+    "app_display_name",
+    "app_id",
+    "data",
+    "device_display_name",
+    "kind",
+    "lang",
+    "profile_tag",
+    "pushkey",
+}
+
 
 class UsersRestServlet(RestServlet):
     PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
@@ -713,6 +724,47 @@ class UserMembershipRestServlet(RestServlet):
         return 200, ret
 
 
+class PushersRestServlet(RestServlet):
+    """
+    Gets information about all pushers for a specific `user_id`.
+
+    Example:
+        http://localhost:8008/_synapse/admin/v1/users/
+        @user:server/pushers
+
+    Returns:
+        pushers: Dictionary containing pushers information.
+        total: Number of pushers in dictonary `pushers`.
+    """
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
+
+    def __init__(self, hs):
+        self.is_mine = hs.is_mine
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.is_mine(UserID.from_string(user_id)):
+            raise SynapseError(400, "Can only lookup local users")
+
+        if not await self.store.get_user_by_id(user_id):
+            raise NotFoundError("User not found")
+
+        pushers = await self.store.get_pushers_by_user_id(user_id)
+
+        filtered_pushers = [
+            {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
+            for p in pushers
+        ]
+
+        return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
+
+
 class UserMediaRestServlet(RestServlet):
     """
     Gets information about all uploaded local media for a specific `user_id`.
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 5cce7237a0..9cac74ebd8 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -305,15 +305,12 @@ class MediaRepository:
         # file_id is the ID we use to track the file locally. If we've already
         # seen the file then reuse the existing ID, otherwise genereate a new
         # one.
-        if media_info:
-            file_id = media_info["filesystem_id"]
-        else:
-            file_id = random_string(24)
-
-        file_info = FileInfo(server_name, file_id)
 
         # If we have an entry in the DB, try and look for it
         if media_info:
+            file_id = media_info["filesystem_id"]
+            file_info = FileInfo(server_name, file_id)
+
             if media_info["quarantined_by"]:
                 logger.info("Media is quarantined")
                 raise NotFoundError()
@@ -324,14 +321,34 @@ class MediaRepository:
 
         # Failed to find the file anywhere, lets download it.
 
-        media_info = await self._download_remote_file(server_name, media_id, file_id)
+        try:
+            media_info = await self._download_remote_file(server_name, media_id,)
+        except SynapseError:
+            raise
+        except Exception as e:
+            # An exception may be because we downloaded media in another
+            # process, so let's check if we magically have the media.
+            media_info = await self.store.get_cached_remote_media(server_name, media_id)
+            if not media_info:
+                raise e
+
+        file_id = media_info["filesystem_id"]
+        file_info = FileInfo(server_name, file_id)
+
+        # We generate thumbnails even if another process downloaded the media
+        # as a) it's conceivable that the other download request dies before it
+        # generates thumbnails, but mainly b) we want to be sure the thumbnails
+        # have finished being generated before responding to the client,
+        # otherwise they'll request thumbnails and get a 404 if they're not
+        # ready yet.
+        await self._generate_thumbnails(
+            server_name, media_id, file_id, media_info["media_type"]
+        )
 
         responder = await self.media_storage.fetch_media(file_info)
         return responder, media_info
 
-    async def _download_remote_file(
-        self, server_name: str, media_id: str, file_id: str
-    ) -> dict:
+    async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
@@ -346,6 +363,8 @@ class MediaRepository:
             The media info of the file.
         """
 
+        file_id = random_string(24)
+
         file_info = FileInfo(server_name=server_name, file_id=file_id)
 
         with self.media_storage.store_into_file(file_info) as (f, fname, finish):
@@ -401,22 +420,32 @@ class MediaRepository:
 
             await finish()
 
-        media_type = headers[b"Content-Type"][0].decode("ascii")
-        upload_name = get_filename_from_headers(headers)
-        time_now_ms = self.clock.time_msec()
+            media_type = headers[b"Content-Type"][0].decode("ascii")
+            upload_name = get_filename_from_headers(headers)
+            time_now_ms = self.clock.time_msec()
+
+            # Multiple remote media download requests can race (when using
+            # multiple media repos), so this may throw a violation constraint
+            # exception. If it does we'll delete the newly downloaded file from
+            # disk (as we're in the ctx manager).
+            #
+            # However: we've already called `finish()` so we may have also
+            # written to the storage providers. This is preferable to the
+            # alternative where we call `finish()` *after* this, where we could
+            # end up having an entry in the DB but fail to write the files to
+            # the storage providers.
+            await self.store.store_cached_remote_media(
+                origin=server_name,
+                media_id=media_id,
+                media_type=media_type,
+                time_now_ms=self.clock.time_msec(),
+                upload_name=upload_name,
+                media_length=length,
+                filesystem_id=file_id,
+            )
 
         logger.info("Stored remote media in file %r", fname)
 
-        await self.store.store_cached_remote_media(
-            origin=server_name,
-            media_id=media_id,
-            media_type=media_type,
-            time_now_ms=self.clock.time_msec(),
-            upload_name=upload_name,
-            media_length=length,
-            filesystem_id=file_id,
-        )
-
         media_info = {
             "media_type": media_type,
             "media_length": length,
@@ -425,8 +454,6 @@ class MediaRepository:
             "filesystem_id": file_id,
         }
 
-        await self._generate_thumbnails(server_name, media_id, file_id, media_type)
-
         return media_info
 
     def _get_thumbnail_requirements(self, media_type):
@@ -692,42 +719,60 @@ class MediaRepository:
             if not t_byte_source:
                 continue
 
-            try:
-                file_info = FileInfo(
-                    server_name=server_name,
-                    file_id=file_id,
-                    thumbnail=True,
-                    thumbnail_width=t_width,
-                    thumbnail_height=t_height,
-                    thumbnail_method=t_method,
-                    thumbnail_type=t_type,
-                    url_cache=url_cache,
-                )
-
-                output_path = await self.media_storage.store_file(
-                    t_byte_source, file_info
-                )
-            finally:
-                t_byte_source.close()
-
-            t_len = os.path.getsize(output_path)
+            file_info = FileInfo(
+                server_name=server_name,
+                file_id=file_id,
+                thumbnail=True,
+                thumbnail_width=t_width,
+                thumbnail_height=t_height,
+                thumbnail_method=t_method,
+                thumbnail_type=t_type,
+                url_cache=url_cache,
+            )
 
-            # Write to database
-            if server_name:
-                await self.store.store_remote_media_thumbnail(
-                    server_name,
-                    media_id,
-                    file_id,
-                    t_width,
-                    t_height,
-                    t_type,
-                    t_method,
-                    t_len,
-                )
-            else:
-                await self.store.store_local_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method, t_len
-                )
+            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+                try:
+                    await self.media_storage.write_to_file(t_byte_source, f)
+                    await finish()
+                finally:
+                    t_byte_source.close()
+
+                t_len = os.path.getsize(fname)
+
+                # Write to database
+                if server_name:
+                    # Multiple remote media download requests can race (when
+                    # using multiple media repos), so this may throw a violation
+                    # constraint exception. If it does we'll delete the newly
+                    # generated thumbnail from disk (as we're in the ctx
+                    # manager).
+                    #
+                    # However: we've already called `finish()` so we may have
+                    # also written to the storage providers. This is preferable
+                    # to the alternative where we call `finish()` *after* this,
+                    # where we could end up having an entry in the DB but fail
+                    # to write the files to the storage providers.
+                    try:
+                        await self.store.store_remote_media_thumbnail(
+                            server_name,
+                            media_id,
+                            file_id,
+                            t_width,
+                            t_height,
+                            t_type,
+                            t_method,
+                            t_len,
+                        )
+                    except Exception as e:
+                        thumbnail_exists = await self.store.get_remote_media_thumbnail(
+                            server_name, media_id, t_width, t_height, t_type,
+                        )
+                        if not thumbnail_exists:
+                            raise e
+                else:
+                    await self.store.store_local_thumbnail(
+                        media_id, t_width, t_height, t_type, t_method, t_len
+                    )
 
         return {"width": m_width, "height": m_height}
 
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a9586fb0b7..268e0c8f50 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -52,6 +52,7 @@ class MediaStorage:
         storage_providers: Sequence["StorageProviderWrapper"],
     ):
         self.hs = hs
+        self.reactor = hs.get_reactor()
         self.local_media_directory = local_media_directory
         self.filepaths = filepaths
         self.storage_providers = storage_providers
@@ -70,13 +71,16 @@ class MediaStorage:
 
         with self.store_into_file(file_info) as (f, fname, finish_cb):
             # Write to the main repository
-            await defer_to_thread(
-                self.hs.get_reactor(), _write_file_synchronously, source, f
-            )
+            await self.write_to_file(source, f)
             await finish_cb()
 
         return fname
 
+    async def write_to_file(self, source: IO, output: IO):
+        """Asynchronously write the `source` to `output`.
+        """
+        await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
     @contextlib.contextmanager
     def store_into_file(self, file_info: FileInfo):
         """Context manager used to get a file like object to write into, as
@@ -112,14 +116,20 @@ class MediaStorage:
 
         finished_called = [False]
 
-        async def finish():
-            for provider in self.storage_providers:
-                await provider.store_file(path, file_info)
-
-            finished_called[0] = True
-
         try:
             with open(fname, "wb") as f:
+
+                async def finish():
+                    # Ensure that all writes have been flushed and close the
+                    # file.
+                    f.flush()
+                    f.close()
+
+                    for provider in self.storage_providers:
+                        await provider.store_file(path, file_info)
+
+                    finished_called[0] = True
+
                 yield f, fname, finish
         except Exception:
             try:
@@ -210,7 +220,7 @@ class MediaStorage:
             if res:
                 with res:
                     consumer = BackgroundFileConsumer(
-                        open(local_path, "wb"), self.hs.get_reactor()
+                        open(local_path, "wb"), self.reactor
                     )
                     await res.write_to_consumer(consumer)
                     await consumer.wait()
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0217e63108..a0572b2952 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -94,7 +94,7 @@ def make_pool(
         cp_openfun=lambda conn: engine.on_new_connection(
             LoggingDatabaseConnection(conn, engine, "on_new_connection")
         ),
-        **db_config.config.get("args", {})
+        **db_config.config.get("args", {}),
     )
 
 
@@ -632,7 +632,7 @@ class DatabasePool:
                 func,
                 *args,
                 db_autocommit=db_autocommit,
-                **kwargs
+                **kwargs,
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 637a938bac..26eef6eb61 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,21 +15,31 @@
 # limitations under the License.
 import logging
 import re
-from typing import List
+from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
 
-from synapse.appservice import ApplicationService, AppServiceTransaction
+from synapse.appservice import (
+    ApplicationService,
+    ApplicationServiceState,
+    AppServiceTransaction,
+)
 from synapse.config.appservice import load_appservices
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.types import Connection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-def _make_exclusive_regex(services_cache):
+def _make_exclusive_regex(
+    services_cache: List[ApplicationService],
+) -> Optional[Pattern]:
     # We precompile a regex constructed from all the regexes that the AS's
     # have registered for exclusive users.
     exclusive_user_regexes = [
@@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
-        exclusive_user_regex = re.compile(exclusive_user_regex)
+        exclusive_user_pattern = re.compile(
+            exclusive_user_regex
+        )  # type: Optional[Pattern]
     else:
         # We handle this case specially otherwise the constructed regex
         # will always match
-        exclusive_user_regex = None
+        exclusive_user_pattern = None
 
-    return exclusive_user_regex
+    return exclusive_user_pattern
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
@@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
     def get_app_services(self):
         return self.services_cache
 
-    def get_if_app_services_interested_in_user(self, user_id):
+    def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
         """Check if the user is one associated with an app service (exclusively)
         """
         if self.exclusive_user_regex:
@@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         else:
             return False
 
-    def get_app_service_by_user_id(self, user_id):
+    def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
         """Retrieve an application service from their user ID.
 
         All application services have associated with them a particular user ID.
@@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         a user ID to an application service.
 
         Args:
-            user_id(str): The user ID to see if it is an application service.
+            user_id: The user ID to see if it is an application service.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.sender == user_id:
                 return service
         return None
 
-    def get_app_service_by_token(self, token):
+    def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
         """Get the application service with the given appservice token.
 
         Args:
-            token (str): The application service token.
+            token: The application service token.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.token == token:
                 return service
         return None
 
-    def get_app_service_by_id(self, as_id):
+    def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
         """Get the application service with the given appservice ID.
 
         Args:
-            as_id (str): The application service ID.
+            as_id: The application service ID.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.id == as_id:
@@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
 class ApplicationServiceTransactionWorkerStore(
     ApplicationServiceWorkerStore, EventsWorkerStore
 ):
-    async def get_appservices_by_state(self, state):
+    async def get_appservices_by_state(
+        self, state: ApplicationServiceState
+    ) -> List[ApplicationService]:
         """Get a list of application services based on their state.
 
         Args:
-            state(ApplicationServiceState): The state to filter on.
+            state: The state to filter on.
         Returns:
             A list of ApplicationServices, which may be empty.
         """
@@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
                     services.append(service)
         return services
 
-    async def get_appservice_state(self, service):
+    async def get_appservice_state(
+        self, service: ApplicationService
+    ) -> Optional[ApplicationServiceState]:
         """Get the application service state.
 
         Args:
-            service(ApplicationService): The service whose state to set.
+            service: The service whose state to set.
         Returns:
-            An ApplicationServiceState.
+            An ApplicationServiceState or none.
         """
         result = await self.db_pool.simple_select_one(
             "application_services_state",
@@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
             return result.get("state")
         return None
 
-    async def set_appservice_state(self, service, state) -> None:
+    async def set_appservice_state(
+        self, service: ApplicationService, state: ApplicationServiceState
+    ) -> None:
         """Set the application service state.
 
         Args:
-            service(ApplicationService): The service whose state to set.
-            state(ApplicationServiceState): The connectivity state to apply.
+            service: The service whose state to set.
+            state: The connectivity state to apply.
         """
         await self.db_pool.simple_upsert(
             "application_services_state", {"as_id": service.id}, {"state": state}
@@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
             "create_appservice_txn", _create_appservice_txn
         )
 
-    async def complete_appservice_txn(self, txn_id, service) -> None:
+    async def complete_appservice_txn(
+        self, txn_id: int, service: ApplicationService
+    ) -> None:
         """Completes an application service transaction.
 
         Args:
-            txn_id(str): The transaction ID being completed.
-            service(ApplicationService): The application service which was sent
-            this transaction.
+            txn_id: The transaction ID being completed.
+            service: The application service which was sent this transaction.
         """
         txn_id = int(txn_id)
 
@@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
             # has probably missed some events), so whine loudly but still continue,
             # since it shouldn't fail completion of the transaction.
             last_txn_id = self._get_last_txn(txn, service.id)
-            if (last_txn_id + 1) != txn_id:
+            if (txn_id + 1) != txn_id:
                 logger.error(
                     "appservice: Completing a transaction which has an ID > 1 from "
                     "the last ID sent to this AS. We've either dropped events or "
@@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
             "complete_appservice_txn", _complete_appservice_txn
         )
 
-    async def get_oldest_unsent_txn(self, service):
-        """Get the oldest transaction which has not been sent for this
-        service.
+    async def get_oldest_unsent_txn(
+        self, service: ApplicationService
+    ) -> Optional[AppServiceTransaction]:
+        """Get the oldest transaction which has not been sent for this service.
 
         Args:
-            service(ApplicationService): The app service to get the oldest txn.
+            service: The app service to get the oldest txn.
         Returns:
             An AppServiceTransaction or None.
         """
@@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
             service=service, id=entry["txn_id"], events=events, ephemeral=[]
         )
 
-    def _get_last_txn(self, txn, service_id):
+    def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
         txn.execute(
             "SELECT last_txn FROM application_services_state WHERE as_id=?",
             (service_id,),
@@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
         else:
             return int(last_txn_id[0])  # select 'last_txn' col
 
-    async def set_appservice_last_pos(self, pos) -> None:
+    async def set_appservice_last_pos(self, pos: int) -> None:
         def set_appservice_last_pos_txn(txn):
             txn.execute(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
-    async def get_new_events_for_appservice(self, current_id, limit):
+    async def get_new_events_for_appservice(
+        self, current_id: int, limit: int
+    ) -> Tuple[int, List[EventBase]]:
         """Get all new events for an appservice"""
 
         def get_new_events_for_appservice_txn(txn):
@@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
         )
 
     async def set_type_stream_id_for_appservice(
-        self, service: ApplicationService, type: str, pos: int
+        self, service: ApplicationService, type: str, pos: Optional[int]
     ) -> None:
         if type not in ("read_receipt", "presence"):
             raise ValueError(
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 849bd5ba7a..3e26d5ba87 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -22,7 +22,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util import json_encoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -104,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 and original_event.internal_metadata.is_redacted()
             ):
                 # Redaction was allowed
-                pruned_json = frozendict_json_encoder.encode(
+                pruned_json = json_encoder.encode(
                     prune_event_dict(
                         original_event.room_version, original_event.get_dict()
                     )
@@ -170,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 return
 
             # Prune the event's dict then convert it to JSON.
-            pruned_json = frozendict_json_encoder.encode(
+            pruned_json = json_encoder.encode(
                 prune_event_dict(event.room_version, event.get_dict())
             )
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 87808c1483..90fb1a1f00 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import StateMap, get_domain_from_id
-from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -769,9 +769,7 @@ class PersistEventsStore:
                     logger.exception("")
                     raise
 
-                metadata_json = frozendict_json_encoder.encode(
-                    event.internal_metadata.get_dict()
-                )
+                metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
 
                 sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
                 txn.execute(sql, (metadata_json, event.event_id))
@@ -826,10 +824,10 @@ class PersistEventsStore:
                 {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
-                    "internal_metadata": frozendict_json_encoder.encode(
+                    "internal_metadata": json_encoder.encode(
                         event.internal_metadata.get_dict()
                     ),
-                    "json": frozendict_json_encoder.encode(event_dict(event)),
+                    "json": json_encoder.encode(event_dict(event)),
                     "format_version": event.format_version,
                 }
                 for event, _ in events_and_contexts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6e7f16f39c..4732685f6e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -31,6 +31,7 @@ from synapse.api.room_versions import (
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import (
@@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, JsonDict, get_domain_from_id
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
@@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_map
 
+    async def get_stripped_room_state_from_event_context(
+        self,
+        context: EventContext,
+        state_types_to_include: List[EventTypes],
+        membership_user_id: Optional[str] = None,
+    ) -> List[JsonDict]:
+        """
+        Retrieve the stripped state from a room, given an event context to retrieve state
+        from as well as the state types to include. Optionally, include the membership
+        events from a specific user.
+
+        "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
+        are included from each state event.
+
+        Args:
+            context: The event context to retrieve state of the room from.
+            state_types_to_include: The type of state events to include.
+            membership_user_id: An optional user ID to include the stripped membership state
+                events of. This is useful when generating the stripped state of a room for
+                invites. We want to send membership events of the inviter, so that the
+                invitee can display the inviter's profile information if the room lacks any.
+
+        Returns:
+            A list of dictionaries, each representing a stripped state event from the room.
+        """
+        current_state_ids = await context.get_current_state_ids()
+
+        # We know this event is not an outlier, so this must be
+        # non-None.
+        assert current_state_ids is not None
+
+        # The state to include
+        state_to_include_ids = [
+            e_id
+            for k, e_id in current_state_ids.items()
+            if k[0] in state_types_to_include
+            or (membership_user_id and k == (EventTypes.Member, membership_user_id))
+        ]
+
+        state_to_include = await self.get_events(state_to_include_ids)
+
+        return [
+            {
+                "type": e.type,
+                "state_key": e.state_key,
+                "content": e.content,
+                "sender": e.sender,
+            }
+            for e in state_to_include.values()
+        ]
+
     def _do_fetch(self, conn):
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
@@ -1065,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore):
         def get_all_new_forward_event_rows(txn):
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
+                " LEFT JOIN room_memberships USING (event_id)"
+                " LEFT JOIN rejections USING (event_id)"
                 " WHERE ? < stream_ordering AND stream_ordering <= ?"
                 " AND instance_name = ?"
                 " ORDER BY stream_ordering ASC"
@@ -1100,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore):
         def get_ex_outlier_stream_rows_txn(txn):
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
+                " LEFT JOIN room_memberships USING (event_id)"
+                " LEFT JOIN rejections USING (event_id)"
                 " WHERE ? < event_stream_ordering"
                 " AND event_stream_ordering <= ?"
                 " AND out.instance_name = ?"
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index daf57675d8..4b2f224718 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_remote_media_thumbnails",
         )
 
+    async def get_remote_media_thumbnail(
+        self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+    ) -> Optional[Dict[str, Any]]:
+        """Fetch the thumbnail info of given width, height and type.
+        """
+
+        return await self.db_pool.simple_select_one(
+            table="remote_media_cache_thumbnails",
+            keyvalues={
+                "media_origin": origin,
+                "media_id": media_id,
+                "thumbnail_width": t_width,
+                "thumbnail_height": t_height,
+                "thumbnail_type": t_type,
+            },
+            retcols=(
+                "thumbnail_width",
+                "thumbnail_height",
+                "thumbnail_method",
+                "thumbnail_type",
+                "thumbnail_length",
+                "filesystem_id",
+            ),
+            allow_none=True,
+            desc="get_remote_media_thumbnail",
+        )
+
     async def store_remote_media_thumbnail(
         self,
         origin,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e7b17a7385..e5d07ce72a 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -18,6 +18,8 @@ import logging
 import re
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
+import attr
+
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 logger = logging.getLogger(__name__)
 
 
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+    """Result of looking up an access token.
+
+    Attributes:
+        user_id: The user that this token authenticates as
+        is_guest
+        shadow_banned
+        token_id: The ID of the access token looked up
+        device_id: The device associated with the token, if any.
+        valid_until_ms: The timestamp the token expires, if any.
+        token_owner: The "owner" of the token. This is either the same as the
+            user, or a server admin who is logged in as the user.
+    """
+
+    user_id = attr.ib(type=str)
+    is_guest = attr.ib(type=bool, default=False)
+    shadow_banned = attr.ib(type=bool, default=False)
+    token_id = attr.ib(type=Optional[int], default=None)
+    device_id = attr.ib(type=Optional[str], default=None)
+    valid_until_ms = attr.ib(type=Optional[int], default=None)
+    token_owner = attr.ib(type=str)
+
+    # Make the token owner default to the user ID, which is the common case.
+    @token_owner.default
+    def _default_token_owner(self):
+        return self.user_id
+
+
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
@@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         return is_trial
 
     @cached()
-    async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+    async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
         """Get a user from the given access token.
 
         Args:
             token: The access token of a user.
         Returns:
-            None, if the token did not match, otherwise dict
-            including the keys `name`, `is_guest`, `device_id`, `token_id`,
-            `valid_until_ms`.
+            None, if the token did not match, otherwise a `TokenLookupResult`
         """
         return await self.db_pool.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
@@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
-    def _query_for_auth(self, txn, token):
+    def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
         sql = """
-            SELECT users.name,
+            SELECT users.name as user_id,
                 users.is_guest,
                 users.shadow_banned,
                 access_tokens.id as token_id,
                 access_tokens.device_id,
-                access_tokens.valid_until_ms
+                access_tokens.valid_until_ms,
+                access_tokens.user_id as token_owner
             FROM users
-            INNER JOIN access_tokens on users.name = access_tokens.user_id
+            INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
             WHERE token = ?
         """
 
         txn.execute(sql, (token,))
         rows = self.db_pool.cursor_to_dict(txn)
         if rows:
-            return rows[0]
+            return TokenLookupResult(**rows[0])
 
         return None
 
diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
new file mode 100644
index 0000000000..00a9431a97
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Whether the access token is an admin token for controlling another user.
+ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
diff --git a/synapse/types.py b/synapse/types.py
index 5bde67cc07..66bb5bac8d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -29,6 +29,7 @@ from typing import (
     Tuple,
     Type,
     TypeVar,
+    Union,
 )
 
 import attr
@@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
 from synapse.api.errors import Codes, SynapseError
 
 if TYPE_CHECKING:
+    from synapse.appservice.api import ApplicationService
     from synapse.storage.databases.main import DataStore
 
 # define a version of typing.Collection that works on python 3.5
@@ -74,6 +76,7 @@ class Requester(
             "shadow_banned",
             "device_id",
             "app_service",
+            "authenticated_entity",
         ],
     )
 ):
@@ -104,6 +107,7 @@ class Requester(
             "shadow_banned": self.shadow_banned,
             "device_id": self.device_id,
             "app_server_id": self.app_service.id if self.app_service else None,
+            "authenticated_entity": self.authenticated_entity,
         }
 
     @staticmethod
@@ -129,16 +133,18 @@ class Requester(
             shadow_banned=input["shadow_banned"],
             device_id=input["device_id"],
             app_service=appservice,
+            authenticated_entity=input["authenticated_entity"],
         )
 
 
 def create_requester(
-    user_id,
-    access_token_id=None,
-    is_guest=False,
-    shadow_banned=False,
-    device_id=None,
-    app_service=None,
+    user_id: Union[str, "UserID"],
+    access_token_id: Optional[int] = None,
+    is_guest: Optional[bool] = False,
+    shadow_banned: Optional[bool] = False,
+    device_id: Optional[str] = None,
+    app_service: Optional["ApplicationService"] = None,
+    authenticated_entity: Optional[str] = None,
 ):
     """
     Create a new ``Requester`` object
@@ -151,14 +157,27 @@ def create_requester(
         shadow_banned (bool):  True if the user making this request is shadow-banned.
         device_id (str|None):  device_id which was set at authentication time
         app_service (ApplicationService|None):  the AS requesting on behalf of the user
+        authenticated_entity: The entity that authenticated when making the request.
+            This is different to the user_id when an admin user or the server is
+            "puppeting" the user.
 
     Returns:
         Requester
     """
     if not isinstance(user_id, UserID):
         user_id = UserID.from_string(user_id)
+
+    if authenticated_entity is None:
+        authenticated_entity = user_id.to_string()
+
     return Requester(
-        user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+        user_id,
+        access_token_id,
+        is_guest,
+        shadow_banned,
+        device_id,
+        app_service,
+        authenticated_entity,
     )
 
 
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index d55b93d763..517686f0a6 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -18,6 +18,7 @@ import logging
 import re
 
 import attr
+from frozendict import frozendict
 
 from twisted.internet import defer, task
 
@@ -31,9 +32,26 @@ def _reject_invalid_json(val):
     raise ValueError("Invalid JSON value: '%s'" % val)
 
 
-# Create a custom encoder to reduce the whitespace produced by JSON encoding and
-# ensure that valid JSON is produced.
-json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
+def _handle_frozendict(obj):
+    """Helper for json_encoder. Makes frozendicts serializable by returning
+    the underlying dict
+    """
+    if type(obj) is frozendict:
+        # fishing the protected dict out of the object is a bit nasty,
+        # but we don't really want the overhead of copying the dict.
+        return obj._dict
+    raise TypeError(
+        "Object of type %s is not JSON serializable" % obj.__class__.__name__
+    )
+
+
+# A custom JSON encoder which:
+#   * handles frozendicts
+#   * produces valid JSON (no NaNs etc)
+#   * reduces redundant whitespace
+json_encoder = json.JSONEncoder(
+    allow_nan=False, separators=(",", ":"), default=_handle_frozendict
+)
 
 # Create a custom decoder to reject Python extensions to JSON.
 json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5d7fffee66..a924140cdf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,10 +13,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import enum
 import functools
 import inspect
 import logging
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 from weakref import WeakValueDictionary
 
 from twisted.internet import defer
@@ -24,6 +37,7 @@ from twisted.internet import defer
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
 from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
 
@@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
 
 
 class _CacheDescriptorBase:
-    def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
+    def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
         self.orig = orig
 
         arg_spec = inspect.getfullargspec(orig)
@@ -97,8 +111,107 @@ class _CacheDescriptorBase:
 
         self.add_cache_context = cache_context
 
+        self.cache_key_builder = get_cache_key_builder(
+            self.arg_names, self.arg_defaults
+        )
+
+
+class _LruCachedFunction(Generic[F]):
+    cache = None  # type: LruCache[CacheKey, Any]
+    __call__ = None  # type: F
+
+
+def lru_cache(
+    max_entries: int = 1000, cache_context: bool = False,
+) -> Callable[[F], _LruCachedFunction[F]]:
+    """A method decorator that applies a memoizing cache around the function.
+
+    This is more-or-less a drop-in equivalent to functools.lru_cache, although note
+    that the signature is slightly different.
+
+    The main differences with functools.lru_cache are:
+        (a) the size of the cache can be controlled via the cache_factor mechanism
+        (b) the wrapped function can request a "cache_context" which provides a
+            callback mechanism to indicate that the result is no longer valid
+        (c) prometheus metrics are exposed automatically.
+
+    The function should take zero or more arguments, which are used as the key for the
+    cache. Single-argument functions use that argument as the cache key; otherwise the
+    arguments are built into a tuple.
+
+    Cached functions can be "chained" (i.e. a cached function can call other cached
+    functions and get appropriately invalidated when they called caches are
+    invalidated) by adding a special "cache_context" argument to the function
+    and passing that as a kwarg to all caches called. For example:
+
+        @lru_cache(cache_context=True)
+        def foo(self, key, cache_context):
+            r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
+            r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
+            return r1 + r2
+
+    The wrapped function also has a 'cache' property which offers direct access to the
+    underlying LruCache.
+    """
+
+    def func(orig: F) -> _LruCachedFunction[F]:
+        desc = LruCacheDescriptor(
+            orig, max_entries=max_entries, cache_context=cache_context,
+        )
+        return cast(_LruCachedFunction[F], desc)
+
+    return func
+
+
+class LruCacheDescriptor(_CacheDescriptorBase):
+    """Helper for @lru_cache"""
+
+    class _Sentinel(enum.Enum):
+        sentinel = object()
+
+    def __init__(
+        self, orig, max_entries: int = 1000, cache_context: bool = False,
+    ):
+        super().__init__(orig, num_args=None, cache_context=cache_context)
+        self.max_entries = max_entries
+
+    def __get__(self, obj, owner):
+        cache = LruCache(
+            cache_name=self.orig.__name__, max_size=self.max_entries,
+        )  # type: LruCache[CacheKey, Any]
+
+        get_cache_key = self.cache_key_builder
+        sentinel = LruCacheDescriptor._Sentinel.sentinel
+
+        @functools.wraps(self.orig)
+        def _wrapped(*args, **kwargs):
+            invalidate_callback = kwargs.pop("on_invalidate", None)
+            callbacks = (invalidate_callback,) if invalidate_callback else ()
+
+            cache_key = get_cache_key(args, kwargs)
 
-class CacheDescriptor(_CacheDescriptorBase):
+            ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
+            if ret != sentinel:
+                return ret
+
+            # Add our own `cache_context` to argument list if the wrapped function
+            # has asked for one
+            if self.add_cache_context:
+                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
+
+            ret2 = self.orig(obj, *args, **kwargs)
+            cache.set(cache_key, ret2, callbacks=callbacks)
+
+            return ret2
+
+        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped.cache = cache
+        obj.__dict__[self.orig.__name__] = wrapped
+
+        return wrapped
+
+
+class DeferredCacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
         cache_context=False,
         iterable=False,
     ):
-
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
         self.max_entries = max_entries
@@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )  # type: DeferredCache[CacheKey, Any]
 
-        def get_cache_key_gen(args, kwargs):
-            """Given some args/kwargs return a generator that resolves into
-            the cache_key.
-
-            We loop through each arg name, looking up if its in the `kwargs`,
-            otherwise using the next argument in `args`. If there are no more
-            args then we try looking the arg name up in the defaults
-            """
-            pos = 0
-            for nm in self.arg_names:
-                if nm in kwargs:
-                    yield kwargs[nm]
-                elif pos < len(args):
-                    yield args[pos]
-                    pos += 1
-                else:
-                    yield self.arg_defaults[nm]
-
-        # By default our cache key is a tuple, but if there is only one item
-        # then don't bother wrapping in a tuple.  This is to save memory.
-        if self.num_args == 1:
-            nm = self.arg_names[0]
-
-            def get_cache_key(args, kwargs):
-                if nm in kwargs:
-                    return kwargs[nm]
-                elif len(args):
-                    return args[0]
-                else:
-                    return self.arg_defaults[nm]
-
-        else:
-
-            def get_cache_key(args, kwargs):
-                return tuple(get_cache_key_gen(args, kwargs))
+        get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
         def _wrapped(*args, **kwargs):
@@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
         else:
             wrapped.invalidate = cache.invalidate
-            wrapped.invalidate_all = cache.invalidate_all
             wrapped.invalidate_many = cache.invalidate_many
             wrapped.prefill = cache.prefill
 
@@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
         return wrapped
 
 
-class CacheListDescriptor(_CacheDescriptorBase):
+class DeferredCacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
@@ -382,11 +459,13 @@ class _CacheContext:
     on a lower level.
     """
 
+    Cache = Union[DeferredCache, LruCache]
+
     _cache_context_objects = (
         WeakValueDictionary()
-    )  # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
+    )  # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
 
-    def __init__(self, cache, cache_key):  # type: (DeferredCache, CacheKey) -> None
+    def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
         self._cache = cache
         self._cache_key = cache_key
 
@@ -396,8 +475,8 @@ class _CacheContext:
 
     @classmethod
     def get_instance(
-        cls, cache, cache_key
-    ):  # type: (DeferredCache, CacheKey) -> _CacheContext
+        cls, cache: "_CacheContext.Cache", cache_key: CacheKey
+    ) -> "_CacheContext":
         """Returns an instance constructed with the given arguments.
 
         A new instance is only created if none already exists.
@@ -418,7 +497,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
 ) -> Callable[[F], _CachedFunction[F]]:
-    func = lambda orig: CacheDescriptor(
+    func = lambda orig: DeferredCacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
@@ -460,7 +539,7 @@ def cachedList(
             def batch_do_something(self, first_arg, second_args):
                 ...
     """
-    func = lambda orig: CacheListDescriptor(
+    func = lambda orig: DeferredCacheListDescriptor(
         orig,
         cached_method_name=cached_method_name,
         list_name=list_name,
@@ -468,3 +547,65 @@ def cachedList(
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
+
+
+def get_cache_key_builder(
+    param_names: Sequence[str], param_defaults: Mapping[str, Any]
+) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
+    """Construct a function which will build cache keys suitable for a cached function
+
+    Args:
+        param_names: list of formal parameter names for the cached function
+        param_defaults: a mapping from parameter name to default value for that param
+
+    Returns:
+        A function which will take an (args, kwargs) pair and return a cache key
+    """
+
+    # By default our cache key is a tuple, but if there is only one item
+    # then don't bother wrapping in a tuple.  This is to save memory.
+
+    if len(param_names) == 1:
+        nm = param_names[0]
+
+        def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+            if nm in kwargs:
+                return kwargs[nm]
+            elif len(args):
+                return args[0]
+            else:
+                return param_defaults[nm]
+
+    else:
+
+        def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+            return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+
+    return get_cache_key
+
+
+def _get_cache_key_gen(
+    param_names: Iterable[str],
+    param_defaults: Mapping[str, Any],
+    args: Sequence[Any],
+    kwargs: Mapping[str, Any],
+) -> Iterable[Any]:
+    """Given some args/kwargs return a generator that resolves into
+    the cache_key.
+
+    This is essentially the same operation as `inspect.getcallargs`, but optimised so
+    that we don't need to inspect the target function for each call.
+    """
+
+    # We loop through each arg name, looking up if its in the `kwargs`,
+    # otherwise using the next argument in `args`. If there are no more
+    # args then we try looking the arg name up in the defaults.
+    pos = 0
+    for nm in param_names:
+        if nm in kwargs:
+            yield kwargs[nm]
+        elif pos < len(args):
+            yield args[pos]
+            pos += 1
+        else:
+            yield param_defaults[nm]
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index bf094c9386..5f7a6dd1d3 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
-
 from frozendict import frozendict
 
 
@@ -49,23 +47,3 @@ def unfreeze(o):
         pass
 
     return o
-
-
-def _handle_frozendict(obj):
-    """Helper for EventEncoder. Makes frozendicts serializable by returning
-    the underlying dict
-    """
-    if type(obj) is frozendict:
-        # fishing the protected dict out of the object is a bit nasty,
-        # but we don't really want the overhead of copying the dict.
-        return obj._dict
-    raise TypeError(
-        "Object of type %s is not JSON serializable" % obj.__class__.__name__
-    )
-
-
-# A JSONEncoder which is capable of encoding frozendicts without barfing.
-# Additionally reduce the whitespace produced by JSON encoding.
-frozendict_json_encoder = json.JSONEncoder(
-    allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
-)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index a5cc9d0551..4ab379e429 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
         failure_ts,
         retry_interval,
         backoff_on_failure=backoff_on_failure,
-        **kwargs
+        **kwargs,
     )