summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/devices.py53
-rw-r--r--synapse/storage/end_to_end_keys.py27
-rw-r--r--synapse/storage/events.py2
-rw-r--r--synapse/storage/push_rule.py7
-rw-r--r--synapse/storage/roommember.py7
5 files changed, 70 insertions, 26 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index ec68e39f1e..cc3cdf2ebc 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -248,17 +248,31 @@ class DeviceStore(SQLBaseStore):
 
     def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
                                                    content, stream_id):
-        self._simple_upsert_txn(
-            txn,
-            table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            values={
-                "content": json.dumps(content),
-            }
-        )
+        if content.get("deleted"):
+            self._simple_delete_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+            )
+
+            txn.call_after(
+                self.device_id_exists_cache.invalidate, (user_id, device_id,)
+            )
+        else:
+            self._simple_upsert_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "content": json.dumps(content),
+                }
+            )
 
         txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
         txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -366,7 +380,7 @@ class DeviceStore(SQLBaseStore):
             now_stream_id = max(stream_id for stream_id in itervalues(query_map))
 
         devices = self._get_e2e_device_keys_txn(
-            txn, query_map.keys(), include_all_devices=True
+            txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
         )
 
         prev_sent_id_sql = """
@@ -393,12 +407,15 @@ class DeviceStore(SQLBaseStore):
 
                 prev_id = stream_id
 
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = json.loads(key_json)
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
+                if device is not None:
+                    key_json = device.get("key_json", None)
+                    if key_json:
+                        result["keys"] = json.loads(key_json)
+                    device_display_name = device.get("device_display_name", None)
+                    if device_display_name:
+                        result["device_display_name"] = device_display_name
+                else:
+                    result["deleted"] = True
 
                 results.append(result)
 
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 7ae5c65482..523b4360c3 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -64,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_e2e_device_keys(self, query_list, include_all_devices=False):
+    def get_e2e_device_keys(
+        self, query_list, include_all_devices=False,
+        include_deleted_devices=False,
+    ):
         """Fetch a list of device keys.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
             include_all_devices (bool): whether to include entries for devices
                 that don't have device keys
+            include_deleted_devices (bool): whether to include null entries for
+                devices which no longer exist (but were in the query_list).
+                This option only takes effect if include_all_devices is true.
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -79,7 +85,7 @@ class EndToEndKeyStore(SQLBaseStore):
 
         results = yield self.runInteraction(
             "get_e2e_device_keys", self._get_e2e_device_keys_txn,
-            query_list, include_all_devices,
+            query_list, include_all_devices, include_deleted_devices,
         )
 
         for user_id, device_keys in iteritems(results):
@@ -88,10 +94,19 @@ class EndToEndKeyStore(SQLBaseStore):
 
         defer.returnValue(results)
 
-    def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
+    def _get_e2e_device_keys_txn(
+        self, txn, query_list, include_all_devices=False,
+        include_deleted_devices=False,
+    ):
         query_clauses = []
         query_params = []
 
+        if include_all_devices is False:
+            include_deleted_devices = False
+
+        if include_deleted_devices:
+            deleted_devices = set(query_list)
+
         for (user_id, device_id) in query_list:
             query_clause = "user_id = ?"
             query_params.append(user_id)
@@ -119,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore):
 
         result = {}
         for row in rows:
+            if include_deleted_devices:
+                deleted_devices.remove((row["user_id"], row["device_id"]))
             result.setdefault(row["user_id"], {})[row["device_id"]] = row
 
+        if include_deleted_devices:
+            for user_id, device_id in deleted_devices:
+                result.setdefault(user_id, {})[device_id] = None
+
         return result
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 4ff0fdc4ab..bf4f3ee92a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -549,7 +549,7 @@ class EventsStore(EventsWorkerStore):
             if ctx.state_group in state_groups_map:
                 continue
 
-            state_groups_map[ctx.state_group] = ctx.current_state_ids
+            state_groups_map[ctx.state_group] = yield ctx.get_current_state_ids(self)
 
         # We need to map the event_ids to their state groups. First, let's
         # check if the event is one we're persisting, in which case we can
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index d25b39ec02..6a5028961d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -185,6 +185,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
 
         defer.returnValue(results)
 
+    @defer.inlineCallbacks
     def bulk_get_push_rules_for_room(self, event, context):
         state_group = context.state_group
         if not state_group:
@@ -194,9 +195,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        return self._bulk_get_push_rules_for_room(
-            event.room_id, state_group, context.current_state_ids, event=event
+        current_state_ids = yield context.get_current_state_ids(self)
+        result = yield self._bulk_get_push_rules_for_room(
+            event.room_id, state_group, current_state_ids, event=event
         )
+        defer.returnValue(result)
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9eb5b18144..01697ab2c9 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -232,6 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         defer.returnValue(user_who_share_room)
 
+    @defer.inlineCallbacks
     def get_joined_users_from_context(self, event, context):
         state_group = context.state_group
         if not state_group:
@@ -241,11 +242,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        return self._get_joined_users_from_context(
-            event.room_id, state_group, context.current_state_ids,
+        current_state_ids = yield context.get_current_state_ids(self)
+        result = yield self._get_joined_users_from_context(
+            event.room_id, state_group, current_state_ids,
             event=event,
             context=context,
         )
+        defer.returnValue(result)
 
     def get_joined_users_from_state(self, room_id, state_entry):
         state_group = state_entry.state_group