diff --git a/changelog.d/18180.feature b/changelog.d/18180.feature
new file mode 100644
index 0000000000..fbf226e51c
--- /dev/null
+++ b/changelog.d/18180.feature
@@ -0,0 +1 @@
+Add `msc4263_limit_key_queries_to_users_who_share_rooms` config option as per [MSC4263](https://github.com/matrix-org/matrix-spec-proposals/pull/4263).
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 1226eaa58a..2dc75a778e 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -560,3 +560,9 @@ class ExperimentalConfig(Config):
# MSC4076: Add `disable_badge_count`` to pusher configuration
self.msc4076_enabled: bool = experimental.get("msc4076_enabled", False)
+
+ # MSC4263: Preventing MXID enumeration via key queries
+ self.msc4263_limit_key_queries_to_users_who_share_rooms = experimental.get(
+ "msc4263_limit_key_queries_to_users_who_share_rooms",
+ False,
+ )
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index f2b2e30bf4..6171aaf29f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -158,7 +158,37 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
- device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
+
+ async def filter_device_key_query(
+ query: Dict[str, List[str]],
+ ) -> Dict[str, List[str]]:
+ if not self.config.experimental.msc4263_limit_key_queries_to_users_who_share_rooms:
+ # Only ignore invalid user IDs, which is the same behaviour as if
+ # the user existed but had no keys.
+ return {
+ user_id: v
+ for user_id, v in query.items()
+ if UserID.is_valid(user_id)
+ }
+
+ # Strip invalid user IDs and user IDs the requesting user does not share rooms with.
+ valid_user_ids = [
+ user_id for user_id in query.keys() if UserID.is_valid(user_id)
+ ]
+ allowed_user_ids = set(
+ await self.store.do_users_share_a_room_joined_or_invited(
+ from_user_id, valid_user_ids
+ )
+ )
+ return {
+ user_id: v
+ for user_id, v in query.items()
+ if user_id in allowed_user_ids
+ }
+
+ device_keys_query: Dict[str, List[str]] = await filter_device_key_query(
+ query_body.get("device_keys", {})
+ )
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -166,11 +196,6 @@ class E2eKeysHandler:
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
- if not UserID.is_valid(user_id):
- # Ignore invalid user IDs, which is the same behaviour as if
- # the user existed but had no keys.
- continue
-
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2084776543..7ca73abb83 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -871,6 +871,73 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
return {u for u, share_room in user_dict.items() if share_room}
+ @cached(max_entries=10000)
+ async def does_pair_of_users_share_a_room_joined_or_invited(
+ self, user_id: str, other_user_id: str
+ ) -> bool:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="does_pair_of_users_share_a_room_joined_or_invited",
+ list_name="other_user_ids",
+ )
+ async def _do_users_share_a_room_joined_or_invited(
+ self, user_id: str, other_user_ids: Collection[str]
+ ) -> Mapping[str, Optional[bool]]:
+ """Return mapping from user ID to whether they share a room with the
+ given user via being either joined or invited.
+
+ Note: `None` and `False` are equivalent and mean they don't share a
+ room.
+ """
+
+ def do_users_share_a_room_joined_or_invited_txn(
+ txn: LoggingTransaction, user_ids: Collection[str]
+ ) -> Dict[str, bool]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
+ )
+
+ # This query works by fetching both the list of rooms for the target
+ # user and the set of other users, and then checking if there is any
+ # overlap.
+ sql = f"""
+ SELECT DISTINCT b.state_key
+ FROM (
+ SELECT room_id FROM current_state_events
+ WHERE type = 'm.room.member' AND (membership = 'join' OR membership = 'invite') AND state_key = ?
+ ) AS a
+ INNER JOIN (
+ SELECT room_id, state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND (membership = 'join' OR membership = 'invite') AND {clause}
+ ) AS b using (room_id)
+ """
+
+ txn.execute(sql, (user_id, *args))
+ return {u: True for (u,) in txn}
+
+ to_return = {}
+ for batch_user_ids in batch_iter(other_user_ids, 1000):
+ res = await self.db_pool.runInteraction(
+ "do_users_share_a_room_joined_or_invited",
+ do_users_share_a_room_joined_or_invited_txn,
+ batch_user_ids,
+ )
+ to_return.update(res)
+
+ return to_return
+
+ async def do_users_share_a_room_joined_or_invited(
+ self, user_id: str, other_user_ids: Collection[str]
+ ) -> Set[str]:
+ """Return the set of users who share a room with the first users via being either joined or invited"""
+
+ user_dict = await self._do_users_share_a_room_joined_or_invited(
+ user_id, other_user_ids
+ )
+
+ return {u for u, share_room in user_dict.items() if share_room}
+
async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
"""Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(user_id)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index e67efcc17f..70fc4263e7 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -1896,3 +1896,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(
remaining_key_ids, {"AAAAAAAAAA", "BAAAAA", "BAAAAB", "BAAAAAAAAA"}
)
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc4263_limit_key_queries_to_users_who_share_rooms": True
+ }
+ }
+ )
+ def test_query_devices_remote_restricted_not_in_shared_room(self) -> None:
+ """Tests that querying keys for a remote user that we don't share a room
+ with returns nothing.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ # Do *not* pretend we're sharing a room with the user we're querying.
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "device_keys": {remote_user_id: {}},
+ "master_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ },
+ "self_signing_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ }
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(
+ query_result,
+ {
+ "device_keys": {},
+ "failures": {},
+ "master_keys": {},
+ "self_signing_keys": {},
+ "user_signing_keys": {},
+ },
+ )
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc4263_limit_key_queries_to_users_who_share_rooms": True
+ }
+ }
+ )
+ def test_query_devices_remote_restricted_in_shared_room(self) -> None:
+ """Tests that querying keys for a remote user that we share a room
+ with returns the cross signing keys correctly.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `query_devices` will filter out the user ID and `_query_devices_for_destination`
+ # will return early.
+ self.store.do_users_share_a_room_joined_or_invited = mock.AsyncMock( # type: ignore[method-assign]
+ return_value=[remote_user_id]
+ )
+ self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(query_result["failures"], {})
+ self.assertEqual(
+ query_result["master_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ }
+ },
+ )
+ self.assertEqual(
+ query_result["self_signing_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ )
|