diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 0478448b47..fc21d58001 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f2989cc4a2..5bf8e86387 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -100,6 +100,7 @@ class AdminHandler:
user_info_dict["avatar_url"] = profile.avatar_url
user_info_dict["threepids"] = threepids
user_info_dict["external_ids"] = external_ids
+ user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
return user_info_dict
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 203b62e015..5d1d21cdc8 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -109,10 +109,13 @@ class ApplicationServicesHandler:
last_token = await self.store.get_appservice_last_pos()
(
upper_bound,
- events,
event_to_received_ts,
- ) = await self.store.get_all_new_events_stream(
- last_token, self.current_max, limit=100, get_prev_content=True
+ ) = await self.store.get_all_new_event_ids_stream(
+ last_token, self.current_max, limit=100
+ )
+
+ events = await self.store.get_events_as_list(
+ event_to_received_ts.keys(), get_prev_content=True
)
events_by_room: Dict[str, List[EventBase]] = {}
@@ -575,9 +578,6 @@ class ApplicationServicesHandler:
device_id,
), messages in recipient_device_to_messages.items():
for message_json in messages:
- # Remove 'message_id' from the to-device message, as it's an internal ID
- message_json.pop("message_id", None)
-
message_payload.append(
{
"to_user_id": user_id,
@@ -612,8 +612,8 @@ class ApplicationServicesHandler:
)
# Fetch the users who have modified their device list since then.
- users_with_changed_device_lists = (
- await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
+ users_with_changed_device_lists = await self.store.get_all_devices_changed(
+ from_key, to_key=new_key
)
# Filter out any users the application service is not interested in
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f5f0e0e7a7..8b9ef25d29 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -38,6 +38,7 @@ from typing import (
import attr
import bcrypt
import unpaddedbase64
+from prometheus_client import Counter
from twisted.internet.defer import CancelledError
from twisted.web.server import Request
@@ -48,6 +49,7 @@ from synapse.api.errors import (
Codes,
InteractiveAuthIncompleteError,
LoginError,
+ NotFoundError,
StoreError,
SynapseError,
UserDeactivatedError,
@@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.registration import (
+ LoginTokenExpired,
+ LoginTokenLookupResult,
+ LoginTokenReused,
+)
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
-from synapse.util.macaroons import LoginTokenAttributes
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import base62_encode
from synapse.util.threepids import canonicalise_email
@@ -80,6 +86,12 @@ logger = logging.getLogger(__name__)
INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
+invalid_login_token_counter = Counter(
+ "synapse_user_login_invalid_login_tokens",
+ "Counts the number of rejected m.login.token on /login",
+ ["reason"],
+)
+
def convert_client_dict_legacy_fields_to_identifier(
submission: JsonDict,
@@ -883,6 +895,25 @@ class AuthHandler:
return True
+ async def create_login_token_for_user_id(
+ self,
+ user_id: str,
+ duration_ms: int = (2 * 60 * 1000),
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
+ ) -> str:
+ login_token = self.generate_login_token()
+ now = self._clock.time_msec()
+ expiry_ts = now + duration_ms
+ await self.store.add_login_token_to_user(
+ user_id=user_id,
+ token=login_token,
+ expiry_ts=expiry_ts,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+ return login_token
+
async def create_refresh_token_for_user_id(
self,
user_id: str,
@@ -1401,6 +1432,18 @@ class AuthHandler:
return None
return user_id
+ def generate_login_token(self) -> str:
+ """Generates an opaque string, for use as an short-term login token"""
+
+ # we use the following format for access tokens:
+ # syl_<random string>_<base62 crc check>
+
+ random_string = stringutils.random_string(20)
+ base = f"syl_{random_string}"
+
+ crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+ return f"{base}_{crc}"
+
def generate_access_token(self, for_user: UserID) -> str:
"""Generates an opaque string, for use as an access token"""
@@ -1427,16 +1470,17 @@ class AuthHandler:
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}"
- async def validate_short_term_login_token(
- self, login_token: str
- ) -> LoginTokenAttributes:
+ async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult:
try:
- res = self.macaroon_gen.verify_short_term_login_token(login_token)
- except Exception:
- raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
+ return await self.store.consume_login_token(login_token)
+ except LoginTokenExpired:
+ invalid_login_token_counter.labels("expired").inc()
+ except LoginTokenReused:
+ invalid_login_token_counter.labels("reused").inc()
+ except NotFoundError:
+ invalid_login_token_counter.labels("not found").inc()
- await self.auth_blocking.check_auth_blocking(res.user_id)
- return res
+ raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token
@@ -1711,7 +1755,7 @@ class AuthHandler:
)
# Create a login token
- login_token = self.macaroon_gen.generate_short_term_login_token(
+ login_token = await self.create_login_token_for_user_id(
registered_user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 816e1a6d79..d74d135c0c 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester
@@ -76,6 +77,9 @@ class DeactivateAccountHandler:
True if identity server supports removing threepids, otherwise False.
"""
+ # This can only be called on the main process.
+ assert isinstance(self._device_handler, DeviceHandler)
+
# Check if this user can be deactivated
if not await self._third_party_rules.check_can_deactivate_user(
user_id, by_admin
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f9cc5bddbc..d4750a32e6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
class DeviceWorkerHandler:
+ device_list_updater: "DeviceListWorkerUpdater"
+
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
@@ -76,6 +78,8 @@ class DeviceWorkerHandler:
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
+ self.device_list_updater = DeviceListWorkerUpdater(hs)
+
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
@@ -99,6 +103,19 @@ class DeviceWorkerHandler:
log_kv(device_map)
return devices
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ return await self.store.get_dehydrated_device(user_id)
+
@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
"""Retrieve the given device
@@ -127,7 +144,7 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
- ) -> Collection[str]:
+ ) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
@@ -320,6 +337,8 @@ class DeviceWorkerHandler:
class DeviceHandler(DeviceWorkerHandler):
+ device_list_updater: "DeviceListUpdater"
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, [old_device_id])
return device_id
- async def get_dehydrated_device(
- self, user_id: str
- ) -> Optional[Tuple[str, JsonDict]]:
- """Retrieve the information for a dehydrated device.
-
- Args:
- user_id: the user whose dehydrated device we are looking for
- Returns:
- a tuple whose first item is the device ID, and the second item is
- the dehydrated device information
- """
- return await self.store.get_dehydrated_device(user_id)
-
async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str
) -> dict:
@@ -682,13 +688,33 @@ class DeviceHandler(DeviceWorkerHandler):
hosts_already_sent_to: Set[str] = set()
try:
+ stream_id, room_id = await self.store.get_device_change_last_converted_pos()
+
while True:
self._handle_new_device_update_new_data = False
- rows = await self.store.get_uncoverted_outbound_room_pokes()
+ max_stream_id = self.store.get_device_stream_token()
+ rows = await self.store.get_uncoverted_outbound_room_pokes(
+ stream_id, room_id
+ )
if not rows:
# If the DB returned nothing then there is nothing left to
# do, *unless* a new device list update happened during the
# DB query.
+
+ # Advance `(stream_id, room_id)`.
+ # `max_stream_id` comes from *before* the query for unconverted
+ # rows, which means that any unconverted rows must have a larger
+ # stream ID.
+ if max_stream_id > stream_id:
+ stream_id, room_id = max_stream_id, ""
+ await self.store.set_device_change_last_converted_pos(
+ stream_id, room_id
+ )
+ else:
+ assert max_stream_id == stream_id
+ # Avoid moving `room_id` backwards.
+ pass
+
if self._handle_new_device_update_new_data:
continue
else:
@@ -718,7 +744,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
room_id=room_id,
- stream_id=stream_id,
hosts=hosts,
context=opentracing_context,
)
@@ -752,6 +777,12 @@ class DeviceHandler(DeviceWorkerHandler):
hosts_already_sent_to.update(hosts)
current_stream_id = stream_id
+ # Advance `(stream_id, room_id)`.
+ _, _, room_id, stream_id, _ = rows[-1]
+ await self.store.set_device_change_last_converted_pos(
+ stream_id, room_id
+ )
+
finally:
self._handle_new_device_update_is_processing = False
@@ -834,7 +865,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
room_id=room_id,
- stream_id=None,
hosts=potentially_changed_hosts,
context=None,
)
@@ -858,7 +888,36 @@ def _update_device_from_client_ips(
)
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+ "Handles incoming device list updates from federation and contacts the main process over replication"
+
+ def __init__(self, hs: "HomeServer"):
+ from synapse.replication.http.devices import (
+ ReplicationUserDevicesResyncRestServlet,
+ )
+
+ self._user_device_resync_client = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
+
+ async def user_device_resync(
+ self, user_id: str, mark_failed_as_stale: bool = True
+ ) -> Optional[JsonDict]:
+ """Fetches all devices for a user and updates the device cache with them.
+
+ Args:
+ user_id: The user's id whose device_list will be updated.
+ mark_failed_as_stale: Whether to mark the user's device list as stale
+ if the attempt to resync failed.
+ Returns:
+ A dict with device info as under the "devices" in the result of this
+ request:
+ https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ """
+ return await self._user_device_resync_client(user_id=user_id)
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
@@ -937,7 +996,10 @@ class DeviceListUpdater:
# Check if we are partially joining any rooms. If so we need to store
# all device list updates so that we can handle them correctly once we
# know who is in the room.
- partial_rooms = await self.store.get_partial_state_rooms_and_servers()
+ # TODO(faster_joins): this fetches and processes a bunch of data that we don't
+ # use. Could be replaced by a tighter query e.g.
+ # SELECT EXISTS(SELECT 1 FROM partial_state_rooms)
+ partial_rooms = await self.store.get_partial_state_room_resync_info()
if partial_rooms:
await self.store.add_remote_device_list_to_pending(
user_id,
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 444c08bc2e..75e89850f5 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
-from synapse.api.constants import EduTypes, ToDeviceEventTypes
+from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
@@ -216,14 +216,24 @@ class DeviceMessageHandler:
"""
sender_user_id = requester.user.to_string()
- message_id = random_string(16)
- set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
-
- log_kv({"number_of_to_device_messages": len(messages)})
- set_tag("sender", sender_user_id)
+ set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
+ set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
+ # add an opentracing log entry for each message
+ for device_id, message_content in by_device.items():
+ log_kv(
+ {
+ "event": "send_to_device_message",
+ "user_id": user_id,
+ "device_id": device_id,
+ EventContentFields.TO_DEVICE_MSGID: message_content.get(
+ EventContentFields.TO_DEVICE_MSGID
+ ),
+ }
+ )
+
# Ratelimit local cross-user key requests by the sending device.
if (
message_type == ToDeviceEventTypes.RoomKeyRequest
@@ -233,6 +243,7 @@ class DeviceMessageHandler:
requester, (sender_user_id, requester.device_id)
)
if not allowed:
+ log_kv({"message": f"dropping key requests to {user_id}"})
logger.info(
"Dropping room_key_request from %s to %s due to rate limit",
sender_user_id,
@@ -247,18 +258,11 @@ class DeviceMessageHandler:
"content": message_content,
"type": message_type,
"sender": sender_user_id,
- "message_id": message_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
- log_kv(
- {
- "user_id": user_id,
- "device_id": list(messages_by_device),
- }
- )
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
@@ -267,7 +271,11 @@ class DeviceMessageHandler:
remote_edu_contents = {}
for destination, messages in remote_messages.items():
- log_kv({"destination": destination})
+ # The EDU contains a "message_id" property which is used for
+ # idempotence. Make up a random one.
+ message_id = random_string(16)
+ log_kv({"destination": destination, "message_id": message_id})
+
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 7127d5aefc..2ea52257cb 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -16,6 +16,8 @@ import logging
import string
from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing_extensions import Literal
+
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import (
AuthError,
@@ -83,7 +85,7 @@ class DirectoryHandler:
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
- servers = await self._storage_controllers.state.get_current_hosts_in_room(
+ servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
room_id
)
@@ -288,7 +290,7 @@ class DirectoryHandler:
Codes.NOT_FOUND,
)
- extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
+ extra_servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
room_id
)
servers_set = set(extra_servers) | set(servers)
@@ -429,7 +431,10 @@ class DirectoryHandler:
return await self.auth.check_can_change_room_list(room_id, requester)
async def edit_published_room_list(
- self, requester: Requester, room_id: str, visibility: str
+ self,
+ requester: Requester,
+ room_id: str,
+ visibility: Literal["public", "private"],
) -> None:
"""Edit the entry of the room in the published room list.
@@ -451,9 +456,6 @@ class DirectoryHandler:
if requester.is_guest:
raise AuthError(403, "Guests cannot edit the published room list")
- if visibility not in ["public", "private"]:
- raise SynapseError(400, "Invalid visibility setting")
-
if visibility == "public" and not self.enable_room_list_search:
# The room list has been disabled.
raise AuthError(
@@ -505,7 +507,11 @@ class DirectoryHandler:
await self.store.set_room_is_public(room_id, making_public)
async def edit_published_appservice_room_list(
- self, appservice_id: str, network_id: str, room_id: str, visibility: str
+ self,
+ appservice_id: str,
+ network_id: str,
+ room_id: str,
+ visibility: Literal["public", "private"],
) -> None:
"""Add or remove a room from the appservice/network specific public
room list.
@@ -516,9 +522,6 @@ class DirectoryHandler:
room_id
visibility: either "public" or "private"
"""
- if visibility not in ["public", "private"]:
- raise SynapseError(400, "Invalid visibility setting")
-
await self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public"
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 09a2492afc..5fe102e2f2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -27,9 +27,9 @@ from twisted.internet import defer
from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
@@ -49,33 +49,30 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
+ self.config = hs.config
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
- self._edu_updater = SigningKeyEduUpdater(hs, self)
-
federation_registry = hs.get_federation_registry()
- self._is_master = hs.config.worker.worker_app is None
- if not self._is_master:
- self._user_device_resync_client = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
- else:
+ is_master = hs.config.worker.worker_app is None
+ if is_master:
+ edu_updater = SigningKeyEduUpdater(hs)
+
# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the
@@ -318,14 +315,13 @@ class E2eKeysHandler:
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
- if self._is_master:
- resync_results = await self.device_handler.device_list_updater.user_device_resync(
+ resync_results = (
+ await self.device_handler.device_list_updater.user_device_resync(
user_id
)
- else:
- resync_results = await self._user_device_resync_client(
- user_id=user_id
- )
+ )
+ if resync_results is None:
+ raise ValueError("Device resync failed")
# Add the device keys to the results.
user_devices = resync_results["devices"]
@@ -431,13 +427,17 @@ class E2eKeysHandler:
@trace
@cancellable
async def query_local_devices(
- self, query: Mapping[str, Optional[List[str]]]
+ self,
+ query: Mapping[str, Optional[List[str]]],
+ include_displaynames: bool = True,
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
Args:
query: map from user_id to a list
of devices to query (None for all devices)
+ include_displaynames: Whether to include device displaynames in the returned
+ device details.
Returns:
A map from user_id -> device_id -> device details
@@ -469,7 +469,9 @@ class E2eKeysHandler:
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
+ results = await self.store.get_e2e_device_keys_for_cs_api(
+ local_query, include_displaynames
+ )
# Build the result structure
for user_id, device_keys in results.items():
@@ -482,11 +484,33 @@ class E2eKeysHandler:
async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
- """Handle a device key query from a federated server"""
+ """Handle a device key query from a federated server:
+
+ Handles the path: GET /_matrix/federation/v1/users/keys/query
+
+ Args:
+ query_body: The body of the query request. Should contain a key
+ "device_keys" that map to a dictionary of user ID's -> list of
+ device IDs. If the list of device IDs is empty, all devices of
+ that user will be queried.
+
+ Returns:
+ A json dictionary containing the following:
+ - device_keys: A dictionary containing the requested device information.
+ - master_keys: An optional dictionary of user ID -> master cross-signing
+ key info.
+ - self_signing_key: An optional dictionary of user ID -> self-signing
+ key info.
+ """
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
)
- res = await self.query_local_devices(device_keys_query)
+ res = await self.query_local_devices(
+ device_keys_query,
+ include_displaynames=(
+ self.config.federation.allow_device_name_lookup_over_federation
+ ),
+ )
ret = {"device_keys": res}
# add in the cross-signing keys
@@ -576,6 +600,8 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
@@ -703,6 +729,8 @@ class E2eKeysHandler:
user_id: the user uploading the keys
keys: the signing keys
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
@@ -794,6 +822,9 @@ class E2eKeysHandler:
Raises:
SynapseError: if the signatures dict is not valid.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
failures = {}
# signatures to be stored. Each item will be a SignatureListItem
@@ -841,7 +872,7 @@ class E2eKeysHandler:
- signatures of the user's master key by the user's devices.
Args:
- user_id (string): the user uploading the keys
+ user_id: the user uploading the keys
signatures (dict[string, dict]): map of devices to signed keys
Returns:
@@ -1171,6 +1202,9 @@ class E2eKeysHandler:
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
@@ -1367,11 +1401,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
- self.e2e_keys_handler = e2e_keys_handler
+
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
@@ -1416,9 +1453,6 @@ class SigningKeyEduUpdater:
user_id: the user whose updates we are processing
"""
- device_handler = self.e2e_keys_handler.device_handler
- device_list_updater = device_handler.device_list_updater
-
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
@@ -1430,13 +1464,11 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = (
- await device_list_updater.process_cross_signing_key_update(
- user_id,
- master_key,
- self_signing_key,
- )
+ new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
+ user_id,
+ master_key,
+ self_signing_key,
)
device_ids = device_ids + new_device_ids
- await device_handler.notify_device_update(user_id, device_ids)
+ await self._device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 28dc08c22a..83f53ceb88 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -377,8 +377,9 @@ class E2eRoomKeysHandler:
"""Deletes a given version of the user's e2e_room_keys backup
Args:
- user_id(str): the user whose current backup version we're deleting
- version(str): the version id of the backup being deleted
+ user_id: the user whose current backup version we're deleting
+ version: Optional. the version ID of the backup version we're deleting
+ If missing, we delete the current backup version info.
Raises:
NotFoundError: if this backup version doesn't exist
"""
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 8249ca1ed2..f91dbbecb7 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, List, Optional, Union
+from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union
from synapse import event_auth
from synapse.api.constants import (
@@ -29,7 +29,6 @@ from synapse.event_auth import (
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
from synapse.types import StateMap, get_domain_from_id
if TYPE_CHECKING:
@@ -46,17 +45,27 @@ class EventAuthHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
+ self._state_storage_controller = hs.get_storage_controllers().state
self._server_name = hs.hostname
async def check_auth_rules_from_context(
self,
event: EventBase,
- context: EventContext,
+ batched_auth_events: Optional[Mapping[str, EventBase]] = None,
) -> None:
- """Check an event passes the auth rules at its own auth events"""
- await check_state_independent_auth_rules(self._store, event)
+ """Check an event passes the auth rules at its own auth events
+ Args:
+ event: event to be authed
+ batched_auth_events: if the event being authed is part of a batch, any events
+ from the same batch that may be necessary to auth the current event
+ """
+ await check_state_independent_auth_rules(
+ self._store, event, batched_auth_events
+ )
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
+ if batched_auth_events:
+ auth_events_by_id.update(batched_auth_events)
check_state_dependent_auth_rules(event, auth_events_by_id.values())
def compute_auth_events(
@@ -171,17 +180,22 @@ class EventAuthHandler:
this function may return an incorrect result as we are not able to fully
track server membership in a room without full state.
"""
- if not allow_partial_state_rooms and await self._store.is_partial_state_room(
- room_id
- ):
- raise AuthError(
- 403,
- "Unable to authorise you right now; room is partial-stated here.",
- errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
- )
-
- if not await self.is_host_in_room(room_id, host):
- raise AuthError(403, "Host not in room.")
+ if await self._store.is_partial_state_room(room_id):
+ if allow_partial_state_rooms:
+ current_hosts = await self._state_storage_controller.get_current_hosts_in_room_or_partial_state_approximation(
+ room_id
+ )
+ if host not in current_hosts:
+ raise AuthError(403, "Host not in room (partial-state approx).")
+ else:
+ raise AuthError(
+ 403,
+ "Unable to authorise you right now; room is partial-stated here.",
+ errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
+ )
+ else:
+ if not await self.is_host_in_room(room_id, host):
+ raise AuthError(403, "Host not in room.")
async def check_restricted_join_rules(
self,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 986ffed3d5..b2784d7333 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
FederationError,
+ FederationPullAttemptBackoffError,
HttpResponseException,
LimitExceededError,
NotFoundError,
@@ -69,8 +70,8 @@ from synapse.replication.http.federation import (
)
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@@ -151,6 +152,7 @@ class FederationHandler:
self._federation_event_handler = hs.get_federation_event_handler()
self._device_handler = hs.get_device_handler()
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
+ self._notifier = hs.get_notifier()
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
@@ -378,6 +380,7 @@ class FederationHandler:
filtered_extremities = await filter_events_for_server(
self._storage_controllers,
self.server_name,
+ self.server_name,
events_to_check,
redact=False,
check_history_visibility_only=True,
@@ -441,6 +444,15 @@ class FederationHandler:
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
return True
+ except NotRetryingDestination as e:
+ logger.info("_maybe_backfill_inner: %s", e)
+ continue
+ except FederationDeniedError:
+ logger.info(
+ "_maybe_backfill_inner: Not attempting to backfill from %s because the homeserver is not on our federation whitelist",
+ dom,
+ )
+ continue
except (SynapseError, InvalidResponseError) as e:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
@@ -476,15 +488,9 @@ class FederationHandler:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
- except NotRetryingDestination as e:
- logger.info(str(e))
- continue
except RequestSendFailed as e:
logger.info("Failed to get backfill from %s because %s", dom, e)
continue
- except FederationDeniedError as e:
- logger.info(e)
- continue
except Exception as e:
logger.exception("Failed to backfill from %s because %s", dom, e)
continue
@@ -631,6 +637,7 @@ class FederationHandler:
room_id=room_id,
servers=ret.servers_in_room,
device_lists_stream_id=self.store.get_device_stream_token(),
+ joined_via=origin,
)
try:
@@ -781,15 +788,27 @@ class FederationHandler:
# Send the signed event back to the room, and potentially receive some
# further information about the room in the form of partial state events
- stripped_room_state = await self.federation_client.send_knock(
- target_hosts, event
- )
+ knock_response = await self.federation_client.send_knock(target_hosts, event)
# Store any stripped room state events in the "unsigned" key of the event.
# This is a bit of a hack and is cribbing off of invites. Basically we
# store the room state here and retrieve it again when this event appears
# in the invitee's sync stream. It is stripped out for all other local users.
- event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+ stripped_room_state = (
+ knock_response.get("knock_room_state")
+ # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
+ # Thus, we also check for a 'knock_state_events' to support old instances.
+ # See https://github.com/matrix-org/synapse/issues/14088.
+ or knock_response.get("knock_state_events")
+ )
+
+ if stripped_room_state is None:
+ raise KeyError(
+ "Missing 'knock_room_state' (or legacy 'knock_state_events') field in "
+ "send_knock response"
+ )
+
+ event.unsigned["knock_room_state"] = stripped_room_state
context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
@@ -928,7 +947,7 @@ class FederationHandler:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- await self._event_auth_handler.check_auth_rules_from_context(event, context)
+ await self._event_auth_handler.check_auth_rules_from_context(event)
return event
async def on_invite_request(
@@ -1003,7 +1022,9 @@ class FederationHandler:
context = EventContext.for_outlier(self._storage_controllers)
- await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context)
+ await self._bulk_push_rule_evaluator.action_for_events_by_user(
+ [(event, context)]
+ )
try:
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
@@ -1109,7 +1130,7 @@ class FederationHandler:
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- await self._event_auth_handler.check_auth_rules_from_context(event, context)
+ await self._event_auth_handler.check_auth_rules_from_context(event)
except AuthError as e:
logger.warning("Failed to create new leave %r because %s", event, e)
raise e
@@ -1168,7 +1189,7 @@ class FederationHandler:
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_knock_request`
- await self._event_auth_handler.check_auth_rules_from_context(event, context)
+ await self._event_auth_handler.check_auth_rules_from_context(event)
except AuthError as e:
logger.warning("Failed to create new knock %r because %s", event, e)
raise e
@@ -1212,7 +1233,9 @@ class FederationHandler:
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
- await self._event_auth_handler.assert_host_in_room(room_id, origin)
+ # We allow partially joined rooms since in this case we are filtering out
+ # non-local events in `filter_events_for_server`.
+ await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
# Synapse asks for 100 events per backfill request. Do not allow more.
limit = min(limit, 100)
@@ -1233,7 +1256,7 @@ class FederationHandler:
)
events = await filter_events_for_server(
- self._storage_controllers, origin, events
+ self._storage_controllers, origin, self.server_name, events
)
return events
@@ -1264,7 +1287,7 @@ class FederationHandler:
await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
events = await filter_events_for_server(
- self._storage_controllers, origin, [event]
+ self._storage_controllers, origin, self.server_name, [event]
)
event = events[0]
return event
@@ -1277,7 +1300,9 @@ class FederationHandler:
latest_events: List[str],
limit: int,
) -> List[EventBase]:
- await self._event_auth_handler.assert_host_in_room(room_id, origin)
+ # We allow partially joined rooms since in this case we are filtering out
+ # non-local events in `filter_events_for_server`.
+ await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
# Only allow up to 20 events to be retrieved per request.
limit = min(limit, 20)
@@ -1290,7 +1315,7 @@ class FederationHandler:
)
missing_events = await filter_events_for_server(
- self._storage_controllers, origin, missing_events
+ self._storage_controllers, origin, self.server_name, missing_events
)
return missing_events
@@ -1334,9 +1359,7 @@ class FederationHandler:
try:
validate_event_for_room_version(event)
- await self._event_auth_handler.check_auth_rules_from_context(
- event, context
- )
+ await self._event_auth_handler.check_auth_rules_from_context(event)
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
raise e
@@ -1386,7 +1409,7 @@ class FederationHandler:
try:
validate_event_for_room_version(event)
- await self._event_auth_handler.check_auth_rules_from_context(event, context)
+ await self._event_auth_handler.check_auth_rules_from_context(event)
except AuthError as e:
logger.warning("Denying third party invite %r because %s", event, e)
raise e
@@ -1579,8 +1602,8 @@ class FederationHandler:
Fetch the complexity of a remote room over federation.
Args:
- remote_room_hosts (list[str]): The remote servers to ask.
- room_id (str): The room ID to ask about.
+ remote_room_hosts: The remote servers to ask.
+ room_id: The room ID to ask about.
Returns:
Dict contains the complexity
@@ -1602,13 +1625,13 @@ class FederationHandler:
"""Resumes resyncing of all partial-state rooms after a restart."""
assert not self.config.worker.worker_app
- partial_state_rooms = await self.store.get_partial_state_rooms_and_servers()
- for room_id, servers_in_room in partial_state_rooms.items():
+ partial_state_rooms = await self.store.get_partial_state_room_resync_info()
+ for room_id, resync_info in partial_state_rooms.items():
run_as_background_process(
desc="sync_partial_state_room",
func=self._sync_partial_state_room,
- initial_destination=None,
- other_destinations=servers_in_room,
+ initial_destination=resync_info.joined_via,
+ other_destinations=resync_info.servers_in_room,
room_id=room_id,
)
@@ -1637,28 +1660,12 @@ class FederationHandler:
# really leave, that might mean we have difficulty getting the room state over
# federation.
# https://github.com/matrix-org/synapse/issues/12802
- #
- # TODO(faster_joins): we need some way of prioritising which homeservers in
- # `other_destinations` to try first, otherwise we'll spend ages trying dead
- # homeservers for large rooms.
- # https://github.com/matrix-org/synapse/issues/12999
-
- if initial_destination is None and len(other_destinations) == 0:
- raise ValueError(
- f"Cannot resync state of {room_id}: no destinations provided"
- )
# Make an infinite iterator of destinations to try. Once we find a working
# destination, we'll stick with it until it flakes.
- destinations: Collection[str]
- if initial_destination is not None:
- # Move `initial_destination` to the front of the list.
- destinations = list(other_destinations)
- if initial_destination in destinations:
- destinations.remove(initial_destination)
- destinations = [initial_destination] + destinations
- else:
- destinations = other_destinations
+ destinations = _prioritise_destinations_for_partial_state_resync(
+ initial_destination, other_destinations, room_id
+ )
destination_iter = itertools.cycle(destinations)
# `destination` is the current remote homeserver we're pulling from.
@@ -1686,6 +1693,9 @@ class FederationHandler:
self._storage_controllers.state.notify_room_un_partial_stated(
room_id
)
+ # Poke the notifier so that other workers see the write to
+ # the un-partial-stated rooms stream.
+ self._notifier.notify_replication()
# TODO(faster_joins) update room stats and user directory?
# https://github.com/matrix-org/synapse/issues/12814
@@ -1708,7 +1718,22 @@ class FederationHandler:
destination, event
)
break
+ except FederationPullAttemptBackoffError as exc:
+ # Log a warning about why we failed to process the event (the error message
+ # for `FederationPullAttemptBackoffError` is pretty good)
+ logger.warning("_sync_partial_state_room: %s", exc)
+ # We do not record a failed pull attempt when we backoff fetching a missing
+ # `prev_event` because not being able to fetch the `prev_events` just means
+ # we won't be able to de-outlier the pulled event. But we can still use an
+ # `outlier` in the state/auth chain for another event. So we shouldn't stop
+ # a downstream event from trying to pull it.
+ #
+ # This avoids a cascade of backoff for all events in the DAG downstream from
+ # one event backoff upstream.
except FederationError as e:
+ # TODO: We should `record_event_failed_pull_attempt` here,
+ # see https://github.com/matrix-org/synapse/issues/13700
+
if attempt == len(destinations) - 1:
# We have tried every remote server for this event. Give up.
# TODO(faster_joins) giving up isn't the right thing to do
@@ -1741,3 +1766,29 @@ class FederationHandler:
room_id,
destination,
)
+
+
+def _prioritise_destinations_for_partial_state_resync(
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+) -> Collection[str]:
+ """Work out the order in which we should ask servers to resync events.
+
+ If an `initial_destination` is given, it takes top priority. Otherwise
+ all servers are treated equally.
+
+ :raises ValueError: if no destination is provided at all.
+ """
+ if initial_destination is None and len(other_destinations) == 0:
+ raise ValueError(f"Cannot resync state of {room_id}: no destinations provided")
+
+ if initial_destination is None:
+ return other_destinations
+
+ # Move `initial_destination` to the front of the list.
+ destinations = list(other_destinations)
+ if initial_destination in destinations:
+ destinations.remove(initial_destination)
+ destinations = [initial_destination] + destinations
+ return destinations
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index da319943cc..66aca2f864 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -43,7 +43,9 @@ from synapse.api.constants import (
from synapse.api.errors import (
AuthError,
Codes,
+ EventSizeError,
FederationError,
+ FederationPullAttemptBackoffError,
HttpResponseException,
RequestSendFailed,
SynapseError,
@@ -57,7 +59,7 @@ from synapse.event_auth import (
)
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.federation.federation_client import InvalidResponseError
+from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import (
SynapseTags,
@@ -74,7 +76,6 @@ from synapse.replication.http.federation import (
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
@@ -82,6 +83,7 @@ from synapse.types import (
UserID,
get_domain_from_id,
)
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
@@ -414,7 +416,9 @@ class FederationEventHandler:
# First, precalculate the joined hosts so that the federation sender doesn't
# need to.
- await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
+ await self._event_creation_handler.cache_joined_hosts_for_events(
+ [(event, context)]
+ )
await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context)
@@ -565,6 +569,9 @@ class FederationEventHandler:
event: partial-state event to be de-partial-stated
Raises:
+ FederationPullAttemptBackoffError if we are are deliberately not attempting
+ to pull the given event over federation because we've already done so
+ recently and are backing off.
FederationError if we fail to request state from the remote server.
"""
logger.info("Updating state for %s", event.event_id)
@@ -792,9 +799,42 @@ class FederationEventHandler:
],
)
+ # Check if we already any of these have these events.
+ # Note: we currently make a lookup in the database directly here rather than
+ # checking the event cache, due to:
+ # https://github.com/matrix-org/synapse/issues/13476
+ existing_events_map = await self._store._get_events_from_db(
+ [event.event_id for event in events]
+ )
+
+ new_events = []
+ for event in events:
+ event_id = event.event_id
+
+ # If we've already seen this event ID...
+ if event_id in existing_events_map:
+ existing_event = existing_events_map[event_id]
+
+ # ...and the event itself was not previously stored as an outlier...
+ if not existing_event.event.internal_metadata.is_outlier():
+ # ...then there's no need to persist it. We have it already.
+ logger.info(
+ "_process_pulled_event: Ignoring received event %s which we "
+ "have already seen",
+ event.event_id,
+ )
+ continue
+
+ # While we have seen this event before, it was stored as an outlier.
+ # We'll now persist it as a non-outlier.
+ logger.info("De-outliering event %s", event_id)
+
+ # Continue on with the events that are new to us.
+ new_events.append(event)
+
# We want to sort these by depth so we process them and
# tell clients about them in order.
- sorted_events = sorted(events, key=lambda x: x.depth)
+ sorted_events = sorted(new_events, key=lambda x: x.depth)
for ev in sorted_events:
with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev, backfilled=backfilled)
@@ -846,18 +886,6 @@ class FederationEventHandler:
event_id = event.event_id
- existing = await self._store.get_event(
- event_id, allow_none=True, allow_rejected=True
- )
- if existing:
- if not existing.internal_metadata.is_outlier():
- logger.info(
- "_process_pulled_event: Ignoring received event %s which we have already seen",
- event_id,
- )
- return
- logger.info("De-outliering event %s", event_id)
-
try:
self._sanity_check_event(event)
except SynapseError as err:
@@ -899,6 +927,18 @@ class FederationEventHandler:
context,
backfilled=backfilled,
)
+ except FederationPullAttemptBackoffError as exc:
+ # Log a warning about why we failed to process the event (the error message
+ # for `FederationPullAttemptBackoffError` is pretty good)
+ logger.warning("_process_pulled_event: %s", exc)
+ # We do not record a failed pull attempt when we backoff fetching a missing
+ # `prev_event` because not being able to fetch the `prev_events` just means
+ # we won't be able to de-outlier the pulled event. But we can still use an
+ # `outlier` in the state/auth chain for another event. So we shouldn't stop
+ # a downstream event from trying to pull it.
+ #
+ # This avoids a cascade of backoff for all events in the DAG downstream from
+ # one event backoff upstream.
except FederationError as e:
await self._store.record_event_failed_pull_attempt(
event.room_id, event_id, str(e)
@@ -945,6 +985,9 @@ class FederationEventHandler:
The event context.
Raises:
+ FederationPullAttemptBackoffError if we are are deliberately not attempting
+ to pull the given event over federation because we've already done so
+ recently and are backing off.
FederationError if we fail to get the state from the remote server after any
missing `prev_event`s.
"""
@@ -955,6 +998,18 @@ class FederationEventHandler:
seen = await self._store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen
+ # If we've already recently attempted to pull this missing event, don't
+ # try it again so soon. Since we have to fetch all of the prev_events, we can
+ # bail early here if we find any to ignore.
+ prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff(
+ room_id, missing_prevs
+ )
+ if len(prevs_to_ignore) > 0:
+ raise FederationPullAttemptBackoffError(
+ event_ids=prevs_to_ignore,
+ message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).",
+ )
+
if not missing_prevs:
return await self._state_handler.compute_event_context(event)
@@ -1011,10 +1066,9 @@ class FederationEventHandler:
state_res_store=StateResolutionStore(self._store),
)
- except Exception:
+ except Exception as e:
logger.warning(
- "Error attempting to resolve state at missing prev_events",
- exc_info=True,
+ "Error attempting to resolve state at missing prev_events: %s", e
)
raise FederationError(
"ERROR",
@@ -1463,8 +1517,8 @@ class FederationEventHandler:
)
async def backfill_event_id(
- self, destination: str, room_id: str, event_id: str
- ) -> EventBase:
+ self, destinations: List[str], room_id: str, event_id: str
+ ) -> PulledPduInfo:
"""Backfill a single event and persist it as a non-outlier which means
we also pull in all of the state and auth events necessary for it.
@@ -1476,24 +1530,21 @@ class FederationEventHandler:
Raises:
FederationError if we are unable to find the event from the destination
"""
- logger.info(
- "backfill_event_id: event_id=%s from destination=%s", event_id, destination
- )
+ logger.info("backfill_event_id: event_id=%s", event_id)
room_version = await self._store.get_room_version(room_id)
- event_from_response = await self._federation_client.get_pdu(
- [destination],
+ pulled_pdu_info = await self._federation_client.get_pdu(
+ destinations,
event_id,
room_version,
)
- if not event_from_response:
+ if not pulled_pdu_info:
raise FederationError(
"ERROR",
404,
- "Unable to find event_id=%s from destination=%s to backfill."
- % (event_id, destination),
+ f"Unable to find event_id={event_id} from remote servers to backfill.",
affected=event_id,
)
@@ -1501,13 +1552,13 @@ class FederationEventHandler:
# and auth events to de-outlier it. This also sets up the necessary
# `state_groups` for the event.
await self._process_pulled_events(
- destination,
- [event_from_response],
+ pulled_pdu_info.pull_origin,
+ [pulled_pdu_info.pdu],
# Prevent notifications going to clients
backfilled=True,
)
- return event_from_response
+ return pulled_pdu_info
@trace
@tag_args
@@ -1530,19 +1581,19 @@ class FederationEventHandler:
async def get_event(event_id: str) -> None:
with nested_logging_context(event_id):
try:
- event = await self._federation_client.get_pdu(
+ pulled_pdu_info = await self._federation_client.get_pdu(
[destination],
event_id,
room_version,
)
- if event is None:
+ if pulled_pdu_info is None:
logger.warning(
"Server %s didn't return event %s",
destination,
event_id,
)
return
- events.append(event)
+ events.append(pulled_pdu_info.pdu)
except Exception as e:
logger.warning(
@@ -1686,6 +1737,15 @@ class FederationEventHandler:
except AuthError as e:
logger.warning("Rejecting %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR
+ except EventSizeError as e:
+ if e.unpersistable:
+ # This event is completely unpersistable.
+ raise e
+ # Otherwise, we are somewhat lenient and just persist the event
+ # as rejected, for moderate compatibility with older Synapse
+ # versions.
+ logger.warning("While validating received event %r: %s", event, e)
+ context.rejected = RejectedReason.OVERSIZED_EVENT
events_and_contexts_to_persist.append((event, context))
@@ -1731,6 +1791,16 @@ class FederationEventHandler:
# TODO: use a different rejected reason here?
context.rejected = RejectedReason.AUTH_ERROR
return
+ except EventSizeError as e:
+ if e.unpersistable:
+ # This event is completely unpersistable.
+ raise e
+ # Otherwise, we are somewhat lenient and just persist the event
+ # as rejected, for moderate compatibility with older Synapse
+ # versions.
+ logger.warning("While validating received event %r: %s", event, e)
+ context.rejected = RejectedReason.OVERSIZED_EVENT
+ return
# next, check that we have all of the event's auth events.
#
@@ -2117,8 +2187,8 @@ class FederationEventHandler:
min_depth,
)
else:
- await self._bulk_push_rule_evaluator.action_for_event_by_user(
- event, context
+ await self._bulk_push_rule_evaluator.action_for_events_by_user(
+ [(event, context)]
)
try:
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 93d09e9939..848e46eb9b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -711,7 +711,7 @@ class IdentityHandler:
inviter_display_name: The current display name of the
inviter.
inviter_avatar_url: The URL of the inviter's avatar.
- id_access_token (str): The access token to authenticate to the identity
+ id_access_token: The access token to authenticate to the identity
server with
Returns:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 860c82c110..9c335e6863 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -57,13 +57,7 @@ class InitialSyncHandler:
self.validator = EventValidator()
self.snapshot_cache: ResponseCache[
Tuple[
- str,
- Optional[StreamToken],
- Optional[StreamToken],
- str,
- Optional[int],
- bool,
- bool,
+ str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
@@ -154,11 +148,6 @@ class InitialSyncHandler:
public_room_ids = await self.store.get_public_room_ids()
- if pagin_config.limit is not None:
- limit = pagin_config.limit
- else:
- limit = 10
-
serializer_options = SerializeEventConfig(as_client_event=as_client_event)
async def handle_room(event: RoomsForUser) -> None:
@@ -210,7 +199,7 @@ class InitialSyncHandler:
run_in_background(
self.store.get_recent_events_for_room,
event.room_id,
- limit=limit,
+ limit=pagin_config.limit,
end_token=room_end_token,
),
deferred_room_state,
@@ -360,15 +349,11 @@ class InitialSyncHandler:
member_event_id
)
- limit = pagin_config.limit if pagin_config else None
- if limit is None:
- limit = 10
-
leave_position = await self.store.get_position_for_event(member_event_id)
stream_token = leave_position.to_room_stream_token()
messages, token = await self.store.get_recent_events_for_room(
- room_id, limit=limit, end_token=stream_token
+ room_id, limit=pagin_config.limit, end_token=stream_token
)
messages = await filter_events_for_client(
@@ -420,10 +405,6 @@ class InitialSyncHandler:
now_token = self.hs.get_event_sources().get_current_token()
- limit = pagin_config.limit if pagin_config else None
- if limit is None:
- limit = 10
-
room_members = [
m
for m in current_state.values()
@@ -467,7 +448,7 @@ class InitialSyncHandler:
run_in_background(
self.store.get_recent_events_for_room,
room_id,
- limit=limit,
+ limit=pagin_config.limit,
end_token=now_token.room_key,
),
),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da1acea275..845f683358 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
+from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
@@ -59,7 +60,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import (
MutableStateMap,
PersistedEventPosition,
@@ -70,6 +70,7 @@ from synapse.types import (
UserID,
create_requester,
)
+from synapse.types.state import StateFilter
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
@@ -877,6 +878,36 @@ class EventCreationHandler:
return prev_event
return None
+ async def get_event_from_transaction(
+ self,
+ requester: Requester,
+ txn_id: str,
+ room_id: str,
+ ) -> Optional[EventBase]:
+ """For the given transaction ID and room ID, check if there is a matching event.
+ If so, fetch it and return it.
+
+ Args:
+ requester: The requester making the request in the context of which we want
+ to fetch the event.
+ txn_id: The transaction ID.
+ room_id: The room ID.
+
+ Returns:
+ An event if one could be found, None otherwise.
+ """
+ if requester.access_token_id:
+ existing_event_id = await self.store.get_event_id_from_transaction_id(
+ room_id,
+ requester.user.to_string(),
+ requester.access_token_id,
+ txn_id,
+ )
+ if existing_event_id:
+ return await self.store.get_event(existing_event_id)
+
+ return None
+
async def create_and_send_nonmember_event(
self,
requester: Requester,
@@ -956,18 +987,17 @@ class EventCreationHandler:
# extremities to pile up, which in turn leads to state resolution
# taking longer.
async with self.limiter.queue(event_dict["room_id"]):
- if txn_id and requester.access_token_id:
- existing_event_id = await self.store.get_event_id_from_transaction_id(
- event_dict["room_id"],
- requester.user.to_string(),
- requester.access_token_id,
- txn_id,
+ if txn_id:
+ event = await self.get_event_from_transaction(
+ requester, txn_id, event_dict["room_id"]
)
- if existing_event_id:
- event = await self.store.get_event(existing_event_id)
+ if event:
# we know it was persisted, so must have a stream ordering
assert event.internal_metadata.stream_ordering
- return event, event.internal_metadata.stream_ordering
+ return (
+ event,
+ event.internal_metadata.stream_ordering,
+ )
event, context = await self.create_event(
requester,
@@ -1106,11 +1136,13 @@ class EventCreationHandler:
)
state_events = await self.store.get_events_as_list(state_event_ids)
# Create a StateMap[str]
- state_map = {(e.type, e.state_key): e.event_id for e in state_events}
+ current_state_ids = {
+ (e.type, e.state_key): e.event_id for e in state_events
+ }
# Actually strip down and only use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
- current_state_ids=state_map,
+ current_state_ids=current_state_ids,
for_verification=False,
)
@@ -1360,8 +1392,16 @@ class EventCreationHandler:
else:
try:
validate_event_for_room_version(event)
+ # If we are persisting a batch of events the event(s) needed to auth the
+ # current event may be part of the batch and will not be in the DB yet
+ event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+ batched_auth_events = {}
+ for event_id in event.auth_event_ids():
+ auth_event = event_id_to_event.get(event_id)
+ if auth_event:
+ batched_auth_events[event_id] = auth_event
await self._event_auth_handler.check_auth_rules_from_context(
- event, context
+ event, batched_auth_events
)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
@@ -1390,7 +1430,7 @@ class EventCreationHandler:
extra_users=extra_users,
),
run_in_background(
- self.cache_joined_hosts_for_event, event, context
+ self.cache_joined_hosts_for_events, events_and_context
).addErrback(
log_failure, "cache_joined_hosts_for_event failed"
),
@@ -1425,17 +1465,9 @@ class EventCreationHandler:
a room that has been un-partial stated.
"""
- for event, context in events_and_context:
- # Skip push notification actions for historical messages
- # because we don't want to notify people about old history back in time.
- # The historical messages also do not have the proper `context.current_state_ids`
- # and `state_groups` because they have `prev_events` that aren't persisted yet
- # (historical messages persisted in reverse-chronological order).
- if not event.internal_metadata.is_historical():
- with opentracing.start_active_span("calculate_push_actions"):
- await self._bulk_push_rule_evaluator.action_for_event_by_user(
- event, context
- )
+ await self._bulk_push_rule_evaluator.action_for_events_by_user(
+ events_and_context
+ )
try:
# If we're a worker we need to hit out to the master.
@@ -1491,62 +1523,65 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
- async def cache_joined_hosts_for_event(
- self, event: EventBase, context: EventContext
+ async def cache_joined_hosts_for_events(
+ self, events_and_context: List[Tuple[EventBase, EventContext]]
) -> None:
- """Precalculate the joined hosts at the event, when using Redis, so that
+ """Precalculate the joined hosts at each of the given events, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""
- if not self._external_cache.is_enabled():
- return
-
- # If external cache is enabled we should always have this.
- assert self._external_cache_joined_hosts_updates is not None
+ for event, _ in events_and_context:
+ if not self._external_cache.is_enabled():
+ return
- # We actually store two mappings, event ID -> prev state group,
- # state group -> joined hosts, which is much more space efficient
- # than event ID -> joined hosts.
- #
- # Note: We have to cache event ID -> prev state group, as we don't
- # store that in the DB.
- #
- # Note: We set the state group -> joined hosts cache if it hasn't been
- # set for a while, so that the expiry time is reset.
+ # If external cache is enabled we should always have this.
+ assert self._external_cache_joined_hosts_updates is not None
- state_entry = await self.state.resolve_state_groups_for_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We set the state group -> joined hosts cache if it hasn't been
+ # set for a while, so that the expiry time is reset.
- if state_entry.state_group:
- await self._external_cache.set(
- "event_to_prev_state_group",
- event.event_id,
- state_entry.state_group,
- expiry_ms=60 * 60 * 1000,
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
)
- if state_entry.state_group in self._external_cache_joined_hosts_updates:
- return
+ if state_entry.state_group:
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
- state = await state_entry.get_state(
- self._storage_controllers.state, StateFilter.all()
- )
- with opentracing.start_active_span("get_joined_hosts"):
- joined_hosts = await self.store.get_joined_hosts(
- event.room_id, state, state_entry
+ if state_entry.state_group in self._external_cache_joined_hosts_updates:
+ return
+
+ state = await state_entry.get_state(
+ self._storage_controllers.state, StateFilter.all()
)
+ with opentracing.start_active_span("get_joined_hosts"):
+ joined_hosts = await self.store.get_joined_hosts(
+ event.room_id, state, state_entry
+ )
- # Note that the expiry times must be larger than the expiry time in
- # _external_cache_joined_hosts_updates.
- await self._external_cache.set(
- "get_joined_hosts",
- str(state_entry.state_group),
- list(joined_hosts),
- expiry_ms=60 * 60 * 1000,
- )
+ # Note that the expiry times must be larger than the expiry time in
+ # _external_cache_joined_hosts_updates.
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
- self._external_cache_joined_hosts_updates[state_entry.state_group] = None
+ self._external_cache_joined_hosts_updates[
+ state_entry.state_group
+ ] = None
async def _validate_canonical_alias(
self,
@@ -1705,12 +1740,15 @@ class EventCreationHandler:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
- event.unsigned[
- "invite_room_state"
- ] = await self.store.get_stripped_room_state_from_event_context(
- context,
- self.room_prejoin_state_types,
- membership_user_id=event.sender,
+ maybe_upsert_event_field(
+ event,
+ event.unsigned,
+ "invite_room_state",
+ await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ membership_user_id=event.sender,
+ ),
)
invitee = UserID.from_string(event.state_key)
@@ -1728,11 +1766,14 @@ class EventCreationHandler:
event.signatures.update(returned_invite.signatures)
if event.content["membership"] == Membership.KNOCK:
- event.unsigned[
- "knock_room_state"
- ] = await self.store.get_stripped_room_state_from_event_context(
- context,
- self.room_prejoin_state_types,
+ maybe_upsert_event_field(
+ event,
+ event.unsigned,
+ "knock_room_state",
+ await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ ),
)
if event.type == EventTypes.Redaction:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index d7a8226900..03de6a4ba6 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -12,14 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import binascii
import inspect
+import json
import logging
-from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+)
from urllib.parse import urlencode, urlparse
import attr
+import unpaddedbase64
from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken, jwt
+from authlib.jose import JsonWebToken, JWTClaims
+from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, UserInfo
@@ -35,9 +49,12 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from twisted.web.http_headers import Headers
+from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes
+from synapse.http.server import finish_request
+from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -88,6 +105,8 @@ class Token(TypedDict):
#: there is no real point of doing this in our case.
JWK = Dict[str, str]
+C = TypeVar("C")
+
#: A JWK Set, as per RFC7517 sec 5.
class JWKS(TypedDict):
@@ -247,6 +266,80 @@ class OidcHandler:
await oidc_provider.handle_oidc_callback(request, session_data, code)
+ async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
+ """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+ This extracts the logout_token from the request and tries to figure out
+ which OpenID Provider it is comming from. This works by matching the iss claim
+ with the issuer and the aud claim with the client_id.
+
+ Since at this point we don't know who signed the JWT, we can't just
+ decode it using authlib since it will always verifies the signature. We
+ have to decode it manually without validating the signature. The actual JWT
+ verification is done in the `OidcProvider.handler_backchannel_logout` method,
+ once we figured out which provider sent the request.
+
+ Args:
+ request: the incoming request from the browser.
+ """
+ logout_token = parse_string(request, "logout_token")
+ if logout_token is None:
+ raise SynapseError(400, "Missing logout_token in request")
+
+ # A JWT looks like this:
+ # header.payload.signature
+ # where all parts are encoded with urlsafe base64.
+ # The aud and iss claims we care about are in the payload part, which
+ # is a JSON object.
+ try:
+ # By destructuring the list after splitting, we ensure that we have
+ # exactly 3 segments
+ _, payload, _ = logout_token.split(".")
+ except ValueError:
+ raise SynapseError(400, "Invalid logout_token in request")
+
+ try:
+ payload_bytes = unpaddedbase64.decode_base64(payload)
+ claims = json_decoder.decode(payload_bytes.decode("utf-8"))
+ except (json.JSONDecodeError, binascii.Error, UnicodeError):
+ raise SynapseError(400, "Invalid logout_token payload in request")
+
+ try:
+ # Let's extract the iss and aud claims
+ iss = claims["iss"]
+ aud = claims["aud"]
+ # The aud claim can be either a string or a list of string. Here we
+ # normalize it as a list of strings.
+ if isinstance(aud, str):
+ aud = [aud]
+
+ # Check that we have the right types for the aud and the iss claims
+ if not isinstance(iss, str) or not isinstance(aud, list):
+ raise TypeError()
+ for a in aud:
+ if not isinstance(a, str):
+ raise TypeError()
+
+ # At this point we properly checked both claims types
+ issuer: str = iss
+ audience: List[str] = aud
+ except (TypeError, KeyError):
+ raise SynapseError(400, "Invalid issuer/audience in logout_token")
+
+ # Now that we know the audience and the issuer, we can figure out from
+ # what provider it is coming from
+ oidc_provider: Optional[OidcProvider] = None
+ for provider in self._providers.values():
+ if provider.issuer == issuer and provider.client_id in audience:
+ oidc_provider = provider
+ break
+
+ if oidc_provider is None:
+ raise SynapseError(400, "Could not find the OP that issued this event")
+
+ # Ask the provider to handle the logout request.
+ await oidc_provider.handle_backchannel_logout(request, logout_token)
+
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint"""
@@ -275,6 +368,7 @@ class OidcProvider:
provider: OidcProviderConfig,
):
self._store = hs.get_datastores().main
+ self._clock = hs.get_clock()
self._macaroon_generaton = macaroon_generator
@@ -341,6 +435,7 @@ class OidcProvider:
self.idp_brand = provider.idp_brand
self._sso_handler = hs.get_sso_handler()
+ self._device_handler = hs.get_device_handler()
self._sso_handler.register_identity_provider(self)
@@ -399,6 +494,41 @@ class OidcProvider:
# If we're not using userinfo, we need a valid jwks to validate the ID token
m.validate_jwks_uri()
+ if self._config.backchannel_logout_enabled:
+ if not m.get("backchannel_logout_supported", False):
+ logger.warning(
+ "OIDC Back-Channel Logout is enabled for issuer %r"
+ "but it does not advertise support for it",
+ self.issuer,
+ )
+
+ elif not m.get("backchannel_logout_session_supported", False):
+ logger.warning(
+ "OIDC Back-Channel Logout is enabled and supported "
+ "by issuer %r but it might not send a session ID with "
+ "logout tokens, which is required for the logouts to work",
+ self.issuer,
+ )
+
+ if not self._config.backchannel_logout_ignore_sub:
+ # If OIDC backchannel logouts are enabled, the provider mapping provider
+ # should use the `sub` claim. We verify that by mapping a dumb user and
+ # see if we get back the sub claim
+ user = UserInfo({"sub": "thisisasubject"})
+ try:
+ subject = self._user_mapping_provider.get_remote_user_id(user)
+ if subject != user["sub"]:
+ raise ValueError("Unexpected subject")
+ except Exception:
+ logger.warning(
+ f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
+ "but it looks like the configured `user_mapping_provider` "
+ "does not use the `sub` claim as subject. If it is the case, "
+ "and you want Synapse to ignore the `sub` claim in OIDC "
+ "Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
+ "to `true` in the issuer config."
+ )
+
@property
def _uses_userinfo(self) -> bool:
"""Returns True if the ``userinfo_endpoint`` should be used.
@@ -414,6 +544,16 @@ class OidcProvider:
or self._user_profile_method == "userinfo_endpoint"
)
+ @property
+ def issuer(self) -> str:
+ """The issuer identifying this provider."""
+ return self._config.issuer
+
+ @property
+ def client_id(self) -> str:
+ """The client_id used when interacting with this provider."""
+ return self._config.client_id
+
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
"""Return the provider metadata.
@@ -647,7 +787,7 @@ class OidcProvider:
Must include an ``access_token`` field.
Returns:
- UserInfo: an object representing the user.
+ an object representing the user.
"""
logger.debug("Using the OAuth2 access_token to request userinfo")
metadata = await self.load_metadata()
@@ -661,61 +801,99 @@ class OidcProvider:
return UserInfo(resp)
- async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
- """Return an instance of UserInfo from token's ``id_token``.
+ async def _verify_jwt(
+ self,
+ alg_values: List[str],
+ token: str,
+ claims_cls: Type[C],
+ claims_options: Optional[dict] = None,
+ claims_params: Optional[dict] = None,
+ ) -> C:
+ """Decode and validate a JWT, re-fetching the JWKS as needed.
Args:
- token: the token given by the ``token_endpoint``.
- Must include an ``id_token`` field.
- nonce: the nonce value originally sent in the initial authorization
- request. This value should match the one inside the token.
+ alg_values: list of `alg` values allowed when verifying the JWT.
+ token: the JWT.
+ claims_cls: the JWTClaims class to use to validate the claims.
+ claims_options: dict of options passed to the `claims_cls` constructor.
+ claims_params: dict of params passed to the `claims_cls` constructor.
Returns:
- The decoded claims in the ID token.
+ The decoded claims in the JWT.
"""
- metadata = await self.load_metadata()
- claims_params = {
- "nonce": nonce,
- "client_id": self._client_auth.client_id,
- }
- if "access_token" in token:
- # If we got an `access_token`, there should be an `at_hash` claim
- # in the `id_token` that we can check against.
- claims_params["access_token"] = token["access_token"]
-
- alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
- claim_options = {"iss": {"values": [metadata["issuer"]]}}
-
- id_token = token["id_token"]
- logger.debug("Attempting to decode JWT id_token %r", id_token)
+ logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
- id_token,
+ token,
key=jwk_set,
- claims_cls=CodeIDToken,
- claims_options=claim_options,
+ claims_cls=claims_cls,
+ claims_options=claims_options,
claims_params=claims_params,
)
except ValueError:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
- id_token,
+ token,
key=jwk_set,
- claims_cls=CodeIDToken,
- claims_options=claim_options,
+ claims_cls=claims_cls,
+ claims_options=claims_options,
claims_params=claims_params,
)
- logger.debug("Decoded id_token JWT %r; validating", claims)
+ logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
- claims.validate(leeway=120) # allows 2 min of clock skew
+ claims.validate(
+ now=self._clock.time(), leeway=120
+ ) # allows 2 min of clock skew
+ return claims
+
+ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
+ """Return an instance of UserInfo from token's ``id_token``.
+
+ Args:
+ token: the token given by the ``token_endpoint``.
+ Must include an ``id_token`` field.
+ nonce: the nonce value originally sent in the initial authorization
+ request. This value should match the one inside the token.
+
+ Returns:
+ The decoded claims in the ID token.
+ """
+ id_token = token.get("id_token")
+
+ # That has been theoritically been checked by the caller, so even though
+ # assertion are not enabled in production, it is mainly here to appease mypy
+ assert id_token is not None
+
+ metadata = await self.load_metadata()
+
+ claims_params = {
+ "nonce": nonce,
+ "client_id": self._client_auth.client_id,
+ }
+ if "access_token" in token:
+ # If we got an `access_token`, there should be an `at_hash` claim
+ # in the `id_token` that we can check against.
+ claims_params["access_token"] = token["access_token"]
+
+ claims_options = {"iss": {"values": [metadata["issuer"]]}}
+
+ alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+ claims = await self._verify_jwt(
+ alg_values=alg_values,
+ token=id_token,
+ claims_cls=CodeIDToken,
+ claims_options=claims_options,
+ claims_params=claims_params,
+ )
return claims
@@ -1036,6 +1214,146 @@ class OidcProvider:
# to be strings.
return str(remote_user_id)
+ async def handle_backchannel_logout(
+ self, request: SynapseRequest, logout_token: str
+ ) -> None:
+ """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+ The OIDC Provider posts a logout token to this endpoint when a user
+ session ends. That token is a JWT signed with the same keys as
+ ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
+ validate the JWT and figure out what session to end.
+
+ Args:
+ request: The request to respond to
+ logout_token: The logout token (a JWT) extracted from the request body
+ """
+ # Back-Channel Logout can be disabled in the config, hence this check.
+ # This is not that important for now since Synapse is registered
+ # manually to the OP, so not specifying the backchannel-logout URI is
+ # as effective than disabling it here. It might make more sense if we
+ # support dynamic registration in Synapse at some point.
+ if not self._config.backchannel_logout_enabled:
+ logger.warning(
+ f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
+ )
+
+ # TODO: this responds with a 400 status code, which is what the OIDC
+ # Back-Channel Logout spec expects, but spec also suggests answering with
+ # a JSON object, with the `error` and `error_description` fields set, which
+ # we are not doing here.
+ # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
+ raise SynapseError(
+ 400, "OpenID Connect Back-Channel Logout is disabled for this provider"
+ )
+
+ metadata = await self.load_metadata()
+
+ # As per OIDC Back-Channel Logout 1.0 sec. 2.4:
+ # A Logout Token MUST be signed and MAY also be encrypted. The same
+ # keys are used to sign and encrypt Logout Tokens as are used for ID
+ # Tokens. If the Logout Token is encrypted, it SHOULD replicate the
+ # iss (issuer) claim in the JWT Header Parameters, as specified in
+ # Section 5.3 of [JWT].
+ alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+ # As per sec. 2.6:
+ # 3. Validate the iss, aud, and iat Claims in the same way they are
+ # validated in ID Tokens.
+ # Which means the audience should contain Synapse's client_id and the
+ # issuer should be the IdP issuer
+ claims_options = {
+ "iss": {"values": [metadata["issuer"]]},
+ "aud": {"values": [self.client_id]},
+ }
+
+ try:
+ claims = await self._verify_jwt(
+ alg_values=alg_values,
+ token=logout_token,
+ claims_cls=LogoutToken,
+ claims_options=claims_options,
+ )
+ except JoseError:
+ logger.exception("Invalid logout_token")
+ raise SynapseError(400, "Invalid logout_token")
+
+ # As per sec. 2.6:
+ # 4. Verify that the Logout Token contains a sub Claim, a sid Claim,
+ # or both.
+ # 5. Verify that the Logout Token contains an events Claim whose
+ # value is JSON object containing the member name
+ # http://schemas.openid.net/event/backchannel-logout.
+ # 6. Verify that the Logout Token does not contain a nonce Claim.
+ # This is all verified by the LogoutToken claims class, so at this
+ # point the `sid` claim exists and is a string.
+ sid: str = claims.get("sid")
+
+ # If the `sub` claim was included in the logout token, we check that it matches
+ # that it matches the right user. We can have cases where the `sub` claim is not
+ # the ID saved in database, so we let admins disable this check in config.
+ sub: Optional[str] = claims.get("sub")
+ expected_user_id: Optional[str] = None
+ if sub is not None and not self._config.backchannel_logout_ignore_sub:
+ expected_user_id = await self._store.get_user_by_external_id(
+ self.idp_id, sub
+ )
+
+ # Invalidate any running user-mapping sessions, in-flight login tokens and
+ # active devices
+ await self._sso_handler.revoke_sessions_for_provider_session_id(
+ auth_provider_id=self.idp_id,
+ auth_provider_session_id=sid,
+ expected_user_id=expected_user_id,
+ )
+
+ request.setResponseCode(200)
+ request.setHeader(b"Cache-Control", b"no-cache, no-store")
+ request.setHeader(b"Pragma", b"no-cache")
+ finish_request(request)
+
+
+class LogoutToken(JWTClaims):
+ """
+ Holds and verify claims of a logout token, as per
+ https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
+ """
+
+ REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
+
+ def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
+ """Validate everything in claims payload."""
+ super().validate(now, leeway)
+ self.validate_sid()
+ self.validate_events()
+ self.validate_nonce()
+
+ def validate_sid(self) -> None:
+ """Ensure the sid claim is present"""
+ sid = self.get("sid")
+ if not sid:
+ raise MissingClaimError("sid")
+
+ if not isinstance(sid, str):
+ raise InvalidClaimError("sid")
+
+ def validate_nonce(self) -> None:
+ """Ensure the nonce claim is absent"""
+ if "nonce" in self:
+ raise InvalidClaimError("nonce")
+
+ def validate_events(self) -> None:
+ """Ensure the events claim is present and with the right value"""
+ events = self.get("events")
+ if not events:
+ raise MissingClaimError("events")
+
+ if not isinstance(events, dict):
+ raise InvalidClaimError("events")
+
+ if "http://schemas.openid.net/event/backchannel-logout" not in events:
+ raise InvalidClaimError("events")
+
# number of seconds a newly-generated client secret should be valid for
CLIENT_SECRET_VALIDITY_SECONDS = 3600
@@ -1105,6 +1423,7 @@ class JwtClientSecret:
logger.info(
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
)
+ jwt = JsonWebToken(header["alg"])
self._cached_secret = jwt.encode(header, payload, self._key.key)
self._cached_secret_replacement_time = (
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
@@ -1116,12 +1435,10 @@ class UserAttributeDict(TypedDict):
localpart: Optional[str]
confirm_localpart: bool
display_name: Optional[str]
+ picture: Optional[str] # may be omitted by older `OidcMappingProviders`
emails: List[str]
-C = TypeVar("C")
-
-
class OidcMappingProvider(Generic[C]):
"""A mapping provider maps a UserInfo object to user attributes.
@@ -1204,6 +1521,7 @@ env.filters.update(
@attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig:
subject_claim: str
+ picture_claim: str
localpart_template: Optional[Template]
display_name_template: Optional[Template]
email_template: Optional[Template]
@@ -1223,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@staticmethod
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
+ picture_claim = config.get("picture_claim", "picture")
def parse_template_config(option_name: str) -> Optional[Template]:
if option_name not in config:
@@ -1256,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
+ picture_claim=picture_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
email_template=email_template,
@@ -1295,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if email:
emails.append(email)
+ picture = userinfo.get("picture")
+
return UserAttributeDict(
localpart=localpart,
display_name=display_name,
emails=emails,
+ picture=picture,
confirm_localpart=self._config.confirm_localpart,
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1f83bab836..8c8ff18a1a 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -27,9 +27,9 @@ from synapse.handlers.room import ShutdownRoomResponse
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamKeyType
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -448,6 +448,12 @@ class PaginationHandler:
if pagin_config.from_token:
from_token = pagin_config.from_token
+ elif pagin_config.direction == "f":
+ from_token = (
+ await self.hs.get_event_sources().get_start_token_for_pagination(
+ room_id
+ )
+ )
else:
from_token = (
await self.hs.get_event_sources().get_current_token_for_pagination(
@@ -458,11 +464,6 @@ class PaginationHandler:
# `/messages` should still works with live tokens when manually provided.
assert from_token.room_key.topological is not None
- if pagin_config.limit is None:
- # This shouldn't happen as we've set a default limit before this
- # gets called.
- raise Exception("limit not set")
-
room_token = from_token.room_key
async with self.pagination_lock.read(room_id):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 4e575ffbaa..2af90b25a3 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -201,7 +201,7 @@ class BasePresenceHandler(abc.ABC):
"""Get the current presence state for multiple users.
Returns:
- dict: `user_id` -> `UserPresenceState`
+ A mapping of `user_id` -> `UserPresenceState`
"""
states = {}
missing = []
@@ -256,7 +256,7 @@ class BasePresenceHandler(abc.ABC):
with the app.
"""
- async def update_external_syncs_row(
+ async def update_external_syncs_row( # noqa: B027 (no-op by design)
self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
) -> None:
"""Update the syncing users for an external process as a delta.
@@ -272,7 +272,9 @@ class BasePresenceHandler(abc.ABC):
sync_time_msec: Time in ms when the user was last syncing
"""
- async def update_external_syncs_clear(self, process_id: str) -> None:
+ async def update_external_syncs_clear( # noqa: B027 (no-op by design)
+ self, process_id: str
+ ) -> None:
"""Marks all users that had been marked as syncing by a given process
as offline.
@@ -476,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
return _NullContextManager()
prev_state = await self.current_state_for_user(user_id)
- if prev_state != PresenceState.BUSY:
+ if prev_state.state != PresenceState.BUSY:
# We set state here but pass ignore_status_msg = True as we don't want to
# cause the status message to be cleared.
# Note that this causes last_active_ts to be incremented which is not
@@ -1596,7 +1598,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
self,
user: UserID,
from_key: Optional[int],
- limit: Optional[int] = None,
+ # Having a default limit doesn't match the EventSource API, but some
+ # callers do not provide it. It is unused in this class.
+ limit: int = 0,
room_ids: Optional[Collection[str]] = None,
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
@@ -1688,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
if from_key is not None:
# First get all users that have had a presence update
- updated_users = stream_change_cache.get_all_entities_changed(from_key)
+ result = stream_change_cache.get_all_entities_changed(from_key)
# Cross-reference users we're interested in with those that have had updates.
- if updated_users is not None:
+ if result.hit:
+ updated_users = result.entities
+
# If we have the full list of changes for presence we can
# simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
@@ -1760,14 +1766,14 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
Returns:
A list of presence states for the given user to receive.
"""
+ updated_users = None
if from_key:
# Only return updates since the last sync
- updated_users = self.store.presence_stream_cache.get_all_entities_changed(
- from_key
- )
- if not updated_users:
- updated_users = []
+ result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
+ if result.hit:
+ updated_users = result.entities
+ if updated_users is not None:
# Get the actual presence update for each change
users_to_state = await self.get_presence_handler().current_state_for_users(
updated_users
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index d8ff5289b5..4bf9a047a3 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -307,7 +307,11 @@ class ProfileHandler:
if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
return True
- server_name, _, media_id = parse_and_validate_mxc_uri(mxc)
+ host, port, media_id = parse_and_validate_mxc_uri(mxc)
+ if port is not None:
+ server_name = host + ":" + str(port)
+ else:
+ server_name = host
if server_name == self.server_name:
media_info = await self.store.get_local_media(media_id)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4a7ec9e426..6a4fed1156 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -92,7 +92,6 @@ class ReceiptsHandler:
continue
# Check if these receipts apply to a thread.
- thread_id = None
data = user_values.get("data", {})
thread_id = data.get("thread_id")
# If the thread ID is invalid, consider it missing.
@@ -257,7 +256,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ca1c7a1866..c611efb760 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
+from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
@@ -45,8 +46,8 @@ from synapse.replication.http.register import (
ReplicationRegisterServlet,
)
from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
+from synapse.types.state import StateFilter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -841,6 +842,9 @@ class RegistrationHandler:
refresh_token = None
refresh_token_id = None
+ # This can only run on the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
registered_device_id = await self.device_handler.check_device_registered(
user_id,
device_id,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 63bc6a7aa5..e96f9999a8 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,17 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
import attr
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import trace
-from synapse.storage.databases.main.relations import _RelatedEvent
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
+from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict, Requester, UserID
+from synapse.util.async_helpers import gather_results
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -31,6 +35,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class ThreadsListInclude(str, enum.Enum):
+ """Valid values for the 'include' flag of /threads."""
+
+ all = "all"
+ participated = "participated"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
@@ -66,19 +77,17 @@ class RelationsHandler:
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
+ self._event_creation_handler = hs.get_event_creation_handler()
async def get_relations(
self,
requester: Requester,
event_id: str,
room_id: str,
+ pagin_config: PaginationConfig,
+ include_original_event: bool,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[StreamToken] = None,
- to_token: Optional[StreamToken] = None,
- include_original_event: bool = False,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
@@ -88,14 +97,10 @@ class RelationsHandler:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
+ pagin_config: The pagination config rules to apply, if any.
+ include_original_event: Whether to include the parent event.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- limit: Only fetch the most recent `limit` events.
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
- include_original_event: Whether to include the parent event.
Returns:
The pagination chunk.
@@ -123,10 +128,10 @@ class RelationsHandler:
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
- limit=limit,
- direction=direction,
- from_token=from_token,
- to_token=to_token,
+ limit=pagin_config.limit,
+ direction=pagin_config.direction,
+ from_token=pagin_config.from_token,
+ to_token=pagin_config.to_token,
)
events = await self._main_store.get_events_as_list(
@@ -162,90 +167,167 @@ class RelationsHandler:
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
- if from_token:
- return_value["prev_batch"] = await from_token.to_string(self._main_store)
+ if pagin_config.from_token:
+ return_value["prev_batch"] = await pagin_config.from_token.to_string(
+ self._main_store
+ )
return return_value
- async def get_relations_for_event(
+ async def redact_events_related_to(
self,
+ requester: Requester,
event_id: str,
- event: EventBase,
- room_id: str,
- relation_type: str,
- ignored_users: FrozenSet[str] = frozenset(),
- ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
- """Get a list of events which relate to an event, ordered by topological ordering.
+ initial_redaction_event: EventBase,
+ relation_types: List[str],
+ ) -> None:
+ """Redacts all events related to the given event ID with one of the given
+ relation types.
- Args:
- event_id: Fetch events that relate to this event ID.
- event: The matching EventBase to event_id.
- room_id: The room the event belongs to.
- relation_type: The type of relation.
- ignored_users: The users ignored by the requesting user.
+ This method is expected to be called when redacting the event referred to by
+ the given event ID.
- Returns:
- List of event IDs that match relations requested. The rows are of
- the form `{"event_id": "..."}`.
- """
+ If an event cannot be redacted (e.g. because of insufficient permissions), log
+ the error and try to redact the next one.
- # Call the underlying storage method, which is cached.
- related_events, next_token = await self._main_store.get_relations_for_event(
- event_id, event, room_id, relation_type, direction="f"
+ Args:
+ requester: The requester to redact events on behalf of.
+ event_id: The event IDs to look and redact relations of.
+ initial_redaction_event: The redaction for the event referred to by
+ event_id.
+ relation_types: The types of relations to look for.
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned
+ """
+ related_event_ids = (
+ await self._main_store.get_all_relations_for_event_with_types(
+ event_id, relation_types
+ )
)
- # Filter out ignored users and convert to the expected format.
- related_events = [
- event for event in related_events if event.sender not in ignored_users
- ]
-
- return related_events, next_token
+ for related_event_id in related_event_ids:
+ try:
+ await self._event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.Redaction,
+ "content": initial_redaction_event.content,
+ "room_id": initial_redaction_event.room_id,
+ "sender": requester.user.to_string(),
+ "redacts": related_event_id,
+ },
+ ratelimit=False,
+ )
+ except SynapseError as e:
+ logger.warning(
+ "Failed to redact event %s (related to event %s): %s",
+ related_event_id,
+ event_id,
+ e.msg,
+ )
- async def get_annotations_for_event(
- self,
- event_id: str,
- room_id: str,
- limit: int = 5,
- ignored_users: FrozenSet[str] = frozenset(),
- ) -> List[JsonDict]:
- """Get a list of annotations on the event, grouped by event type and
+ async def get_annotations_for_events(
+ self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+ ) -> Dict[str, List[JsonDict]]:
+ """Get a list of annotations to the given events, grouped by event type and
aggregation key, sorted by count.
- This is used e.g. to get the what and how many reactions have happend
+ This is used e.g. to get the what and how many reactions have happened
on an event.
Args:
- event_id: Fetch events that relate to this event ID.
- room_id: The room the event belongs to.
- limit: Only fetch the `limit` groups.
+ event_ids: Fetch events that relate to these event IDs.
ignored_users: The users ignored by the requesting user.
Returns:
- List of groups of annotations that match. Each row is a dict with
- `type`, `key` and `count` fields.
+ A map of event IDs to a list of groups of annotations that match.
+ Each entry is a dict with `type`, `key` and `count` fields.
"""
# Get the base results for all users.
- full_results = await self._main_store.get_aggregation_groups_for_event(
- event_id, room_id, limit
+ full_results = await self._main_store.get_aggregation_groups_for_events(
+ event_ids
)
+ # Avoid additional logic if there are no ignored users.
+ if not ignored_users:
+ return {
+ event_id: results
+ for event_id, results in full_results.items()
+ if results
+ }
+
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_aggregation_groups_for_users(
- event_id, room_id, limit, ignored_users
+ [event_id for event_id, results in full_results.items() if results],
+ ignored_users,
)
- filtered_results = []
- for result in full_results:
- key = (result["type"], result["key"])
- if key in ignored_results:
- result = result.copy()
- result["count"] -= ignored_results[key]
- if result["count"] <= 0:
- continue
- filtered_results.append(result)
+ filtered_results = {}
+ for event_id, results in full_results.items():
+ # If no annotations, skip.
+ if not results:
+ continue
+
+ # If there are not ignored results for this event, copy verbatim.
+ if event_id not in ignored_results:
+ filtered_results[event_id] = results
+ continue
+
+ # Otherwise, subtract out the ignored results.
+ event_ignored_results = ignored_results[event_id]
+ for result in results:
+ key = (result["type"], result["key"])
+ if key in event_ignored_results:
+ # Ensure to not modify the cache.
+ result = result.copy()
+ result["count"] -= event_ignored_results[key]
+ if result["count"] <= 0:
+ continue
+ filtered_results.setdefault(event_id, []).append(result)
return filtered_results
+ async def get_references_for_events(
+ self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+ ) -> Dict[str, List[_RelatedEvent]]:
+ """Get a list of references to the given events.
+
+ Args:
+ event_ids: Fetch events that relate to this event ID.
+ ignored_users: The users ignored by the requesting user.
+
+ Returns:
+ A map of event IDs to a list related events.
+ """
+
+ related_events = await self._main_store.get_references_for_events(event_ids)
+
+ # Avoid additional logic if there are no ignored users.
+ if not ignored_users:
+ return {
+ event_id: results
+ for event_id, results in related_events.items()
+ if results
+ }
+
+ # Filter out ignored users.
+ results = {}
+ for event_id, events in related_events.items():
+ # If no references, skip.
+ if not events:
+ continue
+
+ # Filter ignored users out.
+ events = [event for event in events if event.sender not in ignored_users]
+ # If there are no events left, skip this event.
+ if not events:
+ continue
+
+ results[event_id] = events
+
+ return results
+
async def _get_threads_for_events(
self,
events_by_id: Dict[str, EventBase],
@@ -308,59 +390,66 @@ class RelationsHandler:
results = {}
for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event = summary
-
- # Subtract off the count of any ignored users.
- for ignored_user in ignored_users:
- thread_count -= ignored_results.get((event_id, ignored_user), 0)
-
- # This is gnarly, but if the latest event is from an ignored user,
- # attempt to find one that isn't from an ignored user.
- if latest_thread_event.sender in ignored_users:
- room_id = latest_thread_event.room_id
-
- # If the root event is not found, something went wrong, do
- # not include a summary of the thread.
- event = await self._event_handler.get_event(user, room_id, event_id)
- if event is None:
- continue
+ # If no thread, skip.
+ if not summary:
+ continue
- potential_events, _ = await self.get_relations_for_event(
- event_id,
- event,
- room_id,
- RelationTypes.THREAD,
- ignored_users,
- )
+ thread_count, latest_thread_event = summary
- # If all found events are from ignored users, do not include
- # a summary of the thread.
- if not potential_events:
- continue
+ # Subtract off the count of any ignored users.
+ for ignored_user in ignored_users:
+ thread_count -= ignored_results.get((event_id, ignored_user), 0)
- # The *last* event returned is the one that is cared about.
- event = await self._event_handler.get_event(
- user, room_id, potential_events[-1].event_id
- )
- # It is unexpected that the event will not exist.
- if event is None:
- logger.warning(
- "Unable to fetch latest event in a thread with event ID: %s",
- potential_events[-1].event_id,
- )
- continue
- latest_thread_event = event
-
- results[event_id] = _ThreadAggregation(
- latest_event=latest_thread_event,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=events_by_id[event_id].sender == user_id
- or participated[event_id],
+ # This is gnarly, but if the latest event is from an ignored user,
+ # attempt to find one that isn't from an ignored user.
+ if latest_thread_event.sender in ignored_users:
+ room_id = latest_thread_event.room_id
+
+ # If the root event is not found, something went wrong, do
+ # not include a summary of the thread.
+ event = await self._event_handler.get_event(user, room_id, event_id)
+ if event is None:
+ continue
+
+ # Attempt to find another event to use as the latest event.
+ potential_events, _ = await self._main_store.get_relations_for_event(
+ event_id, event, room_id, RelationTypes.THREAD, direction="f"
)
+ # Filter out ignored users.
+ potential_events = [
+ event
+ for event in potential_events
+ if event.sender not in ignored_users
+ ]
+
+ # If all found events are from ignored users, do not include
+ # a summary of the thread.
+ if not potential_events:
+ continue
+
+ # The *last* event returned is the one that is cared about.
+ event = await self._event_handler.get_event(
+ user, room_id, potential_events[-1].event_id
+ )
+ # It is unexpected that the event will not exist.
+ if event is None:
+ logger.warning(
+ "Unable to fetch latest event in a thread with event ID: %s",
+ potential_events[-1].event_id,
+ )
+ continue
+ latest_thread_event = event
+
+ results[event_id] = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=events_by_id[event_id].sender == user_id
+ or participated[event_id],
+ )
+
return results
@trace
@@ -438,48 +527,131 @@ class RelationsHandler:
# (as that is what makes it part of the thread).
relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
- # Fetch other relations per event.
- for event in events_by_id.values():
- # Fetch any annotations (ie, reactions) to bundle with this event.
- annotations = await self.get_annotations_for_event(
- event.event_id, event.room_id, ignored_users=ignored_users
+ async def _fetch_annotations() -> None:
+ """Fetch any annotations (ie, reactions) to bundle with this event."""
+ annotations_by_event_id = await self.get_annotations_for_events(
+ events_by_id.keys(), ignored_users=ignored_users
)
- if annotations:
- results.setdefault(
- event.event_id, BundledAggregations()
- ).annotations = {"chunk": annotations}
-
- # Fetch any references to bundle with this event.
- references, next_token = await self.get_relations_for_event(
- event.event_id,
- event,
- event.room_id,
- RelationTypes.REFERENCE,
- ignored_users=ignored_users,
+ for event_id, annotations in annotations_by_event_id.items():
+ if annotations:
+ results.setdefault(event_id, BundledAggregations()).annotations = {
+ "chunk": annotations
+ }
+
+ async def _fetch_references() -> None:
+ """Fetch any references to bundle with this event."""
+ references_by_event_id = await self.get_references_for_events(
+ events_by_id.keys(), ignored_users=ignored_users
+ )
+ for event_id, references in references_by_event_id.items():
+ if references:
+ results.setdefault(event_id, BundledAggregations()).references = {
+ "chunk": [{"event_id": ev.event_id} for ev in references]
+ }
+
+ async def _fetch_edits() -> None:
+ """
+ Fetch any edits (but not for redacted events).
+
+ Note that there is no use in limiting edits by ignored users since the
+ parent event should be ignored in the first place if the user is ignored.
+ """
+ edits = await self._main_store.get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
+ # Parallelize the calls for annotations, references, and edits since they
+ # are unrelated.
+ await make_deferred_yieldable(
+ gather_results(
+ (
+ run_in_background(_fetch_annotations),
+ run_in_background(_fetch_references),
+ run_in_background(_fetch_edits),
+ )
)
- if references:
- aggregations = results.setdefault(event.event_id, BundledAggregations())
- aggregations.references = {
- "chunk": [{"event_id": ev.event_id} for ev in references]
- }
-
- if next_token:
- aggregations.references["next_batch"] = await next_token.to_string(
- self._main_store
- )
-
- # Fetch any edits (but not for redacted events).
- #
- # Note that there is no use in limiting edits by ignored users since the
- # parent event should be ignored in the first place if the user is ignored.
- edits = await self._main_store.get_applicable_edits(
- [
- event_id
- for event_id, event in events_by_id.items()
- if not event.internal_metadata.is_redacted()
- ]
)
- for event_id, edit in edits.items():
- results.setdefault(event_id, BundledAggregations()).replace = edit
return results
+
+ async def get_threads(
+ self,
+ requester: Requester,
+ room_id: str,
+ include: ThreadsListInclude,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ Args:
+ requester: The user requesting the relations.
+ room_id: The room the event belongs to.
+ include: One of "all" or "participated" to indicate which threads should
+ be returned.
+ limit: Only fetch the most recent `limit` events.
+ from_token: Fetch rows from the given token, or from the start if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ # TODO Properly handle a user leaving a room.
+ (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+ room_id, requester, allow_departed_users=True
+ )
+
+ # Note that ignored users are not passed into get_threads
+ # below. Ignored users are handled in filter_events_for_client (and by
+ # not passing them in here we should get a better cache hit rate).
+ thread_roots, next_batch = await self._main_store.get_threads(
+ room_id=room_id, limit=limit, from_token=from_token
+ )
+
+ events = await self._main_store.get_events_as_list(thread_roots)
+
+ if include == ThreadsListInclude.participated:
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {event.event_id: event.sender == user_id for event in events}
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [eid for eid, p in participated.items() if not p],
+ user_id,
+ )
+ )
+
+ # Limit the returned threads to those the user has participated in.
+ events = [event for event in events if participated[event.event_id]]
+
+ events = await filter_events_for_client(
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+
+ now = self._clock.time_msec()
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value: JsonDict = {"chunk": serialized_events}
+
+ if next_batch:
+ return_value["next_batch"] = str(next_batch)
+
+ return return_value
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 57ab05ad25..f81241c2b3 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -49,7 +49,6 @@ from synapse.api.constants import (
from synapse.api.errors import (
AuthError,
Codes,
- HttpResponseException,
LimitExceededError,
NotFoundError,
StoreError,
@@ -60,11 +59,9 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.utils import copy_and_fixup_power_levels_contents
-from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
@@ -79,6 +76,7 @@ from synapse.types import (
UserID,
create_requester,
)
+from synapse.types.state import StateFilter
from synapse.util import stringutils
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name
@@ -229,9 +227,7 @@ class RoomCreationHandler:
},
)
validate_event_for_room_version(tombstone_event)
- await self._event_auth_handler.check_auth_rules_from_context(
- tombstone_event, tombstone_context
- )
+ await self._event_auth_handler.check_auth_rules_from_context(tombstone_event)
# Upgrade the room
#
@@ -561,7 +557,6 @@ class RoomCreationHandler:
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
- ratelimit=False,
)
# Transfer membership events
@@ -755,6 +750,10 @@ class RoomCreationHandler:
)
if ratelimit:
+ # Rate limit once in advance, but don't rate limit the individual
+ # events in the room — room creation isn't atomic and it's very
+ # janky if half the events in the initial state don't make it because
+ # of rate limiting.
await self.request_ratelimiter.ratelimit(requester)
room_version_id = config.get(
@@ -915,7 +914,6 @@ class RoomCreationHandler:
room_alias=room_alias,
power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
- ratelimit=ratelimit,
)
if "name" in config:
@@ -1039,7 +1037,6 @@ class RoomCreationHandler:
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
- ratelimit: bool = True,
) -> Tuple[int, str, int]:
"""Sends the initial events into a new room. Sends the room creation, membership,
and power level events into the room sequentially, then creates and batches up the
@@ -1048,6 +1045,8 @@ class RoomCreationHandler:
`power_level_content_override` doesn't apply when initial state has
power level state event content.
+ Rate limiting should already have been applied by this point.
+
Returns:
A tuple containing the stream ID, event ID and depth of the last
event sent to the room.
@@ -1057,9 +1056,6 @@ class RoomCreationHandler:
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
depth = 1
- # the last event sent/persisted to the db
- last_sent_event_id: Optional[str] = None
-
# the most recently created event
prev_event: List[str] = []
# a map of event types, state keys -> event_ids. We collect these mappings this as events are
@@ -1084,6 +1080,19 @@ class RoomCreationHandler:
for_batch: bool,
**kwargs: Any,
) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+ """
+ Creates an event and associated event context.
+ Args:
+ etype: the type of event to be created
+ content: content of the event
+ for_batch: whether the event is being created for batch persisting. If
+ bool for_batch is true, this will create an event using the prev_event_ids,
+ and will create an event context for the event using the parameters state_map
+ and current_state_group, thus these parameters must be provided in this
+ case if for_batch is True. The subsequently created event and context
+ are suitable for being batched up and bulk persisted to the database
+ with other similarly created events.
+ """
nonlocal depth
nonlocal prev_event
@@ -1104,26 +1113,6 @@ class RoomCreationHandler:
return new_event, new_context
- async def send(
- event: EventBase,
- context: synapse.events.snapshot.EventContext,
- creator: Requester,
- ) -> int:
- nonlocal last_sent_event_id
-
- ev = await self.event_creation_handler.handle_new_client_event(
- requester=creator,
- events_and_context=[(event, context)],
- ratelimit=False,
- ignore_shadow_ban=True,
- )
-
- last_sent_event_id = ev.event_id
-
- # we know it was persisted, so must have a stream ordering
- assert ev.internal_metadata.stream_ordering
- return ev.internal_metadata.stream_ordering
-
try:
config = self._presets_dict[preset_config]
except KeyError:
@@ -1137,16 +1126,20 @@ class RoomCreationHandler:
)
logger.debug("Sending %s in new room", EventTypes.Member)
- await send(creation_event, creation_context, creator)
+ ev = await self.event_creation_handler.handle_new_client_event(
+ requester=creator,
+ events_and_context=[(creation_event, creation_context)],
+ ratelimit=False,
+ ignore_shadow_ban=True,
+ )
+ last_sent_event_id = ev.event_id
- # Room create event must exist at this point
- assert last_sent_event_id is not None
member_event_id, _ = await self.room_member_handler.update_membership(
creator,
creator.user,
room_id,
"join",
- ratelimit=ratelimit,
+ ratelimit=False,
content=creator_join_profile,
new_room=True,
prev_event_ids=[last_sent_event_id],
@@ -1159,15 +1152,24 @@ class RoomCreationHandler:
depth += 1
state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id
+ # we need the state group of the membership event as it is the current state group
+ event_to_state = (
+ await self._storage_controllers.state.get_state_group_for_events(
+ [member_event_id]
+ )
+ )
+ current_state_group = event_to_state[member_event_id]
+
+ events_to_send = []
# We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
power_event, power_context = await create_event(
- EventTypes.PowerLevels, pl_content, False
+ EventTypes.PowerLevels, pl_content, True
)
current_state_group = power_context._state_group
- await send(power_event, power_context, creator)
+ events_to_send.append((power_event, power_context))
else:
power_level_content: JsonDict = {
"users": {creator_id: 100},
@@ -1213,12 +1215,11 @@ class RoomCreationHandler:
pl_event, pl_context = await create_event(
EventTypes.PowerLevels,
power_level_content,
- False,
+ True,
)
current_state_group = pl_context._state_group
- await send(pl_event, pl_context, creator)
+ events_to_send.append((pl_event, pl_context))
- events_to_send = []
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
@@ -1271,7 +1272,10 @@ class RoomCreationHandler:
events_to_send.append((encryption_event, encryption_context))
last_event = await self.event_creation_handler.handle_new_client_event(
- creator, events_to_send, ignore_shadow_ban=True
+ creator,
+ events_to_send,
+ ignore_shadow_ban=True,
+ ratelimit=False,
)
assert last_event.internal_metadata.stream_ordering is not None
return last_event.internal_metadata.stream_ordering, last_event.event_id, depth
@@ -1447,7 +1451,7 @@ class RoomContextHandler:
events_before=events_before,
event=event,
events_after=events_after,
- state=await filter_evts(state_events),
+ state=state_events,
aggregations=aggregations,
start=await token.copy_and_replace(
StreamKeyType.ROOM, results.start
@@ -1493,7 +1497,12 @@ class TimestampLookupHandler:
Raises:
SynapseError if unable to find any event locally in the given direction
"""
-
+ logger.debug(
+ "get_event_for_timestamp(room_id=%s, timestamp=%s, direction=%s) Finding closest event...",
+ room_id,
+ timestamp,
+ direction,
+ )
local_event_id = await self.store.get_event_id_for_timestamp(
room_id, timestamp, direction
)
@@ -1545,85 +1554,54 @@ class TimestampLookupHandler:
)
)
- # Loop through each homeserver candidate until we get a succesful response
- for domain in likely_domains:
- # We don't want to ask our own server for information we don't have
- if domain == self.server_name:
- continue
+ remote_response = await self.federation_client.timestamp_to_event(
+ destinations=likely_domains,
+ room_id=room_id,
+ timestamp=timestamp,
+ direction=direction,
+ )
+ if remote_response is not None:
+ logger.debug(
+ "get_event_for_timestamp: remote_response=%s",
+ remote_response,
+ )
- try:
- remote_response = await self.federation_client.timestamp_to_event(
- domain, room_id, timestamp, direction
- )
- logger.debug(
- "get_event_for_timestamp: response from domain(%s)=%s",
- domain,
- remote_response,
- )
+ remote_event_id = remote_response.event_id
+ remote_origin_server_ts = remote_response.origin_server_ts
- remote_event_id = remote_response.event_id
- remote_origin_server_ts = remote_response.origin_server_ts
-
- # Backfill this event so we can get a pagination token for
- # it with `/context` and paginate `/messages` from this
- # point.
- #
- # TODO: The requested timestamp may lie in a part of the
- # event graph that the remote server *also* didn't have,
- # in which case they will have returned another event
- # which may be nowhere near the requested timestamp. In
- # the future, we may need to reconcile that gap and ask
- # other homeservers, and/or extend `/timestamp_to_event`
- # to return events on *both* sides of the timestamp to
- # help reconcile the gap faster.
- remote_event = (
- await self.federation_event_handler.backfill_event_id(
- domain, room_id, remote_event_id
- )
- )
+ # Backfill this event so we can get a pagination token for
+ # it with `/context` and paginate `/messages` from this
+ # point.
+ pulled_pdu_info = await self.federation_event_handler.backfill_event_id(
+ likely_domains, room_id, remote_event_id
+ )
+ remote_event = pulled_pdu_info.pdu
- # XXX: When we see that the remote server is not trustworthy,
- # maybe we should not ask them first in the future.
- if remote_origin_server_ts != remote_event.origin_server_ts:
- logger.info(
- "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
- domain,
- remote_event_id,
- remote_origin_server_ts,
- remote_event.origin_server_ts,
- )
-
- # Only return the remote event if it's closer than the local event
- if not local_event or (
- abs(remote_event.origin_server_ts - timestamp)
- < abs(local_event.origin_server_ts - timestamp)
- ):
- logger.info(
- "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
- remote_event_id,
- remote_event.origin_server_ts,
- timestamp,
- local_event.event_id if local_event else None,
- local_event.origin_server_ts if local_event else None,
- )
- return remote_event_id, remote_origin_server_ts
- except (HttpResponseException, InvalidResponseError) as ex:
- # Let's not put a high priority on some other homeserver
- # failing to respond or giving a random response
- logger.debug(
- "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
- domain,
- type(ex).__name__,
- ex,
- ex.args,
+ # XXX: When we see that the remote server is not trustworthy,
+ # maybe we should not ask them first in the future.
+ if remote_origin_server_ts != remote_event.origin_server_ts:
+ logger.info(
+ "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
+ pulled_pdu_info.pull_origin,
+ remote_event_id,
+ remote_origin_server_ts,
+ remote_event.origin_server_ts,
)
- except Exception:
- # But we do want to see some exceptions in our code
- logger.warning(
- "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
- domain,
- exc_info=True,
+
+ # Only return the remote event if it's closer than the local event
+ if not local_event or (
+ abs(remote_event.origin_server_ts - timestamp)
+ < abs(local_event.origin_server_ts - timestamp)
+ ):
+ logger.info(
+ "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
+ remote_event_id,
+ remote_event.origin_server_ts,
+ timestamp,
+ local_event.event_id if local_event else None,
+ local_event.origin_server_ts if local_event else None,
)
+ return remote_event_id, remote_origin_server_ts
# To appease mypy, we have to add both of these conditions to check for
# `None`. We only expect `local_event` to be `None` when
@@ -1646,7 +1624,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
self,
user: UserID,
from_key: RoomStreamToken,
- limit: Optional[int],
+ limit: int,
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6ad2b38b8f..0c39e852a1 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -34,7 +34,6 @@ from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.logging import opentracing
from synapse.module_api import NOT_SPAM
-from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
Requester,
@@ -45,6 +44,7 @@ from synapse.types import (
create_requester,
get_domain_from_id,
)
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 9602f0d0bb..874860d461 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -441,7 +441,7 @@ class DefaultSamlMappingProvider:
client_redirect_url: where the client wants to redirect to
Returns:
- dict: A dict containing new user attributes. Possible keys:
+ A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
* emails (list[str]): Any emails for the user
@@ -483,7 +483,7 @@ class DefaultSamlMappingProvider:
Args:
config: A dictionary containing configuration options for this provider
Returns:
- SamlConfig: A custom config object for this module
+ A custom config object for this module
"""
# Parse config options and use defaults where necessary
mxid_source_attribute = config.get("mxid_source_attribute", "uid")
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index bcab98c6d5..33115ce488 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -23,8 +23,8 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
-from synapse.storage.state import StateFilter
from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types.state import StateFilter
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 73861bbd40..bd9d0bb34b 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.types import Requester
if TYPE_CHECKING:
@@ -29,7 +30,10 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ # This can only be instantiated on the main process.
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler = device_handler
async def set_password(
self,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e035677b8a..44e70fc4b8 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import hashlib
+import io
import logging
from typing import (
TYPE_CHECKING,
@@ -37,6 +39,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
@@ -137,6 +140,7 @@ class UserAttributes:
localpart: Optional[str]
confirm_localpart: bool = False
display_name: Optional[str] = None
+ picture: Optional[str] = None
emails: Collection[str] = attr.Factory(list)
@@ -191,9 +195,14 @@ class SsoHandler:
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
self._error_template = hs.config.sso.sso_error_template
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
self._profile_handler = hs.get_profile_handler()
+ self._media_repo = (
+ hs.get_media_repository() if hs.config.media.can_load_media_repo else None
+ )
+ self._http_client = hs.get_proxied_blacklisted_http_client()
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
@@ -493,6 +502,8 @@ class SsoHandler:
await self._profile_handler.set_displayname(
user_id_obj, requester, attributes.display_name, True
)
+ if attributes.picture:
+ await self.set_avatar(user_id, attributes.picture)
await self._auth_handler.complete_sso_login(
user_id,
@@ -701,8 +712,110 @@ class SsoHandler:
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
+
+ # Set avatar, if available
+ if attributes.picture:
+ await self.set_avatar(registered_user_id, attributes.picture)
+
return registered_user_id
+ async def set_avatar(self, user_id: str, picture_https_url: str) -> bool:
+ """Set avatar of the user.
+
+ This downloads the image file from the URL provided, stores that in
+ the media repository and then sets the avatar on the user's profile.
+
+ It can detect if the same image is being saved again and bails early by storing
+ the hash of the file in the `upload_name` of the avatar image.
+
+ Currently, it only supports server configurations which run the media repository
+ within the same process.
+
+ It silently fails and logs a warning by raising an exception and catching it
+ internally if:
+ * it is unable to fetch the image itself (non 200 status code) or
+ * the image supplied is bigger than max allowed size or
+ * the image type is not one of the allowed image types.
+
+ Args:
+ user_id: matrix user ID in the form @localpart:domain as a string.
+
+ picture_https_url: HTTPS url for the picture image file.
+
+ Returns: `True` if the user's avatar has been successfully set to the image at
+ `picture_https_url`.
+ """
+ if self._media_repo is None:
+ logger.info(
+ "failed to set user avatar because out-of-process media repositories "
+ "are not supported yet "
+ )
+ return False
+
+ try:
+ uid = UserID.from_string(user_id)
+
+ def is_allowed_mime_type(content_type: str) -> bool:
+ if (
+ self._profile_handler.allowed_avatar_mimetypes
+ and content_type
+ not in self._profile_handler.allowed_avatar_mimetypes
+ ):
+ return False
+ return True
+
+ # download picture, enforcing size limit & mime type check
+ picture = io.BytesIO()
+
+ content_length, headers, uri, code = await self._http_client.get_file(
+ url=picture_https_url,
+ output_stream=picture,
+ max_size=self._profile_handler.max_avatar_size,
+ is_allowed_content_type=is_allowed_mime_type,
+ )
+
+ if code != 200:
+ raise Exception(
+ "GET request to download sso avatar image returned {}".format(code)
+ )
+
+ # upload name includes hash of the image file's content so that we can
+ # easily check if it requires an update or not, the next time user logs in
+ upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest()
+
+ # bail if user already has the same avatar
+ profile = await self._profile_handler.get_profile(user_id)
+ if profile["avatar_url"] is not None:
+ server_name = profile["avatar_url"].split("/")[-2]
+ media_id = profile["avatar_url"].split("/")[-1]
+ if server_name == self._server_name:
+ media = await self._media_repo.store.get_local_media(media_id)
+ if media is not None and upload_name == media["upload_name"]:
+ logger.info("skipping saving the user avatar")
+ return True
+
+ # store it in media repository
+ avatar_mxc_url = await self._media_repo.create_content(
+ media_type=headers[b"Content-Type"][0].decode("utf-8"),
+ upload_name=upload_name,
+ content=picture,
+ content_length=content_length,
+ auth_user=uid,
+ )
+
+ # save it as user avatar
+ await self._profile_handler.set_avatar_url(
+ uid,
+ create_requester(uid),
+ str(avatar_mxc_url),
+ )
+
+ logger.info("successfully saved the user avatar")
+ return True
+ except Exception:
+ logger.warning("failed to save the user avatar")
+ return False
+
async def complete_sso_ui_auth_request(
self,
auth_provider_id: str,
@@ -874,7 +987,7 @@ class SsoHandler:
)
async def handle_terms_accepted(
- self, request: Request, session_id: str, terms_version: str
+ self, request: SynapseRequest, session_id: str, terms_version: str
) -> None:
"""Handle a request to the new-user 'consent' endpoint
@@ -1026,6 +1139,84 @@ class SsoHandler:
return True
+ async def revoke_sessions_for_provider_session_id(
+ self,
+ auth_provider_id: str,
+ auth_provider_session_id: str,
+ expected_user_id: Optional[str] = None,
+ ) -> None:
+ """Revoke any devices and in-flight logins tied to a provider session.
+
+ Can only be called from the main process.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ auth_provider_session_id: The session ID from the provider to logout
+ expected_user_id: The user we're expecting to logout. If set, it will ignore
+ sessions belonging to other users and log an error.
+ """
+
+ # It is expected that this is the main process.
+ assert isinstance(
+ self._device_handler, DeviceHandler
+ ), "revoking SSO sessions can only be called on the main process"
+
+ # Invalidate any running user-mapping sessions
+ to_delete = []
+ for session_id, session in self._username_mapping_sessions.items():
+ if (
+ session.auth_provider_id == auth_provider_id
+ and session.auth_provider_session_id == auth_provider_session_id
+ ):
+ to_delete.append(session_id)
+
+ for session_id in to_delete:
+ logger.info("Revoking mapping session %s", session_id)
+ del self._username_mapping_sessions[session_id]
+
+ # Invalidate any in-flight login tokens
+ await self._store.invalidate_login_tokens_by_session_id(
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+
+ # Fetch any device(s) in the store associated with the session ID.
+ devices = await self._store.get_devices_by_auth_provider_session_id(
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+
+ # We have no guarantee that all the devices of that session are for the same
+ # `user_id`. Hence, we have to iterate over the list of devices and log them out
+ # one by one.
+ for device in devices:
+ user_id = device["user_id"]
+ device_id = device["device_id"]
+
+ # If the user_id associated with that device/session is not the one we got
+ # out of the `sub` claim, skip that device and show log an error.
+ if expected_user_id is not None and user_id != expected_user_id:
+ logger.error(
+ "Received a logout notification from SSO provider "
+ f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
+ f"a session ID ({auth_provider_session_id!r}) which belongs to "
+ f"{user_id!r}. This may happen when the SSO provider user mapper "
+ "uses something else than the standard attribute as mapping ID. "
+ "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
+ "in the provider config if that is the case."
+ )
+ continue
+
+ logger.info(
+ "Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
+ user_id,
+ device_id,
+ auth_provider_id,
+ auth_provider_session_id,
+ )
+ await self._device_handler.delete_devices(user_id, [device_id])
+
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1db5d68021..7d6a653747 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,18 +31,24 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
-from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
+from synapse.logging.opentracing import (
+ SynapseTags,
+ log_kv,
+ set_tag,
+ start_active_span,
+ trace,
+)
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
+from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
-from synapse.storage.state import StateFilter
from synapse.types import (
DeviceListUpdates,
JsonDict,
@@ -54,6 +60,7 @@ from synapse.types import (
StreamToken,
UserID,
)
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
@@ -805,18 +812,6 @@ class SyncHandler:
if canonical_alias and canonical_alias.content.get("alias"):
return summary
- me = sync_config.user.to_string()
-
- joined_user_ids = [
- r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
- ]
- invited_user_ids = [
- r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
- ]
- gone_user_ids = [
- r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
- ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
-
# FIXME: only build up a member_ids list for our heroes
member_ids = {}
for membership in (
@@ -828,11 +823,8 @@ class SyncHandler:
for user_id, event_id in details.get(membership, empty_ms).members:
member_ids[user_id] = event_id
- # FIXME: order by stream ordering rather than as returned by SQL
- if joined_user_ids or invited_user_ids:
- summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5]
- else:
- summary["m.heroes"] = sorted(gone_user_ids)[0:5]
+ me = sync_config.user.to_string()
+ summary["m.heroes"] = extract_heroes_from_room_summary(details, me)
if not sync_config.filter_collection.lazy_load_members():
return summary
@@ -1440,14 +1432,14 @@ class SyncHandler:
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
- one_time_key_counts: JsonDict = {}
+ one_time_keys_count: JsonDict = {}
unused_fallback_key_types: List[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
- one_time_key_counts = await self.store.count_e2e_one_time_keys(
+ one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = (
@@ -1477,7 +1469,7 @@ class SyncHandler:
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
- device_one_time_keys_count=one_time_key_counts,
+ device_one_time_keys_count=one_time_keys_count,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
@@ -1542,10 +1534,12 @@ class SyncHandler:
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
- changed_users = self.store.get_cached_device_list_changes(
+ cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
- if changed_users is not None:
+ if cache_result.hit:
+ changed_users = cache_result.entities
+
result = await self.store.get_rooms_for_users(changed_users)
for changed_user_id, entries in result.items():
@@ -1598,6 +1592,7 @@ class SyncHandler:
else:
return DeviceListUpdates()
+ @trace
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"
) -> None:
@@ -1617,11 +1612,16 @@ class SyncHandler:
)
for message in messages:
- # We pop here as we shouldn't be sending the message ID down
- # `/sync`
- message_id = message.pop("message_id", None)
- if message_id:
- set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+ log_kv(
+ {
+ "event": "to_device_message",
+ "sender": message["sender"],
+ "type": message["type"],
+ EventContentFields.TO_DEVICE_MSGID: message["content"].get(
+ EventContentFields.TO_DEVICE_MSGID
+ ),
+ }
+ )
logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)",
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index f953691669..3f656ea4f5 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if last_id == current_id:
return [], current_id, False
- changed_rooms: Optional[
- Iterable[str]
- ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
+ result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
- if changed_rooms is None:
+ if result.hit:
+ changed_rooms: Iterable[str] = result.entities
+ else:
changed_rooms = self._room_serials
rows = []
@@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
|