diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 17920d4480..b594f501f9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import ujson as json
from twisted.internet import defer
@@ -33,17 +34,13 @@ class DeviceStore(SQLBaseStore):
user_id (str): id of user associated with the device
device_id (str): id of device
initial_device_display_name (str): initial displayname of the
- device
- ignore_if_known (bool): ignore integrity errors which mean the
- device is already known
+ device. Ignored if device exists.
Returns:
- defer.Deferred
- Raises:
- StoreError: if ignore_if_known is False and the device was already
- known
+ defer.Deferred: boolean whether the device was inserted or an
+ existing device existed with that ID.
"""
try:
- yield self._simple_insert(
+ inserted = yield self._simple_insert(
"devices",
values={
"user_id": user_id,
@@ -51,8 +48,9 @@ class DeviceStore(SQLBaseStore):
"display_name": initial_device_display_name
},
desc="store_device",
- or_ignore=ignore_if_known,
+ or_ignore=True,
)
+ defer.returnValue(inserted)
except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s",
@@ -139,3 +137,156 @@ class DeviceStore(SQLBaseStore):
)
defer.returnValue({d["device_id"]: d for d in devices})
+
+ def get_devices_by_remote(self, destination, from_stream_id):
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ has_changed = self._device_list_stream_cache.has_entity_changed(
+ destination, int(from_stream_id)
+ )
+ if not has_changed:
+ defer.returnValue((now_stream_id, []))
+
+ return self.runInteraction(
+ "get_devices_by_remote", self._get_devices_by_remote_txn,
+ destination, from_stream_id, now_stream_id,
+ )
+
+ def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+ now_stream_id):
+ sql = """
+ SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+ WHERE destination = ? AND stream_id > ? AND stream_id <= ? AND sent = ?
+ GROUP BY user_id, device_id
+ """
+ txn.execute(
+ sql, (destination, from_stream_id, now_stream_id, False)
+ )
+ rows = txn.fetchall()
+
+ if not rows:
+ return now_stream_id, []
+
+ # maps (user_id, device_id) -> stream_id
+ query_map = {(r[0], r[1]): r[2] for r in rows}
+ devices = self._get_e2e_device_keys_txn(
+ txn, query_map.keys(), include_all_devices=True
+ )
+
+ prev_sent_id_sql = """
+ SELECT coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_pokes
+ WHERE destination = ? AND user_id = ? AND sent = ?
+ """
+
+ results = []
+ for user_id, user_devices in devices.iteritems():
+ txn.execute(prev_sent_id_sql, (destination, user_id, True))
+ rows = txn.fetchall()
+ prev_id = rows[0][0]
+ for device_id, result in user_devices.iteritems():
+ stream_id = query_map[(user_id, device_id)]
+ result = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "prev_id": prev_id,
+ "stream_id": stream_id,
+ }
+
+ prev_id = stream_id
+
+ key_json = result.get("key_json", None)
+ if key_json:
+ result["keys"] = json.loads(key_json)
+ device_display_name = result.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.setdefault(user_id, {})[device_id] = result
+
+ return now_stream_id, results
+
+ def mark_as_sent_devices_by_remote(self, destination, stream_id):
+ return self.runInteraction(
+ "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+ destination, stream_id,
+ )
+
+ @defer.inlineCallbacks
+ def get_user_whose_devices_changed(self, from_key):
+ from_key = int(from_key)
+ changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+ if changed is not None:
+ defer.returnValue(set(changed))
+
+ sql = """
+ SELECT user_id FROM device_lists_stream WHERE stream_id > ?
+ """
+ rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+ defer.returnValue(set(row["user_id"] for row in rows))
+
+ def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+ sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND stream_id < (
+ SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
+ WHERE destination = ? AND stream_id <= ?
+ )
+ """
+ txn.execute(sql, (destination, destination, stream_id,))
+
+ sql = """
+ UPDATE device_lists_outbound_pokes SET sent = ?
+ WHERE destination = ? AND stream_id <= ?
+ """
+ txn.execute(sql, (destination, True,))
+
+ @defer.inlineCallbacks
+ def add_device_change_to_streams(self, user_id, device_id, hosts):
+ # device_lists_stream
+ # device_lists_outbound_pokes
+ with self._device_list_id_gen.get_next() as stream_id:
+ yield self.runInteraction(
+ "add_device_change_to_streams", self._add_device_change_txn,
+ user_id, device_id, hosts, stream_id,
+ )
+ defer.returnValue(stream_id)
+
+ def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id):
+ txn.call_after(
+ self._device_list_stream_cache.entity_has_changed,
+ user_id, stream_id,
+ )
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host, stream_id,
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="device_lists_stream",
+ values={
+ "stream_id": stream_id,
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="device_lists_outbound_pokes",
+ values=[
+ {
+ "destination": destination,
+ "stream_id": stream_id,
+ "user_id": user_id,
+ "device_id": device_id,
+ "sent": False,
+ }
+ for destination in hosts
+ ]
+ )
+
+ def get_device_stream_token(self):
+ return self._device_list_id_gen.get_current_token()
|