summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-12-05 17:58:25 +0000
committerErik Johnston <erik@matrix.org>2019-12-05 17:58:25 +0000
commit8437e2383ed2dffacca5395851023eeacb33d7ba (patch)
treeb16b2cac8d11126b64731cfdef918352269f96b1
parentMerge pull request #6483 from matrix-org/erikj/port_rest_v2 (diff)
downloadsynapse-8437e2383ed2dffacca5395851023eeacb33d7ba.tar.xz
Port SyncHandler to async/await
-rw-r--r--synapse/handlers/events.py30
-rw-r--r--synapse/handlers/sync.py251
-rw-r--r--synapse/notifier.py29
-rw-r--r--synapse/util/metrics.py23
-rw-r--r--tests/handlers/test_sync.py33
-rw-r--r--tests/unittest.py7
6 files changed, 182 insertions, 191 deletions
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 45fe13c62f..ec18a42a68 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -16,8 +16,6 @@
 import logging
 import random
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
@@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler):
         self._server_notices_sender = hs.get_server_notices_sender()
         self._event_serializer = hs.get_event_client_serializer()
 
-    @defer.inlineCallbacks
     @log_function
-    def get_stream(
+    async def get_stream(
         self,
         auth_user_id,
         pagin_config,
@@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler):
         """
 
         if room_id:
-            blocked = yield self.store.is_room_blocked(room_id)
+            blocked = await self.store.is_room_blocked(room_id)
             if blocked:
                 raise SynapseError(403, "This room has been blocked on this server")
 
         # send any outstanding server notices to the user.
-        yield self._server_notices_sender.on_user_syncing(auth_user_id)
+        await self._server_notices_sender.on_user_syncing(auth_user_id)
 
         auth_user = UserID.from_string(auth_user_id)
         presence_handler = self.hs.get_presence_handler()
 
-        context = yield presence_handler.user_syncing(
+        context = await presence_handler.user_syncing(
             auth_user_id, affect_presence=affect_presence
         )
         with context:
@@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler):
                 # thundering herds on restart.
                 timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
 
-            events, tokens = yield self.notifier.get_events_for(
+            events, tokens = await self.notifier.get_events_for(
                 auth_user,
                 pagin_config,
                 timeout,
@@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler):
                     # Send down presence.
                     if event.state_key == auth_user_id:
                         # Send down presence for everyone in the room.
-                        users = yield self.state.get_current_users_in_room(
+                        users = await self.state.get_current_users_in_room(
                             event.room_id
                         )
-                        states = yield presence_handler.get_states(users, as_event=True)
+                        states = await presence_handler.get_states(users, as_event=True)
                         to_add.extend(states)
                     else:
 
-                        ev = yield presence_handler.get_state(
+                        ev = await presence_handler.get_state(
                             UserID.from_string(event.state_key), as_event=True
                         )
                         to_add.append(ev)
@@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler):
 
             time_now = self.clock.time_msec()
 
-            chunks = yield self._event_serializer.serialize_events(
+            chunks = await self._event_serializer.serialize_events(
                 events,
                 time_now,
                 as_client_event=as_client_event,
@@ -151,8 +148,7 @@ class EventHandler(BaseHandler):
         super(EventHandler, self).__init__(hs)
         self.storage = hs.get_storage()
 
-    @defer.inlineCallbacks
-    def get_event(self, user, room_id, event_id):
+    async def get_event(self, user, room_id, event_id):
         """Retrieve a single specified event.
 
         Args:
@@ -167,15 +163,15 @@ class EventHandler(BaseHandler):
             AuthError if the user does not have the rights to inspect this
             event.
         """
-        event = yield self.store.get_event(event_id, check_room_id=room_id)
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         if not event:
             return None
 
-        users = yield self.store.get_users_in_room(event.room_id)
+        users = await self.store.get_users_in_room(event.room_id)
         is_peeking = user.to_string() not in users
 
-        filtered = yield filter_events_for_client(
+        filtered = await filter_events_for_client(
             self.storage, user.to_string(), [event], is_peeking=is_peeking
         )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b536d410e5..12751fd8c0 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -22,8 +22,6 @@ from six import iteritems, itervalues
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.logging.context import LoggingContext
 from synapse.push.clientformat import format_push_rules_for_user
@@ -241,8 +239,7 @@ class SyncHandler(object):
             expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
         )
 
-    @defer.inlineCallbacks
-    def wait_for_sync_for_user(
+    async def wait_for_sync_for_user(
         self, sync_config, since_token=None, timeout=0, full_state=False
     ):
         """Get the sync for a client if we have new data for it now. Otherwise
@@ -255,9 +252,9 @@ class SyncHandler(object):
         # not been exceeded (if not part of the group by this point, almost certain
         # auth_blocking will occur)
         user_id = sync_config.user.to_string()
-        yield self.auth.check_auth_blocking(user_id)
+        await self.auth.check_auth_blocking(user_id)
 
-        res = yield self.response_cache.wrap(
+        res = await self.response_cache.wrap(
             sync_config.request_key,
             self._wait_for_sync_for_user,
             sync_config,
@@ -267,8 +264,9 @@ class SyncHandler(object):
         )
         return res
 
-    @defer.inlineCallbacks
-    def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
+    async def _wait_for_sync_for_user(
+        self, sync_config, since_token, timeout, full_state
+    ):
         if since_token is None:
             sync_type = "initial_sync"
         elif full_state:
@@ -283,7 +281,7 @@ class SyncHandler(object):
         if timeout == 0 or since_token is None or full_state:
             # we are going to return immediately, so don't bother calling
             # notifier.wait_for_events.
-            result = yield self.current_sync_for_user(
+            result = await self.current_sync_for_user(
                 sync_config, since_token, full_state=full_state
             )
         else:
@@ -291,7 +289,7 @@ class SyncHandler(object):
             def current_sync_callback(before_token, after_token):
                 return self.current_sync_for_user(sync_config, since_token)
 
-            result = yield self.notifier.wait_for_events(
+            result = await self.notifier.wait_for_events(
                 sync_config.user.to_string(),
                 timeout,
                 current_sync_callback,
@@ -314,15 +312,13 @@ class SyncHandler(object):
         """
         return self.generate_sync_result(sync_config, since_token, full_state)
 
-    @defer.inlineCallbacks
-    def push_rules_for_user(self, user):
+    async def push_rules_for_user(self, user):
         user_id = user.to_string()
-        rules = yield self.store.get_push_rules_for_user(user_id)
+        rules = await self.store.get_push_rules_for_user(user_id)
         rules = format_push_rules_for_user(user, rules)
         return rules
 
-    @defer.inlineCallbacks
-    def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
+    async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
         """Get the ephemeral events for each room the user is in
         Args:
             sync_result_builder(SyncResultBuilder)
@@ -343,7 +339,7 @@ class SyncHandler(object):
             room_ids = sync_result_builder.joined_room_ids
 
             typing_source = self.event_sources.sources["typing"]
-            typing, typing_key = yield typing_source.get_new_events(
+            typing, typing_key = typing_source.get_new_events(
                 user=sync_config.user,
                 from_key=typing_key,
                 limit=sync_config.filter_collection.ephemeral_limit(),
@@ -365,7 +361,7 @@ class SyncHandler(object):
             receipt_key = since_token.receipt_key if since_token else "0"
 
             receipt_source = self.event_sources.sources["receipt"]
-            receipts, receipt_key = yield receipt_source.get_new_events(
+            receipts, receipt_key = await receipt_source.get_new_events(
                 user=sync_config.user,
                 from_key=receipt_key,
                 limit=sync_config.filter_collection.ephemeral_limit(),
@@ -382,8 +378,7 @@ class SyncHandler(object):
 
         return now_token, ephemeral_by_room
 
-    @defer.inlineCallbacks
-    def _load_filtered_recents(
+    async def _load_filtered_recents(
         self,
         room_id,
         sync_config,
@@ -415,10 +410,10 @@ class SyncHandler(object):
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in recents):
-                    current_state_ids = yield self.state.get_current_state_ids(room_id)
+                    current_state_ids = await self.state.get_current_state_ids(room_id)
                     current_state_ids = frozenset(itervalues(current_state_ids))
 
-                recents = yield filter_events_for_client(
+                recents = await filter_events_for_client(
                     self.storage,
                     sync_config.user.to_string(),
                     recents,
@@ -449,14 +444,14 @@ class SyncHandler(object):
                 # Otherwise, we want to return the last N events in the room
                 # in toplogical ordering.
                 if since_key:
-                    events, end_key = yield self.store.get_room_events_stream_for_room(
+                    events, end_key = await self.store.get_room_events_stream_for_room(
                         room_id,
                         limit=load_limit + 1,
                         from_key=since_key,
                         to_key=end_key,
                     )
                 else:
-                    events, end_key = yield self.store.get_recent_events_for_room(
+                    events, end_key = await self.store.get_recent_events_for_room(
                         room_id, limit=load_limit + 1, end_token=end_key
                     )
                 loaded_recents = sync_config.filter_collection.filter_room_timeline(
@@ -468,10 +463,10 @@ class SyncHandler(object):
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in loaded_recents):
-                    current_state_ids = yield self.state.get_current_state_ids(room_id)
+                    current_state_ids = await self.state.get_current_state_ids(room_id)
                     current_state_ids = frozenset(itervalues(current_state_ids))
 
-                loaded_recents = yield filter_events_for_client(
+                loaded_recents = await filter_events_for_client(
                     self.storage,
                     sync_config.user.to_string(),
                     loaded_recents,
@@ -498,8 +493,7 @@ class SyncHandler(object):
             limited=limited or newly_joined_room,
         )
 
-    @defer.inlineCallbacks
-    def get_state_after_event(self, event, state_filter=StateFilter.all()):
+    async def get_state_after_event(self, event, state_filter=StateFilter.all()):
         """
         Get the room state after the given event
 
@@ -511,7 +505,7 @@ class SyncHandler(object):
         Returns:
             A Deferred map from ((type, state_key)->Event)
         """
-        state_ids = yield self.state_store.get_state_ids_for_event(
+        state_ids = await self.state_store.get_state_ids_for_event(
             event.event_id, state_filter=state_filter
         )
         if event.is_state():
@@ -519,8 +513,9 @@ class SyncHandler(object):
             state_ids[(event.type, event.state_key)] = event.event_id
         return state_ids
 
-    @defer.inlineCallbacks
-    def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
+    async def get_state_at(
+        self, room_id, stream_position, state_filter=StateFilter.all()
+    ):
         """ Get the room state at a particular stream position
 
         Args:
@@ -536,13 +531,13 @@ class SyncHandler(object):
         # get_recent_events_for_room operates by topo ordering. This therefore
         # does not reliably give you the state at the given stream position.
         # (https://github.com/matrix-org/synapse/issues/3305)
-        last_events, _ = yield self.store.get_recent_events_for_room(
+        last_events, _ = await self.store.get_recent_events_for_room(
             room_id, end_token=stream_position.room_key, limit=1
         )
 
         if last_events:
             last_event = last_events[-1]
-            state = yield self.get_state_after_event(
+            state = await self.get_state_after_event(
                 last_event, state_filter=state_filter
             )
 
@@ -551,8 +546,7 @@ class SyncHandler(object):
             state = {}
         return state
 
-    @defer.inlineCallbacks
-    def compute_summary(self, room_id, sync_config, batch, state, now_token):
+    async def compute_summary(self, room_id, sync_config, batch, state, now_token):
         """ Works out a room summary block for this room, summarising the number
         of joined members in the room, and providing the 'hero' members if the
         room has no name so clients can consistently name rooms.  Also adds
@@ -574,7 +568,7 @@ class SyncHandler(object):
         # FIXME: we could/should get this from room_stats when matthew/stats lands
 
         # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
-        last_events, _ = yield self.store.get_recent_event_ids_for_room(
+        last_events, _ = await self.store.get_recent_event_ids_for_room(
             room_id, end_token=now_token.room_key, limit=1
         )
 
@@ -582,7 +576,7 @@ class SyncHandler(object):
             return None
 
         last_event = last_events[-1]
-        state_ids = yield self.state_store.get_state_ids_for_event(
+        state_ids = await self.state_store.get_state_ids_for_event(
             last_event.event_id,
             state_filter=StateFilter.from_types(
                 [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -590,7 +584,7 @@ class SyncHandler(object):
         )
 
         # this is heavily cached, thus: fast.
-        details = yield self.store.get_room_summary(room_id)
+        details = await self.store.get_room_summary(room_id)
 
         name_id = state_ids.get((EventTypes.Name, ""))
         canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
@@ -608,12 +602,12 @@ class SyncHandler(object):
         # calculating heroes. Empty strings are falsey, so we check
         # for the "name" value and default to an empty string.
         if name_id:
-            name = yield self.store.get_event(name_id, allow_none=True)
+            name = await self.store.get_event(name_id, allow_none=True)
             if name and name.content.get("name"):
                 return summary
 
         if canonical_alias_id:
-            canonical_alias = yield self.store.get_event(
+            canonical_alias = await self.store.get_event(
                 canonical_alias_id, allow_none=True
             )
             if canonical_alias and canonical_alias.content.get("alias"):
@@ -678,7 +672,7 @@ class SyncHandler(object):
             )
         ]
 
-        missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
+        missing_hero_state = await self.store.get_events(missing_hero_event_ids)
         missing_hero_state = missing_hero_state.values()
 
         for s in missing_hero_state:
@@ -697,8 +691,7 @@ class SyncHandler(object):
             logger.debug("found LruCache for %r", cache_key)
         return cache
 
-    @defer.inlineCallbacks
-    def compute_state_delta(
+    async def compute_state_delta(
         self, room_id, batch, sync_config, since_token, now_token, full_state
     ):
         """ Works out the difference in state between the start of the timeline
@@ -759,16 +752,16 @@ class SyncHandler(object):
 
             if full_state:
                 if batch:
-                    current_state_ids = yield self.state_store.get_state_ids_for_event(
+                    current_state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[-1].event_id, state_filter=state_filter
                     )
 
-                    state_ids = yield self.state_store.get_state_ids_for_event(
+                    state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
 
                 else:
-                    current_state_ids = yield self.get_state_at(
+                    current_state_ids = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -783,13 +776,13 @@ class SyncHandler(object):
                 )
             elif batch.limited:
                 if batch:
-                    state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
+                    state_at_timeline_start = await self.state_store.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
-                    state_at_timeline_start = yield self.get_state_at(
+                    state_at_timeline_start = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -807,19 +800,19 @@ class SyncHandler(object):
                 # about them).
                 state_filter = StateFilter.all()
 
-                state_at_previous_sync = yield self.get_state_at(
+                state_at_previous_sync = await self.get_state_at(
                     room_id, stream_position=since_token, state_filter=state_filter
                 )
 
                 if batch:
-                    current_state_ids = yield self.state_store.get_state_ids_for_event(
+                    current_state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[-1].event_id, state_filter=state_filter
                     )
                 else:
                     # Its not clear how we get here, but empirically we do
                     # (#5407). Logging has been added elsewhere to try and
                     # figure out where this state comes from.
-                    current_state_ids = yield self.get_state_at(
+                    current_state_ids = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -843,7 +836,7 @@ class SyncHandler(object):
                         # So we fish out all the member events corresponding to the
                         # timeline here, and then dedupe any redundant ones below.
 
-                        state_ids = yield self.state_store.get_state_ids_for_event(
+                        state_ids = await self.state_store.get_state_ids_for_event(
                             batch.events[0].event_id,
                             # we only want members!
                             state_filter=StateFilter.from_types(
@@ -883,7 +876,7 @@ class SyncHandler(object):
 
         state = {}
         if state_ids:
-            state = yield self.store.get_events(list(state_ids.values()))
+            state = await self.store.get_events(list(state_ids.values()))
 
         return {
             (e.type, e.state_key): e
@@ -892,10 +885,9 @@ class SyncHandler(object):
             )
         }
 
-    @defer.inlineCallbacks
-    def unread_notifs_for_room_id(self, room_id, sync_config):
+    async def unread_notifs_for_room_id(self, room_id, sync_config):
         with Measure(self.clock, "unread_notifs_for_room_id"):
-            last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
+            last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
                 room_id=room_id,
                 receipt_type="m.read",
@@ -903,7 +895,7 @@ class SyncHandler(object):
 
             notifs = []
             if last_unread_event_id:
-                notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
+                notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
                     room_id, sync_config.user.to_string(), last_unread_event_id
                 )
                 return notifs
@@ -912,8 +904,9 @@ class SyncHandler(object):
         # count is whatever it was last time.
         return None
 
-    @defer.inlineCallbacks
-    def generate_sync_result(self, sync_config, since_token=None, full_state=False):
+    async def generate_sync_result(
+        self, sync_config, since_token=None, full_state=False
+    ):
         """Generates a sync result.
 
         Args:
@@ -928,7 +921,7 @@ class SyncHandler(object):
         # this is due to some of the underlying streams not supporting the ability
         # to query up to a given point.
         # Always use the `now_token` in `SyncResultBuilder`
-        now_token = yield self.event_sources.get_current_token()
+        now_token = await self.event_sources.get_current_token()
 
         logger.info(
             "Calculating sync response for %r between %s and %s",
@@ -944,10 +937,9 @@ class SyncHandler(object):
             # See https://github.com/matrix-org/matrix-doc/issues/1144
             raise NotImplementedError()
         else:
-            joined_room_ids = yield self.get_rooms_for_user_at(
+            joined_room_ids = await self.get_rooms_for_user_at(
                 user_id, now_token.room_stream_id
             )
-
         sync_result_builder = SyncResultBuilder(
             sync_config,
             full_state,
@@ -956,11 +948,11 @@ class SyncHandler(object):
             joined_room_ids=joined_room_ids,
         )
 
-        account_data_by_room = yield self._generate_sync_entry_for_account_data(
+        account_data_by_room = await self._generate_sync_entry_for_account_data(
             sync_result_builder
         )
 
-        res = yield self._generate_sync_entry_for_rooms(
+        res = await self._generate_sync_entry_for_rooms(
             sync_result_builder, account_data_by_room
         )
         newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
@@ -970,13 +962,13 @@ class SyncHandler(object):
             since_token is None and sync_config.filter_collection.blocks_all_presence()
         )
         if self.hs_config.use_presence and not block_all_presence_data:
-            yield self._generate_sync_entry_for_presence(
+            await self._generate_sync_entry_for_presence(
                 sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
             )
 
-        yield self._generate_sync_entry_for_to_device(sync_result_builder)
+        await self._generate_sync_entry_for_to_device(sync_result_builder)
 
-        device_lists = yield self._generate_sync_entry_for_device_list(
+        device_lists = await self._generate_sync_entry_for_device_list(
             sync_result_builder,
             newly_joined_rooms=newly_joined_rooms,
             newly_joined_or_invited_users=newly_joined_or_invited_users,
@@ -987,11 +979,11 @@ class SyncHandler(object):
         device_id = sync_config.device_id
         one_time_key_counts = {}
         if device_id:
-            one_time_key_counts = yield self.store.count_e2e_one_time_keys(
+            one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
 
-        yield self._generate_sync_entry_for_groups(sync_result_builder)
+        await self._generate_sync_entry_for_groups(sync_result_builder)
 
         # debug for https://github.com/matrix-org/synapse/issues/4422
         for joined_room in sync_result_builder.joined:
@@ -1015,18 +1007,17 @@ class SyncHandler(object):
         )
 
     @measure_func("_generate_sync_entry_for_groups")
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_groups(self, sync_result_builder):
+    async def _generate_sync_entry_for_groups(self, sync_result_builder):
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
         now_token = sync_result_builder.now_token
 
         if since_token and since_token.groups_key:
-            results = yield self.store.get_groups_changes_for_user(
+            results = self.store.get_groups_changes_for_user(
                 user_id, since_token.groups_key, now_token.groups_key
             )
         else:
-            results = yield self.store.get_all_groups_for_user(
+            results = await self.store.get_all_groups_for_user(
                 user_id, now_token.groups_key
             )
 
@@ -1059,8 +1050,7 @@ class SyncHandler(object):
         )
 
     @measure_func("_generate_sync_entry_for_device_list")
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_device_list(
+    async def _generate_sync_entry_for_device_list(
         self,
         sync_result_builder,
         newly_joined_rooms,
@@ -1108,32 +1098,32 @@ class SyncHandler(object):
             # room with by looking at all users that have left a room plus users
             # that were in a room we've left.
 
-            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+            users_who_share_room = await self.store.get_users_who_share_room_with_user(
                 user_id
             )
 
             # Step 1a, check for changes in devices of users we share a room with
-            users_that_have_changed = yield self.store.get_users_whose_devices_changed(
+            users_that_have_changed = await self.store.get_users_whose_devices_changed(
                 since_token.device_list_key, users_who_share_room
             )
 
             # Step 1b, check for newly joined rooms
             for room_id in newly_joined_rooms:
-                joined_users = yield self.state.get_current_users_in_room(room_id)
+                joined_users = await self.state.get_current_users_in_room(room_id)
                 newly_joined_or_invited_users.update(joined_users)
 
             # TODO: Check that these users are actually new, i.e. either they
             # weren't in the previous sync *or* they left and rejoined.
             users_that_have_changed.update(newly_joined_or_invited_users)
 
-            user_signatures_changed = yield self.store.get_users_whose_signatures_changed(
+            user_signatures_changed = await self.store.get_users_whose_signatures_changed(
                 user_id, since_token.device_list_key
             )
             users_that_have_changed.update(user_signatures_changed)
 
             # Now find users that we no longer track
             for room_id in newly_left_rooms:
-                left_users = yield self.state.get_current_users_in_room(room_id)
+                left_users = await self.state.get_current_users_in_room(room_id)
                 newly_left_users.update(left_users)
 
             # Remove any users that we still share a room with.
@@ -1143,8 +1133,7 @@ class SyncHandler(object):
         else:
             return DeviceLists(changed=[], left=[])
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_to_device(self, sync_result_builder):
+    async def _generate_sync_entry_for_to_device(self, sync_result_builder):
         """Generates the portion of the sync response. Populates
         `sync_result_builder` with the result.
 
@@ -1165,14 +1154,14 @@ class SyncHandler(object):
             # We only delete messages when a new message comes in, but that's
             # fine so long as we delete them at some point.
 
-            deleted = yield self.store.delete_messages_for_device(
+            deleted = await self.store.delete_messages_for_device(
                 user_id, device_id, since_stream_id
             )
             logger.debug(
                 "Deleted %d to-device messages up to %d", deleted, since_stream_id
             )
 
-            messages, stream_id = yield self.store.get_new_messages_for_device(
+            messages, stream_id = await self.store.get_new_messages_for_device(
                 user_id, device_id, since_stream_id, now_token.to_device_key
             )
 
@@ -1190,8 +1179,7 @@ class SyncHandler(object):
         else:
             sync_result_builder.to_device = []
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_account_data(self, sync_result_builder):
+    async def _generate_sync_entry_for_account_data(self, sync_result_builder):
         """Generates the account data portion of the sync response. Populates
         `sync_result_builder` with the result.
 
@@ -1209,25 +1197,25 @@ class SyncHandler(object):
             (
                 account_data,
                 account_data_by_room,
-            ) = yield self.store.get_updated_account_data_for_user(
+            ) = self.store.get_updated_account_data_for_user(
                 user_id, since_token.account_data_key
             )
 
-            push_rules_changed = yield self.store.have_push_rules_changed_for_user(
+            push_rules_changed = await self.store.have_push_rules_changed_for_user(
                 user_id, int(since_token.push_rules_key)
             )
 
             if push_rules_changed:
-                account_data["m.push_rules"] = yield self.push_rules_for_user(
+                account_data["m.push_rules"] = await self.push_rules_for_user(
                     sync_config.user
                 )
         else:
             (
                 account_data,
                 account_data_by_room,
-            ) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
+            ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
 
-            account_data["m.push_rules"] = yield self.push_rules_for_user(
+            account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
 
@@ -1242,8 +1230,7 @@ class SyncHandler(object):
 
         return account_data_by_room
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_presence(
+    async def _generate_sync_entry_for_presence(
         self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
     ):
         """Generates the presence portion of the sync response. Populates the
@@ -1271,7 +1258,7 @@ class SyncHandler(object):
             presence_key = None
             include_offline = False
 
-        presence, presence_key = yield presence_source.get_new_events(
+        presence, presence_key = await presence_source.get_new_events(
             user=user,
             from_key=presence_key,
             is_guest=sync_config.is_guest,
@@ -1283,12 +1270,12 @@ class SyncHandler(object):
 
         extra_users_ids = set(newly_joined_or_invited_users)
         for room_id in newly_joined_rooms:
-            users = yield self.state.get_current_users_in_room(room_id)
+            users = await self.state.get_current_users_in_room(room_id)
             extra_users_ids.update(users)
         extra_users_ids.discard(user.to_string())
 
         if extra_users_ids:
-            states = yield self.presence_handler.get_states(extra_users_ids)
+            states = await self.presence_handler.get_states(extra_users_ids)
             presence.extend(states)
 
             # Deduplicate the presence entries so that there's at most one per user
@@ -1298,8 +1285,9 @@ class SyncHandler(object):
 
         sync_result_builder.presence = presence
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room):
+    async def _generate_sync_entry_for_rooms(
+        self, sync_result_builder, account_data_by_room
+    ):
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
 
@@ -1321,7 +1309,7 @@ class SyncHandler(object):
         if block_all_room_ephemeral:
             ephemeral_by_room = {}
         else:
-            now_token, ephemeral_by_room = yield self.ephemeral_by_room(
+            now_token, ephemeral_by_room = await self.ephemeral_by_room(
                 sync_result_builder,
                 now_token=sync_result_builder.now_token,
                 since_token=sync_result_builder.since_token,
@@ -1333,16 +1321,16 @@ class SyncHandler(object):
         since_token = sync_result_builder.since_token
         if not sync_result_builder.full_state:
             if since_token and not ephemeral_by_room and not account_data_by_room:
-                have_changed = yield self._have_rooms_changed(sync_result_builder)
+                have_changed = await self._have_rooms_changed(sync_result_builder)
                 if not have_changed:
-                    tags_by_room = yield self.store.get_updated_tags(
+                    tags_by_room = await self.store.get_updated_tags(
                         user_id, since_token.account_data_key
                     )
                     if not tags_by_room:
                         logger.debug("no-oping sync")
                         return [], [], [], []
 
-        ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
+        ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
             "m.ignored_user_list", user_id=user_id
         )
 
@@ -1352,18 +1340,18 @@ class SyncHandler(object):
             ignored_users = frozenset()
 
         if since_token:
-            res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
+            res = await self._get_rooms_changed(sync_result_builder, ignored_users)
             room_entries, invited, newly_joined_rooms, newly_left_rooms = res
 
-            tags_by_room = yield self.store.get_updated_tags(
+            tags_by_room = await self.store.get_updated_tags(
                 user_id, since_token.account_data_key
             )
         else:
-            res = yield self._get_all_rooms(sync_result_builder, ignored_users)
+            res = await self._get_all_rooms(sync_result_builder, ignored_users)
             room_entries, invited, newly_joined_rooms = res
             newly_left_rooms = []
 
-            tags_by_room = yield self.store.get_tags_for_user(user_id)
+            tags_by_room = await self.store.get_tags_for_user(user_id)
 
         def handle_room_entries(room_entry):
             return self._generate_room_entry(
@@ -1376,7 +1364,7 @@ class SyncHandler(object):
                 always_include=sync_result_builder.full_state,
             )
 
-        yield concurrently_execute(handle_room_entries, room_entries, 10)
+        await concurrently_execute(handle_room_entries, room_entries, 10)
 
         sync_result_builder.invited.extend(invited)
 
@@ -1410,8 +1398,7 @@ class SyncHandler(object):
             newly_left_users,
         )
 
-    @defer.inlineCallbacks
-    def _have_rooms_changed(self, sync_result_builder):
+    async def _have_rooms_changed(self, sync_result_builder):
         """Returns whether there may be any new events that should be sent down
         the sync. Returns True if there are.
         """
@@ -1422,7 +1409,7 @@ class SyncHandler(object):
         assert since_token
 
         # Get a list of membership change events that have happened.
-        rooms_changed = yield self.store.get_membership_changes_for_user(
+        rooms_changed = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
@@ -1435,8 +1422,7 @@ class SyncHandler(object):
                 return True
         return False
 
-    @defer.inlineCallbacks
-    def _get_rooms_changed(self, sync_result_builder, ignored_users):
+    async def _get_rooms_changed(self, sync_result_builder, ignored_users):
         """Gets the the changes that have happened since the last sync.
 
         Args:
@@ -1461,7 +1447,7 @@ class SyncHandler(object):
         assert since_token
 
         # Get a list of membership change events that have happened.
-        rooms_changed = yield self.store.get_membership_changes_for_user(
+        rooms_changed = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
@@ -1499,11 +1485,11 @@ class SyncHandler(object):
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = yield self.get_state_at(room_id, since_token)
+                old_state_ids = await self.get_state_at(room_id, since_token)
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
                 if old_mem_ev_id:
-                    old_mem_ev = yield self.store.get_event(
+                    old_mem_ev = await self.store.get_event(
                         old_mem_ev_id, allow_none=True
                     )
 
@@ -1536,13 +1522,13 @@ class SyncHandler(object):
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = yield self.get_state_at(room_id, since_token)
+                        old_state_ids = await self.get_state_at(room_id, since_token)
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
                         )
                         old_mem_ev = None
                         if old_mem_ev_id:
-                            old_mem_ev = yield self.store.get_event(
+                            old_mem_ev = await self.store.get_event(
                                 old_mem_ev_id, allow_none=True
                             )
                     if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
@@ -1566,7 +1552,7 @@ class SyncHandler(object):
 
             if leave_events:
                 leave_event = leave_events[-1]
-                leave_stream_token = yield self.store.get_stream_token_for_event(
+                leave_stream_token = await self.store.get_stream_token_for_event(
                     leave_event.event_id
                 )
                 leave_token = since_token.copy_and_replace(
@@ -1603,7 +1589,7 @@ class SyncHandler(object):
         timeline_limit = sync_config.filter_collection.timeline_limit()
 
         # Get all events for rooms we're currently joined to.
-        room_to_events = yield self.store.get_room_events_stream_for_rooms(
+        room_to_events = await self.store.get_room_events_stream_for_rooms(
             room_ids=sync_result_builder.joined_room_ids,
             from_key=since_token.room_key,
             to_key=now_token.room_key,
@@ -1652,8 +1638,7 @@ class SyncHandler(object):
 
         return room_entries, invited, newly_joined_rooms, newly_left_rooms
 
-    @defer.inlineCallbacks
-    def _get_all_rooms(self, sync_result_builder, ignored_users):
+    async def _get_all_rooms(self, sync_result_builder, ignored_users):
         """Returns entries for all rooms for the user.
 
         Args:
@@ -1677,7 +1662,7 @@ class SyncHandler(object):
             Membership.BAN,
         )
 
-        room_list = yield self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_user_where_membership_is(
             user_id=user_id, membership_list=membership_list
         )
 
@@ -1700,7 +1685,7 @@ class SyncHandler(object):
             elif event.membership == Membership.INVITE:
                 if event.sender in ignored_users:
                     continue
-                invite = yield self.store.get_event(event.event_id)
+                invite = await self.store.get_event(event.event_id)
                 invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
             elif event.membership in (Membership.LEAVE, Membership.BAN):
                 # Always send down rooms we were banned or kicked from.
@@ -1726,8 +1711,7 @@ class SyncHandler(object):
 
         return room_entries, invited, []
 
-    @defer.inlineCallbacks
-    def _generate_room_entry(
+    async def _generate_room_entry(
         self,
         sync_result_builder,
         ignored_users,
@@ -1769,7 +1753,7 @@ class SyncHandler(object):
         since_token = room_builder.since_token
         upto_token = room_builder.upto_token
 
-        batch = yield self._load_filtered_recents(
+        batch = await self._load_filtered_recents(
             room_id,
             sync_config,
             now_token=upto_token,
@@ -1796,7 +1780,7 @@ class SyncHandler(object):
         # tag was added by synapse e.g. for server notice rooms.
         if full_state:
             user_id = sync_result_builder.sync_config.user.to_string()
-            tags = yield self.store.get_tags_for_room(user_id, room_id)
+            tags = await self.store.get_tags_for_room(user_id, room_id)
 
             # If there aren't any tags, don't send the empty tags list down
             # sync
@@ -1821,7 +1805,7 @@ class SyncHandler(object):
         ):
             return
 
-        state = yield self.compute_state_delta(
+        state = await self.compute_state_delta(
             room_id, batch, sync_config, since_token, now_token, full_state=full_state
         )
 
@@ -1844,7 +1828,7 @@ class SyncHandler(object):
             )
             or since_token is None
         ):
-            summary = yield self.compute_summary(
+            summary = await self.compute_summary(
                 room_id, sync_config, batch, state, now_token
             )
 
@@ -1861,7 +1845,7 @@ class SyncHandler(object):
             )
 
             if room_sync or always_include:
-                notifs = yield self.unread_notifs_for_room_id(room_id, sync_config)
+                notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
                 if notifs is not None:
                     unread_notifications["notification_count"] = notifs["notify_count"]
@@ -1887,8 +1871,7 @@ class SyncHandler(object):
         else:
             raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
-    @defer.inlineCallbacks
-    def get_rooms_for_user_at(self, user_id, stream_ordering):
+    async def get_rooms_for_user_at(self, user_id, stream_ordering):
         """Get set of joined rooms for a user at the given stream ordering.
 
         The stream ordering *must* be recent, otherwise this may throw an
@@ -1903,7 +1886,7 @@ class SyncHandler(object):
             Deferred[frozenset[str]]: Set of room_ids the user is in at given
             stream_ordering.
         """
-        joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id)
+        joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
 
         joined_room_ids = set()
 
@@ -1921,10 +1904,10 @@ class SyncHandler(object):
 
             logger.info("User joined room after current token: %s", room_id)
 
-            extrems = yield self.store.get_forward_extremeties_for_room(
+            extrems = await self.store.get_forward_extremeties_for_room(
                 room_id, stream_ordering
             )
-            users_in_room = yield self.state.get_current_users_in_room(room_id, extrems)
+            users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
             if user_id in users_in_room:
                 joined_room_ids.add(room_id)
 
diff --git a/synapse/notifier.py b/synapse/notifier.py
index af161a81d7..5f5f765bea 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -304,8 +304,7 @@ class Notifier(object):
         without waking up any of the normal user event streams"""
         self.notify_replication()
 
-    @defer.inlineCallbacks
-    def wait_for_events(
+    async def wait_for_events(
         self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
     ):
         """Wait until the callback returns a non empty response or the
@@ -313,9 +312,9 @@ class Notifier(object):
         """
         user_stream = self.user_to_user_stream.get(user_id)
         if user_stream is None:
-            current_token = yield self.event_sources.get_current_token()
+            current_token = await self.event_sources.get_current_token()
             if room_ids is None:
-                room_ids = yield self.store.get_rooms_for_user(user_id)
+                room_ids = await self.store.get_rooms_for_user(user_id)
             user_stream = _NotifierUserStream(
                 user_id=user_id,
                 rooms=room_ids,
@@ -344,11 +343,11 @@ class Notifier(object):
                         self.hs.get_reactor(),
                     )
                     with PreserveLoggingContext():
-                        yield listener.deferred
+                        await listener.deferred
 
                     current_token = user_stream.current_token
 
-                    result = yield callback(prev_token, current_token)
+                    result = await callback(prev_token, current_token)
                     if result:
                         break
 
@@ -364,12 +363,11 @@ class Notifier(object):
             # This happened if there was no timeout or if the timeout had
             # already expired.
             current_token = user_stream.current_token
-            result = yield callback(prev_token, current_token)
+            result = await callback(prev_token, current_token)
 
         return result
 
-    @defer.inlineCallbacks
-    def get_events_for(
+    async def get_events_for(
         self,
         user,
         pagination_config,
@@ -391,15 +389,14 @@ class Notifier(object):
         """
         from_token = pagination_config.from_token
         if not from_token:
-            from_token = yield self.event_sources.get_current_token()
+            from_token = await self.event_sources.get_current_token()
 
         limit = pagination_config.limit
 
-        room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
+        room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
         is_peeking = not is_joined
 
-        @defer.inlineCallbacks
-        def check_for_updates(before_token, after_token):
+        async def check_for_updates(before_token, after_token):
             if not after_token.is_after(before_token):
                 return EventStreamResult([], (from_token, from_token))
 
@@ -415,7 +412,7 @@ class Notifier(object):
                 if only_keys and name not in only_keys:
                     continue
 
-                new_events, new_key = yield source.get_new_events(
+                new_events, new_key = await source.get_new_events(
                     user=user,
                     from_key=getattr(from_token, keyname),
                     limit=limit,
@@ -425,7 +422,7 @@ class Notifier(object):
                 )
 
                 if name == "room":
-                    new_events = yield filter_events_for_client(
+                    new_events = await filter_events_for_client(
                         self.storage,
                         user.to_string(),
                         new_events,
@@ -461,7 +458,7 @@ class Notifier(object):
                 user_id_for_stream,
             )
 
-        result = yield self.wait_for_events(
+        result = await self.wait_for_events(
             user_id_for_stream,
             timeout,
             check_for_updates,
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 3286804322..63ddaaba87 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
 import logging
 from functools import wraps
 
@@ -64,12 +65,22 @@ def measure_func(name=None):
     def wrapper(func):
         block_name = func.__name__ if name is None else name
 
-        @wraps(func)
-        @defer.inlineCallbacks
-        def measured_func(self, *args, **kwargs):
-            with Measure(self.clock, block_name):
-                r = yield func(self, *args, **kwargs)
-            return r
+        if inspect.iscoroutinefunction(func):
+
+            @wraps(func)
+            async def measured_func(self, *args, **kwargs):
+                with Measure(self.clock, block_name):
+                    r = await func(self, *args, **kwargs)
+                return r
+
+        else:
+
+            @wraps(func)
+            @defer.inlineCallbacks
+            def measured_func(self, *args, **kwargs):
+                with Measure(self.clock, block_name):
+                    r = yield func(self, *args, **kwargs)
+                return r
 
         return measured_func
 
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 31f54bbd7d..758ee071a5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,54 +12,53 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.internet import defer
 
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
-from synapse.handlers.sync import SyncConfig, SyncHandler
+from synapse.handlers.sync import SyncConfig
 from synapse.types import UserID
 
 import tests.unittest
 import tests.utils
-from tests.utils import setup_test_homeserver
 
 
-class SyncTestCase(tests.unittest.TestCase):
+class SyncTestCase(tests.unittest.HomeserverTestCase):
     """ Tests Sync Handler. """
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.sync_handler = SyncHandler(self.hs)
+    def prepare(self, reactor, clock, hs):
+        self.hs = hs
+        self.sync_handler = self.hs.get_sync_handler()
         self.store = self.hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_wait_for_sync_for_user_auth_blocking(self):
 
         user_id1 = "@user1:server"
         user_id2 = "@user2:server"
         sync_config = self._generate_sync_config(user_id1)
 
+        self.reactor.advance(100)  # So we get not 0 time
         self.hs.config.limit_usage_by_mau = True
         self.hs.config.max_mau_value = 1
 
         # Check that the happy case does not throw errors
-        yield self.store.upsert_monthly_active_user(user_id1)
-        yield self.sync_handler.wait_for_sync_for_user(sync_config)
+        self.get_success(self.store.upsert_monthly_active_user(user_id1))
+        self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
 
         # Test that global lock works
         self.hs.config.hs_disabled = True
-        with self.assertRaises(ResourceLimitError) as e:
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        e = self.get_failure(
+            self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+        )
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
         self.hs.config.hs_disabled = False
 
         sync_config = self._generate_sync_config(user_id2)
 
-        with self.assertRaises(ResourceLimitError) as e:
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        e = self.get_failure(
+            self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+        )
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
     def _generate_sync_config(self, user_id):
         return SyncConfig(
diff --git a/tests/unittest.py b/tests/unittest.py
index 295573bc46..a1bdd963e6 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -18,6 +18,7 @@
 import gc
 import hashlib
 import hmac
+import inspect
 import logging
 import time
 
@@ -25,7 +26,7 @@ from mock import Mock
 
 from canonicaljson import json
 
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
 from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
 
@@ -415,6 +416,8 @@ class HomeserverTestCase(TestCase):
         self.reactor.pump([by] * 100)
 
     def get_success(self, d, by=0.0):
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump(by=by)
@@ -424,6 +427,8 @@ class HomeserverTestCase(TestCase):
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump()