summary refs log tree commit diff
path: root/synapse/storage/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/devices.py')
-rw-r--r--synapse/storage/devices.py225
1 files changed, 98 insertions, 127 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index e716dc1437..fd869b934c 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -67,7 +67,7 @@ class DeviceWorkerStore(SQLBaseStore):
             table="devices",
             keyvalues={"user_id": user_id},
             retcols=("user_id", "device_id", "display_name"),
-            desc="get_devices_by_user"
+            desc="get_devices_by_user",
         )
 
         defer.returnValue({d["device_id"]: d for d in devices})
@@ -87,21 +87,23 @@ class DeviceWorkerStore(SQLBaseStore):
             return (now_stream_id, [])
 
         return self.runInteraction(
-            "get_devices_by_remote", self._get_devices_by_remote_txn,
-            destination, from_stream_id, now_stream_id,
+            "get_devices_by_remote",
+            self._get_devices_by_remote_txn,
+            destination,
+            from_stream_id,
+            now_stream_id,
         )
 
-    def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
-                                   now_stream_id):
+    def _get_devices_by_remote_txn(
+        self, txn, destination, from_stream_id, now_stream_id
+    ):
         sql = """
             SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
             GROUP BY user_id, device_id
             LIMIT 20
         """
-        txn.execute(
-            sql, (destination, from_stream_id, now_stream_id, False)
-        )
+        txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
 
         # maps (user_id, device_id) -> stream_id
         query_map = {(r[0], r[1]): r[2] for r in txn}
@@ -112,7 +114,10 @@ class DeviceWorkerStore(SQLBaseStore):
             now_stream_id = max(stream_id for stream_id in itervalues(query_map))
 
         devices = self._get_e2e_device_keys_txn(
-            txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
+            txn,
+            query_map.keys(),
+            include_all_devices=True,
+            include_deleted_devices=True,
         )
 
         prev_sent_id_sql = """
@@ -157,8 +162,10 @@ class DeviceWorkerStore(SQLBaseStore):
         """Mark that updates have successfully been sent to the destination.
         """
         return self.runInteraction(
-            "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
-            destination, stream_id,
+            "mark_as_sent_devices_by_remote",
+            self._mark_as_sent_devices_by_remote_txn,
+            destination,
+            stream_id,
         )
 
     def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
@@ -173,7 +180,7 @@ class DeviceWorkerStore(SQLBaseStore):
             WHERE destination = ? AND o.stream_id <= ?
             GROUP BY user_id
         """
-        txn.execute(sql, (destination, stream_id,))
+        txn.execute(sql, (destination, stream_id))
         rows = txn.fetchall()
 
         sql = """
@@ -181,16 +188,14 @@ class DeviceWorkerStore(SQLBaseStore):
             SET stream_id = ?
             WHERE destination = ? AND user_id = ?
         """
-        txn.executemany(
-            sql, ((row[1], destination, row[0],) for row in rows if row[2])
-        )
+        txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
 
         sql = """
             INSERT INTO device_lists_outbound_last_success
             (destination, user_id, stream_id) VALUES (?, ?, ?)
         """
         txn.executemany(
-            sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+            sql, ((destination, row[0], row[1]) for row in rows if not row[2])
         )
 
         # Delete all sent outbound pokes
@@ -198,7 +203,7 @@ class DeviceWorkerStore(SQLBaseStore):
             DELETE FROM device_lists_outbound_pokes
             WHERE destination = ? AND stream_id <= ?
         """
-        txn.execute(sql, (destination, stream_id,))
+        txn.execute(sql, (destination, stream_id))
 
     def get_device_stream_token(self):
         return self._device_list_id_gen.get_current_token()
@@ -240,10 +245,7 @@ class DeviceWorkerStore(SQLBaseStore):
     def _get_cached_user_device(self, user_id, device_id):
         content = yield self._simple_select_one_onecol(
             table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
+            keyvalues={"user_id": user_id, "device_id": device_id},
             retcol="content",
             desc="_get_cached_user_device",
         )
@@ -253,16 +255,13 @@ class DeviceWorkerStore(SQLBaseStore):
     def _get_cached_devices_for_user(self, user_id):
         devices = yield self._simple_select_list(
             table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-            },
+            keyvalues={"user_id": user_id},
             retcols=("device_id", "content"),
             desc="_get_cached_devices_for_user",
         )
-        defer.returnValue({
-            device["device_id"]: db_to_json(device["content"])
-            for device in devices
-        })
+        defer.returnValue(
+            {device["device_id"]: db_to_json(device["content"]) for device in devices}
+        )
 
     def get_devices_with_keys_by_user(self, user_id):
         """Get all devices (with any device keys) for a user
@@ -272,7 +271,8 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         return self.runInteraction(
             "get_devices_with_keys_by_user",
-            self._get_devices_with_keys_by_user_txn, user_id,
+            self._get_devices_with_keys_by_user_txn,
+            user_id,
         )
 
     def _get_devices_with_keys_by_user_txn(self, txn, user_id):
@@ -286,9 +286,7 @@ class DeviceWorkerStore(SQLBaseStore):
             user_devices = devices[user_id]
             results = []
             for device_id, device in iteritems(user_devices):
-                result = {
-                    "device_id": device_id,
-                }
+                result = {"device_id": device_id}
 
                 key_json = device.get("key_json", None)
                 if key_json:
@@ -315,7 +313,9 @@ class DeviceWorkerStore(SQLBaseStore):
         sql = """
             SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
         """
-        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+        rows = yield self._execute(
+            "get_user_whose_devices_changed", None, sql, from_key
+        )
         defer.returnValue(set(row[0] for row in rows))
 
     def get_all_device_list_changes_for_remotes(self, from_key, to_key):
@@ -333,8 +333,7 @@ class DeviceWorkerStore(SQLBaseStore):
             GROUP BY user_id, destination
         """
         return self._execute(
-            "get_all_device_list_changes_for_remotes", None,
-            sql, from_key, to_key
+            "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
         )
 
     @cached(max_entries=10000)
@@ -350,21 +349,22 @@ class DeviceWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
-                list_name="user_ids", inlineCallbacks=True)
+    @cachedList(
+        cached_method_name="get_device_list_last_stream_id_for_remote",
+        list_name="user_ids",
+        inlineCallbacks=True,
+    )
     def get_device_list_last_stream_id_for_remotes(self, user_ids):
         rows = yield self._simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
             iterable=user_ids,
-            retcols=("user_id", "stream_id",),
+            retcols=("user_id", "stream_id"),
             desc="get_device_list_last_stream_id_for_remotes",
         )
 
         results = {user_id: None for user_id in user_ids}
-        results.update({
-            row["user_id"]: row["stream_id"] for row in rows
-        })
+        results.update({row["user_id"]: row["stream_id"] for row in rows})
 
         defer.returnValue(results)
 
@@ -376,14 +376,10 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
         self.device_id_exists_cache = Cache(
-            name="device_id_exists",
-            keylen=2,
-            max_entries=10000,
+            name="device_id_exists", keylen=2, max_entries=10000
         )
 
-        self._clock.looping_call(
-            self._prune_old_outbound_device_pokes, 60 * 60 * 1000
-        )
+        self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
 
         self.register_background_index_update(
             "device_lists_stream_idx",
@@ -417,8 +413,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         )
 
     @defer.inlineCallbacks
-    def store_device(self, user_id, device_id,
-                     initial_device_display_name):
+    def store_device(self, user_id, device_id, initial_device_display_name):
         """Ensure the given device is known; add it to the store if not
 
         Args:
@@ -440,7 +435,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
                 values={
                     "user_id": user_id,
                     "device_id": device_id,
-                    "display_name": initial_device_display_name
+                    "display_name": initial_device_display_name,
                 },
                 desc="store_device",
                 or_ignore=True,
@@ -448,12 +443,17 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             self.device_id_exists_cache.prefill(key, True)
             defer.returnValue(inserted)
         except Exception as e:
-            logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
-                         " display_name=%s(%r) failed: %s",
-                         type(device_id).__name__, device_id,
-                         type(user_id).__name__, user_id,
-                         type(initial_device_display_name).__name__,
-                         initial_device_display_name, e)
+            logger.error(
+                "store_device with device_id=%s(%r) user_id=%s(%r)"
+                " display_name=%s(%r) failed: %s",
+                type(device_id).__name__,
+                device_id,
+                type(user_id).__name__,
+                user_id,
+                type(initial_device_display_name).__name__,
+                initial_device_display_name,
+                e,
+            )
             raise StoreError(500, "Problem storing device.")
 
     @defer.inlineCallbacks
@@ -525,15 +525,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         """
         yield self._simple_delete(
             table="device_lists_remote_extremeties",
-            keyvalues={
-                "user_id": user_id,
-            },
+            keyvalues={"user_id": user_id},
             desc="mark_remote_user_device_list_as_unsubscribed",
         )
         self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
 
-    def update_remote_device_list_cache_entry(self, user_id, device_id, content,
-                                              stream_id):
+    def update_remote_device_list_cache_entry(
+        self, user_id, device_id, content, stream_id
+    ):
         """Updates a single device in the cache of a remote user's devicelist.
 
         Note: assumes that we are the only thread that can be updating this user's
