diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index aa68755936..d92780b642 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,7 @@
from synapse.api import errors
from synapse.util import stringutils
+from synapse.types import get_domain_from_id
from twisted.internet import defer
from ._base import BaseHandler
@@ -27,6 +28,8 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
+ self.state = hs.get_state_handler()
+
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
@@ -45,29 +48,29 @@ class DeviceHandler(BaseHandler):
str: device id (generated if none was supplied)
"""
if device_id is not None:
- yield self.store.store_device(
+ new_device = yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
- ignore_if_known=True,
)
+ if new_device:
+ yield self.notify_device_update(user_id, device_id)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
- try:
- device_id = stringutils.random_string(10).upper()
- yield self.store.store_device(
- user_id=user_id,
- device_id=device_id,
- initial_device_display_name=initial_device_display_name,
- ignore_if_known=False,
- )
+ device_id = stringutils.random_string(10).upper()
+ new_device = yield self.store.store_device(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=initial_device_display_name,
+ )
+ if new_device:
+ yield self.notify_device_update(user_id, device_id)
defer.returnValue(device_id)
- except errors.StoreError:
- attempts += 1
+ attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@@ -147,6 +150,8 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id
)
+ yield self.notify_device_update(user_id, device_id)
+
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
@@ -166,12 +171,48 @@ class DeviceHandler(BaseHandler):
device_id,
new_display_name=content.get("display_name")
)
+ yield self.notify_device_update(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
+ @defer.inlineCallbacks
+ def notify_device_update(self, user_id, device_id):
+ rooms = yield self.store.get_rooms_for_user(user_id)
+ room_ids = [r.room_id for r in rooms]
+
+ hosts = set()
+ for room_id in room_ids:
+ users = yield self.state.get_current_user_in_room(room_id)
+ hosts.update(get_domain_from_id(u) for u in users)
+ hosts.discard(self.server_name)
+
+ position = yield self.store.add_device_change_to_streams(
+ user_id, device_id, list(hosts)
+ )
+
+ yield self.notifier.on_new_event(
+ "device_list_key", position, rooms=room_ids,
+ )
+
+ for host in hosts:
+ self.federation.send_device_messages(host)
+
+ @defer.inlineCallbacks
+ def get_device_list_changes(self, user_id, room_ids, from_key):
+ room_ids = frozenset(room_ids)
+
+ user_ids_changed = set()
+ changed = yield self.store.get_user_whose_devices_changed(from_key)
+ for other_user_id in changed:
+ other_rooms = yield self.store.get_rooms_for_user(other_user_id)
+ if room_ids.intersection(e.room_id for e in other_rooms):
+ user_ids_changed.add(other_user_id)
+
+ defer.returnValue(user_ids_changed)
+
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|