diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 79a58df591..52dc838dd5 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
get_active_span_text_map,
set_tag,
@@ -28,7 +30,12 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import Cache, SQLBaseStore, db_to_json
+from synapse.storage._base import (
+ Cache,
+ SQLBaseStore,
+ db_to_json,
+ make_in_list_sql_clause,
+)
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -42,7 +49,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
- """Retrieve a device.
+ """Retrieve a device. Only returns devices that are not marked as
+ hidden.
Args:
user_id (str): The ID of the user which owns the device
@@ -54,14 +62,15 @@ class DeviceWorkerStore(SQLBaseStore):
"""
return self._simple_select_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
- """Retrieve all of a user's registered devices.
+ """Retrieve all of a user's registered devices. Only returns devices
+ that are not marked as hidden.
Args:
user_id (str):
@@ -72,7 +81,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
devices = yield self._simple_select_list(
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
)
@@ -448,11 +457,14 @@ class DeviceWorkerStore(SQLBaseStore):
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ?
- AND user_id IN (%s)
+ AND
"""
for chunk in batch_iter(to_check, 100):
- txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk)
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", chunk
+ )
+ txn.execute(sql + clause, (from_key,) + tuple(args))
changes.update(user_id for user_id, in txn)
return changes
@@ -512,17 +524,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
-class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
+class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
- super(DeviceStore, self).__init__(db_conn, hs)
-
- # Map of (user_id, device_id) -> bool. If there is an entry that implies
- # the device exists.
- self.device_id_exists_cache = Cache(
- name="device_id_exists", keylen=2, max_entries=10000
- )
-
- self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+ super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_lists_stream_idx",
@@ -556,6 +560,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
@defer.inlineCallbacks
+ def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+ def f(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
+ txn.close()
+
+ yield self.runWithConnection(f)
+ yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
+ return 1
+
+
+class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+ super(DeviceStore, self).__init__(db_conn, hs)
+
+ # Map of (user_id, device_id) -> bool. If there is an entry that implies
+ # the device exists.
+ self.device_id_exists_cache = Cache(
+ name="device_id_exists", keylen=2, max_entries=10000
+ )
+
+ self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+
+ @defer.inlineCallbacks
def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
@@ -567,6 +596,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
defer.Deferred: boolean whether the device was inserted or an
existing device existed with that ID.
+ Raises:
+ StoreError: if the device is already in use
"""
key = (user_id, device_id)
if self.device_id_exists_cache.get(key, None):
@@ -579,12 +610,25 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"user_id": user_id,
"device_id": device_id,
"display_name": initial_device_display_name,
+ "hidden": False,
},
desc="store_device",
or_ignore=True,
)
+ if not inserted:
+ # if the device already exists, check if it's a real device, or
+ # if the device ID is reserved by something else
+ hidden = yield self._simple_select_one_onecol(
+ "devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="hidden",
+ )
+ if hidden:
+ raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
self.device_id_exists_cache.prefill(key, True)
return inserted
+ except StoreError:
+ raise
except Exception as e:
logger.error(
"store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -611,7 +655,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
yield self._simple_delete_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
)
@@ -631,14 +675,15 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
table="devices",
column="device_id",
iterable=device_ids,
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
desc="delete_devices",
)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
def update_device(self, user_id, device_id, new_display_name=None):
- """Update a device.
+ """Update a device. Only updates the device if it is not marked as
+ hidden.
Args:
user_id (str): The ID of the user which owns the device
@@ -657,7 +702,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return defer.succeed(None)
return self._simple_update_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
desc="update_device",
)
@@ -910,15 +955,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"_prune_old_outbound_device_pokes",
_prune_txn,
)
-
- @defer.inlineCallbacks
- def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
- def f(conn):
- txn = conn.cursor()
- txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
- txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
- txn.close()
-
- yield self.runWithConnection(f)
- yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
- return 1
|