diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ed60d494ff..2d44f15da3 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -12,18 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from six import iteritems, itervalues
+
+from twisted.internet import defer
+
from synapse.api import errors
from synapse.api.constants import EventTypes
+from synapse.api.errors import FederationDeniedError
+from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.retryutils import NotRetryingDestination
from synapse.util.metrics import measure_func
-from synapse.types import get_domain_from_id, RoomStreamToken
-from twisted.internet import defer
-from ._base import BaseHandler
+from synapse.util.retryutils import NotRetryingDestination
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -34,15 +39,17 @@ class DeviceHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
+ self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender()
- self.federation = hs.get_replication_layer()
self._edu_updater = DeviceListEduUpdater(hs, self)
- self.federation.register_edu_handler(
+ federation_registry = hs.get_federation_registry()
+
+ federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update,
)
- self.federation.register_query_handler(
+ federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
@@ -109,7 +116,7 @@ class DeviceHandler(BaseHandler):
user_id, device_id=None
)
- devices = device_map.values()
+ devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
@@ -152,16 +159,15 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_device(user_id, device_id)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
else:
raise
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
@@ -171,12 +177,30 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
+ def delete_all_devices_for_user(self, user_id, except_device_id=None):
+ """Delete all of the user's devices
+
+ Args:
+ user_id (str):
+ except_device_id (str|None): optional device id which should not
+ be deleted
+
+ Returns:
+ defer.Deferred:
+ """
+ device_map = yield self.store.get_devices_by_user(user_id)
+ device_ids = list(device_map)
+ if except_device_id is not None:
+ device_ids = [d for d in device_ids if d != except_device_id]
+ yield self.delete_devices(user_id, device_ids)
+
+ @defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
""" Delete several devices
Args:
user_id (str):
- device_ids (str): The list of device IDs to delete
+ device_ids (List[str]): The list of device IDs to delete
Returns:
defer.Deferred:
@@ -184,7 +208,7 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_devices(user_id, device_ids)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
@@ -194,9 +218,8 @@ class DeviceHandler(BaseHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
- yield self.store.user_delete_access_tokens(
+ yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
- delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
@@ -224,7 +247,7 @@ class DeviceHandler(BaseHandler):
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
else:
@@ -270,6 +293,8 @@ class DeviceHandler(BaseHandler):
user_id (str)
from_token (StreamToken)
"""
+ now_token = yield self.hs.get_event_sources().get_current_token()
+
room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
@@ -280,11 +305,30 @@ class DeviceHandler(BaseHandler):
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+ member_events = yield self.store.get_membership_changes_for_user(
+ user_id, from_token.room_key, now_token.room_key
+ )
+ rooms_changed.update(event.room_id for event in member_events)
+
stream_ordering = RoomStreamToken.parse_stream_token(
- from_token.room_key).stream
+ from_token.room_key
+ ).stream
possibly_changed = set(changed)
+ possibly_left = set()
for room_id in rooms_changed:
+ current_state_ids = yield self.store.get_current_state_ids(room_id)
+
+ # The user may have left the room
+ # TODO: Check if they actually did or if we were just invited.
+ if room_id not in room_ids:
+ for key, event_id in iteritems(current_state_ids):
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_left.add(state_key)
+ continue
+
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
@@ -295,44 +339,69 @@ class DeviceHandler(BaseHandler):
# ordering: treat it the same as a new room
event_ids = []
- current_state_ids = yield self.store.get_current_state_ids(room_id)
-
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
+ current_member_id = current_state_ids.get((EventTypes.Member, user_id))
+ if not current_member_id:
+ continue
+
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+ # Check if we've joined the room? If so we just blindly add all the users to
+ # the "possibly changed" users.
+ for state_dict in itervalues(prev_state_ids):
+ member_event = state_dict.get((EventTypes.Member, user_id), None)
+ if not member_event or member_event != current_member_id:
+ for key, event_id in iteritems(current_state_ids):
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_changed.add(state_key)
+ break
+
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
- for state_dict in prev_state_ids.values():
+ for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
- possibly_changed.add(state_key)
+ if state_key != user_id:
+ possibly_changed.add(state_key)
break
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id
- )
+ if possibly_changed or possibly_left:
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id
+ )
- # Take the intersection of the users whose devices may have changed
- # and those that actually still share a room with the user
- defer.returnValue(users_who_share_room & possibly_changed)
+ # Take the intersection of the users whose devices may have changed
+ # and those that actually still share a room with the user
+ possibly_joined = possibly_changed & users_who_share_room
+ possibly_left = (possibly_changed | possibly_left) - users_who_share_room
+ else:
+ possibly_joined = []
+ possibly_left = []
+
+ defer.returnValue({
+ "changed": list(possibly_joined),
+ "left": list(possibly_left),
+ })
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
@@ -366,7 +435,7 @@ class DeviceListEduUpdater(object):
def __init__(self, hs, device_handler):
self.store = hs.get_datastore()
- self.federation = hs.get_replication_layer()
+ self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.device_handler = device_handler
@@ -450,6 +519,9 @@ class DeviceListEduUpdater(object):
# This makes it more likely that the device lists will
# eventually become consistent.
return
+ except FederationDeniedError as e:
+ logger.info(e)
+ return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
@@ -467,7 +539,7 @@ class DeviceListEduUpdater(object):
yield self.device_handler.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
- # change (becuase of the single prev_id matching the current cache)
+ # change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
|