summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6484.misc1
-rw-r--r--changelog.d/6488.removal1
-rw-r--r--changelog.d/6491.bugfix1
-rw-r--r--changelog.d/6493.bugfix1
-rw-r--r--synapse/federation/federation_client.py89
-rw-r--r--synapse/federation/transport/client.py24
-rw-r--r--synapse/handlers/events.py30
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/handlers/sync.py251
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/notifier.py29
-rw-r--r--synapse/storage/data_stores/main/account_data.py2
-rw-r--r--synapse/storage/data_stores/main/group_server.py4
-rw-r--r--synapse/util/metrics.py83
-rw-r--r--tests/handlers/test_sync.py33
-rw-r--r--tests/handlers/test_typing.py24
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/unittest.py7
18 files changed, 249 insertions, 341 deletions
diff --git a/changelog.d/6484.misc b/changelog.d/6484.misc
new file mode 100644
index 0000000000..b7cd600012
--- /dev/null
+++ b/changelog.d/6484.misc
@@ -0,0 +1 @@
+Port SyncHandler to async/await.
diff --git a/changelog.d/6488.removal b/changelog.d/6488.removal
new file mode 100644
index 0000000000..06e034a213
--- /dev/null
+++ b/changelog.d/6488.removal
@@ -0,0 +1 @@
+Remove fallback for federation with old servers which lack the /federation/v1/state_ids API.
diff --git a/changelog.d/6491.bugfix b/changelog.d/6491.bugfix
new file mode 100644
index 0000000000..78204693b0
--- /dev/null
+++ b/changelog.d/6491.bugfix
@@ -0,0 +1 @@
+Fix inaccurate per-block Prometheus metrics.
diff --git a/changelog.d/6493.bugfix b/changelog.d/6493.bugfix
new file mode 100644
index 0000000000..440c02efbe
--- /dev/null
+++ b/changelog.d/6493.bugfix
@@ -0,0 +1 @@
+Fix small performance regression for sending invites.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 27f6aff004..709449c9e3 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -324,87 +324,32 @@ class FederationClient(FederationBase):
                 A list of events in the state, and a list of events in the auth chain
                 for the given event.
         """
-        try:
-            # First we try and ask for just the IDs, as thats far quicker if
-            # we have most of the state and auth_chain already.
-            # However, this may 404 if the other side has an old synapse.
-            result = yield self.transport_layer.get_room_state_ids(
-                destination, room_id, event_id=event_id
-            )
-
-            state_event_ids = result["pdu_ids"]
-            auth_event_ids = result.get("auth_chain_ids", [])
-
-            fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
-                destination, room_id, set(state_event_ids + auth_event_ids)
-            )
-
-            if failed_to_fetch:
-                logger.warning(
-                    "Failed to fetch missing state/auth events for %s: %s",
-                    room_id,
-                    failed_to_fetch,
-                )
-
-            event_map = {ev.event_id: ev for ev in fetched_events}
-
-            pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-            auth_chain = [
-                event_map[e_id] for e_id in auth_event_ids if e_id in event_map
-            ]
-
-            auth_chain.sort(key=lambda e: e.depth)
-
-            return pdus, auth_chain
-        except HttpResponseException as e:
-            if e.code == 400 or e.code == 404:
-                logger.info("Failed to use get_room_state_ids API, falling back")
-            else:
-                raise e
-
-        result = yield self.transport_layer.get_room_state(
+        result = yield self.transport_layer.get_room_state_ids(
             destination, room_id, event_id=event_id
         )
 
-        room_version = yield self.store.get_room_version(room_id)
-        format_ver = room_version_to_event_format(room_version)
-
-        pdus = [
-            event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"]
-        ]
+        state_event_ids = result["pdu_ids"]
+        auth_event_ids = result.get("auth_chain_ids", [])
 
-        auth_chain = [
-            event_from_pdu_json(p, format_ver, outlier=True)
-            for p in result.get("auth_chain", [])
-        ]
-
-        seen_events = yield self.store.get_events(
-            [ev.event_id for ev in itertools.chain(pdus, auth_chain)]
+        fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
+            destination, room_id, set(state_event_ids + auth_event_ids)
         )
 
-        signed_pdus = yield self._check_sigs_and_hash_and_fetch(
-            destination,
-            [p for p in pdus if p.event_id not in seen_events],
-            outlier=True,
-            room_version=room_version,
-        )
-        signed_pdus.extend(
-            seen_events[p.event_id] for p in pdus if p.event_id in seen_events
-        )
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state/auth events for %s: %s",
+                room_id,
+                failed_to_fetch,
+            )
 
-        signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            destination,
-            [p for p in auth_chain if p.event_id not in seen_events],
-            outlier=True,
-            room_version=room_version,
-        )
-        signed_auth.extend(
-            seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
-        )
+        event_map = {ev.event_id: ev for ev in fetched_events}
 
-        signed_auth.sort(key=lambda e: e.depth)
+        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+
+        auth_chain.sort(key=lambda e: e.depth)
 
-        return signed_pdus, signed_auth
+        return pdus, auth_chain
 
     @defer.inlineCallbacks
     def get_events_from_store_or_dest(self, destination, room_id, event_ids):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index dc95ab2113..46dba84cac 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -39,30 +39,6 @@ class TransportLayerClient(object):
         self.client = hs.get_http_client()
 
     @log_function
-    def get_room_state(self, destination, room_id, event_id):
-        """ Requests all state for a given room from the given server at the
-        given event.
-
-        Args:
-            destination (str): The host name of the remote homeserver we want
-                to get the state from.
-            context (str): The name of the context we want the state of
-            event_id (str): The event we want the context at.
-
-        Returns:
-            Deferred: Results in a dict received from the remote homeserver.
-        """
-        logger.debug("get_room_state dest=%s, room=%s", destination, room_id)
-
-        path = _create_v1_path("/state/%s", room_id)
-        return self.client.get_json(
-            destination,
-            path=path,
-            args={"event_id": event_id},
-            try_trailing_slash_on_400=True,
-        )
-
-    @log_function
     def get_room_state_ids(self, destination, room_id, event_id):
         """ Requests all state for a given room from the given server at the
         given event. Returns the state's event_id's
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 45fe13c62f..ec18a42a68 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -16,8 +16,6 @@
 import logging
 import random
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
@@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler):
         self._server_notices_sender = hs.get_server_notices_sender()
         self._event_serializer = hs.get_event_client_serializer()
 
-    @defer.inlineCallbacks
     @log_function
