summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6499.bugfix1
-rw-r--r--changelog.d/6501.misc1
-rw-r--r--changelog.d/6502.removal1
-rw-r--r--changelog.d/6505.misc1
-rw-r--r--changelog.d/6506.misc1
-rw-r--r--changelog.d/6514.bugfix1
-rw-r--r--synapse/event_auth.py8
-rw-r--r--synapse/federation/federation_client.py44
-rw-r--r--synapse/handlers/initial_sync.py19
-rw-r--r--synapse/logging/context.py11
-rw-r--r--synapse/storage/data_stores/main/client_ips.py8
-rw-r--r--synapse/storage/data_stores/main/events.py23
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py8
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql4
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql16
-rw-r--r--synapse/util/caches/snapshot_cache.py94
-rw-r--r--tests/storage/test_client_ips.py49
-rw-r--r--tests/util/test_logcontext.py24
-rw-r--r--tests/util/test_snapshot_cache.py63
20 files changed, 140 insertions, 238 deletions
diff --git a/changelog.d/6499.bugfix b/changelog.d/6499.bugfix
new file mode 100644
index 0000000000..299feba0f8
--- /dev/null
+++ b/changelog.d/6499.bugfix
@@ -0,0 +1 @@
+Fix support for SQLite 3.7.
diff --git a/changelog.d/6501.misc b/changelog.d/6501.misc
new file mode 100644
index 0000000000..255f45a9c3
--- /dev/null
+++ b/changelog.d/6501.misc
@@ -0,0 +1 @@
+Refactor get_events_from_store_or_dest to return a dict.
diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal
new file mode 100644
index 0000000000..0b72261d58
--- /dev/null
+++ b/changelog.d/6502.removal
@@ -0,0 +1 @@
+Remove redundant code from event authorisation implementation.
diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc
new file mode 100644
index 0000000000..3a75b2d9dd
--- /dev/null
+++ b/changelog.d/6505.misc
@@ -0,0 +1 @@
+Make `make_deferred_yieldable` to work with async/await.
diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc
new file mode 100644
index 0000000000..99d7a70bcf
--- /dev/null
+++ b/changelog.d/6506.misc
@@ -0,0 +1 @@
+Remove `SnapshotCache` in favour of `ResponseCache`.
diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix
new file mode 100644
index 0000000000..6dc1985c24
--- /dev/null
+++ b/changelog.d/6514.bugfix
@@ -0,0 +1 @@
+Fix race which occasionally caused deleted devices to reappear.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ec3243b27b..c940b84470 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
     Returns:
          if the auth checks pass.
     """
+    assert isinstance(auth_events, dict)
+
     if do_size_check:
         _check_size_limits(event)
 
@@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
             if not event.signatures.get(event_id_domain):
                 raise AuthError(403, "Event not signed by sending server")
 
-    if auth_events is None:
-        # Oh, we don't know what the state of the room was, so we
-        # are trusting that this is allowed (at least for now)
-        logger.warning("Trusting event: %s", event.event_id)
-        return
-
     if event.type == EventTypes.Create:
         sender_domain = get_domain_from_id(event.sender)
         room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 709449c9e3..73e1dda6a3 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
 import itertools
 import logging
 
-from six.moves import range
-
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -41,7 +39,7 @@ from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.utils import log_function
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -331,10 +329,12 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
-        fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
-            destination, room_id, set(state_event_ids + auth_event_ids)
+        desired_events = set(state_event_ids + auth_event_ids)
+        event_map = yield self.get_events_from_store_or_dest(
+            destination, room_id, desired_events
         )
 
+        failed_to_fetch = desired_events - event_map.keys()
         if failed_to_fetch:
             logger.warning(
                 "Failed to fetch missing state/auth events for %s: %s",
@@ -342,8 +342,6 @@ class FederationClient(FederationBase):
                 failed_to_fetch,
             )
 
-        event_map = {ev.event_id: ev for ev in fetched_events}
-
         pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
         auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
 
@@ -358,23 +356,18 @@ class FederationClient(FederationBase):
         Args:
             destination (str)
             room_id (str)
-            event_ids (list)
+            event_ids (Iterable[str])
 
         Returns:
-            Deferred: A deferred resolving to a 2-tuple where the first is a list of
-            events and the second is a list of event ids that we failed to fetch.
+            Deferred[dict[str, EventBase]]: A deferred resolving to a map
+            from event_id to event
         """
-        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
-        signed_events = list(seen_events.values())
-
-        failed_to_fetch = set()
+        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
 
-        missing_events = set(event_ids)
-        for k in seen_events:
-            missing_events.discard(k)
+        missing_events = set(event_ids) - fetched_events.keys()
 
         if not missing_events:
-            return signed_events, failed_to_fetch
+            return fetched_events
 
         logger.debug(
             "Fetching unknown state/auth events %s for room %s",
@@ -384,11 +377,8 @@ class FederationClient(FederationBase):
 
         room_version = yield self.store.get_room_version(room_id)
 
-        batch_size = 20
-        missing_events = list(missing_events)
-        for i in range(0, len(missing_events), batch_size):
-            batch = set(missing_events[i : i + batch_size])
-
+        # XXX 20 requests at once? really?
+        for batch in batch_iter(missing_events, 20):
             deferreds = [
                 run_in_background(
                     self.get_pdu,
@@ -404,13 +394,9 @@ class FederationClient(FederationBase):
             )
             for success, result in res:
                 if success and result:
-                    signed_events.append(result)
-                    batch.discard(result.event_id)
-
-            # We removed all events we successfully fetched from `batch`
-            failed_to_fetch.update(batch)
+                    fetched_events[result.event_id] = result
 
-        return signed_events, failed_to_fetch
+        return fetched_events
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..73c110a92b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = SnapshotCache()
+        self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -79,17 +79,14 @@ class InitialSyncHandler(BaseHandler):
             as_client_event,
             include_archived,
         )
-        now_ms = self.clock.time_msec()
-        result = self.snapshot_cache.get(now_ms, key)
-        if result is not None:
-            return result
 
-        return self.snapshot_cache.set(
-            now_ms,
+        return self.snapshot_cache.wrap(
             key,
-            self._snapshot_all_rooms(
-                user_id, pagin_config, as_client_event, include_archived
-            ),
+            self._snapshot_all_rooms,
+            user_id,
+            pagin_config,
+            as_client_event,
+            include_archived,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..6747f29e6a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
 See doc/log_contexts.rst for details on how this works.
 """
 
+import inspect
 import logging
 import threading
 import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
 
 
 def make_deferred_yieldable(deferred):
-    """Given a deferred, make it follow the Synapse logcontext rules:
+    """Given a deferred (or coroutine), make it follow the Synapse logcontext
+    rules:
 
     If the deferred has completed (or is not actually a Deferred), essentially
     does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
+    if inspect.isawaitable(deferred):
+        # If we're given a coroutine we convert it to a deferred so that we
+        # run it and find out if it immediately finishes, it it does then we
+        # don't need to fiddle with log contexts at all and can return
+        # immediately.
+        deferred = defer.ensureDeferred(deferred)
+
     if not isinstance(deferred, defer.Deferred):
         return deferred
 
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..add3037b69 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Technically an access token might not be associated with
                 # a device so we need to check.
                 if device_id:
-                    self.db.simple_upsert_txn(
+                    # this is always an update rather than an upsert: the row should
+                    # already exist, and if it doesn't, that may be because it has been
+                    # deleted, and we don't want to re-create it.
+                    self.db.simple_update_txn(
                         txn,
                         table="devices",
                         keyvalues={"user_id": user_id, "device_id": device_id},
-                        values={
+                        updatevalues={
                             "user_agent": user_agent,
                             "last_seen": last_seen,
                             "ip": ip,
                         },
-                        lock=False,
                     )
             except Exception as e:
                 # Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index da1529f6ea..998bba1aad 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1039,20 +1039,25 @@ class EventsStore(
             },
         )
 
-    @defer.inlineCallbacks
-    def _censor_redactions(self):
+    async def _censor_redactions(self):
         """Censors all redactions older than the configured period that haven't
         been censored yet.
 
         By censor we mean update the event_json table with the redacted event.
-
-        Returns:
-            Deferred
         """
 
         if self.hs.config.redaction_retention_period is None:
             return
 
+        if not (
+            await self.db.updates.has_completed_background_update(
+                "redactions_have_censored_ts_idx"
+            )
+        ):
+            # We don't want to run this until the appropriate index has been
+            # created.
+            return
+
         before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
 
         # We fetch all redactions that:
@@ -1074,15 +1079,15 @@ class EventsStore(
             LIMIT ?
         """
 
-        rows = yield self.db.execute(
+        rows = await self.db.execute(
             "_censor_redactions_fetch", None, sql, before_ts, 100
         )
 
         updates = []
 
         for redaction_id, event_id in rows:
-            redaction_event = yield self.get_event(redaction_id, allow_none=True)
-            original_event = yield self.get_event(
+            redaction_event = await self.get_event(redaction_id, allow_none=True)
+            original_event = await self.get_event(
                 event_id, allow_rejected=True, allow_none=True
             )
 
@@ -1115,7 +1120,7 @@ class EventsStore(
                     updatevalues={"have_censored": True},
                 )
 
-        yield self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
 
     def _censor_event_txn(self, txn, event_id, pruned_json):
         """Censor an event by replacing its JSON in the event_json table with the
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index efee17b929..5177b71016 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -90,6 +90,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             "event_store_labels", self._event_store_labels
         )
 
+        self.db.updates.register_background_index_update(
+            "redactions_have_censored_ts_idx",
+            index_name="redactions_have_censored_ts",
+            table="redactions",
+            columns=["received_ts"],
+            where_clause="NOT have_censored",
+        )
+
     @defer.inlineCallbacks
     def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
index fe51b02309..ea95db0ed7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -14,4 +14,3 @@
  */
 
 ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
-CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
index 77a5eca499..49ce35d794 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -14,7 +14,9 @@
  */
 
 ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
-CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored;
 
 INSERT INTO background_updates (update_name, progress_json) VALUES
   ('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
-    """Cache for snapshots like the response of /initialSync.
-    The response of initialSync only has to be a recent snapshot of the
-    server state. It shouldn't matter to clients if it is a few minutes out
-    of date.
-
-    This caches a deferred response. Until the deferred completes it will be
-    returned from the cache. This means that if the client retries the request
-    while the response is still being computed, that original response will be
-    used rather than trying to compute a new response.
-
-    Once the deferred completes it will removed from the cache after 5 minutes.
-    We delay removing it from the cache because a client retrying its request
-    could race with us finishing computing the response.
-
-    Rather than tracking precisely how long something has been in the cache we
-    keep two generations of completed responses. Every 5 minutes discard the
-    old generation, move the new generation to the old generation, and set the
-    new generation to be empty. This means that a result will be in the cache
-    somewhere between 5 and 10 minutes.
-    """
-
-    DURATION_MS = 5 * 60 * 1000  # Cache results for 5 minutes.
-
-    def __init__(self):
-        self.pending_result_cache = {}  # Request that haven't finished yet.
-        self.prev_result_cache = {}  # The older requests that have finished.
-        self.next_result_cache = {}  # The newer requests that have finished.
-        self.time_last_rotated_ms = 0
-
-    def rotate(self, time_now_ms):
-        # Rotate once if the cache duration has passed since the last rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms += self.DURATION_MS
-
-        # Rotate again if the cache duration has passed twice since the last
-        # rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms = time_now_ms
-
-    def get(self, time_now_ms, key):
-        self.rotate(time_now_ms)
-        # This cache is intended to deduplicate requests, so we expect it to be
-        # missed most of the time. So we just lookup the key in all of the
-        # dictionaries rather than trying to short circuit the lookup if the
-        # key is found.
-        result = self.prev_result_cache.get(key)
-        result = self.next_result_cache.get(key, result)
-        result = self.pending_result_cache.get(key, result)
-        if result is not None:
-            return result.observe()
-        else:
-            return None
-
-    def set(self, time_now_ms, key, deferred):
-        self.rotate(time_now_ms)
-
-        result = ObservableDeferred(deferred)
-
-        self.pending_result_cache[key] = result
-
-        def shuffle_along(r):
-            # When the deferred completes we shuffle it along to the first
-            # generation of the result cache. So that it will eventually
-            # expire from the rotation of that cache.
-            self.next_result_cache[key] = result
-            self.pending_result_cache.pop(key, None)
-            return r
-
-        result.addBoth(shuffle_along)
-
-        return result.observe()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index fc279340d4..bf674dd184 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(12345678)
 
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10)
 
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
@@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                 self.store.db.updates.do_next_background_update(100), by=0.1
             )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
-
         # Force persisting to disk
         self.reactor.advance(200)
 
@@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(
             self.store.db.simple_update(
                 table="devices",
-                keyvalues={"user_id": user_id, "device_id": "device_id"},
+                keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
                 desc="test_devices_last_seen_bg_update",
             )
@@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get nulls when querying
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": None,
                 "user_agent": None,
                 "last_seen": None,
@@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get the correct result again
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
@@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                 self.store.db.updates.do_next_background_update(100), by=0.1
             )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                     "access_token": "access_token",
                     "ip": "ip",
                     "user_agent": "user_agent",
-                    "device_id": "device_id",
+                    "device_id": device_id,
                     "last_seen": 0,
                 }
             ],
@@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But we should still get the correct values for the device
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.request, "foo-bar")
 
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_with_await(self):
+        # an async function which retuns an incomplete coroutine, but doesn't
+        # follow the synapse rules.
+
+        async def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            await d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
-    def setUp(self):
-        self.cache = SnapshotCache()
-        self.cache.DURATION_MS = 1
-
-    def test_get_set(self):
-        # Check that getting a missing key returns None
-        self.assertEquals(self.cache.get(0, "key"), None)
-
-        # Check that setting a key with a deferred returns
-        # a deferred that resolves when the initial deferred does
-        d = Deferred()
-        set_result = self.cache.set(0, "key", d)
-        self.assertIsNotNone(set_result)
-        self.assertFalse(set_result.called)
-
-        # Check that getting the key before the deferred has resolved
-        # returns a deferred that resolves when the initial deferred does.
-        get_result_at_10 = self.cache.get(10, "key")
-        self.assertIsNotNone(get_result_at_10)
-        self.assertFalse(get_result_at_10.called)
-
-        # Check that the returned deferreds resolve when the initial deferred
-        # does.
-        d.callback("v")
-        self.assertTrue(set_result.called)
-        self.assertTrue(get_result_at_10.called)
-
-        # Check that getting the key after the deferred has resolved
-        # before the cache expires returns a resolved deferred.
-        get_result_at_11 = self.cache.get(11, "key")
-        self.assertIsNotNone(get_result_at_11)
-        if isinstance(get_result_at_11, Deferred):
-            # The cache may return the actual result rather than a deferred
-            self.assertTrue(get_result_at_11.called)
-
-        # Check that getting the key after the deferred has resolved
-        # after the cache expires returns None
-        get_result_at_12 = self.cache.get(12, "key")
-        self.assertIsNone(get_result_at_12)