diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f9cc5bddbc..d4750a32e6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
class DeviceWorkerHandler:
+ device_list_updater: "DeviceListWorkerUpdater"
+
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
@@ -76,6 +78,8 @@ class DeviceWorkerHandler:
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
+ self.device_list_updater = DeviceListWorkerUpdater(hs)
+
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
@@ -99,6 +103,19 @@ class DeviceWorkerHandler:
log_kv(device_map)
return devices
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ return await self.store.get_dehydrated_device(user_id)
+
@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
"""Retrieve the given device
@@ -127,7 +144,7 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
- ) -> Collection[str]:
+ ) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
@@ -320,6 +337,8 @@ class DeviceWorkerHandler:
class DeviceHandler(DeviceWorkerHandler):
+ device_list_updater: "DeviceListUpdater"
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, [old_device_id])
return device_id
- async def get_dehydrated_device(
- self, user_id: str
- ) -> Optional[Tuple[str, JsonDict]]:
- """Retrieve the information for a dehydrated device.
-
- Args:
- user_id: the user whose dehydrated device we are looking for
- Returns:
- a tuple whose first item is the device ID, and the second item is
- the dehydrated device information
- """
- return await self.store.get_dehydrated_device(user_id)
-
async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str
) -> dict:
@@ -682,13 +688,33 @@ class DeviceHandler(DeviceWorkerHandler):
hosts_already_sent_to: Set[str] = set()
try:
+ stream_id, room_id = await self.store.get_device_change_last_converted_pos()
+
while True:
self._handle_new_device_update_new_data = False
- rows = await self.store.get_uncoverted_outbound_room_pokes()
+ max_stream_id = self.store.get_device_stream_token()
+ rows = await self.store.get_uncoverted_outbound_room_pokes(
+ stream_id, room_id
+ )
if not rows:
# If the DB returned nothing then there is nothing left to
# do, *unless* a new device list update happened during the
# DB query.
+
+ # Advance `(stream_id, room_id)`.
+ # `max_stream_id` comes from *before* the query for unconverted
+ # rows, which means that any unconverted rows must have a larger
+ # stream ID.
+ if max_stream_id > stream_id:
+ stream_id, room_id = max_stream_id, ""
+ await self.store.set_device_change_last_converted_pos(
+ stream_id, room_id
+ )
+ else:
+ assert max_stream_id == stream_id
+ # Avoid moving `room_id` backwards.
+ pass
+
if self._handle_new_device_update_new_data:
continue
else:
@@ -718,7 +744,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
room_id=room_id,
- stream_id=stream_id,
hosts=hosts,
context=opentracing_context,
)
@@ -752,6 +777,12 @@ class DeviceHandler(DeviceWorkerHandler):
hosts_already_sent_to.update(hosts)
current_stream_id = stream_id
+ # Advance `(stream_id, room_id)`.
+ _, _, room_id, stream_id, _ = rows[-1]
+ await self.store.set_device_change_last_converted_pos(
+ stream_id, room_id
+ )
+
finally:
self._handle_new_device_update_is_processing = False
@@ -834,7 +865,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
room_id=room_id,
- stream_id=None,
hosts=potentially_changed_hosts,
context=None,
)
@@ -858,7 +888,36 @@ def _update_device_from_client_ips(
)
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+ "Handles incoming device list updates from federation and contacts the main process over replication"
+
+ def __init__(self, hs: "HomeServer"):
+ from synapse.replication.http.devices import (
+ ReplicationUserDevicesResyncRestServlet,
+ )
+
+ self._user_device_resync_client = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
+
+ async def user_device_resync(
+ self, user_id: str, mark_failed_as_stale: bool = True
+ ) -> Optional[JsonDict]:
+ """Fetches all devices for a user and updates the device cache with them.
+
+ Args:
+ user_id: The user's id whose device_list will be updated.
+ mark_failed_as_stale: Whether to mark the user's device list as stale
+ if the attempt to resync failed.
+ Returns:
+ A dict with device info as under the "devices" in the result of this
+ request:
+ https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ """
+ return await self._user_device_resync_client(user_id=user_id)
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
@@ -937,7 +996,10 @@ class DeviceListUpdater:
# Check if we are partially joining any rooms. If so we need to store
# all device list updates so that we can handle them correctly once we
# know who is in the room.
- partial_rooms = await self.store.get_partial_state_rooms_and_servers()
+ # TODO(faster_joins): this fetches and processes a bunch of data that we don't
+ # use. Could be replaced by a tighter query e.g.
+ # SELECT EXISTS(SELECT 1 FROM partial_state_rooms)
+ partial_rooms = await self.store.get_partial_state_room_resync_info()
if partial_rooms:
await self.store.add_remote_device_list_to_pending(
user_id,
|