-    def get_stream(
+    async def get_stream(
         self,
         auth_user_id,
         pagin_config,
@@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler):
         """
 
         if room_id:
-            blocked = yield self.store.is_room_blocked(room_id)
+            blocked = await self.store.is_room_blocked(room_id)
             if blocked:
                 raise SynapseError(403, "This room has been blocked on this server")
 
         # send any outstanding server notices to the user.
-        yield self._server_notices_sender.on_user_syncing(auth_user_id)
+        await self._server_notices_sender.on_user_syncing(auth_user_id)
 
         auth_user = UserID.from_string(auth_user_id)
         presence_handler = self.hs.get_presence_handler()
 
-        context = yield presence_handler.user_syncing(
+        context = await presence_handler.user_syncing(
             auth_user_id, affect_presence=affect_presence
         )
         with context:
@@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler):
                 # thundering herds on restart.
                 timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
 
-            events, tokens = yield self.notifier.get_events_for(
+            events, tokens = await self.notifier.get_events_for(
                 auth_user,
                 pagin_config,
                 timeout,
@@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler):
                     # Send down presence.
                     if event.state_key == auth_user_id:
                         # Send down presence for everyone in the room.
-                        users = yield self.state.get_current_users_in_room(
+                        users = await self.state.get_current_users_in_room(
                             event.room_id
                         )
-                        states = yield presence_handler.get_states(users, as_event=True)
+                        states = await presence_handler.get_states(users, as_event=True)
                         to_add.extend(states)
                     else:
 
-                        ev = yield presence_handler.get_state(
+                        ev = await presence_handler.get_state(
                             UserID.from_string(event.state_key), as_event=True
                         )
                         to_add.append(ev)
@@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler):
 
             time_now = self.clock.time_msec()
 
-            chunks = yield self._event_serializer.serialize_events(
+            chunks = await self._event_serializer.serialize_events(
                 events,
                 time_now,
                 as_client_event=as_client_event,
@@ -151,8 +148,7 @@ class EventHandler(BaseHandler):
         super(EventHandler, self).__init__(hs)
         self.storage = hs.get_storage()
 
-    @defer.inlineCallbacks
-    def get_event(self, user, room_id, event_id):
+    async def get_event(self, user, room_id, event_id):
         """Retrieve a single specified event.
 
         Args:
@@ -167,15 +163,15 @@ class EventHandler(BaseHandler):
             AuthError if the user does not have the rights to inspect this
             event.
         """
-        event = yield self.store.get_event(event_id, check_room_id=room_id)
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         if not event:
             return None
 
-        users = yield self.store.get_users_in_room(event.room_id)
+        users = await self.store.get_users_in_room(event.room_id)
         is_peeking = user.to_string() not in users
 
-        filtered = yield filter_events_for_client(
+        filtered = await filter_events_for_client(
             self.storage, user.to_string(), [event], is_peeking=is_peeking
         )
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4f53a5f5dc..54fa216d83 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -363,6 +363,8 @@ class EventCreationHandler(object):
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
 
+        self.room_invite_state_types = self.hs.config.room_invite_state_types
+
         self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
 
         # This is only used to get at ratelimit function, and maybe_kick_guest_users
@@ -916,7 +918,7 @@ class EventCreationHandler(object):
                 state_to_include_ids = [
                     e_id
                     for k, e_id in iteritems(current_state_ids)
-                    if k[0] in self.hs.config.room_invite_state_types
+                    if k[0] in self.room_invite_state_types
                     or k == (EventTypes.Member, event.sender)
                 ]
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b536d410e5..2d3b8ba73c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -22,8 +22,6 @@ from six import iteritems, itervalues
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.logging.context import LoggingContext
 from synapse.push.clientformat import format_push_rules_for_user
@@ -241,8 +239,7 @@ class SyncHandler(object):
             expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
         )
 
-    @defer.inlineCallbacks
-    def wait_for_sync_for_user(
+    async def wait_for_sync_for_user(
         self, sync_config, since_token=None, timeout=0, full_state=False
     ):
         """Get the sync for a client if we have new data for it now. Otherwise
@@ -255,9 +252,9 @@ class SyncHandler(object):
         # not been exceeded (if not part of the group by this point, almost certain
         # auth_blocking will occur)
         user_id = sync_config.user.to_string()
-        yield self.auth.check_auth_blocking(user_id)
+        await self.auth.check_auth_blocking(user_id)
 
-        res = yield self.response_cache.wrap(
+        res = await self.response_cache.wrap(
             sync_config.request_key,
             self._wait_for_sync_for_user,
             sync_config,
@@ -267,8 +264,9 @@ class SyncHandler(object):
         )
         return res
 
-    @defer.inlineCallbacks
-    def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
+    async def _wait_for_sync_for_user(
+        self, sync_config, since_token, timeout, full_state
+    ):
         if since_token is None:
             sync_type = "initial_sync"
         elif full_state:
@@ -283,7 +281,7 @@ class SyncHandler(object):
         if timeout == 0 or since_token is None or full_state:
             # we are going to return immediately, so don't bother calling
             # notifier.wait_for_events.
-            result = yield self.current_sync_for_user(
+            result = await self.current_sync_for_user(
                 sync_config, since_token, full_state=full_state
             )
         else:
@@ -291,7 +289,7 @@ class SyncHandler(object):
             def current_sync_callback(before_token, after_token):
                 return self.current_sync_for_user(sync_config, since_token)
 
-            result = yield self.notifier.wait_for_events(
+            result = await self.notifier.wait_for_events(
                 sync_config.user.to_string(),
                 timeout,
                 current_sync_callback,
@@ -314,15 +312,13 @@ class SyncHandler(object):
         """
         return self.generate_sync_result(sync_config, since_token, full_state)
 
-    @defer.inlineCallbacks
-    def push_rules_for_user(self, user):
+    async def push_rules_for_user(self, user):
         user_id = user.to_string()
-        rules = yield self.store.get_push_rules_for_user(user_id)
+        rules = await self.store.get_push_rules_for_user(user_id)
         rules = format_push_rules_for_user(user, rules)
         return rules
 
-    @defer.inlineCallbacks
-    def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
+    async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
         """Get the ephemeral events for each room the user is in
         Args:
             sync_result_builder(SyncResultBuilder)
@@ -343,7 +339,7 @@ class SyncHandler(object):
             room_ids = sync_result_builder.joined_room_ids
 
             typing_source = self.event_sources.sources["typing"]
-            typing, typing_key = yield typing_source.get_new_events(
+            typing, typing_key = await typing_source.get_new_events(
                 user=sync_config.user,
                 from_key=typing_key,
                 limit=sync_config.filter_collection.ephemeral_limit(),
@@ -365,7 +361,7 @@ class SyncHandler(object):
             receipt_key = since_token.receipt_key if since_token else "0"
 
             receipt_source = self.event_sources.sources["receipt"]
-            receipts, receipt_key = yield receipt_source.get_new_events(
+            receipts, receipt_key = await receipt_source.get_new_events(
                 user=sync_config.user,
                 from_key=receipt_key,
                 limit=sync_config.filter_collection.ephemeral_limit(),
@@ -382,8 +378,7 @@ class SyncHandler(object):
 
         return now_token, ephemeral_by_room
 
-    @defer.inlineCallbacks
-    def _load_filtered_recents(
+    async def _load_filtered_recents(
         self,
         room_id,
         sync_config,
@@ -415,10 +410,10 @@ class SyncHandler(object):
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in recents):
-                    current_state_ids = yield self.state.get_current_state_ids(room_id)
+                    current_state_ids = await self.state.get_current_state_ids(room_id)
                     current_state_ids = frozenset(itervalues(current_state_ids))
 
-                recents = yield filter_events_for_client(
+                recents = await filter_events_for_client(
                     self.storage,
                     sync_config.user.to_string(),
                     recents,
@@ -449,14 +444,14 @@ class SyncHandler(object):
                 # Otherwise, we want to return the last N events in the room
                 # in toplogical ordering.
                 if since_key:
-                    events, end_key = yield self.store.get_room_events_stream_for_room(
+                    events, end_key = await self.store.get_room_events_stream_for_room(
                         room_id,
                         limit=load_limit + 1,
                         from_key=since_key,
                         to_key=end_key,
                     )
                 else:
-                    events, end_key = yield self.store.get_recent_events_for_room(
+                    events, end_key = await self.store.get_recent_events_for_room(
                         room_id, limit=load_limit + 1, end_token=end_key
                     )
                 loaded_recents = sync_config.filter_collection.filter_room_timeline(
@@ -468,10 +463,10 @@ class SyncHandler(object):
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in loaded_recents):
-                    current_state_ids = yield self.state.get_current_state_ids(room_id)
+                    current_state_ids = await self.state.get_current_state_ids(room_id)
                     current_state_ids = frozenset(itervalues(current_state_ids))
 
-                loaded_recents = yield filter_events_for_client(
+                loaded_recents = await filter_events_for_client(
                     self.storage,
                     sync_config.user.to_string(),
                     loaded_recents,
@@ -498,8 +493,7 @@ class SyncHandler(object):
             limited=limited or newly_joined_room,
         )
 
-    @defer.inlineCallbacks
-    def get_state_after_event(self, event, state_filter=StateFilter.all()):
+    async def get_state_after_event(self, event, state_filter=StateFilter.all()):
         """
         Get the room state after the given event
 
@@ -511,7 +505,7 @@ class SyncHandler(object):
         Returns:
             A Deferred map from ((type, state_key)->Event)
         """
-        state_ids = yield self.state_store.get_state_ids_for_event(
+        state_ids = await self.state_store.get_state_ids_for_event(
             event.event_id, state_filter=state_filter
         )
         if event.is_state():
@@ -519,8 +513,9 @@ class SyncHandler(object):
             state_ids[(event.type, event.state_key)] = event.event_id
         return state_ids
 
-    @defer.inlineCallbacks
-    def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
+    async def get_state_at(
+        self, room_id, stream_position, state_filter=StateFilter.all()
+    ):
         """ Get the room state at a particular stream position
 
         Args:
@@ -536,13 +531,13 @@ class SyncHandler(object):
         # get_recent_events_for_room operates by topo ordering. This therefore
         # does not reliably give you the state at the given stream position.
         # (https://github.com/matrix-org/synapse/issues/3305)
-        last_events, _ = yield self.store.get_recent_events_for_room(
+        last_events, _ = await self.store.get_recent_events_for_room(
             room_id, end_token=stream_position.room_key, limit=1
         )
 
         if last_events:
             last_event = last_events[-1]
-            state = yield self.get_state_after_event(
+            state = await self.get_state_after_event(
                 last_event, state_filter=state_filter
             )
 
@@ -551,8 +546,7 @@ class SyncHandler(object):
             state = {}
         return state
 
-    @defer.inlineCallbacks
-    def compute_summary(self, room_id, sync_config, batch, state, now_token):
+    async def compute_summary(self, room_id, sync_config, batch, state, now_token):
         """ Works out a room summary block for this room, summarising the number
         of joined members in the room, and providing the 'hero' members if the
         room has no name so clients can consistently name rooms.  Also adds
@@ -574,7 +568,7 @@ class SyncHandler(object):
         # FIXME: we could/should get this from room_stats when matthew/stats lands
 
         # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
-        last_events, _ = yield self.store.get_recent_event_ids_for_room(
+        last_events, _ = await self.store.get_recent_event_ids_for_room(
             room_id, end_token=now_token.room_key, limit=1
         )
 
@@ -582,7 +576,7 @@ class SyncHandler(object):
             return None
 
         last_event = last_events[-1]
-        state_ids = yield self.state_store.get_state_ids_for_event(
+        state_ids = await self.state_store.get_state_ids_for_event(
             last_event.event_id,
             state_filter=StateFilter.from_types(
                 [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -590,7 +584,7 @@ class SyncHandler(object):
         )
 
         # this is heavily cached, thus: fast.
-        details = yield self.store.get_room_summary(room_id)
+        details = await self.store.get_room_summary(room_id)
 
         name_id = state_ids.get((EventTypes.Name, ""))
         canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
@@ -608,12 +602,12 @@ class SyncHandler(object):
         # calculating heroes. Empty strings are falsey, so we check
         # for the "name" value and default to an empty string.
         if name_id:
-            name = yield self.store.get_event(name_id, allow_none=True)
+            name = await self.store.get_event(name_id, allow_none=True)
             if name and name.content.get("name"):
                 return summary
 
         if canonical_alias_id:
-            canonical_alias = yield self.store.get_event(
+            canonical_alias = await self.store.get_event(
                 canonical_alias_id, allow_none=True
             )
             if canonical_alias and canonical_alias.content.get("alias"):
@@ -678,7 +672,7 @@ class SyncHandler(object):
             )
         ]
 
-        missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
+        missing_hero_state = await self.store.get_events(missing_hero_event_ids)
         missing_hero_state = missing_hero_state.values()
 
         for s in missing_hero_state:
@@ -697,8 +691,7 @@ class SyncHandler(object):
             logger.debug("found LruCache for %r", cache_key)
         return cache
 
-    @defer.inlineCallbacks
-    def compute_state_delta(
+    async def compute_state_delta(
         self, room_id, batch, sync_config, since_token, now_token, full_state
     ):
         """ Works out the difference in state between the start of the timeline
@@ -759,16 +752,16 @@ class SyncHandler(object):
 
             if full_state:
                 if batch:
-                    current_state_ids = yield self.state_store.get_state_ids_for_event(
+                    current_state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[-1].event_id, state_filter=state_filter
                     )
 
-                    state_ids = yield self.state_store.get_state_ids_for_event(
+                    state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
 
                 else:
-                    current_state_ids = yield self.get_state_at(
+                    current_state_ids = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -783,13 +776,13 @@ class SyncHandler(object):
                 )
             elif batch.limited:
                 if batch:
-                    state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
+                    state_at_timeline_start = await self.state_store.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
-                    state_at_timeline_start = yield self.get_state_at(
+                    state_at_timeline_start = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -807,19 +800,19 @@ class SyncHandler(object):
                 # about them).
                 state_filter = StateFilter.all()
 
-                state_at_previous_sync = yield self.get_state_at(
+                state_at_previous_sync = await self.get_state_at(
                     room_id, stream_position=since_token, state_filter=state_filter
                 )
 
                 if batch:
-                    current_state_ids = yield self.state_store.get_state_ids_for_event(
+                    current_state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[-1].event_id, state_filter=state_filter
                     )
                 else:
                     # Its not clear how we get here, but empirically we do
                     # (#5407). Logging has been added elsewhere to try and
                     # figure out where this state comes from.
-                    current_state_ids = yield self.get_state_at(
+                    current_state_ids = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -843,7 +836,7 @@ class SyncHandler(object):
                         # So we fish out all the member events corresponding to the
                         # timeline here, and then dedupe any redundant ones below.
 
-                        state_ids = yield self.state_store.get_state_ids_for_event(
+                        state_ids = await self.state_store.get_state_ids_for_event(
                             batch.events[0].event_id,
                             # we only want members!
                             state_filter=StateFilter.from_types(
@@ -883,7 +876,7 @@ class SyncHandler(object):
 
         state = {}
         if state_ids:
-            state = yield self.store.get_events(list(state_ids.values()))
+            state = await self.store.get_events(list(state_ids.values()))
 
         return {
             (e.type, e.state_key): e
@@ -892,10 +885,9 @@ class SyncHandler(object):
             )
         }
 
-    @defer.inlineCallbacks
-    def unread_notifs_for_room_id(self, room_id, sync_config):
+    async def unread_notifs_for_room_id(self, room_id, sync_config):
         with Measure(self.clock, "unread_notifs_for_room_id"):
-            last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
+            last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
                 room_id=room_id,
                 receipt_type="m.read",
@@ -903,7 +895,7 @@ class SyncHandler(object):
 
             notifs = []
             if last_unread_event_id:
-                notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
+                notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
                     room_id, sync_config.user.to_string(), last_unread_event_id
                 )
                 return notifs
@@ -912,8 +904,9 @@ class SyncHandler(object):
         # count is whatever it was last time.
         return None
 
-    @defer.inlineCallbacks
-    def generate_sync_result(self, sync_config, since_token=None, full_state=False):
+    async def generate_sync_result(
+        self, sync_config, since_token=None, full_state=False
+    ):
         """Generates a sync result.
 
         Args:
@@ -928,7 +921,7 @@ class SyncHandler(object):
         # this is due to some of the underlying streams not supporting the ability
         # to query up to a given point.
         # Always use the `now_token` in `SyncResultBuilder`
-        now_token = yield self.event_sources.get_current_token()
+        now_token = await self.event_sources.get_current_token()
 
         logger.info(
             "Calculating sync response for %r between %s and %s",
@@ -944,10 +937,9 @@ class SyncHandler(object):
             # See https://github.com/matrix-org/matrix-doc/issues/1144
             raise NotImplementedError()
         else:
-            joined_room_ids = yield self.get_rooms_for_user_at(
+            joined_room_ids = await self.get_rooms_for_user_at(
                 user_id, now_token.room_stream_id
             )
-
         sync_result_builder = SyncResultBuilder(
             sync_config,
             full_state,
@@ -956,11 +948,11 @@ class SyncHandler(object):
             joined_room_ids=joined_room_ids,
         )
 
-        account_data_by_room = yield self._generate_sync_entry_for_account_data(
+        account_data_by_room = await self._generate_sync_entry_for_account_data(
             sync_result_builder
         )
 
-        res = yield self._generate_sync_entry_for_rooms(
+        res = await self._generate_sync_entry_for_rooms(
             sync_result_builder, account_data_by_room
         )
         newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
@@ -970,13 +962,13 @@ class SyncHandler(object):
             since_token is None and sync_config.filter_collection.blocks_all_presence()
         )
         if self.hs_config.use_presence and not block_all_presence_data:
-            yield self._generate_sync_entry_for_presence(
+            await self._generate_sync_entry_for_presence(
                 sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
             )
 
-        yield self._generate_sync_entry_for_to_device(sync_result_builder)
+        await self._generate_sync_entry_for_to_device(sync_result_builder)
 
-        device_lists = yield self._generate_sync_entry_for_device_list(
+        device_lists = await self._generate_sync_entry_for_device_list(
             sync_result_builder,
             newly_joined_rooms=newly_joined_rooms,
             newly_joined_or_invited_users=newly_joined_or_invited_users,
@@ -987,11 +979,11 @@ class SyncHandler(object):
         device_id = sync_config.device_id
         one_time_key_counts = {}
         if device_id:
-            one_time_key_counts = yield self.store.count_e2e_one_time_keys(
+            one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
 
-        yield self._generate_sync_entry_for_groups(sync_result_builder)
+        await self._generate_sync_entry_for_groups(sync_result_builder)
 
         # debug for https://github.com/matrix-org/synapse/issues/4422
         for joined_room in sync_result_builder.joined:
@@ -1015,18 +1007,17 @@ class SyncHandler(object):
         )
 
     @measure_func("_generate_sync_entry_for_groups")
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_groups(self, sync_result_builder):
+    async def _generate_sync_entry_for_groups(self, sync_result_builder):
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
         now_token = sync_result_builder.now_token
 
         if since_token and since_token.groups_key:
-            results = yield self.store.get_groups_changes_for_user(
+            results = await self.store.get_groups_changes_for_user(
                 user_id, since_token.groups_key, now_token.groups_key
             )
         else:
-            results = yield self.store.get_all_groups_for_user(
+            results = await self.store.get_all_groups_for_user(
                 user_id, now_token.groups_key
             )
 
@@ -1059,8 +1050,7 @@ class SyncHandler(object):
         )
 
     @measure_func("_generate_sync_entry_for_device_list")
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_device_list(
+    async def _generate_sync_entry_for_device_list(
         self,
         sync_result_builder,
         newly_joined_rooms,
@@ -1108,32 +1098,32 @@ class SyncHandler(object):
             # room with by looking at all users that have left a room plus users
             # that were in a room we've left.
 
-            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+            users_who_share_room = await self.store.get_users_who_share_room_with_user(
                 user_id
             )
 
             # Step 1a, check for changes in devices of users we share a room with
-            users_that_have_changed = yield self.store.get_users_whose_devices_changed(
+            users_that_have_changed = await self.store.get_users_whose_devices_changed(
                 since_token.device_list_key, users_who_share_room
             )
 
             # Step 1b, check for newly joined rooms
             for room_id in newly_joined_rooms:
-                joined_users = yield self.state.get_current_users_in_room(room_id)
+                joined_users = await self.state.get_current_users_in_room(room_id)
                 newly_joined_or_invited_users.update(joined_users)
 
             # TODO: Check that these users are actually new, i.e. either they
             # weren't in the previous sync *or* they left and rejoined.
             users_that_have_changed.update(newly_joined_or_invited_users)
 
-            user_signatures_changed = yield self.store.get_users_whose_signatures_changed(
+            user_signatures_changed = await self.store.get_users_whose_signatures_changed(
                 user_id, since_token.device_list_key
             )
             users_that_have_changed.update(user_signatures_changed)
 
             # Now find users that we no longer track
             for room_id in newly_left_rooms:
-                left_users = yield self.state.get_current_users_in_room(room_id)
+                left_users = await self.state.get_current_users_in_room(room_id)
                 newly_left_users.update(left_users)
 
             # Remove any users that we still share a room with.
@@ -1143,8 +1133,7 @@ class SyncHandler(object):
         else:
             return DeviceLists(changed=[], left=[])
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_to_device(self, sync_result_builder):
+    async def _generate_sync_entry_for_to_device(self, sync_result_builder):
         """Generates the portion of the sync response. Populates
         `sync_result_builder` with the result.
 
@@ -1165,14 +1154,14 @@ class SyncHandler(object):
             # We only delete messages when a new message comes in, but that's
             # fine so long as we delete them at some point.
 
-            deleted = yield self.store.delete_messages_for_device(
+            deleted = await self.store.delete_messages_for_device(
                 user_id, device_id, since_stream_id
             )
             logger.debug(
                 "Deleted %d to-device messages up to %d", deleted, since_stream_id
             )
 
-            messages, stream_id = yield self.store.get_new_messages_for_device(
+            messages, stream_id = await self.store.get_new_messages_for_device(
                 user_id, device_id, since_stream_id, now_token.to_device_key
             )
 
@@ -1190,8 +1179,7 @@ class SyncHandler(object):
         else:
             sync_result_builder.to_device = []
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_account_data(self, sync_result_builder):
+    async def _generate_sync_entry_for_account_data(self, sync_result_builder):
         """Generates the account data portion of the sync response. Populates
         `sync_result_builder` with the result.
 
@@ -1209,25 +1197,25 @@ class SyncHandler(object):
             (
                 account_data,
                 account_data_by_room,
-            ) = yield self.store.get_updated_account_data_for_user(
+            ) = await self.store.get_updated_account_data_for_user(
                 user_id, since_token.account_data_key
             )
 
-            push_rules_changed = yield self.store.have_push_rules_changed_for_user(
+            push_rules_changed = await self.store.have_push_rules_changed_for_user(
                 user_id, int(since_token.push_rules_key)
             )
 
             if push_rules_changed:
-                account_data["m.push_rules"] = yield self.push_rules_for_user(
+                account_data["m.push_rules"] = await self.push_rules_for_user(
                     sync_config.user
                 )
         else:
             (
                 account_data,
                 account_data_by_room,
-            ) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
+            ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
 
-            account_data["m.push_rules"] = yield self.push_rules_for_user(
+            account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
 
@@ -1242,8 +1230,7 @@ class SyncHandler(object):
 
         return account_data_by_room
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_presence(
+    async def _generate_sync_entry_for_presence(
         self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
     ):
         """Generates the presence portion of the sync response. Populates the
@@ -1271,7 +1258,7 @@ class SyncHandler(object):
             presence_key = None
             include_offline = False
 
-        presence, presence_key = yield presence_source.get_new_events(
+        presence, presence_key = await presence_source.get_new_events(
             user=user,
             from_key=presence_key,
             is_guest=sync_config.is_guest,
@@ -1283,12 +1270,12 @@ class SyncHandler(object):
 
         extra_users_ids = set(newly_joined_or_invited_users)
         for room_id in newly_joined_rooms:
-            users = yield self.state.get_current_users_in_room(room_id)
+            users = await self.state.get_current_users_in_room(room_id)
             extra_users_ids.update(users)
         extra_users_ids.discard(user.to_string())
 
         if extra_users_ids:
-            states = yield self.presence_handler.get_states(extra_users_ids)
+            states = await self.presence_handler.get_states(extra_users_ids)
             presence.extend(states)
 
             # Deduplicate the presence entries so that there's at most one per user
@@ -1298,8 +1285,9 @@ class SyncHandler(object):
 
         sync_result_builder.presence = presence
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room):
+    async def _generate_sync_entry_for_rooms(
+        self, sync_result_builder, account_data_by_room
+    ):
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
 
@@ -1321,7 +1309,7 @@ class SyncHandler(object):
         if block_all_room_ephemeral:
             ephemeral_by_room = {}
         else:
-            now_token, ephemeral_by_room = yield self.ephemeral_by_room(
+            now_token, ephemeral_by_room = await self.ephemeral_by_room(
                 sync_result_builder,
                 now_token=sync_result_builder.now_token,
                 since_token=sync_result_builder.since_token,
@@ -1333,16 +1321,16 @@ class SyncHandler(object):
         since_token = sync_result_builder.since_token
         if not sync_result_builder.full_state:
             if since_token and not ephemeral_by_room and not account_data_by_room:
-                have_changed = yield self._have_rooms_changed(sync_result_builder)
+                have_changed = await self._have_rooms_changed(sync_result_builder)
                 if not have_changed:
-                    tags_by_room = yield self.store.get_updated_tags(
+                    tags_by_room = await self.store.get_updated_tags(
                         user_id, since_token.account_data_key
                     )
                     if not tags_by_room:
                         logger.debug("no-oping sync")
                         return [], [], [], []
 
-        ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
+        ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
             "m.ignored_user_list", user_id=user_id
         )
 
@@ -1352,18 +1340,18 @@ class SyncHandler(object):
             ignored_users = frozenset()
 
         if since_token:
-            res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
+            res = await self._get_rooms_changed(sync_result_builder, ignored_users)
             room_entries, invited, newly_joined_rooms, newly_left_rooms = res
 
-            tags_by_room = yield self.store.get_updated_tags(
+            tags_by_room = await self.store.get_updated_tags(
                 user_id, since_token.account_data_key
             )
         else:
-            res = yield self._get_all_rooms(sync_result_builder, ignored_users)
+            res = await self._get_all_rooms(sync_result_builder, ignored_users)
             room_entries, invited, newly_joined_rooms = res
             newly_left_rooms = []
 
-            tags_by_room = yield self.store.get_tags_for_user(user_id)
+            tags_by_room = await self.store.get_tags_for_user(user_id)
 
         def handle_room_entries(room_entry):
             return self._generate_room_entry(
@@ -1376,7 +1364,7 @@ class SyncHandler(object):
                 always_include=sync_result_builder.full_state,
             )
 
-        yield concurrently_execute(handle_room_entries, room_entries, 10)
+        await concurrently_execute(handle_room_entries, room_entries, 10)
 
         sync_result_builder.invited.extend(invited)
 
@@ -1410,8 +1398,7 @@ class SyncHandler(object):
             newly_left_users,
         )
 
-    @defer.inlineCallbacks
-    def _have_rooms_changed(self, sync_result_builder):
+    async def _have_rooms_changed(self, sync_result_builder):
         """Returns whether there may be any new events that should be sent down
         the sync. Returns True if there are.
         """
@@ -1422,7 +1409,7 @@ class SyncHandler(object):
         assert since_token
 
         # Get a list of membership change events that have happened.
-        rooms_changed = yield self.store.get_membership_changes_for_user(
+        rooms_changed = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
@@ -1435,8 +1422,7 @@ class SyncHandler(object):
                 return True
         return False
 
-    @defer.inlineCallbacks
-    def _get_rooms_changed(self, sync_result_builder, ignored_users):
+    async def _get_rooms_changed(self, sync_result_builder, ignored_users):
         """Gets the the changes that have happened since the last sync.
 
         Args:
@@ -1461,7 +1447,7 @@ class SyncHandler(object):
         assert since_token
 
         # Get a list of membership change events that have happened.
-        rooms_changed = yield self.store.get_membership_changes_for_user(
+        rooms_changed = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
@@ -1499,11 +1485,11 @@ class SyncHandler(object):
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = yield self.get_state_at(room_id, since_token)
+                old_state_ids = await self.get_state_at(room_id, since_token)
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
                 if old_mem_ev_id:
-                    old_mem_ev = yield self.store.get_event(
+                    old_mem_ev = await self.store.get_event(
                         old_mem_ev_id, allow_none=True
                     )
 
@@ -1536,13 +1522,13 @@ class SyncHandler(object):
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = yield self.get_state_at(room_id, since_token)
+                        old_state_ids = await self.get_state_at(room_id, since_token)
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
                         )
                         old_mem_ev = None
                         if old_mem_ev_id:
-                            old_mem_ev = yield self.store.get_event(
+                            old_mem_ev = await self.store.get_event(
                                 old_mem_ev_id, allow_none=True
                             )
                     if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
@@ -1566,7 +1552,7 @@ class SyncHandler(object):
 
             if leave_events:
                 leave_event = leave_events[-1]
-                leave_stream_token = yield self.store.get_stream_token_for_event(
+                leave_stream_token = await self.store.get_stream_token_for_event(
                     leave_event.event_id
                 )
                 leave_token = since_token.copy_and_replace(
@@ -1603,7 +1589,7 @@ class SyncHandler(object):
         timeline_limit = sync_config.filter_collection.timeline_limit()
 
         # Get all events for rooms we're currently joined to.
-        room_to_events = yield self.store.get_room_events_stream_for_rooms(
+        room_to_events = await self.store.get_room_events_stream_for_rooms(
             room_ids=sync_result_builder.joined_room_ids,
             from_key=since_token.room_key,
             to_key=now_token.room_key,
@@ -1652,8 +1638,7 @@ class SyncHandler(object):
 
         return room_entries, invited, newly_joined_rooms, newly_left_rooms
 
-    @defer.inlineCallbacks
-    def _get_all_rooms(self, sync_result_builder, ignored_users):
+    async def _get_all_rooms(self, sync_result_builder, ignored_users):
         """Returns entries for all rooms for the user.
 
         Args:
@@ -1677,7 +1662,7 @@ class SyncHandler(object):
             Membership.BAN,
         )
 
-        room_list = yield self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_user_where_membership_is(
             user_id=user_id, membership_list=membership_list
         )
 
@@ -1700,7 +1685,7 @@ class SyncHandler(object):
             elif event.membership == Membership.INVITE:
                 if event.sender in ignored_users:
                     continue
-                invite = yield self.store.get_event(event.event_id)
+                invite = await self.store.get_event(event.event_id)
                 invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
             elif event.membership in (Membership.LEAVE, Membership.BAN):
                 # Always send down rooms we were banned or kicked from.
@@ -1726,8 +1711,7 @@ class SyncHandler(object):
 
         return room_entries, invited, []
 
-    @defer.inlineCallbacks
-    def _generate_room_entry(
+    async def _generate_room_entry(
         self,
         sync_result_builder,
         ignored_users,
@@ -1769,7 +1753,7 @@ class SyncHandler(object):
         since_token = room_builder.since_token
         upto_token = room_builder.upto_token
 
-        batch = yield self._load_filtered_recents(
+        batch = await self._load_filtered_recents(
             room_id,
             sync_config,
             now_token=upto_token,
@@ -1796,7 +1780,7 @@ class SyncHandler(object):
         # tag was added by synapse e.g. for server notice rooms.
         if full_state:
             user_id = sync_result_builder.sync_config.user.to_string()
-            tags = yield self.store.get_tags_for_room(user_id, room_id)
+            tags = await self.store.get_tags_for_room(user_id, room_id)
 
             # If there aren't any tags, don't send the empty tags list down
             # sync
@@ -1821,7 +1805,7 @@ class SyncHandler(object):
         ):
             return
 
-        state = yield self.compute_state_delta(
+        state = await self.compute_state_delta(
             room_id, batch, sync_config, since_token, now_token, full_state=full_state
         )
 
@@ -1844,7 +1828,7 @@ class SyncHandler(object):
             )
             or since_token is None
         ):
-            summary = yield self.compute_summary(
+            summary = await self.compute_summary(
                 room_id, sync_config, batch, state, now_token
             )
 
@@ -1861,7 +1845,7 @@ class SyncHandler(object):
             )
 
             if room_sync or always_include:
-                notifs = yield self.unread_notifs_for_room_id(room_id, sync_config)
+                notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
                 if notifs is not None:
                     unread_notifications["notification_count"] = notifs["notify_count"]
@@ -1887,8 +1871,7 @@ class SyncHandler(object):
         else:
             raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
-    @defer.inlineCallbacks
-    def get_rooms_for_user_at(self, user_id, stream_ordering):
+    async def get_rooms_for_user_at(self, user_id, stream_ordering):
         """Get set of joined rooms for a user at the given stream ordering.
 
         The stream ordering *must* be recent, otherwise this may throw an
@@ -1903,7 +1886,7 @@ class SyncHandler(object):
             Deferred[frozenset[str]]: Set of room_ids the user is in at given
             stream_ordering.
         """
-        joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id)
+        joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
 
         joined_room_ids = set()
 
@@ -1921,10 +1904,10 @@ class SyncHandler(object):
 
             logger.info("User joined room after current token: %s", room_id)
 
-            extrems = yield self.store.get_forward_extremeties_for_room(
+            extrems = await self.store.get_forward_extremeties_for_room(
                 room_id, stream_ordering
             )
-            users_in_room = yield self.state.get_current_users_in_room(room_id, extrems)
+            users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
             if user_id in users_in_room:
                 joined_room_ids.add(room_id)
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 856337b7e2..6f78454322 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -313,7 +313,7 @@ class TypingNotificationEventSource(object):
 
                 events.append(self._make_event_for(room_id))
 
-            return events, handler._latest_room_serial
+            return defer.succeed((events, handler._latest_room_serial))
 
     def get_current_key(self):
         return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/notifier.py b/synapse/notifier.py
index af161a81d7..5f5f765bea 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -304,8 +304,7 @@ class Notifier(object):
         without waking up any of the normal user event streams"""
         self.notify_replication()
 
-    @defer.inlineCallbacks
-    def wait_for_events(
+    async def wait_for_events(
         self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
     ):
         """Wait until the callback returns a non empty response or the
@@ -313,9 +312,9 @@ class Notifier(object):
         """
         user_stream = self.user_to_user_stream.get(user_id)
         if user_stream is None:
-            current_token = yield self.event_sources.get_current_token()
+            current_token = await self.event_sources.get_current_token()
             if room_ids is None:
-                room_ids = yield self.store.get_rooms_for_user(user_id)
+                room_ids = await self.store.get_rooms_for_user(user_id)
             user_stream = _NotifierUserStream(
                 user_id=user_id,
                 rooms=room_ids,
@@ -344,11 +343,11 @@ class Notifier(object):
                         self.hs.get_reactor(),
                     )
                     with PreserveLoggingContext():
-                        yield listener.deferred
+                        await listener.deferred
 
                     current_token = user_stream.current_token
 
-                    result = yield callback(prev_token, current_token)
+                    result = await callback(prev_token, current_token)
                     if result:
                         break
 
@@ -364,12 +363,11 @@ class Notifier(object):
             # This happened if there was no timeout or if the timeout had
             # already expired.
             current_token = user_stream.current_token
-            result = yield callback(prev_token, current_token)
+            result = await callback(prev_token, current_token)
 
         return result
 
-    @defer.inlineCallbacks
-    def get_events_for(
+    async def get_events_for(
         self,
         user,
         pagination_config,
@@ -391,15 +389,14 @@ class Notifier(object):
         """
         from_token = pagination_config.from_token
         if not from_token:
-            from_token = yield self.event_sources.get_current_token()
+            from_token = await self.event_sources.get_current_token()
 
         limit = pagination_config.limit
 
-        room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
+        room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
         is_peeking = not is_joined
 
-        @defer.inlineCallbacks
-        def check_for_updates(before_token, after_token):
+        async def check_for_updates(before_token, after_token):
             if not after_token.is_after(before_token):
                 return EventStreamResult([], (from_token, from_token))
 
@@ -415,7 +412,7 @@ class Notifier(object):
                 if only_keys and name not in only_keys:
                     continue
 
-                new_events, new_key = yield source.get_new_events(
+                new_events, new_key = await source.get_new_events(
                     user=user,
                     from_key=getattr(from_token, keyname),
                     limit=limit,
@@ -425,7 +422,7 @@ class Notifier(object):
                 )
 
                 if name == "room":
-                    new_events = yield filter_events_for_client(
+                    new_events = await filter_events_for_client(
                         self.storage,
                         user.to_string(),
                         new_events,
@@ -461,7 +458,7 @@ class Notifier(object):
                 user_id_for_stream,
             )
 
-        result = yield self.wait_for_events(
+        result = await self.wait_for_events(
             user_id_for_stream,
             timeout,
             check_for_updates,
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 44d20c19bf..46b494b334 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -251,7 +251,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return defer.succeed(({}, {}))
 
         return self.db.runInteraction(
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 7f5e8dce66..6acd45e9f3 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -1111,7 +1111,7 @@ class GroupServerStore(SQLBaseStore):
             user_id, from_token
         )
         if not has_changed:
-            return []
+            return defer.succeed([])
 
         def _get_groups_changes_for_user_txn(txn):
             sql = """
@@ -1141,7 +1141,7 @@ class GroupServerStore(SQLBaseStore):
             from_token
         )
         if not has_changed:
-            return []
+            return defer.succeed([])
 
         def _get_all_groups_changes_txn(txn):
             sql = """
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 3286804322..7b18455469 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
 import logging
 from functools import wraps
 
@@ -64,12 +65,22 @@ def measure_func(name=None):
     def wrapper(func):
         block_name = func.__name__ if name is None else name
 
-        @wraps(func)
-        @defer.inlineCallbacks
-        def measured_func(self, *args, **kwargs):
-            with Measure(self.clock, block_name):
-                r = yield func(self, *args, **kwargs)
-            return r
+        if inspect.iscoroutinefunction(func):
+
+            @wraps(func)
+            async def measured_func(self, *args, **kwargs):
+                with Measure(self.clock, block_name):
+                    r = await func(self, *args, **kwargs)
+                return r
+
+        else:
+
+            @wraps(func)
+            @defer.inlineCallbacks
+            def measured_func(self, *args, **kwargs):
+                with Measure(self.clock, block_name):
+                    r = yield func(self, *args, **kwargs)
+                return r
 
         return measured_func
 
@@ -80,72 +91,48 @@ class Measure(object):
     __slots__ = [
         "clock",
         "name",
-        "start_context",
+        "_logging_context",
         "start",
-        "created_context",
-        "start_usage",
     ]
 
     def __init__(self, clock, name):
         self.clock = clock
         self.name = name
-        self.start_context = None
+        self._logging_context = None
         self.start = None
-        self.created_context = False
 
     def __enter__(self):
-        self.start = self.clock.time()
-        self.start_context = LoggingContext.current_context()
-        if not self.start_context:
-            self.start_context = LoggingContext("Measure")
-            self.start_context.__enter__()
-            self.created_context = True
-
-        self.start_usage = self.start_context.get_resource_usage()
+        if self._logging_context:
+            raise RuntimeError("Measure() objects cannot be re-used")
 
+        self.start = self.clock.time()
+        parent_context = LoggingContext.current_context()
+        self._logging_context = LoggingContext(
+            "Measure[%s]" % (self.name,), parent_context
+        )
+        self._logging_context.__enter__()
         in_flight.register((self.name,), self._update_in_flight)
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        if isinstance(exc_type, Exception) or not self.start_context:
-            return
-
-        in_flight.unregister((self.name,), self._update_in_flight)
+        if not self._logging_context:
+            raise RuntimeError("Measure() block exited without being entered")
 
         duration = self.clock.time() - self.start
+        usage = self._logging_context.get_resource_usage()
 
-        block_counter.labels(self.name).inc()
-        block_timer.labels(self.name).inc(duration)
-
-        context = LoggingContext.current_context()
-
-        if context != self.start_context:
-            logger.warning(
-                "Context has unexpectedly changed from '%s' to '%s'. (%r)",
-                self.start_context,
-                context,
-                self.name,
-            )
-            return
-
-        if not context:
-            logger.warning("Expected context. (%r)", self.name)
-            return
+        in_flight.unregister((self.name,), self._update_in_flight)
+        self._logging_context.__exit__(exc_type, exc_val, exc_tb)
 
-        current = context.get_resource_usage()
-        usage = current - self.start_usage
         try:
+            block_counter.labels(self.name).inc()
+            block_timer.labels(self.name).inc(duration)
             block_ru_utime.labels(self.name).inc(usage.ru_utime)
             block_ru_stime.labels(self.name).inc(usage.ru_stime)
             block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
             block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
             block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
         except ValueError:
-            logger.warning(
-                "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current
-            )
-
-        if self.created_context:
-            self.start_context.__exit__(exc_type, exc_val, exc_tb)
+            logger.warning("Failed to save metrics! Usage: %s", usage)
 
     def _update_in_flight(self, metrics):
         """Gets called when processing in flight metrics
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 31f54bbd7d..758ee071a5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,54 +12,53 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.internet import defer
 
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
-from synapse.handlers.sync import SyncConfig, SyncHandler
+from synapse.handlers.sync import SyncConfig
 from synapse.types import UserID
 
 import tests.unittest
 import tests.utils
-from tests.utils import setup_test_homeserver
 
 
-class SyncTestCase(tests.unittest.TestCase):
+class SyncTestCase(tests.unittest.HomeserverTestCase):
     """ Tests Sync Handler. """
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.sync_handler = SyncHandler(self.hs)
+    def prepare(self, reactor, clock, hs):
+        self.hs = hs
+        self.sync_handler = self.hs.get_sync_handler()
         self.store = self.hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_wait_for_sync_for_user_auth_blocking(self):
 
         user_id1 = "@user1:server"
         user_id2 = "@user2:server"
         sync_config = self._generate_sync_config(user_id1)
 
+        self.reactor.advance(100)  # So we get not 0 time
         self.hs.config.limit_usage_by_mau = True
         self.hs.config.max_mau_value = 1
 
         # Check that the happy case does not throw errors
-        yield self.store.upsert_monthly_active_user(user_id1)
-        yield self.sync_handler.wait_for_sync_for_user(sync_config)
+        self.get_success(self.store.upsert_monthly_active_user(user_id1))
+        self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
 
         # Test that global lock works
         self.hs.config.hs_disabled = True
-        with self.assertRaises(ResourceLimitError) as e:
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        e = self.get_failure(
+            self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+        )
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
         self.hs.config.hs_disabled = False
 
         sync_config = self._generate_sync_config(user_id2)
 
-        with self.assertRaises(ResourceLimitError) as e:
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        e = self.get_failure(
+            self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+        )
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
     def _generate_sync_config(self, user_id):
         return SyncConfig(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f6d8660285..92b8726093 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -163,7 +163,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -227,7 +229,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -279,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -300,7 +306,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -317,7 +325,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 2)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+        )
         self.assertEquals(
             events[0],
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -335,7 +345,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 3)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 30fb77bac8..4bc3aaf02d 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+        events = self.get_success(
+            self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+        )
         self.assertEquals(
             events[0],
             [
diff --git a/tests/unittest.py b/tests/unittest.py
index 68d245ec9f..b30b7d1718 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -18,6 +18,7 @@
 import gc
 import hashlib
 import hmac
+import inspect
 import logging
 import time
 
@@ -25,7 +26,7 @@ from mock import Mock
 
 from canonicaljson import json
 
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
 from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
 
@@ -417,6 +418,8 @@ class HomeserverTestCase(TestCase):
         self.reactor.pump([by] * 100)
 
     def get_success(self, d, by=0.0):
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump(by=by)
@@ -426,6 +429,8 @@ class HomeserverTestCase(TestCase):
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump()