diff --git a/changelog.d/14716.misc b/changelog.d/14716.misc
new file mode 100644
index 0000000000..ef9522e01d
--- /dev/null
+++ b/changelog.d/14716.misc
@@ -0,0 +1 @@
+Batch up replication requests to request the resyncing of remote users's devices.
\ No newline at end of file
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d4750a32e6..89864e1119 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -33,6 +34,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
+ InvalidAPICallError,
RequestSendFailed,
SynapseError,
)
@@ -45,6 +47,7 @@ from synapse.types import (
JsonDict,
StreamKeyType,
StreamToken,
+ UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@@ -893,12 +896,47 @@ class DeviceListWorkerUpdater:
def __init__(self, hs: "HomeServer"):
from synapse.replication.http.devices import (
+ ReplicationMultiUserDevicesResyncRestServlet,
ReplicationUserDevicesResyncRestServlet,
)
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
+ self._multi_user_device_resync_client = (
+ ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
+ )
+
+ async def multi_user_device_resync(
+ self, user_ids: List[str], mark_failed_as_stale: bool = True
+ ) -> Dict[str, Optional[JsonDict]]:
+ """
+ Like `user_device_resync` but operates on multiple users **from the same origin**
+ at once.
+
+ Returns:
+ Dict from User ID to the same Dict as `user_device_resync`.
+ """
+ # mark_failed_as_stale is not sent. Ensure this doesn't break expectations.
+ assert mark_failed_as_stale
+
+ if not user_ids:
+ # Shortcut empty requests
+ return {}
+
+ try:
+ return await self._multi_user_device_resync_client(user_ids=user_ids)
+ except SynapseError as err:
+ if not (
+ err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED
+ ):
+ raise
+
+ # Fall back to single requests
+ result: Dict[str, Optional[JsonDict]] = {}
+ for user_id in user_ids:
+ result[user_id] = await self._user_device_resync_client(user_id=user_id)
+ return result
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
@@ -913,8 +951,10 @@ class DeviceListWorkerUpdater:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ None when we weren't able to fetch the device info for some reason,
+ e.g. due to a connection problem.
"""
- return await self._user_device_resync_client(user_id=user_id)
+ return (await self.multi_user_device_resync([user_id]))[user_id]
class DeviceListUpdater(DeviceListWorkerUpdater):
@@ -1160,19 +1200,66 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
# Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False
+ async def multi_user_device_resync(
+ self, user_ids: List[str], mark_failed_as_stale: bool = True
+ ) -> Dict[str, Optional[JsonDict]]:
+ """
+ Like `user_device_resync` but operates on multiple users **from the same origin**
+ at once.
+
+ Returns:
+ Dict from User ID to the same Dict as `user_device_resync`.
+ """
+ if not user_ids:
+ return {}
+
+ origins = {UserID.from_string(user_id).domain for user_id in user_ids}
+
+ if len(origins) != 1:
+ raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}")
+
+ result = {}
+ failed = set()
+ # TODO(Perf): Actually batch these up
+ for user_id in user_ids:
+ user_result, user_failed = await self._user_device_resync_returning_failed(
+ user_id
+ )
+ result[user_id] = user_result
+ if user_failed:
+ failed.add(user_id)
+
+ if mark_failed_as_stale:
+ await self.store.mark_remote_users_device_caches_as_stale(failed)
+
+ return result
+
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
+ result, failed = await self._user_device_resync_returning_failed(user_id)
+
+ if failed and mark_failed_as_stale:
+ # Mark the remote user's device list as stale so we know we need to retry
+ # it later.
+ await self.store.mark_remote_users_device_caches_as_stale((user_id,))
+
+ return result
+
+ async def _user_device_resync_returning_failed(
+ self, user_id: str
+ ) -> Tuple[Optional[JsonDict], bool]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id: The user's id whose device_list will be updated.
- mark_failed_as_stale: Whether to mark the user's device list as stale
- if the attempt to resync failed.
Returns:
- A dict with device info as under the "devices" in the result of this
- request:
- https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ - A dict with device info as under the "devices" in the result of this
+ request:
+ https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ None when we weren't able to fetch the device info for some reason,
+ e.g. due to a connection problem.
+ - True iff the resync failed and the device list should be marked as stale.
"""
logger.debug("Attempting to resync the device list for %s", user_id)
log_kv({"message": "Doing resync to update device list."})
@@ -1181,12 +1268,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
try:
result = await self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
- if mark_failed_as_stale:
- # Mark the remote user's device list as stale so we know we need to retry
- # it later.
- await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
- return None
+ return None, True
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to handle device list update for %s: %s",
@@ -1194,23 +1276,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
e,
)
- if mark_failed_as_stale:
- # Mark the remote user's device list as stale so we know we need to retry
- # it later.
- await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
- return None
+ return None, True
except FederationDeniedError as e:
set_tag("error", True)
log_kv({"reason": "FederationDeniedError"})
logger.info(e)
- return None
+ return None, False
except Exception as e:
set_tag("error", True)
log_kv(
@@ -1218,12 +1295,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
)
logger.exception("Failed to handle device list update for %s", user_id)
- if mark_failed_as_stale:
- # Mark the remote user's device list as stale so we know we need to retry
- # it later.
- await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
- return None
+ return None, True
log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
@@ -1305,7 +1377,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
# point.
self._seen_updates[user_id] = {stream_id}
- return result
+ return result, False
async def process_cross_signing_key_update(
self,
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 75e89850f5..00c403db49 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -195,7 +195,7 @@ class DeviceMessageHandler:
sender_user_id,
unknown_devices,
)
- await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+ await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
# Immediately attempt a resync in the background
run_in_background(self._user_device_resync, user_id=sender_user_id)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5fe102e2f2..d2188ca08f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,8 +36,8 @@ from synapse.types import (
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
-from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util import json_decoder
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination
@@ -238,24 +238,28 @@ class E2eKeysHandler:
# Now fetch any devices that we don't have in our cache
# TODO It might make sense to propagate cancellations into the
# deferreds which are querying remote homeservers.
- await make_deferred_yieldable(
- delay_cancellation(
- defer.gatherResults(
- [
- run_in_background(
- self._query_devices_for_destination,
- results,
- cross_signing_keys,
- failures,
- destination,
- queries,
- timeout,
- )
- for destination, queries in remote_queries_not_in_cache.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ logger.debug(
+ "%d destinations to query devices for", len(remote_queries_not_in_cache)
+ )
+
+ async def _query(
+ destination_queries: Tuple[str, Dict[str, Iterable[str]]]
+ ) -> None:
+ destination, queries = destination_queries
+ return await self._query_devices_for_destination(
+ results,
+ cross_signing_keys,
+ failures,
+ destination,
+ queries,
+ timeout,
)
+
+ await concurrently_execute(
+ _query,
+ remote_queries_not_in_cache.items(),
+ 10,
+ delay_cancellation=True,
)
ret = {"device_keys": results, "failures": failures}
@@ -300,28 +304,41 @@ class E2eKeysHandler:
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
- for (user_id, device_list) in destination_query.items():
- if user_id in user_ids_updated:
- continue
- if device_list:
- continue
+ # Perform a user device resync for each user only once and only as long as:
+ # - they have an empty device_list
+ # - they are in some rooms that this server can see
+ users_to_resync_devices = {
+ user_id
+ for (user_id, device_list) in destination_query.items()
+ if (not device_list) and (await self.store.get_rooms_for_user(user_id))
+ }
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- continue
+ logger.debug(
+ "%d users to resync devices for from destination %s",
+ len(users_to_resync_devices),
+ destination,
+ )
- # We've decided we're sharing a room with this user and should
- # probably be tracking their device lists. However, we haven't
- # done an initial sync on the device list so we do it now.
- try:
- resync_results = (
- await self.device_handler.device_list_updater.user_device_resync(
- user_id
- )
+ try:
+ user_resync_results = (
+ await self.device_handler.device_list_updater.multi_user_device_resync(
+ list(users_to_resync_devices)
)
+ )
+ for user_id in users_to_resync_devices:
+ resync_results = user_resync_results[user_id]
+
if resync_results is None:
- raise ValueError("Device resync failed")
+ # TODO: It's weird that we'll store a failure against a
+ # destination, yet continue processing users from that
+ # destination.
+ # We might want to consider changing this, but for now
+ # I'm leaving it as I found it.
+ failures[destination] = _exception_to_failure(
+ ValueError(f"Device resync failed for {user_id!r}")
+ )
+ continue
# Add the device keys to the results.
user_devices = resync_results["devices"]
@@ -339,8 +356,8 @@ class E2eKeysHandler:
if self_signing_key:
cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
- except Exception as e:
- failures[destination] = _exception_to_failure(e)
+ except Exception as e:
+ failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 31df7f55cc..6df000faaf 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1423,7 +1423,7 @@ class FederationEventHandler:
"""
try:
- await self._store.mark_remote_user_device_cache_as_stale(sender)
+ await self._store.mark_remote_users_device_caches_as_stale((sender,))
# Immediately attempt a resync in the background
if self._config.worker.worker_app:
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 7c4941c3d3..ea5c08e6cf 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -13,12 +13,13 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
+from synapse.logging.opentracing import active_span
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -84,6 +85,76 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
return 200, user_devices
+class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
+ """Ask master to resync the device list for multiple users from the same
+ remote server by contacting their server.
+
+ This must happen on master so that the results can be correctly cached in
+ the database and streamed to workers.
+
+ Request format:
+
+ POST /_synapse/replication/multi_user_device_resync
+
+ {
+ "user_ids": ["@alice:example.org", "@bob:example.org", ...]
+ }
+
+ Response is roughly equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
+ response, but there is a map from user ID to response, e.g.:
+
+ {
+ "@alice:example.org": {
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": { ... },
+ "device_display_name": "Alice's Mobile Phone"
+ }
+ ]
+ },
+ ...
+ }
+ """
+
+ NAME = "multi_user_device_resync"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ from synapse.handlers.device import DeviceHandler
+
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_list_updater = handler.device_list_updater
+
+ self.store = hs.get_datastores().main
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[override]
+ return {"user_ids": user_ids}
+
+ async def _handle_request( # type: ignore[override]
+ self, request: Request
+ ) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
+ content = parse_json_object_from_request(request)
+ user_ids: List[str] = content["user_ids"]
+
+ logger.info("Resync for %r", user_ids)
+ span = active_span()
+ if span:
+ span.set_tag("user_ids", f"{user_ids!r}")
+
+ multi_user_devices = await self.device_list_updater.multi_user_device_resync(
+ user_ids
+ )
+
+ return 200, multi_user_devices
+
+
class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
"""Ask master to upload keys for the user and send them out over federation to
update other servers.
@@ -151,4 +222,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
+ ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index db877e3f13..b067664473 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -54,7 +54,7 @@ from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
StreamIdGenerator,
)
-from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
+from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
@@ -1069,16 +1069,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {row["user_id"] for row in rows}
- async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
+ async def mark_remote_users_device_caches_as_stale(
+ self, user_ids: StrCollection
+ ) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
- await self.db_pool.simple_upsert(
- table="device_lists_remote_resync",
- keyvalues={"user_id": user_id},
- values={},
- insertion_values={"added_ts": self._clock.time_msec()},
- desc="mark_remote_user_device_cache_as_stale",
+
+ def _mark_remote_users_device_caches_as_stale_txn(
+ txn: LoggingTransaction,
+ ) -> None:
+ # TODO add insertion_values support to simple_upsert_many and use
+ # that!
+ for user_id in user_ids:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="device_lists_remote_resync",
+ keyvalues={"user_id": user_id},
+ values={},
+ insertion_values={"added_ts": self._clock.time_msec()},
+ )
+
+ await self.db_pool.runInteraction(
+ "mark_remote_users_device_caches_as_stale",
+ _mark_remote_users_device_caches_as_stale_txn,
)
async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index f2d436ddc3..0c725eb967 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -77,6 +77,10 @@ JsonMapping = Mapping[str, Any]
# A JSON-serialisable object.
JsonSerializable = object
+# Collection[str] that does not include str itself; str being a Sequence[str]
+# is very misleading and results in bugs.
+StrCollection = Union[Tuple[str, ...], List[str], Set[str]]
+
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index d24c4f68c4..01e3cd46f6 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -205,7 +205,10 @@ T = TypeVar("T")
async def concurrently_execute(
- func: Callable[[T], Any], args: Iterable[T], limit: int
+ func: Callable[[T], Any],
+ args: Iterable[T],
+ limit: int,
+ delay_cancellation: bool = False,
) -> None:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.
@@ -215,6 +218,8 @@ async def concurrently_execute(
args: List of arguments to pass to func, each invocation of func
gets a single argument.
limit: Maximum number of conccurent executions.
+ delay_cancellation: Whether to delay cancellation until after the invocations
+ have finished.
Returns:
None, when all function invocations have finished. The return values
@@ -233,9 +238,16 @@ async def concurrently_execute(
# We use `itertools.islice` to handle the case where the number of args is
# less than the limit, avoiding needlessly spawning unnecessary background
# tasks.
- await yieldable_gather_results(
- _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
- )
+ if delay_cancellation:
+ await yieldable_gather_results_delaying_cancellation(
+ _concurrently_execute_inner,
+ (value for value in itertools.islice(it, limit)),
+ )
+ else:
+ await yieldable_gather_results(
+ _concurrently_execute_inner,
+ (value for value in itertools.islice(it, limit)),
+ )
P = ParamSpec("P")
@@ -292,6 +304,41 @@ async def yieldable_gather_results(
raise dfe.subFailure.value from None
+async def yieldable_gather_results_delaying_cancellation(
+ func: Callable[Concatenate[T, P], Awaitable[R]],
+ iter: Iterable[T],
+ *args: P.args,
+ **kwargs: P.kwargs,
+) -> List[R]:
+ """Executes the function with each argument concurrently.
+ Cancellation is delayed until after all the results have been gathered.
+
+ See `yieldable_gather_results`.
+
+ Args:
+ func: Function to execute that returns a Deferred
+ iter: An iterable that yields items that get passed as the first
+ argument to the function
+ *args: Arguments to be passed to each call to func
+ **kwargs: Keyword arguments to be passed to each call to func
+
+ Returns
+ A list containing the results of the function
+ """
+ try:
+ return await make_deferred_yieldable(
+ delay_cancellation(
+ defer.gatherResults(
+ [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
+ consumeErrors=True,
+ )
+ )
+ )
+ except defer.FirstError as dfe:
+ assert isinstance(dfe.subFailure.value, BaseException)
+ raise dfe.subFailure.value from None
+
+
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
|