@@ -551,42 +550,35 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         return self.runInteraction(
             "update_remote_device_list_cache_entry",
             self._update_remote_device_list_cache_entry_txn,
-            user_id, device_id, content, stream_id,
+            user_id,
+            device_id,
+            content,
+            stream_id,
         )
 
-    def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
-                                                   content, stream_id):
+    def _update_remote_device_list_cache_entry_txn(
+        self, txn, user_id, device_id, content, stream_id
+    ):
         if content.get("deleted"):
             self._simple_delete_txn(
                 txn,
                 table="device_lists_remote_cache",
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                },
+                keyvalues={"user_id": user_id, "device_id": device_id},
             )
 
-            txn.call_after(
-                self.device_id_exists_cache.invalidate, (user_id, device_id,)
-            )
+            txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
         else:
             self._simple_upsert_txn(
                 txn,
                 table="device_lists_remote_cache",
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                },
-                values={
-                    "content": json.dumps(content),
-                },
-
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                values={"content": json.dumps(content)},
                 # we don't need to lock, because we assume we are the only thread
                 # updating this user's devices.
                 lock=False,
             )
 
-        txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
+        txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
         txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
         txn.call_after(
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
@@ -595,13 +587,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         self._simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
-            keyvalues={
-                "user_id": user_id,
-            },
-            values={
-                "stream_id": stream_id,
-            },
-
+            keyvalues={"user_id": user_id},
+            values={"stream_id": stream_id},
             # again, we can assume we are the only thread updating this user's
             # extremity.
             lock=False,
@@ -624,17 +611,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         return self.runInteraction(
             "update_remote_device_list_cache",
             self._update_remote_device_list_cache_txn,
-            user_id, devices, stream_id,
+            user_id,
+            devices,
+            stream_id,
         )
 
-    def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
-                                             stream_id):
+    def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
         self._simple_delete_txn(
-            txn,
-            table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-            },
+            txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
 
         self._simple_insert_many_txn(
@@ -647,7 +631,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
                     "content": json.dumps(content),
                 }
                 for content in devices
-            ]
+            ],
         )
 
         txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -659,13 +643,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         self._simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
-            keyvalues={
-                "user_id": user_id,
-            },
-            values={
-                "stream_id": stream_id,
-            },
-
+            keyvalues={"user_id": user_id},
+            values={"stream_id": stream_id},
             # we don't need to lock, because we can assume we are the only thread
             # updating this user's extremity.
             lock=False,
@@ -678,8 +657,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         """
         with self._device_list_id_gen.get_next() as stream_id:
             yield self.runInteraction(
-                "add_device_change_to_streams", self._add_device_change_txn,
-                user_id, device_ids, hosts, stream_id,
+                "add_device_change_to_streams",
+                self._add_device_change_txn,
+                user_id,
+                device_ids,
+                hosts,
+                stream_id,
             )
         defer.returnValue(stream_id)
 
@@ -687,13 +670,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         now = self._clock.time_msec()
 
         txn.call_after(
-            self._device_list_stream_cache.entity_has_changed,
-            user_id, stream_id,
+            self._device_list_stream_cache.entity_has_changed, user_id, stream_id
         )
         for host in hosts:
             txn.call_after(
                 self._device_list_federation_stream_cache.entity_has_changed,
-                host, stream_id,
+                host,
+                stream_id,
             )
 
         # Delete older entries in the table, as we really only care about
@@ -703,20 +686,16 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
             """,
-            [(user_id, device_id, stream_id) for device_id in device_ids]
+            [(user_id, device_id, stream_id) for device_id in device_ids],
         )
 
         self._simple_insert_many_txn(
             txn,
             table="device_lists_stream",
             values=[
-                {
-                    "stream_id": stream_id,
-                    "user_id": user_id,
-                    "device_id": device_id,
-                }
+                {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
                 for device_id in device_ids
-            ]
+            ],
         )
 
         self._simple_insert_many_txn(
@@ -733,7 +712,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
                 }
                 for destination in hosts
                 for device_id in device_ids
-            ]
+            ],
         )
 
     def _prune_old_outbound_device_pokes(self):
@@ -764,11 +743,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             """
 
             txn.executemany(
-                delete_sql,
-                (
-                    (yesterday, row[0], row[1], row[2])
-                    for row in rows
-                )
+                delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
             )
 
             # Since we've deleted unsent deltas, we need to remove the entry
@@ -792,12 +767,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
     def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
         def f(conn):
             txn = conn.cursor()
-            txn.execute(
-                "DROP INDEX IF EXISTS device_lists_remote_cache_id"
-            )
-            txn.execute(
-                "DROP INDEX IF EXISTS device_lists_remote_extremeties_id"
-            )
+            txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
+            txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
             txn.close()
 
         yield self.runWithConnection(f)