diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py
index 89e944bc17..37cc3d3ff5 100644
--- a/synapse/handlers/account.py
+++ b/synapse/handlers/account.py
@@ -118,10 +118,10 @@ class AccountHandler:
}
if self._use_account_validity_in_account_status:
- status["org.matrix.expired"] = (
- await self._account_validity_handler.is_user_expired(
- user_id.to_string()
- )
+ status[
+ "org.matrix.expired"
+ ] = await self._account_validity_handler.is_user_expired(
+ user_id.to_string()
)
return status
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 97a463d8d0..228132db48 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -33,7 +33,7 @@ from synapse.replication.http.account_data import (
ReplicationRemoveUserAccountDataRestServlet,
)
from synapse.streams import EventSource
-from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
+from synapse.types import JsonDict, JsonMapping, StrCollection, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -253,7 +253,7 @@ class AccountDataHandler:
return response["max_stream_id"]
async def add_tag_to_room(
- self, user_id: str, room_id: str, tag: str, content: JsonDict
+ self, user_id: str, room_id: str, tag: str, content: JsonMapping
) -> int:
"""Add a tag to a room for a user.
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 7004d95a0f..e40ca3e73f 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,6 @@
#
#
-import email.mime.multipart
-import email.utils
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
@@ -40,18 +38,13 @@ class AccountValidityHandler:
self.hs = hs
self.config = hs.config
self.store = hs.get_datastores().main
- self.send_email_handler = hs.get_send_email_handler()
self.clock = hs.get_clock()
- self._app_name = hs.config.email.email_app_name
self._module_api_callbacks = hs.get_module_api_callbacks().account_validity
self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
)
- self._account_validity_renew_by_email_enabled = (
- hs.config.account_validity.account_validity_renew_by_email_enabled
- )
self._account_validity_period = None
if self._account_validity_enabled:
@@ -59,21 +52,6 @@ class AccountValidityHandler:
hs.config.account_validity.account_validity_period
)
- if (
- self._account_validity_enabled
- and self._account_validity_renew_by_email_enabled
- ):
- # Don't do email-specific configuration if renewal by email is disabled.
- self._template_html = hs.config.email.account_validity_template_html
- self._template_text = hs.config.email.account_validity_template_text
- self._renew_email_subject = (
- hs.config.account_validity.account_validity_renew_email_subject
- )
-
- # Check the renewal emails to send and send them every 30min.
- if hs.config.worker.run_background_tasks:
- self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
-
async def is_user_expired(self, user_id: str) -> bool:
"""Checks if a user has expired against third-party modules.
@@ -120,125 +98,6 @@ class AccountValidityHandler:
for callback in self._module_api_callbacks.on_user_login_callbacks:
await callback(user_id, auth_provider_type, auth_provider_id)
- @wrap_as_background_process("send_renewals")
- async def _send_renewal_emails(self) -> None:
- """Gets the list of users whose account is expiring in the amount of time
- configured in the ``renew_at`` parameter from the ``account_validity``
- configuration, and sends renewal emails to all of these users as long as they
- have an email 3PID attached to their account.
- """
- expiring_users = await self.store.get_users_expiring_soon()
-
- if expiring_users:
- for user_id, expiration_ts_ms in expiring_users:
- await self._send_renewal_email(
- user_id=user_id, expiration_ts=expiration_ts_ms
- )
-
- async def send_renewal_email_to_user(self, user_id: str) -> None:
- """
- Send a renewal email for a specific user.
-
- Args:
- user_id: The user ID to send a renewal email for.
-
- Raises:
- SynapseError if the user is not set to renew.
- """
- # If a module supports sending a renewal email from here, do that, otherwise do
- # the legacy dance.
- if self._module_api_callbacks.on_legacy_send_mail_callback is not None:
- await self._module_api_callbacks.on_legacy_send_mail_callback(user_id)
- return
-
- if not self._account_validity_renew_by_email_enabled:
- raise AuthError(
- 403, "Account renewal via email is disabled on this server."
- )
-
- expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
-
- # If this user isn't set to be expired, raise an error.
- if expiration_ts is None:
- raise SynapseError(400, "User has no expiration time: %s" % (user_id,))
-
- await self._send_renewal_email(user_id, expiration_ts)
-
- async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None:
- """Sends out a renewal email to every email address attached to the given user
- with a unique link allowing them to renew their account.
-
- Args:
- user_id: ID of the user to send email(s) to.
- expiration_ts: Timestamp in milliseconds for the expiration date of
- this user's account (used in the email templates).
- """
- addresses = await self._get_email_addresses_for_user(user_id)
-
- # Stop right here if the user doesn't have at least one email address.
- # In this case, they will have to ask their server admin to renew their
- # account manually.
- # We don't need to do a specific check to make sure the account isn't
- # deactivated, as a deactivated account isn't supposed to have any
- # email address attached to it.
- if not addresses:
- return
-
- try:
- user_display_name = await self.store.get_profile_displayname(
- UserID.from_string(user_id)
- )
- if user_display_name is None:
- user_display_name = user_id
- except StoreError:
- user_display_name = user_id
-
- renewal_token = await self._get_renewal_token(user_id)
- url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
- self.hs.config.server.public_baseurl,
- renewal_token,
- )
-
- template_vars = {
- "display_name": user_display_name,
- "expiration_ts": expiration_ts,
- "url": url,
- }
-
- html_text = self._template_html.render(**template_vars)
- plain_text = self._template_text.render(**template_vars)
-
- for address in addresses:
- raw_to = email.utils.parseaddr(address)[1]
-
- await self.send_email_handler.send_email(
- email_address=raw_to,
- subject=self._renew_email_subject,
- app_name=self._app_name,
- html=html_text,
- text=plain_text,
- )
-
- await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
-
- async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
- """Retrieve the list of email addresses attached to a user's account.
-
- Args:
- user_id: ID of the user to lookup email addresses for.
-
- Returns:
- Email addresses for this account.
- """
- threepids = await self.store.user_get_threepids(user_id)
-
- addresses = []
- for threepid in threepids:
- if threepid.medium == "email":
- addresses.append(threepid.address)
-
- return addresses
-
async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index b44e862493..5467d129bd 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -21,13 +21,34 @@
import abc
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
import attr
-from synapse.api.constants import Direction, Membership
+from synapse.api.constants import Direction, EventTypes, Membership
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
-from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
+from synapse.types import (
+ JsonMapping,
+ Requester,
+ RoomStreamToken,
+ ScheduledTask,
+ StateMap,
+ TaskStatus,
+ UserID,
+ UserInfo,
+ create_requester,
+)
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -35,6 +56,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+REDACT_ALL_EVENTS_ACTION_NAME = "redact_all_events"
+
class AdminHandler:
def __init__(self, hs: "HomeServer"):
@@ -43,6 +66,22 @@ class AdminHandler:
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self._task_scheduler = hs.get_task_scheduler()
+
+ self._task_scheduler.register_action(
+ self._redact_all_events, REDACT_ALL_EVENTS_ACTION_NAME
+ )
+
+ self.hs = hs
+
+ async def get_redact_task(self, redact_id: str) -> Optional[ScheduledTask]:
+ """Get the current status of an active redaction process
+
+ Args:
+ redact_id: redact_id returned by start_redact_events.
+ """
+ return await self._task_scheduler.get_task(redact_id)
async def get_whois(self, user: UserID) -> JsonMapping:
connections = []
@@ -85,6 +124,7 @@ class AdminHandler:
"consent_ts": user_info.consent_ts,
"user_type": user_info.user_type,
"is_guest": user_info.is_guest,
+ "suspended": user_info.suspended,
}
if self._msc3866_enabled:
@@ -93,7 +133,6 @@ class AdminHandler:
# Add additional user metadata
profile = await self._store.get_profileinfo(user)
- threepids = await self._store.user_get_threepids(user.to_string())
external_ids = [
({"auth_provider": auth_provider, "external_id": external_id})
for auth_provider, external_id in await self._store.get_external_ids_by_user(
@@ -102,7 +141,6 @@ class AdminHandler:
]
user_info_dict["displayname"] = profile.display_name
user_info_dict["avatar_url"] = profile.avatar_url
- user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
@@ -197,14 +235,16 @@ class AdminHandler:
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything.
while True:
- events, _ = (
- await self._store.paginate_room_events_by_topological_ordering(
- room_id=room_id,
- from_key=from_key,
- to_key=to_key,
- limit=100,
- direction=Direction.FORWARDS,
- )
+ (
+ events,
+ _,
+ _,
+ ) = await self._store.paginate_room_events_by_topological_ordering(
+ room_id=room_id,
+ from_key=from_key,
+ to_key=to_key,
+ limit=100,
+ direction=Direction.FORWARDS,
)
if not events:
break
@@ -311,6 +351,155 @@ class AdminHandler:
return writer.finished()
+ async def start_redact_events(
+ self,
+ user_id: str,
+ rooms: list,
+ requester: JsonMapping,
+ reason: Optional[str],
+ limit: Optional[int],
+ ) -> str:
+ """
+ Start a task redacting the events of the given user in the given rooms
+
+ Args:
+ user_id: the user ID of the user whose events should be redacted
+ rooms: the rooms in which to redact the user's events
+ requester: the user requesting the events
+ reason: reason for requesting the redaction, ie spam, etc
+ limit: limit on the number of events in each room to redact
+
+ Returns:
+ a unique ID which can be used to query the status of the task
+ """
+ active_tasks = await self._task_scheduler.get_tasks(
+ actions=[REDACT_ALL_EVENTS_ACTION_NAME],
+ resource_id=user_id,
+ statuses=[TaskStatus.ACTIVE],
+ )
+
+ if len(active_tasks) > 0:
+ raise SynapseError(
+ 400, "Redact already in progress for user %s" % (user_id,)
+ )
+
+ if not limit:
+ limit = 1000
+
+ redact_id = await self._task_scheduler.schedule_task(
+ REDACT_ALL_EVENTS_ACTION_NAME,
+ resource_id=user_id,
+ params={
+ "rooms": rooms,
+ "requester": requester,
+ "user_id": user_id,
+ "reason": reason,
+ "limit": limit,
+ },
+ )
+
+ logger.info(
+ "starting redact events with redact_id %s",
+ redact_id,
+ )
+
+ return redact_id
+
+ async def _redact_all_events(
+ self, task: ScheduledTask
+ ) -> Tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]:
+ """
+ Task to redact all of a users events in the given rooms, tracking which, if any, events
+ whose redaction failed
+ """
+
+ assert task.params is not None
+ rooms = task.params.get("rooms")
+ assert rooms is not None
+
+ r = task.params.get("requester")
+ assert r is not None
+ admin = Requester.deserialize(self._store, r)
+
+ user_id = task.params.get("user_id")
+ assert user_id is not None
+
+ # puppet the user if they're ours, otherwise use admin to redact
+ requester = create_requester(
+ user_id if self.hs.is_mine_id(user_id) else admin.user.to_string(),
+ authenticated_entity=admin.user.to_string(),
+ )
+
+ reason = task.params.get("reason")
+ limit = task.params.get("limit")
+ assert limit is not None
+
+ result: Mapping[str, Any] = (
+ task.result if task.result else {"failed_redactions": {}}
+ )
+ for room in rooms:
+ room_version = await self._store.get_room_version(room)
+ event_ids = await self._store.get_events_sent_by_user_in_room(
+ user_id,
+ room,
+ limit,
+ ["m.room.member", "m.room.message", "m.room.encrypted"],
+ )
+ if not event_ids:
+ # nothing to redact in this room
+ continue
+
+ events = await self._store.get_events_as_list(event_ids)
+ for event in events:
+ # we care about join events but not other membership events
+ if event.type == "m.room.member":
+ content = event.content
+ if content:
+ if content.get("membership") == Membership.JOIN:
+ pass
+ else:
+ continue
+ relations = await self._store.get_relations_for_event(
+ room, event.event_id, event, event_type=EventTypes.Redaction
+ )
+
+ # if we've already successfully redacted this event then skip processing it
+ if relations[0]:
+ continue
+
+ event_dict = {
+ "type": EventTypes.Redaction,
+ "content": {"reason": reason} if reason else {},
+ "room_id": room,
+ "sender": requester.user.to_string(),
+ }
+ if room_version.updated_redaction_rules:
+ event_dict["content"]["redacts"] = event.event_id
+ else:
+ event_dict["redacts"] = event.event_id
+
+ try:
+ # set the prev event to the offending message to allow for redactions
+ # to be processed in the case where the user has been kicked/banned before
+ # redactions are requested
+ (
+ redaction,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ event_dict,
+ prev_event_ids=[event.event_id],
+ ratelimit=False,
+ )
+ except Exception as ex:
+ logger.info(
+ f"Redaction of event {event.event_id} failed due to: {ex}"
+ )
+ result["failed_redactions"][event.event_id] = str(ex)
+ await self._task_scheduler.update_task(task.id, result=result)
+
+ return TaskStatus.COMPLETE, result, None
+
class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data."""
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 4b33e1330d..b7d1033351 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -896,10 +896,10 @@ class ApplicationServicesHandler:
results = await make_deferred_yieldable(
defer.DeferredList(
[
- run_in_background(
+ run_in_background( # type: ignore[call-overload]
self.appservice_api.claim_client_keys,
# We know this must be an app service.
- self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
+ self.store.get_app_service_by_id(service_id),
service_query,
)
for service_id, service_query in query_by_appservice.items()
@@ -952,10 +952,10 @@ class ApplicationServicesHandler:
results = await make_deferred_yieldable(
defer.DeferredList(
[
- run_in_background(
+ run_in_background( # type: ignore[call-overload]
self.appservice_api.query_keys,
# We know this must be an app service.
- self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
+ self.store.get_app_service_by_id(service_id),
service_query,
)
for service_id, service_query in query_by_appservice.items()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a1fab99f6b..d37324cc46 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -79,9 +79,7 @@ from synapse.storage.databases.main.registration import (
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
-from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import base62_encode
-from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
from synapse.module_api import ModuleApi
@@ -153,42 +151,9 @@ def convert_client_dict_legacy_fields_to_identifier(
return identifier
-def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
- """
- Convert a phone login identifier type to a generic threepid identifier.
-
- Args:
- identifier: Login identifier dict of type 'm.id.phone'
-
- Returns:
- An equivalent m.id.thirdparty identifier dict
- """
- if "country" not in identifier or (
- # The specification requires a "phone" field, while Synapse used to require a "number"
- # field. Accept both for backwards compatibility.
- "phone" not in identifier
- and "number" not in identifier
- ):
- raise SynapseError(
- 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
- )
-
- # Accept both "phone" and "number" as valid keys in m.id.phone
- phone_number = identifier.get("phone", identifier["number"])
-
- # Convert user-provided phone number to a consistent representation
- msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
-
- return {
- "type": "m.id.thirdparty",
- "medium": "msisdn",
- "address": msisdn,
- }
-
-
@attr.s(slots=True, auto_attribs=True)
class SsoLoginExtraAttributes:
- """Data we track about SAML2 sessions"""
+ """Data we track about SAML2 sessions""" # Not other SSO types...?
# time the session was created, in milliseconds
creation_time: int
@@ -1195,70 +1160,11 @@ class AuthHandler:
# convert phone type identifiers to generic threepids
if identifier_dict["type"] == "m.id.phone":
- identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
+ raise SynapseError(400, "Third party identifiers are not supported on this server.")
# convert threepid identifiers to user IDs
if identifier_dict["type"] == "m.id.thirdparty":
- address = identifier_dict.get("address")
- medium = identifier_dict.get("medium")
-
- if medium is None or address is None:
- raise SynapseError(400, "Invalid thirdparty identifier")
-
- # For emails, canonicalise the address.
- # We store all email addresses canonicalised in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- if medium == "email":
- try:
- address = canonicalise_email(address)
- except ValueError as e:
- raise SynapseError(400, str(e))
-
- # We also apply account rate limiting using the 3PID as a key, as
- # otherwise using 3PID bypasses the ratelimiting based on user ID.
- if ratelimit:
- await self._failed_login_attempts_ratelimiter.ratelimit(
- None, (medium, address), update=False
- )
-
- # Check for login providers that support 3pid login types
- if login_type == LoginType.PASSWORD:
- # we've already checked that there is a (valid) password field
- assert isinstance(password, str)
- (
- canonical_user_id,
- callback_3pid,
- ) = await self.check_password_provider_3pid(medium, address, password)
- if canonical_user_id:
- # Authentication through password provider and 3pid succeeded
- return canonical_user_id, callback_3pid
-
- # No password providers were able to handle this 3pid
- # Check local store
- user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
- medium, address
- )
- if not user_id:
- logger.warning(
- "unknown 3pid identifier medium %s, address %r", medium, address
- )
- # We mark that we've failed to log in here, as
- # `check_password_provider_3pid` might have returned `None` due
- # to an incorrect password, rather than the account not
- # existing.
- #
- # If it returned None but the 3PID was bound then we won't hit
- # this code path, which is fine as then the per-user ratelimit
- # will kick in below.
- if ratelimit:
- await self._failed_login_attempts_ratelimiter.can_do_action(
- None, (medium, address)
- )
- raise LoginError(
- 403, msg=INVALID_USERNAME_OR_PASSWORD, errcode=Codes.FORBIDDEN
- )
-
- identifier_dict = {"type": "m.id.user", "user": user_id}
+ raise SynapseError(400, "Third party identifiers are not supported on this server.")
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
@@ -1548,83 +1454,6 @@ class AuthHandler:
user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
- async def add_threepid(
- self, user_id: str, medium: str, address: str, validated_at: int
- ) -> None:
- """
- Adds an association between a user's Matrix ID and a third-party ID (email,
- phone number).
-
- Args:
- user_id: The ID of the user to associate.
- medium: The medium of the third-party ID (email, msisdn).
- address: The address of the third-party ID (i.e. an email address).
- validated_at: The timestamp in ms of when the validation that the user owns
- this third-party ID occurred.
- """
- # check if medium has a valid value
- if medium not in ["email", "msisdn"]:
- raise SynapseError(
- code=400,
- msg=("'%s' is not a valid value for 'medium'" % (medium,)),
- errcode=Codes.INVALID_PARAM,
- )
-
- # 'Canonicalise' email addresses down to lower case.
- # We've now moving towards the homeserver being the entity that
- # is responsible for validating threepids used for resetting passwords
- # on accounts, so in future Synapse will gain knowledge of specific
- # types (mediums) of threepid. For now, we still use the existing
- # infrastructure, but this is the start of synapse gaining knowledge
- # of specific types of threepid (and fixes the fact that checking
- # for the presence of an email address during password reset was
- # case sensitive).
- if medium == "email":
- address = canonicalise_email(address)
-
- await self.store.user_add_threepid(
- user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
- )
-
- # Inform Synapse modules that a 3PID association has been created.
- await self._third_party_rules.on_add_user_third_party_identifier(
- user_id, medium, address
- )
-
- # Deprecated method for informing Synapse modules that a 3PID association
- # has successfully been created.
- await self._third_party_rules.on_threepid_bind(user_id, medium, address)
-
- async def delete_local_threepid(
- self, user_id: str, medium: str, address: str
- ) -> None:
- """Deletes an association between a third-party ID and a user ID from the local
- database. This method does not unbind the association from any identity servers.
-
- If `medium` is 'email' and a pusher is associated with this third-party ID, the
- pusher will also be deleted.
-
- Args:
- user_id: ID of user to remove the 3pid from.
- medium: The medium of the 3pid being removed: "email" or "msisdn".
- address: The 3pid address to remove.
- """
- # 'Canonicalise' email addresses as per above
- if medium == "email":
- address = canonicalise_email(address)
-
- await self.store.user_delete_threepid(user_id, medium, address)
-
- # Inform Synapse modules that a 3PID association has been deleted.
- await self._third_party_rules.on_remove_user_third_party_identifier(
- user_id, medium, address
- )
-
- if medium == "email":
- await self.store.delete_pusher_by_app_id_pushkey_user_id(
- app_id="m.email", pushkey=address, user_id=user_id
- )
-
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
deleted file mode 100644
index cc3d641b7d..0000000000
--- a/synapse/handlers/cas.py
+++ /dev/null
@@ -1,412 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import logging
-import urllib.parse
-from typing import TYPE_CHECKING, Dict, List, Optional
-from xml.etree import ElementTree as ET
-
-import attr
-
-from twisted.web.client import PartialDownloadError
-
-from synapse.api.errors import HttpResponseException
-from synapse.handlers.sso import MappingException, UserAttributes
-from synapse.http.site import SynapseRequest
-from synapse.types import UserID, map_username_to_mxid_localpart
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class CasError(Exception):
- """Used to catch errors when validating the CAS ticket."""
-
- def __init__(self, error: str, error_description: Optional[str] = None):
- self.error = error
- self.error_description = error_description
-
- def __str__(self) -> str:
- if self.error_description:
- return f"{self.error}: {self.error_description}"
- return self.error
-
-
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class CasResponse:
- username: str
- attributes: Dict[str, List[Optional[str]]]
-
-
-class CasHandler:
- """
- Utility class for to handle the response from a CAS SSO service.
-
- Args:
- hs
- """
-
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self._hostname = hs.hostname
- self._store = hs.get_datastores().main
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
-
- self._cas_server_url = hs.config.cas.cas_server_url
- self._cas_service_url = hs.config.cas.cas_service_url
- self._cas_protocol_version = hs.config.cas.cas_protocol_version
- self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute
- self._cas_required_attributes = hs.config.cas.cas_required_attributes
- self._cas_enable_registration = hs.config.cas.cas_enable_registration
- self._cas_allow_numeric_ids = hs.config.cas.cas_allow_numeric_ids
- self._cas_numeric_ids_prefix = hs.config.cas.cas_numeric_ids_prefix
-
- self._http_client = hs.get_proxied_http_client()
-
- # identifier for the external_ids table
- self.idp_id = "cas"
-
- # user-facing name of this auth provider
- self.idp_name = hs.config.cas.idp_name
-
- # MXC URI for icon for this auth provider
- self.idp_icon = hs.config.cas.idp_icon
-
- # optional brand identifier for this auth provider
- self.idp_brand = hs.config.cas.idp_brand
-
- self._sso_handler = hs.get_sso_handler()
-
- self._sso_handler.register_identity_provider(self)
-
- def _build_service_param(self, args: Dict[str, str]) -> str:
- """
- Generates a value to use as the "service" parameter when redirecting or
- querying the CAS service.
-
- Args:
- args: Additional arguments to include in the final redirect URL.
-
- Returns:
- The URL to use as a "service" parameter.
- """
- return "%s?%s" % (
- self._cas_service_url,
- urllib.parse.urlencode(args),
- )
-
- async def _validate_ticket(
- self, ticket: str, service_args: Dict[str, str]
- ) -> CasResponse:
- """
- Validate a CAS ticket with the server, and return the parsed the response.
-
- Args:
- ticket: The CAS ticket from the client.
- service_args: Additional arguments to include in the service URL.
- Should be the same as those passed to `handle_redirect_request`.
-
- Raises:
- CasError: If there's an error parsing the CAS response.
-
- Returns:
- The parsed CAS response.
- """
- if self._cas_protocol_version == 3:
- uri = self._cas_server_url + "/p3/proxyValidate"
- else:
- uri = self._cas_server_url + "/proxyValidate"
- args = {
- "ticket": ticket,
- "service": self._build_service_param(service_args),
- }
- try:
- body = await self._http_client.get_raw(uri, args)
- except PartialDownloadError as pde:
- # Twisted raises this error if the connection is closed,
- # even if that's being used old-http style to signal end-of-data
- # Assertion is for mypy's benefit. Error.response is Optional[bytes],
- # but a PartialDownloadError should always have a non-None response.
- assert pde.response is not None
- body = pde.response
- except HttpResponseException as e:
- description = (
- 'Authorization server responded with a "{status}" error '
- "while exchanging the authorization code."
- ).format(status=e.code)
- raise CasError("server_error", description) from e
-
- return self._parse_cas_response(body)
-
- def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
- """
- Retrieve the user and other parameters from the CAS response.
-
- Args:
- cas_response_body: The response from the CAS query.
-
- Raises:
- CasError: If there's an error parsing the CAS response.
-
- Returns:
- The parsed CAS response.
- """
-
- # Ensure the response is valid.
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise CasError(
- "missing_service_response",
- "root of CAS response is not serviceResponse",
- )
-
- success = root[0].tag.endswith("authenticationSuccess")
- if not success:
- raise CasError("unsucessful_response", "Unsuccessful CAS response")
-
- # Iterate through the nodes and pull out the user and any extra attributes.
- user = None
- attributes: Dict[str, List[Optional[str]]] = {}
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- # if numeric user IDs are allowed and username is numeric then we add the prefix so Synapse can handle it
- if self._cas_allow_numeric_ids and user is not None and user.isdigit():
- user = f"{self._cas_numeric_ids_prefix}{user}"
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes.setdefault(tag, []).append(attribute.text)
-
- # Ensure a user was found.
- if user is None:
- raise CasError("no_user", "CAS response does not contain user")
-
- return CasResponse(user, attributes)
-
- async def handle_redirect_request(
- self,
- request: SynapseRequest,
- client_redirect_url: Optional[bytes],
- ui_auth_session_id: Optional[str] = None,
- ) -> str:
- """Generates a URL for the CAS server where the client should be redirected.
-
- Args:
- request: the incoming HTTP request
- client_redirect_url: the URL that we should redirect the
- client to after login (or None for UI Auth).
- ui_auth_session_id: The session ID of the ongoing UI Auth (or
- None if this is a login).
-
- Returns:
- URL to redirect to
- """
-
- if ui_auth_session_id:
- service_args = {"session": ui_auth_session_id}
- else:
- assert client_redirect_url
- service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
-
- args = urllib.parse.urlencode(
- {"service": self._build_service_param(service_args)}
- )
-
- return "%s/login?%s" % (self._cas_server_url, args)
-
- async def handle_ticket(
- self,
- request: SynapseRequest,
- ticket: str,
- client_redirect_url: Optional[str],
- session: Optional[str],
- ) -> None:
- """
- Called once the user has successfully authenticated with the SSO.
- Validates a CAS ticket sent by the client and completes the auth process.
-
- If the user interactive authentication session is provided, marks the
- UI Auth session as complete, then returns an HTML page notifying the
- user they are done.
-
- Otherwise, this registers the user if necessary, and then returns a
- redirect (with a login token) to the client.
-
- Args:
- request: the incoming request from the browser. We'll
- respond to it with a redirect or an HTML page.
-
- ticket: The CAS ticket provided by the client.
-
- client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
- This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
-
- session: The session parameter from the `/cas/ticket` HTTP request, if given.
- This should be the UI Auth session id.
- """
- args = {}
- if client_redirect_url:
- args["redirectUrl"] = client_redirect_url
- if session:
- args["session"] = session
-
- try:
- cas_response = await self._validate_ticket(ticket, args)
- except CasError as e:
- logger.exception("Could not validate ticket")
- self._sso_handler.render_error(request, e.error, e.error_description, 401)
- return
-
- await self._handle_cas_response(
- request, cas_response, client_redirect_url, session
- )
-
- async def _handle_cas_response(
- self,
- request: SynapseRequest,
- cas_response: CasResponse,
- client_redirect_url: Optional[str],
- session: Optional[str],
- ) -> None:
- """Handle a CAS response to a ticket request.
-
- Assumes that the response has been validated. Maps the user onto an MXID,
- registering them if necessary, and returns a response to the browser.
-
- Args:
- request: the incoming request from the browser. We'll respond to it with an
- HTML page or a redirect
-
- cas_response: The parsed CAS response.
-
- client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
- This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
-
- session: The session parameter from the `/cas/ticket` HTTP request, if given.
- This should be the UI Auth session id.
- """
-
- # first check if we're doing a UIA
- if session:
- return await self._sso_handler.complete_sso_ui_auth_request(
- self.idp_id,
- cas_response.username,
- session,
- request,
- )
-
- # otherwise, we're handling a login request.
-
- # Ensure that the attributes of the logged in user meet the required
- # attributes.
- if not self._sso_handler.check_required_attributes(
- request, cas_response.attributes, self._cas_required_attributes
- ):
- return
-
- # Call the mapper to register/login the user
-
- # If this not a UI auth request than there must be a redirect URL.
- assert client_redirect_url is not None
-
- try:
- await self._complete_cas_login(cas_response, request, client_redirect_url)
- except MappingException as e:
- logger.exception("Could not map user")
- self._sso_handler.render_error(request, "mapping_error", str(e))
-
- async def _complete_cas_login(
- self,
- cas_response: CasResponse,
- request: SynapseRequest,
- client_redirect_url: str,
- ) -> None:
- """
- Given a CAS response, complete the login flow
-
- Retrieves the remote user ID, registers the user if necessary, and serves
- a redirect back to the client with a login-token.
-
- Args:
- cas_response: The parsed CAS response.
- request: The request to respond to
- client_redirect_url: The redirect URL passed in by the client.
-
- Raises:
- MappingException if there was a problem mapping the response to a user.
- RedirectException: some mapping providers may raise this if they need
- to redirect to an interstitial page.
- """
- # Note that CAS does not support a mapping provider, so the logic is hard-coded.
- localpart = map_username_to_mxid_localpart(cas_response.username)
-
- async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
- """
- Map from CAS attributes to user attributes.
- """
- # Due to the grandfathering logic matching any previously registered
- # mxids it isn't expected for there to be any failures.
- if failures:
- raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
-
- # Arbitrarily use the first attribute found.
- display_name = cas_response.attributes.get(
- self._cas_displayname_attribute, [None]
- )[0]
-
- return UserAttributes(localpart=localpart, display_name=display_name)
-
- async def grandfather_existing_users() -> Optional[str]:
- # Since CAS did not always use the user_external_ids table, always
- # to attempt to map to existing users.
- user_id = UserID(localpart, self._hostname).to_string()
-
- logger.debug(
- "Looking for existing account based on mapped %s",
- user_id,
- )
-
- users = await self._store.get_users_by_id_case_insensitive(user_id)
- if users:
- registered_user_id = list(users.keys())[0]
- logger.info("Grandfathering mapping to %s", registered_user_id)
- return registered_user_id
-
- return None
-
- await self._sso_handler.complete_sso_login_request(
- self.idp_id,
- cas_response.username,
- request,
- client_redirect_url,
- cas_response_to_user_attributes,
- grandfather_existing_users,
- registration_enabled=self._cas_enable_registration,
- )
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 12a7cace55..2c4991c6e5 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -43,7 +43,6 @@ class DeactivateAccountHandler:
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
- self._identity_handler = hs.get_identity_handler()
self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
@@ -82,7 +81,7 @@ class DeactivateAccountHandler:
by_admin: Whether this change was made by an administrator.
Returns:
- True if identity server supports removing threepids, otherwise False.
+ True
"""
# This can only be called on the main process.
@@ -96,40 +95,6 @@ class DeactivateAccountHandler:
403, "Deactivation of this user is forbidden", Codes.FORBIDDEN
)
- # FIXME: Theoretically there is a race here wherein user resets
- # password using threepid.
-
- # delete threepids first. We remove these from the IS so if this fails,
- # leave the user still active so they can try again.
- # Ideally we would prevent password resets and then do this in the
- # background thread.
-
- # This will be set to false if the identity server doesn't support
- # unbinding
- identity_server_supports_unbinding = True
-
- # Attempt to unbind any known bound threepids to this account from identity
- # server(s).
- bound_threepids = await self.store.user_get_bound_threepids(user_id)
- for medium, address in bound_threepids:
- try:
- result = await self._identity_handler.try_unbind_threepid(
- user_id, medium, address, id_server
- )
- except Exception:
- # Do we want this to be a fatal error or should we carry on?
- logger.exception("Failed to remove threepid from ID server")
- raise SynapseError(400, "Failed to remove threepid from ID server")
-
- identity_server_supports_unbinding &= result
-
- # Remove any local threepid associations for this account.
- local_threepids = await self.store.user_get_threepids(user_id)
- for local_threepid in local_threepids:
- await self._auth_handler.delete_local_threepid(
- user_id, local_threepid.medium, local_threepid.address
- )
-
# delete any devices belonging to the user, which will also
# delete corresponding access tokens.
await self._device_handler.delete_all_devices_for_user(user_id)
@@ -194,7 +159,7 @@ class DeactivateAccountHandler:
by_admin,
)
- return identity_server_supports_unbinding
+ return True
async def _reject_pending_invites_and_knocks_for_user(self, user_id: str) -> None:
"""Reject pending invites and knocks addressed to a given user ID.
diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py
new file mode 100644
index 0000000000..cb2a34ff73
--- /dev/null
+++ b/synapse/handlers/delayed_events.py
@@ -0,0 +1,545 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+import logging
+from typing import TYPE_CHECKING, List, Optional, Set, Tuple
+
+from twisted.internet.interfaces import IDelayedCall
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import ShadowBanError
+from synapse.api.ratelimiting import Ratelimiter
+from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
+from synapse.logging.opentracing import set_tag
+from synapse.metrics import event_processing_positions
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.http.delayed_events import (
+ ReplicationAddedDelayedEventRestServlet,
+)
+from synapse.storage.databases.main.delayed_events import (
+ DelayedEventDetails,
+ DelayID,
+ EventType,
+ StateKey,
+ Timestamp,
+ UserLocalpart,
+)
+from synapse.storage.databases.main.state_deltas import StateDelta
+from synapse.types import (
+ JsonDict,
+ Requester,
+ RoomID,
+ UserID,
+ create_requester,
+)
+from synapse.util.events import generate_fake_event_id
+from synapse.util.metrics import Measure
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class DelayedEventsHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
+ self._config = hs.config
+ self._clock = hs.get_clock()
+ self._event_creation_handler = hs.get_event_creation_handler()
+ self._room_member_handler = hs.get_room_member_handler()
+
+ self._request_ratelimiter = hs.get_request_ratelimiter()
+
+ # Ratelimiter for management of existing delayed events,
+ # keyed by the sending user ID & device ID.
+ self._delayed_event_mgmt_ratelimiter = Ratelimiter(
+ store=self._store,
+ clock=self._clock,
+ cfg=self._config.ratelimiting.rc_delayed_event_mgmt,
+ )
+
+ self._next_delayed_event_call: Optional[IDelayedCall] = None
+
+ # The current position in the current_state_delta stream
+ self._event_pos: Optional[int] = None
+
+ # Guard to ensure we only process event deltas one at a time
+ self._event_processing = False
+
+ if hs.config.worker.worker_app is None:
+ self._repl_client = None
+
+ async def _schedule_db_events() -> None:
+ # We kick this off to pick up outstanding work from before the last restart.
+ # Block until we're up to date.
+ await self._unsafe_process_new_event()
+ hs.get_notifier().add_replication_callback(self.notify_new_event)
+ # Kick off again (without blocking) to catch any missed notifications
+ # that may have fired before the callback was added.
+ self._clock.call_later(0, self.notify_new_event)
+
+ # Delayed events that are already marked as processed on startup might not have been
+ # sent properly on the last run of the server, so unmark them to send them again.
+ # Caveat: this will double-send delayed events that successfully persisted, but failed
+ # to be removed from the DB table of delayed events.
+ # TODO: To avoid double-sending, scan the timeline to find which of these events were
+ # already sent. To do so, must store delay_ids in sent events to retrieve them later.
+ await self._store.unprocess_delayed_events()
+
+ events, next_send_ts = await self._store.process_timeout_delayed_events(
+ self._get_current_ts()
+ )
+
+ if next_send_ts:
+ self._schedule_next_at(next_send_ts)
+
+ # Can send the events in background after having awaited on marking them as processed
+ run_as_background_process(
+ "_send_events",
+ self._send_events,
+ events,
+ )
+
+ self._initialized_from_db = run_as_background_process(
+ "_schedule_db_events", _schedule_db_events
+ )
+ else:
+ self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs)
+
+ @property
+ def _is_master(self) -> bool:
+ return self._repl_client is None
+
+ def notify_new_event(self) -> None:
+ """
+ Called when there may be more state event deltas to process,
+ which should cancel pending delayed events for the same state.
+ """
+ if self._event_processing:
+ return
+
+ self._event_processing = True
+
+ async def process() -> None:
+ try:
+ await self._unsafe_process_new_event()
+ finally:
+ self._event_processing = False
+
+ run_as_background_process("delayed_events.notify_new_event", process)
+
+ async def _unsafe_process_new_event(self) -> None:
+ # If self._event_pos is None then means we haven't fetched it from the DB yet
+ if self._event_pos is None:
+ self._event_pos = await self._store.get_delayed_events_stream_pos()
+ room_max_stream_ordering = self._store.get_room_max_stream_ordering()
+ if self._event_pos > room_max_stream_ordering:
+ # apparently, we've processed more events than exist in the database!
+ # this can happen if events are removed with history purge or similar.
+ logger.warning(
+ "Event stream ordering appears to have gone backwards (%i -> %i): "
+ "rewinding delayed events processor",
+ self._event_pos,
+ room_max_stream_ordering,
+ )
+ self._event_pos = room_max_stream_ordering
+
+ # Loop round handling deltas until we're up to date
+ while True:
+ with Measure(self._clock, "delayed_events_delta"):
+ room_max_stream_ordering = self._store.get_room_max_stream_ordering()
+ if self._event_pos == room_max_stream_ordering:
+ return
+
+ logger.debug(
+ "Processing delayed events %s->%s",
+ self._event_pos,
+ room_max_stream_ordering,
+ )
+ (
+ max_pos,
+ deltas,
+ ) = await self._storage_controllers.state.get_current_state_deltas(
+ self._event_pos, room_max_stream_ordering
+ )
+
+ logger.debug(
+ "Handling %d state deltas for delayed events processing",
+ len(deltas),
+ )
+ await self._handle_state_deltas(deltas)
+
+ self._event_pos = max_pos
+
+ # Expose current event processing position to prometheus
+ event_processing_positions.labels("delayed_events").set(max_pos)
+
+ await self._store.update_delayed_events_stream_pos(max_pos)
+
+ async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None:
+ """
+ Process current state deltas to cancel other users' pending delayed events
+ that target the same state.
+ """
+ for delta in deltas:
+ if delta.event_id is None:
+ logger.debug(
+ "Not handling delta for deleted state: %r %r",
+ delta.event_type,
+ delta.state_key,
+ )
+ continue
+
+ logger.debug(
+ "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
+ )
+
+ event = await self._store.get_event(
+ delta.event_id, check_room_id=delta.room_id, allow_rejected=True, allow_none=True
+ )
+
+ if event is None or event.rejected_reason is not None:
+ # This event has been rejected, so we don't want to cancel any delayed events for it.
+ continue
+
+ sender = UserID.from_string(event.sender)
+
+ next_send_ts = await self._store.cancel_delayed_state_events(
+ room_id=delta.room_id,
+ event_type=delta.event_type,
+ state_key=delta.state_key,
+ not_from_localpart=(
+ sender.localpart
+ if sender.domain == self._config.server.server_name
+ else ""
+ ),
+ )
+
+ if self._next_send_ts_changed(next_send_ts):
+ self._schedule_next_at_or_none(next_send_ts)
+
+ async def add(
+ self,
+ requester: Requester,
+ *,
+ room_id: str,
+ event_type: str,
+ state_key: Optional[str],
+ origin_server_ts: Optional[int],
+ content: JsonDict,
+ delay: int,
+ ) -> str:
+ """
+ Creates a new delayed event and schedules its delivery.
+
+ Args:
+ requester: The requester of the delayed event, who will be its owner.
+ room_id: The ID of the room where the event should be sent to.
+ event_type: The type of event to be sent.
+ state_key: The state key of the event to be sent, or None if it is not a state event.
+ origin_server_ts: The custom timestamp to send the event with.
+ If None, the timestamp will be the actual time when the event is sent.
+ content: The content of the event to be sent.
+ delay: How long (in milliseconds) to wait before automatically sending the event.
+
+ Returns: The ID of the added delayed event.
+
+ Raises:
+ SynapseError: if the delayed event fails validation checks.
+ """
+ # Use standard request limiter for scheduling new delayed events.
+ # TODO: Instead apply ratelimiting based on the scheduled send time.
+ # See https://github.com/element-hq/synapse/issues/18021
+ await self._request_ratelimiter.ratelimit(requester)
+
+ self._event_creation_handler.validator.validate_builder(
+ self._event_creation_handler.event_builder_factory.for_room_version(
+ await self._store.get_room_version(room_id),
+ {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": str(requester.user),
+ **({"state_key": state_key} if state_key is not None else {}),
+ },
+ )
+ )
+
+ creation_ts = self._get_current_ts()
+
+ delay_id, next_send_ts = await self._store.add_delayed_event(
+ user_localpart=requester.user.localpart,
+ device_id=requester.device_id,
+ creation_ts=creation_ts,
+ room_id=room_id,
+ event_type=event_type,
+ state_key=state_key,
+ origin_server_ts=origin_server_ts,
+ content=content,
+ delay=delay,
+ )
+
+ if self._repl_client is not None:
+ # NOTE: If this throws, the delayed event will remain in the DB and
+ # will be picked up once the main worker gets another delayed event.
+ await self._repl_client(
+ instance_name=MAIN_PROCESS_INSTANCE_NAME,
+ next_send_ts=next_send_ts,
+ )
+ elif self._next_send_ts_changed(next_send_ts):
+ self._schedule_next_at(next_send_ts)
+
+ return delay_id
+
+ def on_added(self, next_send_ts: int) -> None:
+ next_send_ts = Timestamp(next_send_ts)
+ if self._next_send_ts_changed(next_send_ts):
+ self._schedule_next_at(next_send_ts)
+
+ async def cancel(self, requester: Requester, delay_id: str) -> None:
+ """
+ Cancels the scheduled delivery of the matching delayed event.
+
+ Args:
+ requester: The owner of the delayed event to act on.
+ delay_id: The ID of the delayed event to act on.
+
+ Raises:
+ NotFoundError: if no matching delayed event could be found.
+ """
+ assert self._is_master
+ await self._delayed_event_mgmt_ratelimiter.ratelimit(
+ requester,
+ (requester.user.to_string(), requester.device_id),
+ )
+ await self._initialized_from_db
+
+ next_send_ts = await self._store.cancel_delayed_event(
+ delay_id=delay_id,
+ user_localpart=requester.user.localpart,
+ )
+
+ if self._next_send_ts_changed(next_send_ts):
+ self._schedule_next_at_or_none(next_send_ts)
+
+ async def restart(self, requester: Requester, delay_id: str) -> None:
+ """
+ Restarts the scheduled delivery of the matching delayed event.
+
+ Args:
+ requester: The owner of the delayed event to act on.
+ delay_id: The ID of the delayed event to act on.
+
+ Raises:
+ NotFoundError: if no matching delayed event could be found.
+ """
+ assert self._is_master
+ await self._delayed_event_mgmt_ratelimiter.ratelimit(
+ requester,
+ (requester.user.to_string(), requester.device_id),
+ )
+ await self._initialized_from_db
+
+ next_send_ts = await self._store.restart_delayed_event(
+ delay_id=delay_id,
+ user_localpart=requester.user.localpart,
+ current_ts=self._get_current_ts(),
+ )
+
+ if self._next_send_ts_changed(next_send_ts):
+ self._schedule_next_at(next_send_ts)
+
+ async def send(self, requester: Requester, delay_id: str) -> None:
+ """
+ Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
+
+ Args:
+ requester: The owner of the delayed event to act on.
+ delay_id: The ID of the delayed event to act on.
+
+ Raises:
+ NotFoundError: if no matching delayed event could be found.
+ """
+ assert self._is_master
+ # Use standard request limiter for sending delayed events on-demand,
+ # as an on-demand send is similar to sending a regular event.
+ await self._request_ratelimiter.ratelimit(requester)
+ await self._initialized_from_db
+
+ event, next_send_ts = await self._store.process_target_delayed_event(
+ delay_id=delay_id,
+ user_localpart=requester.user.localpart,
+ )
+
+ if self._next_send_ts_changed(next_send_ts):
+ self._schedule_next_at_or_none(next_send_ts)
+
+ await self._send_event(
+ DelayedEventDetails(
+ delay_id=DelayID(delay_id),
+ user_localpart=UserLocalpart(requester.user.localpart),
+ room_id=event.room_id,
+ type=event.type,
+ state_key=event.state_key,
+ origin_server_ts=event.origin_server_ts,
+ content=event.content,
+ device_id=event.device_id,
+ )
+ )
+
+ async def _send_on_timeout(self) -> None:
+ self._next_delayed_event_call = None
+
+ events, next_send_ts = await self._store.process_timeout_delayed_events(
+ self._get_current_ts()
+ )
+
+ if next_send_ts:
+ self._schedule_next_at(next_send_ts)
+
+ await self._send_events(events)
+
+ async def _send_events(self, events: List[DelayedEventDetails]) -> None:
+ sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set()
+ for event in events:
+ if event.state_key is not None:
+ state_info = (event.room_id, event.type, event.state_key)
+ if state_info in sent_state:
+ continue
+ else:
+ state_info = None
+ try:
+ # TODO: send in background if message event or non-conflicting state event
+ await self._send_event(event)
+ if state_info is not None:
+ sent_state.add(state_info)
+ except Exception:
+ logger.exception("Failed to send delayed event")
+
+ for room_id, event_type, state_key in sent_state:
+ await self._store.delete_processed_delayed_state_events(
+ room_id=str(room_id),
+ event_type=event_type,
+ state_key=state_key,
+ )
+
+ def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None:
+ if next_send_ts is not None:
+ self._schedule_next_at(next_send_ts)
+ elif self._next_delayed_event_call is not None:
+ self._next_delayed_event_call.cancel()
+ self._next_delayed_event_call = None
+
+ def _schedule_next_at(self, next_send_ts: Timestamp) -> None:
+ delay = next_send_ts - self._get_current_ts()
+ delay_sec = delay / 1000 if delay > 0 else 0
+
+ if self._next_delayed_event_call is None:
+ self._next_delayed_event_call = self._clock.call_later(
+ delay_sec,
+ run_as_background_process,
+ "_send_on_timeout",
+ self._send_on_timeout,
+ )
+ else:
+ self._next_delayed_event_call.reset(delay_sec)
+
+ async def get_all_for_user(self, requester: Requester) -> List[JsonDict]:
+ """Return all pending delayed events requested by the given user."""
+ await self._delayed_event_mgmt_ratelimiter.ratelimit(
+ requester,
+ (requester.user.to_string(), requester.device_id),
+ )
+ return await self._store.get_all_delayed_events_for_user(
+ requester.user.localpart
+ )
+
+ async def _send_event(
+ self,
+ event: DelayedEventDetails,
+ txn_id: Optional[str] = None,
+ ) -> None:
+ user_id = UserID(event.user_localpart, self._config.server.server_name)
+ user_id_str = user_id.to_string()
+ # Create a new requester from what data is currently available
+ requester = create_requester(
+ user_id,
+ is_guest=await self._store.is_guest(user_id_str),
+ device_id=event.device_id,
+ )
+
+ try:
+ if event.state_key is not None and event.type == EventTypes.Member:
+ membership = event.content.get("membership")
+ assert membership is not None
+ event_id, _ = await self._room_member_handler.update_membership(
+ requester,
+ target=UserID.from_string(event.state_key),
+ room_id=event.room_id.to_string(),
+ action=membership,
+ content=event.content,
+ origin_server_ts=event.origin_server_ts,
+ )
+ else:
+ event_dict: JsonDict = {
+ "type": event.type,
+ "content": event.content,
+ "room_id": event.room_id.to_string(),
+ "sender": user_id_str,
+ }
+
+ if event.origin_server_ts is not None:
+ event_dict["origin_server_ts"] = event.origin_server_ts
+
+ if event.state_key is not None:
+ event_dict["state_key"] = event.state_key
+
+ (
+ sent_event,
+ _,
+ ) = await self._event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ event_dict,
+ txn_id=txn_id,
+ )
+ event_id = sent_event.event_id
+ except ShadowBanError:
+ event_id = generate_fake_event_id()
+ finally:
+ # TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
+ try:
+ await self._store.delete_processed_delayed_event(
+ event.delay_id, event.user_localpart
+ )
+ except Exception:
+ logger.exception("Failed to delete processed delayed event")
+
+ set_tag("event_id", event_id)
+
+ def _get_current_ts(self) -> Timestamp:
+ return Timestamp(self._clock.time_msec())
+
+ def _next_send_ts_changed(self, next_send_ts: Optional[Timestamp]) -> bool:
+ # The DB alone knows if the next send time changed after adding/modifying
+ # a delayed event, but if we were to ever miss updating our delayed call's
+ # firing time, we may miss other updates. So, keep track of changes to the
+ # the next send time here instead of in the DB.
+ cached_next_send_ts = (
+ int(self._next_delayed_event_call.getTime() * 1000)
+ if self._next_delayed_event_call is not None
+ else None
+ )
+ return next_send_ts != cached_next_send_ts
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 4fc6fcd7ae..f8b547bbed 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -20,10 +20,21 @@
#
#
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple
+from threading import Lock
+from typing import (
+ TYPE_CHECKING,
+ AbstractSet,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+)
from synapse.api import errors
-from synapse.api.constants import EduTypes, EventTypes
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import (
Codes,
FederationDeniedError,
@@ -38,6 +49,8 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
+from synapse.storage.databases.main.roommember import EventIdMembership
+from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import (
DeviceListUpdates,
JsonDict,
@@ -151,6 +164,8 @@ class DeviceWorkerHandler:
raise errors.NotFoundError()
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
+
+ device = dict(device)
_update_device_from_client_ips(device, ips)
set_tag("device", str(device))
@@ -211,7 +226,6 @@ class DeviceWorkerHandler:
return changed
@trace
- @measure_func("device.get_user_ids_changed")
@cancellable
async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
@@ -222,129 +236,113 @@ class DeviceWorkerHandler:
set_tag("user_id", user_id)
set_tag("from_token", str(from_token))
- now_room_key = self.store.get_room_max_token()
-
- room_ids = await self.store.get_rooms_for_user(user_id)
- changed = await self.get_device_changes_in_shared_rooms(
- user_id, room_ids, from_token
- )
+ now_token = self._event_sources.get_current_token()
- # Then work out if any users have since joined
- rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+ # We need to work out all the different membership changes for the user
+ # and user they share a room with, to pass to
+ # `generate_sync_entry_for_device_list`. See its docstring for details
+ # on the data required.
- member_events = await self.store.get_membership_changes_for_user(
- user_id, from_token.room_key, now_room_key
- )
- rooms_changed.update(event.room_id for event in member_events)
-
- stream_ordering = from_token.room_key.stream
-
- possibly_changed = set(changed)
- possibly_left = set()
- for room_id in rooms_changed:
- # Check if the forward extremities have changed. If not then we know
- # the current state won't have changed, and so we can skip this room.
- try:
- if not await self.store.have_room_forward_extremities_changed_since(
- room_id, stream_ordering
- ):
- continue
- except errors.StoreError:
- pass
+ joined_room_ids = await self.store.get_rooms_for_user(user_id)
- current_state_ids = await self._state_storage.get_current_state_ids(
- room_id, await_full_state=False
+ # Get the set of rooms that the user has joined/left
+ membership_changes = (
+ await self.store.get_current_state_delta_membership_changes_for_user(
+ user_id, from_key=from_token.room_key, to_key=now_token.room_key
)
+ )
- # The user may have left the room
- # TODO: Check if they actually did or if we were just invited.
- if room_id not in room_ids:
- for etype, state_key in current_state_ids.keys():
- if etype != EventTypes.Member:
- continue
- possibly_left.add(state_key)
- continue
-
- # Fetch the current state at the time.
- try:
- event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering(
- room_id, stream_ordering=stream_ordering
- )
- except errors.StoreError:
- # we have purged the stream_ordering index since the stream
- # ordering: treat it the same as a new room
- event_ids = []
-
- # special-case for an empty prev state: include all members
- # in the changed list
- if not event_ids:
- log_kv(
- {"event": "encountered empty previous state", "room_id": room_id}
- )
- for etype, state_key in current_state_ids.keys():
- if etype != EventTypes.Member:
- continue
- possibly_changed.add(state_key)
- continue
-
- current_member_id = current_state_ids.get((EventTypes.Member, user_id))
- if not current_member_id:
+ # Check for newly joined or left rooms. We need to make sure that we add
+ # to newly joined in the case membership goes from join -> leave -> join
+ # again.
+ newly_joined_rooms: Set[str] = set()
+ newly_left_rooms: Set[str] = set()
+ for change in membership_changes:
+ # We check for changes in "joinedness", i.e. if the membership has
+ # changed to or from JOIN.
+ if change.membership == Membership.JOIN:
+ if change.prev_membership != Membership.JOIN:
+ newly_joined_rooms.add(change.room_id)
+ newly_left_rooms.discard(change.room_id)
+ elif change.prev_membership == Membership.JOIN:
+ newly_joined_rooms.discard(change.room_id)
+ newly_left_rooms.add(change.room_id)
+
+ # We now work out if any other users have since joined or left the rooms
+ # the user is currently in.
+
+ # List of membership changes per room
+ room_to_deltas: Dict[str, List[StateDelta]] = {}
+ # The set of event IDs of membership events (so we can fetch their
+ # associated membership).
+ memberships_to_fetch: Set[str] = set()
+
+ # TODO: Only pull out membership events?
+ state_changes = await self.store.get_current_state_deltas_for_rooms(
+ joined_room_ids, from_token=from_token.room_key, to_token=now_token.room_key
+ )
+ for delta in state_changes:
+ if delta.event_type != EventTypes.Member:
continue
- # mapping from event_id -> state_dict
- prev_state_ids = await self._state_storage.get_state_ids_for_events(
- event_ids,
- await_full_state=False,
+ room_to_deltas.setdefault(delta.room_id, []).append(delta)
+ if delta.event_id:
+ memberships_to_fetch.add(delta.event_id)
+ if delta.prev_event_id:
+ memberships_to_fetch.add(delta.prev_event_id)
+
+ # Fetch all the memberships for the membership events
+ event_id_to_memberships: Mapping[str, Optional[EventIdMembership]] = {}
+ if memberships_to_fetch:
+ event_id_to_memberships = await self.store.get_membership_from_event_ids(
+ memberships_to_fetch
)
- # Check if we've joined the room? If so we just blindly add all the users to
- # the "possibly changed" users.
- for state_dict in prev_state_ids.values():
- member_event = state_dict.get((EventTypes.Member, user_id), None)
- if not member_event or member_event != current_member_id:
- for etype, state_key in current_state_ids.keys():
- if etype != EventTypes.Member:
- continue
- possibly_changed.add(state_key)
- break
-
- # If there has been any change in membership, include them in the
- # possibly changed list. We'll check if they are joined below,
- # and we're not toooo worried about spuriously adding users.
- for key, event_id in current_state_ids.items():
- etype, state_key = key
- if etype != EventTypes.Member:
- continue
-
- # check if this member has changed since any of the extremities
- # at the stream_ordering, and add them to the list if so.
- for state_dict in prev_state_ids.values():
- prev_event_id = state_dict.get(key, None)
- if not prev_event_id or prev_event_id != event_id:
- if state_key != user_id:
- possibly_changed.add(state_key)
- break
-
- if possibly_changed or possibly_left:
- possibly_joined = possibly_changed
- possibly_left = possibly_changed | possibly_left
-
- # Double check if we still share rooms with the given user.
- users_rooms = await self.store.get_rooms_for_users(possibly_left)
- for changed_user_id, entries in users_rooms.items():
- if any(rid in room_ids for rid in entries):
- possibly_left.discard(changed_user_id)
- else:
- possibly_joined.discard(changed_user_id)
-
- else:
- possibly_joined = set()
- possibly_left = set()
+ joined_invited_knocked = (
+ Membership.JOIN,
+ Membership.INVITE,
+ Membership.KNOCK,
+ )
- device_list_updates = DeviceListUpdates(
- changed=possibly_joined,
- left=possibly_left,
+ # We now want to find any user that have newly joined/invited/knocked,
+ # or newly left, similarly to above.
+ newly_joined_or_invited_or_knocked_users: Set[str] = set()
+ newly_left_users: Set[str] = set()
+ for _, deltas in room_to_deltas.items():
+ for delta in deltas:
+ # Get the prev/new memberships for the delta
+ new_membership = None
+ prev_membership = None
+ if delta.event_id:
+ m = event_id_to_memberships.get(delta.event_id)
+ if m is not None:
+ new_membership = m.membership
+ if delta.prev_event_id:
+ m = event_id_to_memberships.get(delta.prev_event_id)
+ if m is not None:
+ prev_membership = m.membership
+
+ # Check if a user has newly joined/invited/knocked, or left.
+ if new_membership in joined_invited_knocked:
+ if prev_membership not in joined_invited_knocked:
+ newly_joined_or_invited_or_knocked_users.add(delta.state_key)
+ newly_left_users.discard(delta.state_key)
+ elif prev_membership in joined_invited_knocked:
+ newly_joined_or_invited_or_knocked_users.discard(delta.state_key)
+ newly_left_users.add(delta.state_key)
+
+ # Now we actually calculate the device list entry with the information
+ # calculated above.
+ device_list_updates = await self.generate_sync_entry_for_device_list(
+ user_id=user_id,
+ since_token=from_token,
+ now_token=now_token,
+ joined_room_ids=joined_room_ids,
+ newly_joined_rooms=newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
+ newly_left_rooms=newly_left_rooms,
+ newly_left_users=newly_left_users,
)
log_kv(
@@ -356,6 +354,87 @@ class DeviceWorkerHandler:
return device_list_updates
+ async def generate_sync_entry_for_device_list(
+ self,
+ user_id: str,
+ since_token: StreamToken,
+ now_token: StreamToken,
+ joined_room_ids: AbstractSet[str],
+ newly_joined_rooms: AbstractSet[str],
+ newly_joined_or_invited_or_knocked_users: AbstractSet[str],
+ newly_left_rooms: AbstractSet[str],
+ newly_left_users: AbstractSet[str],
+ ) -> DeviceListUpdates:
+ """Generate the DeviceListUpdates section of sync
+
+ Args:
+ sync_result_builder
+ newly_joined_rooms: Set of rooms user has joined since previous sync
+ newly_joined_or_invited_or_knocked_users: Set of users that have joined,
+ been invited to a room or are knocking on a room since
+ previous sync.
+ newly_left_rooms: Set of rooms user has left since previous sync
+ newly_left_users: Set of users that have left a room we're in since
+ previous sync
+ """
+ # Take a copy since these fields will be mutated later.
+ newly_joined_or_invited_or_knocked_users = set(
+ newly_joined_or_invited_or_knocked_users
+ )
+ newly_left_users = set(newly_left_users)
+
+ # We want to figure out what user IDs the client should refetch
+ # device keys for, and which users we aren't going to track changes
+ # for anymore.
+ #
+ # For the first step we check:
+ # a. if any users we share a room with have updated their devices,
+ # and
+ # b. we also check if we've joined any new rooms, or if a user has
+ # joined a room we're in.
+ #
+ # For the second step we just find any users we no longer share a
+ # room with by looking at all users that have left a room plus users
+ # that were in a room we've left.
+
+ users_that_have_changed = set()
+
+ # Step 1a, check for changes in devices of users we share a room
+ # with
+ users_that_have_changed = await self.get_device_changes_in_shared_rooms(
+ user_id,
+ joined_room_ids,
+ from_token=since_token,
+ now_token=now_token,
+ )
+
+ # Step 1b, check for newly joined rooms
+ for room_id in newly_joined_rooms:
+ joined_users = await self.store.get_users_in_room(room_id)
+ newly_joined_or_invited_or_knocked_users.update(joined_users)
+
+ # TODO: Check that these users are actually new, i.e. either they
+ # weren't in the previous sync *or* they left and rejoined.
+ users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
+
+ user_signatures_changed = await self.store.get_users_whose_signatures_changed(
+ user_id, since_token.device_list_key
+ )
+ users_that_have_changed.update(user_signatures_changed)
+
+ # Now find users that we no longer track
+ for room_id in newly_left_rooms:
+ left_users = await self.store.get_users_in_room(room_id)
+ newly_left_users.update(left_users)
+
+ # Remove any users that we still share a room with.
+ left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
+ for user_id, entries in left_users_rooms.items():
+ if any(rid in joined_room_ids for rid in entries):
+ newly_left_users.discard(user_id)
+
+ return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users)
+
async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
if not self.hs.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -653,6 +732,40 @@ class DeviceHandler(DeviceWorkerHandler):
await self.notify_device_update(user_id, device_ids)
+ async def upsert_device(
+ self, user_id: str, device_id: str, display_name: Optional[str] = None
+ ) -> bool:
+ """Create or update a device
+
+ Args:
+ user_id: The user to update devices of.
+ device_id: The device to update.
+ display_name: The new display name for this device.
+
+ Returns:
+ True if the device was created, False if it was updated.
+
+ """
+
+ # Reject a new displayname which is too long.
+ self._check_device_name_length(display_name)
+
+ created = await self.store.store_device(
+ user_id,
+ device_id,
+ initial_device_display_name=display_name,
+ )
+
+ if not created:
+ await self.store.update_device(
+ user_id,
+ device_id,
+ new_display_name=display_name,
+ )
+
+ await self.notify_device_update(user_id, [device_id])
+ return created
+
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
"""Update the given device
@@ -1125,7 +1238,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
)
# Attempt to resync out of sync device lists every 30s.
- self._resync_retry_in_progress = False
+ self._resync_retry_lock = Lock()
self.clock.looping_call(
run_as_background_process,
30 * 1000,
@@ -1307,13 +1420,10 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
- if self._resync_retry_in_progress:
+ # If the lock can not be acquired we want to always return immediately instead of blocking here
+ if not self._resync_retry_lock.acquire(blocking=False):
return
-
try:
- # Prevent another call of this function to retry resyncing device lists so
- # we don't send too many requests.
- self._resync_retry_in_progress = True
# Get all of the users that need resyncing.
need_resync = await self.store.get_user_ids_requiring_device_list_resync()
@@ -1354,8 +1464,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
e,
)
finally:
- # Allow future calls to retry resyncinc out of sync device lists.
- self._resync_retry_in_progress = False
+ self._resync_retry_lock.release()
async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad2b0f5fcc..48c7d411d5 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -21,9 +21,7 @@
import logging
import string
-from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
-
-from typing_extensions import Literal
+from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Sequence
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import (
@@ -265,9 +263,9 @@ class DirectoryHandler:
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None
if self.hs.is_mine(room_alias):
- result: Optional[RoomAliasMapping] = (
- await self.get_association_from_room_alias(room_alias)
- )
+ result: Optional[
+ RoomAliasMapping
+ ] = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
@@ -512,11 +510,9 @@ class DirectoryHandler:
raise SynapseError(403, "Not allowed to publish room")
# Check if publishing is blocked by a third party module
- allowed_by_third_party_rules = (
- await (
- self._third_party_event_rules.check_visibility_can_be_modified(
- room_id, visibility
- )
+ allowed_by_third_party_rules = await (
+ self._third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
)
)
if not allowed_by_third_party_rules:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index f78e66ad0a..6171aaf29f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -39,6 +39,8 @@ from synapse.replication.http.devices import ReplicationUploadKeysForUserRestSer
from synapse.types import (
JsonDict,
JsonMapping,
+ ScheduledTask,
+ TaskStatus,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -70,6 +72,7 @@ class E2eKeysHandler:
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
self._worker_lock_handler = hs.get_worker_locks_handler()
+ self._task_scheduler = hs.get_task_scheduler()
federation_registry = hs.get_federation_registry()
@@ -116,6 +119,10 @@ class E2eKeysHandler:
hs.config.experimental.msc3984_appservice_key_query
)
+ self._task_scheduler.register_action(
+ self._delete_old_one_time_keys_task, "delete_old_otks"
+ )
+
@trace
@cancellable
async def query_devices(
@@ -151,7 +158,37 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
- device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
+
+ async def filter_device_key_query(
+ query: Dict[str, List[str]],
+ ) -> Dict[str, List[str]]:
+ if not self.config.experimental.msc4263_limit_key_queries_to_users_who_share_rooms:
+ # Only ignore invalid user IDs, which is the same behaviour as if
+ # the user existed but had no keys.
+ return {
+ user_id: v
+ for user_id, v in query.items()
+ if UserID.is_valid(user_id)
+ }
+
+ # Strip invalid user IDs and user IDs the requesting user does not share rooms with.
+ valid_user_ids = [
+ user_id for user_id in query.keys() if UserID.is_valid(user_id)
+ ]
+ allowed_user_ids = set(
+ await self.store.do_users_share_a_room_joined_or_invited(
+ from_user_id, valid_user_ids
+ )
+ )
+ return {
+ user_id: v
+ for user_id, v in query.items()
+ if user_id in allowed_user_ids
+ }
+
+ device_keys_query: Dict[str, List[str]] = await filter_device_key_query(
+ query_body.get("device_keys", {})
+ )
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -159,11 +196,6 @@ class E2eKeysHandler:
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
- if not UserID.is_valid(user_id):
- # Ignore invalid user IDs, which is the same behaviour as if
- # the user existed but had no keys.
- continue
-
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
@@ -615,7 +647,7 @@ class E2eKeysHandler:
3. Attempt to fetch fallback keys from the database.
Args:
- local_query: An iterable of tuples of (user ID, device ID, algorithm).
+ local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
always_include_fallback_keys: True to always include fallback keys.
Returns:
@@ -1156,7 +1188,7 @@ class E2eKeysHandler:
devices = devices[user_id]
except SynapseError as e:
failure = _exception_to_failure(e)
- failures[user_id] = {device: failure for device in signatures.keys()}
+ failures[user_id] = dict.fromkeys(signatures.keys(), failure)
return signature_list, failures
for device_id, device in signatures.items():
@@ -1296,7 +1328,7 @@ class E2eKeysHandler:
except SynapseError as e:
failure = _exception_to_failure(e)
for user, devicemap in signatures.items():
- failures[user] = {device_id: failure for device_id in devicemap.keys()}
+ failures[user] = dict.fromkeys(devicemap.keys(), failure)
return signature_list, failures
for target_user, devicemap in signatures.items():
@@ -1337,9 +1369,7 @@ class E2eKeysHandler:
# other devices were signed -- mark those as failures
logger.debug("upload signature: too many devices specified")
failure = _exception_to_failure(NotFoundError("Unknown device"))
- failures[target_user] = {
- device: failure for device in other_devices
- }
+ failures[target_user] = dict.fromkeys(other_devices, failure)
if user_signing_key_id in master_key.get("signatures", {}).get(
user_id, {}
@@ -1360,9 +1390,7 @@ class E2eKeysHandler:
except SynapseError as e:
failure = _exception_to_failure(e)
if device_id is None:
- failures[target_user] = {
- device_id: failure for device_id in devicemap.keys()
- }
+ failures[target_user] = dict.fromkeys(devicemap.keys(), failure)
else:
failures.setdefault(target_user, {})[device_id] = failure
@@ -1574,6 +1602,45 @@ class E2eKeysHandler:
return True
return False
+ async def _delete_old_one_time_keys_task(
+ self, task: ScheduledTask
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ """Scheduler task to delete old one time keys.
+
+ Until Synapse 1.119, Synapse used to issue one-time-keys in a random order, leading to the possibility
+ that it could still have old OTKs that the client has dropped. This task is scheduled exactly once
+ by a database schema delta file, and it clears out old one-time-keys that look like they came from libolm.
+ """
+ last_user = task.result.get("from_user", "") if task.result else ""
+ while True:
+ # We process users in batches of 100
+ users, rowcount = await self.store.delete_old_otks_for_next_user_batch(
+ last_user, 100
+ )
+ if len(users) == 0:
+ # We're done!
+ return TaskStatus.COMPLETE, None, None
+
+ logger.debug(
+ "Deleted %i old one-time-keys for users '%s'..'%s'",
+ rowcount,
+ users[0],
+ users[-1],
+ )
+ last_user = users[-1]
+
+ # Store our progress
+ await self._task_scheduler.update_task(
+ task.id, result={"from_user": last_user}
+ )
+
+ # Sleep a little before doing the next user.
+ #
+ # matrix.org has about 15M users in the e2e_one_time_keys_json table
+ # (comprising 20M devices). We want this to take about a week, so we need
+ # to do about one batch of 100 users every 4 seconds.
+ await self.clock.sleep(4)
+
def _check_cross_signing_key(
key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f397911f28..623fd33f13 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -20,9 +20,7 @@
#
import logging
-from typing import TYPE_CHECKING, Dict, Optional, cast
-
-from typing_extensions import Literal
+from typing import TYPE_CHECKING, Dict, Literal, Optional, cast
from synapse.api.errors import (
Codes,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 299588e476..ff751d25f6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -78,6 +78,7 @@ from synapse.replication.http.federation import (
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.invite_rule import InviteRule
from synapse.types import JsonDict, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
@@ -210,7 +211,7 @@ class FederationHandler:
@tag_args
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int, record_time: bool = True
- ) -> bool:
+ ) -> None:
"""Checks the database to see if we should backfill before paginating,
and if so do.
@@ -224,8 +225,6 @@ class FederationHandler:
should back paginate.
record_time: Whether to record the time it takes to backfill.
- Returns:
- True if we actually tried to backfill something, otherwise False.
"""
# Starting the processing time here so we can include the room backfill
# linearizer lock queue in the timing
@@ -251,7 +250,7 @@ class FederationHandler:
limit: int,
*,
processing_start_time: Optional[int],
- ) -> bool:
+ ) -> None:
"""
Checks whether the `current_depth` is at or approaching any backfill
points in the room and if so, will backfill. We only care about
@@ -325,7 +324,7 @@ class FederationHandler:
limit=1,
)
if not have_later_backfill_points:
- return False
+ return None
logger.debug(
"_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points."
@@ -345,15 +344,15 @@ class FederationHandler:
)
# We return `False` because we're backfilling in the background and there is
# no new events immediately for the caller to know about yet.
- return False
+ return None
# Even after recursing with `MAX_DEPTH`, we didn't find any
# backward extremities to backfill from.
if not sorted_backfill_points:
logger.debug(
- "_maybe_backfill_inner: Not backfilling as no backward extremeties found."
+ "_maybe_backfill_inner: Not backfilling as no backward extremities found."
)
- return False
+ return None
# If we're approaching an extremity we trigger a backfill, otherwise we
# no-op.
@@ -372,7 +371,7 @@ class FederationHandler:
current_depth,
limit,
)
- return False
+ return None
# For performance's sake, we only want to paginate from a particular extremity
# if we can actually see the events we'll get. Otherwise, we'd just spend a lot
@@ -440,7 +439,7 @@ class FederationHandler:
logger.debug(
"_maybe_backfill_inner: found no extremities which would be visible"
)
- return False
+ return None
logger.debug(
"_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
@@ -463,7 +462,7 @@ class FederationHandler:
)
)
- async def try_backfill(domains: StrCollection) -> bool:
+ async def try_backfill(domains: StrCollection) -> None:
# TODO: Should we try multiple of these at a time?
# Number of contacted remote homeservers that have denied our backfill
@@ -486,7 +485,7 @@ class FederationHandler:
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
- return True
+ return None
except NotRetryingDestination as e:
logger.info("_maybe_backfill_inner: %s", e)
continue
@@ -510,7 +509,7 @@ class FederationHandler:
)
denied_count += 1
if denied_count >= max_denied_count:
- return False
+ return None
continue
logger.info("Failed to backfill from %s because %s", dom, e)
@@ -526,7 +525,7 @@ class FederationHandler:
)
denied_count += 1
if denied_count >= max_denied_count:
- return False
+ return None
continue
logger.info("Failed to backfill from %s because %s", dom, e)
@@ -538,7 +537,7 @@ class FederationHandler:
logger.exception("Failed to backfill from %s because %s", dom, e)
continue
- return False
+ return None
# If we have the `processing_start_time`, then we can make an
# observation. We wouldn't have the `processing_start_time` in the case
@@ -550,14 +549,9 @@ class FederationHandler:
(processing_end_time - processing_start_time) / 1000
)
- success = await try_backfill(likely_domains)
- if success:
- return True
-
# TODO: we could also try servers which were previously in the room, but
# are no longer.
-
- return False
+ return await try_backfill(likely_domains)
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing.
@@ -880,6 +874,9 @@ class FederationHandler:
if stripped_room_state is None:
raise KeyError("Missing 'knock_room_state' field in send_knock response")
+ if not isinstance(stripped_room_state, list):
+ raise TypeError("'knock_room_state' has wrong type")
+
event.unsigned["knock_room_state"] = stripped_room_state
context = EventContext.for_outlier(self._storage_controllers)
@@ -1001,11 +998,11 @@ class FederationHandler:
)
if include_auth_user_id:
- event_content[EventContentFields.AUTHORISING_USER] = (
- await self._event_auth_handler.get_user_which_could_invite(
- room_id,
- state_ids,
- )
+ event_content[
+ EventContentFields.AUTHORISING_USER
+ ] = await self._event_auth_handler.get_user_which_could_invite(
+ room_id,
+ state_ids,
)
builder = self.event_builder_factory.for_room_version(
@@ -1086,6 +1083,20 @@ class FederationHandler:
if event.state_key == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
+ # check the invitee's configuration and apply rules
+ invite_config = await self.store.get_invite_config_for_user(event.state_key)
+ rule = invite_config.get_invite_rule(event.sender)
+ if rule == InviteRule.BLOCK:
+ logger.info(
+ f"Automatically rejecting invite from {event.sender} due to the invite filtering rules of {event.state_key}"
+ )
+ raise SynapseError(
+ 403,
+ "You are not permitted to invite this user.",
+ errcode=Codes.INVITE_BLOCKED,
+ )
+ # InviteRule.IGNORE is handled at the sync layer
+
# We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
# We don't rate limit based on room ID, as that should be done by
@@ -1309,9 +1320,9 @@ class FederationHandler:
if state_key is not None:
# the event was not rejected (get_event raises a NotFoundError for rejected
# events) so the state at the event should include the event itself.
- assert (
- state_map.get((event.type, state_key)) == event.event_id
- ), "State at event did not include event itself"
+ assert state_map.get((event.type, state_key)) == event.event_id, (
+ "State at event did not include event itself"
+ )
# ... but we need the state *before* that event
if "replaces_state" in event.unsigned:
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index c85deaed56..1e738f484f 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -151,6 +151,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
+ self._state_store = hs.get_datastores().state
+ self._state_deletion_store = hs.get_datastores().state_deletion
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -580,7 +582,9 @@ class FederationEventHandler:
room_version.identifier,
state_maps_to_resolve,
event_map=None,
- state_res_store=StateResolutionStore(self._store),
+ state_res_store=StateResolutionStore(
+ self._store, self._state_deletion_store
+ ),
)
)
else:
@@ -1179,7 +1183,9 @@ class FederationEventHandler:
room_version,
state_maps,
event_map={event_id: event},
- state_res_store=StateResolutionStore(self._store),
+ state_res_store=StateResolutionStore(
+ self._store, self._state_deletion_store
+ ),
)
except Exception as e:
@@ -1874,7 +1880,9 @@ class FederationEventHandler:
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
- state_res_store=StateResolutionStore(self._store),
+ state_res_store=StateResolutionStore(
+ self._store, self._state_deletion_store
+ ),
)
)
else:
@@ -2014,7 +2022,9 @@ class FederationEventHandler:
room_version,
state_sets,
event_map=None,
- state_res_store=StateResolutionStore(self._store),
+ state_res_store=StateResolutionStore(
+ self._store, self._state_deletion_store
+ ),
)
)
else:
@@ -2272,8 +2282,9 @@ class FederationEventHandler:
event_and_contexts, backfilled=backfilled
)
- # After persistence we always need to notify replication there may
- # be new data.
+ # After persistence, we never notify clients (wake up `/sync` streams) about
+ # backfilled events but it's important to let all the workers know about any
+ # new event (backfilled or not) because TODO
self._notifier.notify_replication()
if self._ephemeral_messages_enabled:
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
deleted file mode 100644
index cb31d65aa9..0000000000
--- a/synapse/handlers/identity.py
+++ /dev/null
@@ -1,811 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-"""Utilities for interacting with Identity Servers"""
-import logging
-import urllib.parse
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
-
-import attr
-
-from synapse.api.errors import (
- CodeMessageException,
- Codes,
- HttpResponseException,
- SynapseError,
-)
-from synapse.api.ratelimiting import Ratelimiter
-from synapse.http import RequestTimedOutError
-from synapse.http.client import SimpleHttpClient
-from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, Requester
-from synapse.util import json_decoder
-from synapse.util.hash import sha256_and_url_safe_base64
-from synapse.util.stringutils import (
- assert_valid_client_secret,
- random_string,
- valid_id_server_location,
-)
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-id_server_scheme = "https://"
-
-
-class IdentityHandler:
- def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastores().main
- # An HTTP client for contacting trusted URLs.
- self.http_client = SimpleHttpClient(hs)
- # An HTTP client for contacting identity servers specified by clients.
- self._http_client = SimpleHttpClient(
- hs,
- ip_blocklist=hs.config.server.federation_ip_range_blocklist,
- ip_allowlist=hs.config.server.federation_ip_range_allowlist,
- )
- self.federation_http_client = hs.get_federation_http_client()
- self.hs = hs
-
- self._web_client_location = hs.config.email.invite_client_location
-
- # Ratelimiters for `/requestToken` endpoints.
- self._3pid_validation_ratelimiter_ip = Ratelimiter(
- store=self.store,
- clock=hs.get_clock(),
- cfg=hs.config.ratelimiting.rc_3pid_validation,
- )
- self._3pid_validation_ratelimiter_address = Ratelimiter(
- store=self.store,
- clock=hs.get_clock(),
- cfg=hs.config.ratelimiting.rc_3pid_validation,
- )
-
- async def ratelimit_request_token_requests(
- self,
- request: SynapseRequest,
- medium: str,
- address: str,
- ) -> None:
- """Used to ratelimit requests to `/requestToken` by IP and address.
-
- Args:
- request: The associated request
- medium: The type of threepid, e.g. "msisdn" or "email"
- address: The actual threepid ID, e.g. the phone number or email address
- """
-
- await self._3pid_validation_ratelimiter_ip.ratelimit(
- None, (medium, request.getClientAddress().host)
- )
- await self._3pid_validation_ratelimiter_address.ratelimit(
- None, (medium, address)
- )
-
- async def threepid_from_creds(
- self, id_server: str, creds: Dict[str, str]
- ) -> Optional[JsonDict]:
- """
- Retrieve and validate a threepid identifier from a "credentials" dictionary against a
- given identity server
-
- Args:
- id_server: The identity server to validate 3PIDs against. Must be a
- complete URL including the protocol (http(s)://)
- creds: Dictionary containing the following keys:
- * client_secret|clientSecret: A unique secret str provided by the client
- * sid: The ID of the validation session
-
- Returns:
- A dictionary consisting of response params to the /getValidated3pid
- endpoint of the Identity Service API, or None if the threepid was not found
- """
- client_secret = creds.get("client_secret") or creds.get("clientSecret")
- if not client_secret:
- raise SynapseError(
- 400, "Missing param client_secret in creds", errcode=Codes.MISSING_PARAM
- )
- assert_valid_client_secret(client_secret)
-
- session_id = creds.get("sid")
- if not session_id:
- raise SynapseError(
- 400, "Missing param session_id in creds", errcode=Codes.MISSING_PARAM
- )
-
- query_params = {"sid": session_id, "client_secret": client_secret}
-
- url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
-
- try:
- data = await self.http_client.get_json(url, query_params)
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
- except HttpResponseException as e:
- logger.info(
- "%s returned %i for threepid validation for: %s",
- id_server,
- e.code,
- creds,
- )
- return None
-
- # Old versions of Sydent return a 200 http code even on a failed validation
- # check. Thus, in addition to the HttpResponseException check above (which
- # checks for non-200 errors), we need to make sure validation_session isn't
- # actually an error, identified by the absence of a "medium" key
- # See https://github.com/matrix-org/sydent/issues/215 for details
- if "medium" in data:
- return data
-
- logger.info("%s reported non-validated threepid: %s", id_server, creds)
- return None
-
- async def bind_threepid(
- self,
- client_secret: str,
- sid: str,
- mxid: str,
- id_server: str,
- id_access_token: str,
- ) -> JsonDict:
- """Bind a 3PID to an identity server
-
- Args:
- client_secret: A unique secret provided by the client
- sid: The ID of the validation session
- mxid: The MXID to bind the 3PID to
- id_server: The domain of the identity server to query
- id_access_token: The access token to authenticate to the identity
- server with
-
- Raises:
- SynapseError: On any of the following conditions
- - the supplied id_server is not a valid identity server name
- - we failed to contact the supplied identity server
-
- Returns:
- The response from the identity server
- """
- logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
-
- if not valid_id_server_location(id_server):
- raise SynapseError(
- 400,
- "id_server must be a valid hostname with optional port and path components",
- )
-
- bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
- headers = {"Authorization": create_id_access_token_header(id_access_token)}
-
- try:
- # Use the blacklisting http client as this call is only to identity servers
- # provided by a client
- data = await self._http_client.post_json_get_json(
- bind_url, bind_data, headers=headers
- )
-
- # Remember where we bound the threepid
- await self.store.add_user_bound_threepid(
- user_id=mxid,
- medium=data["medium"],
- address=data["address"],
- id_server=id_server,
- )
-
- return data
- except HttpResponseException as e:
- logger.error("3PID bind failed with Matrix error: %r", e)
- raise e.to_synapse_error()
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
- except CodeMessageException as e:
- data = json_decoder.decode(e.msg) # XXX WAT?
- return data
-
- async def try_unbind_threepid(
- self, mxid: str, medium: str, address: str, id_server: Optional[str]
- ) -> bool:
- """Attempt to remove a 3PID from one or more identity servers.
-
- Args:
- mxid: Matrix user ID of binding to be removed
- medium: The medium of the third-party ID.
- address: The address of the third-party ID.
- id_server: An identity server to attempt to unbind from. If None,
- attempt to remove the association from all identity servers
- known to potentially have it.
-
- Raises:
- SynapseError: If we failed to contact one or more identity servers.
-
- Returns:
- True on success, otherwise False if the identity server doesn't
- support unbinding (or no identity server to contact was found).
- """
- if id_server:
- id_servers = [id_server]
- else:
- id_servers = await self.store.get_id_servers_user_bound(
- mxid, medium, address
- )
-
- # We don't know where to unbind, so we don't have a choice but to return
- if not id_servers:
- return False
-
- changed = True
- for id_server in id_servers:
- changed &= await self._try_unbind_threepid_with_id_server(
- mxid, medium, address, id_server
- )
-
- return changed
-
- async def _try_unbind_threepid_with_id_server(
- self, mxid: str, medium: str, address: str, id_server: str
- ) -> bool:
- """Removes a binding from an identity server
-
- Args:
- mxid: Matrix user ID of binding to be removed
- medium: The medium of the third-party ID
- address: The address of the third-party ID
- id_server: Identity server to unbind from
-
- Raises:
- SynapseError: On any of the following conditions
- - the supplied id_server is not a valid identity server name
- - we failed to contact the supplied identity server
-
- Returns:
- True on success, otherwise False if the identity
- server doesn't support unbinding
- """
-
- if not valid_id_server_location(id_server):
- raise SynapseError(
- 400,
- "id_server must be a valid hostname with optional port and path components",
- )
-
- url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,)
- url_bytes = b"/_matrix/identity/v2/3pid/unbind"
-
- content = {
- "mxid": mxid,
- "threepid": {"medium": medium, "address": address},
- }
-
- # we abuse the federation http client to sign the request, but we have to send it
- # using the normal http client since we don't want the SRV lookup and want normal
- # 'browser-like' HTTPS.
- auth_headers = self.federation_http_client.build_auth_headers(
- destination=None,
- method=b"POST",
- url_bytes=url_bytes,
- content=content,
- destination_is=id_server.encode("ascii"),
- )
- headers = {b"Authorization": auth_headers}
-
- try:
- # Use the blacklisting http client as this call is only to identity servers
- # provided by a client
- await self._http_client.post_json_get_json(url, content, headers)
- changed = True
- except HttpResponseException as e:
- changed = False
- if e.code in (400, 404, 501):
- # The remote server probably doesn't support unbinding (yet)
- logger.warning("Received %d response while unbinding threepid", e.code)
- else:
- logger.error("Failed to unbind threepid on identity server: %s", e)
- raise SynapseError(500, "Failed to contact identity server")
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
-
- await self.store.remove_user_bound_threepid(mxid, medium, address, id_server)
-
- return changed
-
- async def send_threepid_validation(
- self,
- email_address: str,
- client_secret: str,
- send_attempt: int,
- send_email_func: Callable[[str, str, str, str], Awaitable],
- next_link: Optional[str] = None,
- ) -> str:
- """Send a threepid validation email for password reset or
- registration purposes
-
- Args:
- email_address: The user's email address
- client_secret: The provided client secret
- send_attempt: Which send attempt this is
- send_email_func: A function that takes an email address, token,
- client_secret and session_id, sends an email
- and returns an Awaitable.
- next_link: The URL to redirect the user to after validation
-
- Returns:
- The new session_id upon success
-
- Raises:
- SynapseError is an error occurred when sending the email
- """
- # Check that this email/client_secret/send_attempt combo is new or
- # greater than what we've seen previously
- session = await self.store.get_threepid_validation_session(
- "email", client_secret, address=email_address, validated=False
- )
-
- # Check to see if a session already exists and that it is not yet
- # marked as validated
- if session and session.validated_at is None:
- session_id = session.session_id
- last_send_attempt = session.last_send_attempt
-
- # Check that the send_attempt is higher than previous attempts
- if send_attempt <= last_send_attempt:
- # If not, just return a success without sending an email
- return session_id
- else:
- # An non-validated session does not exist yet.
- # Generate a session id
- session_id = random_string(16)
-
- if next_link:
- # Manipulate the next_link to add the sid, because the caller won't get
- # it until we send a response, by which time we've sent the mail.
- if "?" in next_link:
- next_link += "&"
- else:
- next_link += "?"
- next_link += "sid=" + urllib.parse.quote(session_id)
-
- # Generate a new validation token
- token = random_string(32)
-
- # Send the mail with the link containing the token, client_secret
- # and session_id
- try:
- await send_email_func(email_address, token, client_secret, session_id)
- except Exception:
- logger.exception(
- "Error sending threepid validation email to %s", email_address
- )
- raise SynapseError(500, "An error was encountered when sending the email")
-
- token_expires = (
- self.hs.get_clock().time_msec()
- + self.hs.config.email.email_validation_token_lifetime
- )
-
- await self.store.start_or_continue_validation_session(
- "email",
- email_address,
- session_id,
- client_secret,
- send_attempt,
- next_link,
- token,
- token_expires,
- )
-
- return session_id
-
- async def requestMsisdnToken(
- self,
- id_server: str,
- country: str,
- phone_number: str,
- client_secret: str,
- send_attempt: int,
- next_link: Optional[str] = None,
- ) -> JsonDict:
- """
- Request an external server send an SMS message on our behalf for the purposes of
- threepid validation.
- Args:
- id_server: The identity server to proxy to
- country: The country code of the phone number
- phone_number: The number to send the message to
- client_secret: The unique client_secret sends by the user
- send_attempt: Which attempt this is
- next_link: A link to redirect the user to once they submit the token
-
- Returns:
- The json response body from the server
- """
- params = {
- "country": country,
- "phone_number": phone_number,
- "client_secret": client_secret,
- "send_attempt": send_attempt,
- }
- if next_link:
- params["next_link"] = next_link
-
- try:
- data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
- params,
- )
- except HttpResponseException as e:
- logger.info("Proxied requestToken failed: %r", e)
- raise e.to_synapse_error()
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
-
- # we need to tell the client to send the token back to us, since it doesn't
- # otherwise know where to send it, so add submit_url response parameter
- # (see also MSC2078)
- data["submit_url"] = (
- self.hs.config.server.public_baseurl
- + "_matrix/client/unstable/add_threepid/msisdn/submit_token"
- )
- return data
-
- async def validate_threepid_session(
- self, client_secret: str, sid: str
- ) -> Optional[JsonDict]:
- """Validates a threepid session with only the client secret and session ID
- Tries validating against any configured account_threepid_delegates as well as locally.
-
- Args:
- client_secret: A secret provided by the client
- sid: The ID of the session
-
- Returns:
- The json response if validation was successful, otherwise None
- """
- # XXX: We shouldn't need to keep wrapping and unwrapping this value
- threepid_creds = {"client_secret": client_secret, "sid": sid}
-
- # We don't actually know which medium this 3PID is. Thus we first assume it's email,
- # and if validation fails we try msisdn
-
- # Try to validate as email
- if self.hs.config.email.can_verify_email:
- # Get a validated session matching these details
- validation_session = await self.store.get_threepid_validation_session(
- "email", client_secret, sid=sid, validated=True
- )
- if validation_session:
- return attr.asdict(validation_session)
-
- # Try to validate as msisdn
- if self.hs.config.registration.account_threepid_delegate_msisdn:
- # Ask our delegated msisdn identity server
- return await self.threepid_from_creds(
- self.hs.config.registration.account_threepid_delegate_msisdn,
- threepid_creds,
- )
-
- return None
-
- async def proxy_msisdn_submit_token(
- self, id_server: str, client_secret: str, sid: str, token: str
- ) -> JsonDict:
- """Proxy a POST submitToken request to an identity server for verification purposes
-
- Args:
- id_server: The identity server URL to contact
- client_secret: Secret provided by the client
- sid: The ID of the session
- token: The verification token
-
- Raises:
- SynapseError: If we failed to contact the identity server
-
- Returns:
- The response dict from the identity server
- """
- body = {"client_secret": client_secret, "sid": sid, "token": token}
-
- try:
- return await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
- body,
- )
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
- except HttpResponseException as e:
- logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
- raise SynapseError(400, "Error contacting the identity server")
-
- async def lookup_3pid(
- self, id_server: str, medium: str, address: str, id_access_token: str
- ) -> Optional[str]:
- """Looks up a 3pid in the passed identity server.
-
- Args:
- id_server: The server name (including port, if required)
- of the identity server to use.
- medium: The type of the third party identifier (e.g. "email").
- address: The third party identifier (e.g. "foo@example.com").
- id_access_token: The access token to authenticate to the identity
- server with
-
- Returns:
- the matrix ID of the 3pid, or None if it is not recognized.
- """
-
- try:
- results = await self._lookup_3pid_v2(
- id_server, id_access_token, medium, address
- )
- return results
- except Exception as e:
- logger.warning("Error when looking up hashing details: %s", e)
- return None
-
- async def _lookup_3pid_v2(
- self, id_server: str, id_access_token: str, medium: str, address: str
- ) -> Optional[str]:
- """Looks up a 3pid in the passed identity server using v2 lookup.
-
- Args:
- id_server: The server name (including port, if required)
- of the identity server to use.
- id_access_token: The access token to authenticate to the identity server with
- medium: The type of the third party identifier (e.g. "email").
- address: The third party identifier (e.g. "foo@example.com").
-
- Returns:
- the matrix ID of the 3pid, or None if it is not recognised.
- """
- # Check what hashing details are supported by this identity server
- try:
- hash_details = await self._http_client.get_json(
- "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
- {"access_token": id_access_token},
- )
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
-
- if not isinstance(hash_details, dict):
- logger.warning(
- "Got non-dict object when checking hash details of %s%s: %s",
- id_server_scheme,
- id_server,
- hash_details,
- )
- raise SynapseError(
- 400,
- "Non-dict object from %s%s during v2 hash_details request: %s"
- % (id_server_scheme, id_server, hash_details),
- )
-
- # Extract information from hash_details
- supported_lookup_algorithms = hash_details.get("algorithms")
- lookup_pepper = hash_details.get("lookup_pepper")
- if (
- not supported_lookup_algorithms
- or not isinstance(supported_lookup_algorithms, list)
- or not lookup_pepper
- or not isinstance(lookup_pepper, str)
- ):
- raise SynapseError(
- 400,
- "Invalid hash details received from identity server %s%s: %s"
- % (id_server_scheme, id_server, hash_details),
- )
-
- # Check if any of the supported lookup algorithms are present
- if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
- # Perform a hashed lookup
- lookup_algorithm = LookupAlgorithm.SHA256
-
- # Hash address, medium and the pepper with sha256
- to_hash = "%s %s %s" % (address, medium, lookup_pepper)
- lookup_value = sha256_and_url_safe_base64(to_hash)
-
- elif LookupAlgorithm.NONE in supported_lookup_algorithms:
- # Perform a non-hashed lookup
- lookup_algorithm = LookupAlgorithm.NONE
-
- # Combine together plaintext address and medium
- lookup_value = "%s %s" % (address, medium)
-
- else:
- logger.warning(
- "None of the provided lookup algorithms of %s are supported: %s",
- id_server,
- supported_lookup_algorithms,
- )
- raise SynapseError(
- 400,
- "Provided identity server does not support any v2 lookup "
- "algorithms that this homeserver supports.",
- )
-
- # Authenticate with identity server given the access token from the client
- headers = {"Authorization": create_id_access_token_header(id_access_token)}
-
- try:
- lookup_results = await self._http_client.post_json_get_json(
- "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
- {
- "addresses": [lookup_value],
- "algorithm": lookup_algorithm,
- "pepper": lookup_pepper,
- },
- headers=headers,
- )
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
- except Exception as e:
- logger.warning("Error when performing a v2 3pid lookup: %s", e)
- raise SynapseError(
- 500, "Unknown error occurred during identity server lookup"
- )
-
- # Check for a mapping from what we looked up to an MXID
- if "mappings" not in lookup_results or not isinstance(
- lookup_results["mappings"], dict
- ):
- logger.warning("No results from 3pid lookup")
- return None
-
- # Return the MXID if it's available, or None otherwise
- mxid = lookup_results["mappings"].get(lookup_value)
- return mxid
-
- async def ask_id_server_for_third_party_invite(
- self,
- requester: Requester,
- id_server: str,
- medium: str,
- address: str,
- room_id: str,
- inviter_user_id: str,
- room_alias: str,
- room_avatar_url: str,
- room_join_rules: str,
- room_name: str,
- room_type: Optional[str],
- inviter_display_name: str,
- inviter_avatar_url: str,
- id_access_token: str,
- ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
- """
- Asks an identity server for a third party invite.
-
- Args:
- requester
- id_server: hostname + optional port for the identity server.
- medium: The literal string "email".
- address: The third party address being invited.
- room_id: The ID of the room to which the user is invited.
- inviter_user_id: The user ID of the inviter.
- room_alias: An alias for the room, for cosmetic notifications.
- room_avatar_url: The URL of the room's avatar, for cosmetic
- notifications.
- room_join_rules: The join rules of the email (e.g. "public").
- room_name: The m.room.name of the room.
- room_type: The type of the room from its m.room.create event (e.g "m.space").
- inviter_display_name: The current display name of the
- inviter.
- inviter_avatar_url: The URL of the inviter's avatar.
- id_access_token: The access token to authenticate to the identity
- server with
-
- Returns:
- A tuple containing:
- token: The token which must be signed to prove authenticity.
- public_keys ([{"public_key": str, "key_validity_url": str}]):
- public_key is a base64-encoded ed25519 public key.
- fallback_public_key: One element from public_keys.
- display_name: A user-friendly name to represent the invited user.
- """
- invite_config = {
- "medium": medium,
- "address": address,
- "room_id": room_id,
- "room_alias": room_alias,
- "room_avatar_url": room_avatar_url,
- "room_join_rules": room_join_rules,
- "room_name": room_name,
- "sender": inviter_user_id,
- "sender_display_name": inviter_display_name,
- "sender_avatar_url": inviter_avatar_url,
- }
-
- if room_type is not None:
- invite_config["room_type"] = room_type
-
- # If a custom web client location is available, include it in the request.
- if self._web_client_location:
- invite_config["org.matrix.web_client_location"] = self._web_client_location
-
- # Add the identity service access token to the JSON body and use the v2
- # Identity Service endpoints
- data = None
-
- key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
- id_server_scheme,
- id_server,
- )
-
- url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server)
- try:
- data = await self._http_client.post_json_get_json(
- url,
- invite_config,
- {"Authorization": create_id_access_token_header(id_access_token)},
- )
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
-
- token = data["token"]
- public_keys = data.get("public_keys", [])
- if "public_key" in data:
- fallback_public_key = {
- "public_key": data["public_key"],
- "key_validity_url": key_validity_url,
- }
- else:
- fallback_public_key = public_keys[0]
-
- if not public_keys:
- public_keys.append(fallback_public_key)
- display_name = data["display_name"]
- return token, public_keys, fallback_public_key, display_name
-
-
-def create_id_access_token_header(id_access_token: str) -> List[str]:
- """Create an Authorization header for passing to SimpleHttpClient as the header value
- of an HTTP request.
-
- Args:
- id_access_token: An identity server access token.
-
- Returns:
- The ascii-encoded bearer token encased in a list.
- """
- # Prefix with Bearer
- bearer_token = "Bearer %s" % id_access_token
-
- # Encode headers to standard ascii
- bearer_token.encode("ascii")
-
- # Return as a list as that's how SimpleHttpClient takes header values
- return [bearer_token]
-
-
-class LookupAlgorithm:
- """
- Supported hashing algorithms when performing a 3PID lookup.
-
- SHA256 - Hashing an (address, medium, pepper) combo with sha256, then url-safe base64
- encoding
- NONE - Not performing any hashing. Simply sending an (address, medium) combo in plaintext
- """
-
- SHA256 = "sha256"
- NONE = "none"
diff --git a/synapse/handlers/jwt.py b/synapse/handlers/jwt.py
index 5fa7a305ad..400f3a59aa 100644
--- a/synapse/handlers/jwt.py
+++ b/synapse/handlers/jwt.py
@@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple
from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
@@ -36,11 +36,12 @@ class JwtHandler:
self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
+ self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences
- def validate_login(self, login_submission: JsonDict) -> str:
+ def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]:
"""
Authenticates the user for the /login API
@@ -49,7 +50,8 @@ class JwtHandler:
(including 'type' and other relevant fields)
Returns:
- The user ID that is logging in.
+ A tuple of (user_id, display_name) of the user that is logging in.
+ If the JWT does not contain a display name, the second element of the tuple will be None.
Raises:
LoginError if there was an authentication problem.
@@ -109,4 +111,10 @@ class JwtHandler:
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
- return UserID(user, self.hs.hostname).to_string()
+ default_display_name = None
+ if self.jwt_display_name_claim:
+ display_name_claim = claims.get(self.jwt_display_name_claim)
+ if display_name_claim is not None:
+ default_display_name = display_name_claim
+
+ return UserID(user, self.hs.hostname).to_string(), default_display_name
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5aa48230ec..cb6de02309 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -143,9 +143,9 @@ class MessageHandler:
elif membership == Membership.LEAVE:
key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist.
- assert (
- membership_event_id is not None
- ), "check_user_in_room_or_world_readable returned invalid data"
+ assert membership_event_id is not None, (
+ "check_user_in_room_or_world_readable returned invalid data"
+ )
room_state = await self._state_storage_controller.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
@@ -196,7 +196,9 @@ class MessageHandler:
AuthError (403) if the user doesn't have permission to view
members of this room.
"""
- state_filter = state_filter or StateFilter.all()
+ if state_filter is None:
+ state_filter = StateFilter.all()
+
user_id = requester.user.to_string()
if at_token:
@@ -240,9 +242,9 @@ class MessageHandler:
room_state = await self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
# If the membership is not JOIN, then the event ID should exist.
- assert (
- membership_event_id is not None
- ), "check_user_in_room_or_world_readable returned invalid data"
+ assert membership_event_id is not None, (
+ "check_user_in_room_or_world_readable returned invalid data"
+ )
room_state_events = (
await self._state_storage_controller.get_state_for_events(
[membership_event_id], state_filter=state_filter
@@ -493,6 +495,7 @@ class EventCreationHandler:
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
self._worker_lock_handler = hs.get_worker_locks_handler()
+ self._policy_handler = hs.get_room_policy_handler()
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
@@ -642,11 +645,33 @@ class EventCreationHandler:
"""
await self.auth_blocking.check_auth_blocking(requester=requester)
- if event_dict["type"] == EventTypes.Message:
- requester_suspended = await self.store.get_user_suspended_status(
- requester.user.to_string()
- )
- if requester_suspended:
+ requester_suspended = await self.store.get_user_suspended_status(
+ requester.user.to_string()
+ )
+ if requester_suspended:
+ # We want to allow suspended users to perform "corrective" actions
+ # asked of them by server admins, such as redact their messages and
+ # leave rooms.
+ if event_dict["type"] in ["m.room.redaction", "m.room.member"]:
+ if event_dict["type"] == "m.room.redaction":
+ event = await self.store.get_event(
+ event_dict["content"]["redacts"], allow_none=True
+ )
+ if event:
+ if event.sender != requester.user.to_string():
+ raise SynapseError(
+ 403,
+ "You can only redact your own events while account is suspended.",
+ Codes.USER_ACCOUNT_SUSPENDED,
+ )
+ if event_dict["type"] == "m.room.member":
+ if event_dict["content"]["membership"] != "leave":
+ raise SynapseError(
+ 403,
+ "Changing membership while account is suspended is not allowed.",
+ Codes.USER_ACCOUNT_SUSPENDED,
+ )
+ else:
raise SynapseError(
403,
"Sending messages while account is suspended is not allowed.",
@@ -1084,6 +1109,18 @@ class EventCreationHandler:
event.sender,
)
+ policy_allowed = await self._policy_handler.is_event_allowed(event)
+ if not policy_allowed:
+ logger.warning(
+ "Event not allowed by policy server, rejecting %s",
+ event.event_id,
+ )
+ raise SynapseError(
+ 403,
+ "This message has been rejected as probable spam",
+ Codes.FORBIDDEN,
+ )
+
spam_check_result = (
await self._spam_checker_module_callbacks.check_event_for_spam(
event
@@ -1095,7 +1132,7 @@ class EventCreationHandler:
[code, dict] = spam_check_result
raise SynapseError(
403,
- "This message had been rejected as probable spam",
+ "This message has been rejected as probable spam",
code,
dict,
)
@@ -1225,10 +1262,9 @@ class EventCreationHandler:
)
if prev_event_ids is not None:
- assert (
- len(prev_event_ids) <= 10
- ), "Attempting to create an event with %i prev_events" % (
- len(prev_event_ids),
+ assert len(prev_event_ids) <= 10, (
+ "Attempting to create an event with %i prev_events"
+ % (len(prev_event_ids),)
)
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
@@ -1243,12 +1279,14 @@ class EventCreationHandler:
# Allow an event to have empty list of prev_event_ids
# only if it has auth_event_ids.
or auth_event_ids
- ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids"
+ ), (
+ "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids"
+ )
else:
# we now ought to have some prev_events (unless it's a create event).
- assert (
- builder.type == EventTypes.Create or prev_event_ids
- ), "Attempting to create a non-m.room.create event with no prev_events"
+ assert builder.type == EventTypes.Create or prev_event_ids, (
+ "Attempting to create a non-m.room.create event with no prev_events"
+ )
if for_batch:
assert prev_event_ids is not None
@@ -1439,6 +1477,12 @@ class EventCreationHandler:
)
return prev_event
+ if not event.is_state() and event.type in [
+ EventTypes.Message,
+ EventTypes.Encrypted,
+ ]:
+ await self.store.set_room_participation(event.user_id, event.room_id)
+
if event.internal_metadata.is_out_of_band_membership():
# the only sort of out-of-band-membership events we expect to see here are
# invite rejections and rescinded knocks that we have generated ourselves.
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 22b59829fa..4b85282c1e 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -31,6 +31,7 @@ from typing import (
List,
Optional,
Type,
+ TypedDict,
TypeVar,
Union,
)
@@ -52,7 +53,6 @@ from pymacaroons.exceptions import (
MacaroonInitException,
MacaroonInvalidSignatureException,
)
-from typing_extensions import TypedDict
from twisted.web.client import readBody
from twisted.web.http_headers import Headers
@@ -382,7 +382,12 @@ class OidcProvider:
self._macaroon_generaton = macaroon_generator
self._config = provider
- self._callback_url: str = hs.config.oidc.oidc_callback_url
+
+ self._callback_url: str
+ if provider.redirect_uri is not None:
+ self._callback_url = provider.redirect_uri
+ else:
+ self._callback_url = hs.config.oidc.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
@@ -462,6 +467,10 @@ class OidcProvider:
self._sso_handler.register_identity_provider(self)
+ self.passthrough_authorization_parameters = (
+ provider.passthrough_authorization_parameters
+ )
+
def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
"""Verifies the provider metadata.
@@ -578,6 +587,24 @@ class OidcProvider:
)
@property
+ def _uses_access_token(self) -> bool:
+ """Return True if the `access_token` will be used during the login process.
+
+ This is useful to determine whether the access token
+ returned by the identity provider, and
+ any related metadata (such as the `at_hash` field in
+ the ID token), should be validated.
+ """
+ # Currently, Synapse only uses the access_token to fetch user metadata
+ # from the userinfo endpoint. Therefore we only have a single criteria
+ # to check right now but this may change in the future and this function
+ # should be updated if more usages are introduced.
+ #
+ # For example, if we start to use the access_token given to us by the
+ # IdP for more things, such as accessing Resource Server APIs.
+ return self._uses_userinfo
+
+ @property
def issuer(self) -> str:
"""The issuer identifying this provider."""
return self._config.issuer
@@ -640,6 +667,11 @@ class OidcProvider:
elif self._config.pkce_method == "never":
metadata.pop("code_challenge_methods_supported", None)
+ if self._config.id_token_signing_alg_values_supported:
+ metadata["id_token_signing_alg_values_supported"] = (
+ self._config.id_token_signing_alg_values_supported
+ )
+
self._validate_metadata(metadata)
return metadata
@@ -943,9 +975,16 @@ class OidcProvider:
"nonce": nonce,
"client_id": self._client_auth.client_id,
}
- if "access_token" in token:
+ if self._uses_access_token and "access_token" in token:
# If we got an `access_token`, there should be an `at_hash` claim
- # in the `id_token` that we can check against.
+ # in the `id_token` that we can check against. Setting this
+ # instructs authlib to check the value of `at_hash` in the
+ # ID token.
+ #
+ # We only need to verify the access token if we actually make
+ # use of it. Which currently only happens when we need to fetch
+ # the user's information from the userinfo_endpoint. Thus, this
+ # check is also gated on self._uses_userinfo.
claims_params["access_token"] = token["access_token"]
claims_options = {"iss": {"values": [metadata["issuer"]]}}
@@ -995,14 +1034,27 @@ class OidcProvider:
when everything is done (or None for UI Auth)
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
-
Returns:
The redirect URL to the authorization endpoint.
"""
state = generate_token()
- nonce = generate_token()
+
+ # Generate a nonce 32 characters long. When encoded with base64url later on,
+ # the nonce will be 43 characters when sent to the identity provider.
+ #
+ # While RFC7636 does not specify a minimum length for the `nonce`
+ # parameter, the TI-Messenger IDP_FD spec v1.7.3 does require it to be
+ # between 43 and 128 characters. This spec concerns using Matrix for
+ # communication in German healthcare.
+ #
+ # As increasing the length only strengthens security, we use this length
+ # to allow TI-Messenger deployments using Synapse to satisfy this
+ # external spec.
+ #
+ # See https://github.com/element-hq/synapse/pull/18109 for more context.
+ nonce = generate_token(length=32)
code_verifier = ""
if not client_redirect_url:
@@ -1054,6 +1106,13 @@ class OidcProvider:
)
)
+ # add passthrough additional authorization parameters
+ passthrough_authorization_parameters = self.passthrough_authorization_parameters
+ for parameter in passthrough_authorization_parameters:
+ parameter_value = parse_string(request, parameter)
+ if parameter_value:
+ additional_authorization_parameters.update({parameter: parameter_value})
+
authorization_endpoint = metadata.get("authorization_endpoint")
return prepare_grant_uri(
authorization_endpoint,
@@ -1716,17 +1775,12 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- emails: List[str] = []
- email = render_template_field(self._config.email_template)
- if email:
- emails.append(email)
-
picture = self._config.picture_template.render(user=userinfo).strip()
return UserAttributeDict(
localpart=localpart,
display_name=display_name,
- emails=emails,
+ emails=[], # 3PIDs are not supported
picture=picture,
confirm_localpart=self._config.confirm_localpart,
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6fd7afa280..365c9cabcb 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -507,15 +507,17 @@ class PaginationHandler:
# Initially fetch the events from the database. With any luck, we can return
# these without blocking on backfill (handled below).
- events, next_key = (
- await self.store.paginate_room_events_by_topological_ordering(
- room_id=room_id,
- from_key=from_token.room_key,
- to_key=to_room_key,
- direction=pagin_config.direction,
- limit=pagin_config.limit,
- event_filter=event_filter,
- )
+ (
+ events,
+ next_key,
+ limited,
+ ) = await self.store.paginate_room_events_by_topological_ordering(
+ room_id=room_id,
+ from_key=from_token.room_key,
+ to_key=to_room_key,
+ direction=pagin_config.direction,
+ limit=pagin_config.limit,
+ event_filter=event_filter,
)
if pagin_config.direction == Direction.BACKWARDS:
@@ -575,25 +577,31 @@ class PaginationHandler:
or missing_too_many_events
or not_enough_events_to_fill_response
):
- did_backfill = await self.hs.get_federation_handler().maybe_backfill(
+ # Historical Note: There used to be a check here for if backfill was
+ # successful or not
+ await self.hs.get_federation_handler().maybe_backfill(
room_id,
curr_topo,
limit=pagin_config.limit,
)
- # If we did backfill something, refetch the events from the database to
- # catch anything new that might have been added since we last fetched.
- if did_backfill:
- events, next_key = (
- await self.store.paginate_room_events_by_topological_ordering(
- room_id=room_id,
- from_key=from_token.room_key,
- to_key=to_room_key,
- direction=pagin_config.direction,
- limit=pagin_config.limit,
- event_filter=event_filter,
- )
- )
+ # Regardless if we backfilled or not, another worker or even a
+ # simultaneous request may have backfilled for us while we were held
+ # behind the linearizer. This should not have too much additional
+ # database load as it will only be triggered if a backfill *might* have
+ # been needed
+ (
+ events,
+ next_key,
+ limited,
+ ) = await self.store.paginate_room_events_by_topological_ordering(
+ room_id=room_id,
+ from_key=from_token.room_key,
+ to_key=to_room_key,
+ direction=pagin_config.direction,
+ limit=pagin_config.limit,
+ event_filter=event_filter,
+ )
else:
# Otherwise, we can backfill in the background for eventual
# consistency's sake but we don't need to block the client waiting
@@ -608,6 +616,15 @@ class PaginationHandler:
next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
+ # We might have hit some internal filtering first, for example rejected
+ # events. Ensure we return a pagination token then.
+ if not events and limited:
+ return {
+ "chunk": [],
+ "start": await from_token.to_string(self.store),
+ "end": await next_token.to_string(self.store),
+ }
+
# if no events are returned from pagination, that implies
# we have reached the end of the available events.
# In that case we do not return end, to tell the client
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 37ee625f71..390cafa8f6 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -71,6 +71,7 @@ user state; this device follows the normal timeout logic (see above) and will
automatically be replaced with any information from currently available devices.
"""
+
import abc
import contextlib
import itertools
@@ -493,9 +494,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
# The number of ongoing syncs on this process, by (user ID, device ID).
# Empty if _presence_enabled is false.
- self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = (
- {}
- )
+ self._user_device_to_num_current_syncs: Dict[
+ Tuple[str, Optional[str]], int
+ ] = {}
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
@@ -818,9 +819,9 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
- self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = (
- {}
- )
+ self._user_device_to_num_current_syncs: Dict[
+ Tuple[str, Optional[str]], int
+ ] = {}
# Keeps track of the number of *ongoing* syncs on other processes.
#
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 6663d4b271..cdc388b4ab 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -22,6 +22,7 @@ import logging
import random
from typing import TYPE_CHECKING, List, Optional, Union
+from synapse.api.constants import ProfileFields
from synapse.api.errors import (
AuthError,
Codes,
@@ -31,7 +32,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -42,6 +43,8 @@ logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
+# Field name length is specced at 255 bytes.
+MAX_CUSTOM_FIELD_LEN = 255
class ProfileHandler:
@@ -74,17 +77,42 @@ class ProfileHandler:
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict:
+ """
+ Get a user's profile as a JSON dictionary.
+
+ Args:
+ user_id: The user to fetch the profile of.
+ ignore_backoff: True to ignore backoff when fetching over federation.
+
+ Returns:
+ A JSON dictionary. For local queries this will include the displayname and avatar_url
+ fields, if set. For remote queries it may contain arbitrary information.
+ """
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
profileinfo = await self.store.get_profileinfo(target_user)
- if profileinfo.display_name is None and profileinfo.avatar_url is None:
+ extra_fields = {}
+ if self.hs.config.experimental.msc4133_enabled:
+ extra_fields = await self.store.get_profile_fields(target_user)
+
+ if (
+ profileinfo.display_name is None
+ and profileinfo.avatar_url is None
+ and not extra_fields
+ ):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
- return {
- "displayname": profileinfo.display_name,
- "avatar_url": profileinfo.avatar_url,
- }
+ # Do not include display name or avatar if unset.
+ ret = {}
+ if profileinfo.display_name is not None:
+ ret[ProfileFields.DISPLAYNAME] = profileinfo.display_name
+ if profileinfo.avatar_url is not None:
+ ret[ProfileFields.AVATAR_URL] = profileinfo.avatar_url
+ if extra_fields:
+ ret.update(extra_fields)
+
+ return ret
else:
try:
result = await self.federation.make_query(
@@ -107,6 +135,15 @@ class ProfileHandler:
raise e.to_synapse_error()
async def get_displayname(self, target_user: UserID) -> Optional[str]:
+ """
+ Fetch a user's display name from their profile.
+
+ Args:
+ target_user: The user to fetch the display name of.
+
+ Returns:
+ The user's display name or None if unset.
+ """
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(target_user)
@@ -203,6 +240,15 @@ class ProfileHandler:
await self._update_join_states(requester, target_user)
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
+ """
+ Fetch a user's avatar URL from their profile.
+
+ Args:
+ target_user: The user to fetch the avatar URL of.
+
+ Returns:
+ The user's avatar URL or None if unset.
+ """
if self.hs.is_mine(target_user):
try:
avatar_url = await self.store.get_profile_avatar_url(target_user)
@@ -322,9 +368,9 @@ class ProfileHandler:
server_name = host
if self._is_mine_server_name(server_name):
- media_info: Optional[Union[LocalMedia, RemoteMedia]] = (
- await self.store.get_local_media(media_id)
- )
+ media_info: Optional[
+ Union[LocalMedia, RemoteMedia]
+ ] = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -370,6 +416,110 @@ class ProfileHandler:
return True
+ async def get_profile_field(
+ self, target_user: UserID, field_name: str
+ ) -> JsonValue:
+ """
+ Fetch a user's profile from the database for local users and over federation
+ for remote users.
+
+ Args:
+ target_user: The user ID to fetch the profile for.
+ field_name: The field to fetch the profile for.
+
+ Returns:
+ The value for the profile field or None if the field does not exist.
+ """
+ if self.hs.is_mine(target_user):
+ try:
+ field_value = await self.store.get_profile_field(
+ target_user, field_name
+ )
+ except StoreError as e:
+ if e.code == 404:
+ raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
+ raise
+
+ return field_value
+ else:
+ try:
+ result = await self.federation.make_query(
+ destination=target_user.domain,
+ query_type="profile",
+ args={"user_id": target_user.to_string(), "field": field_name},
+ ignore_backoff=True,
+ )
+ except RequestSendFailed as e:
+ raise SynapseError(502, "Failed to fetch profile") from e
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+
+ return result.get(field_name)
+
+ async def set_profile_field(
+ self,
+ target_user: UserID,
+ requester: Requester,
+ field_name: str,
+ new_value: JsonValue,
+ by_admin: bool = False,
+ deactivation: bool = False,
+ ) -> None:
+ """Set a new profile field for a user.
+
+ Args:
+ target_user: the user whose profile is to be changed.
+ requester: The user attempting to make this change.
+ field_name: The name of the profile field to update.
+ new_value: The new field value for this user.
+ by_admin: Whether this change was made by an administrator.
+ deactivation: Whether this change was made while deactivating the user.
+ """
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ if not by_admin and target_user != requester.user:
+ raise AuthError(403, "Cannot set another user's profile")
+
+ await self.store.set_profile_field(target_user, field_name, new_value)
+
+ # Custom fields do not propagate into the user directory *or* rooms.
+ profile = await self.store.get_profileinfo(target_user)
+ await self._third_party_rules.on_profile_update(
+ target_user.to_string(), profile, by_admin, deactivation
+ )
+
+ async def delete_profile_field(
+ self,
+ target_user: UserID,
+ requester: Requester,
+ field_name: str,
+ by_admin: bool = False,
+ deactivation: bool = False,
+ ) -> None:
+ """Delete a field from a user's profile.
+
+ Args:
+ target_user: the user whose profile is to be changed.
+ requester: The user attempting to make this change.
+ field_name: The name of the profile field to remove.
+ by_admin: Whether this change was made by an administrator.
+ deactivation: Whether this change was made while deactivating the user.
+ """
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ if not by_admin and target_user != requester.user:
+ raise AuthError(400, "Cannot set another user's profile")
+
+ await self.store.delete_profile_field(target_user, field_name)
+
+ # Custom fields do not propagate into the user directory *or* rooms.
+ profile = await self.store.get_profileinfo(target_user)
+ await self._third_party_rules.on_profile_update(
+ target_user.to_string(), profile, by_admin, deactivation
+ )
+
async def on_profile_query(self, args: JsonDict) -> JsonDict:
"""Handles federation profile query requests."""
@@ -386,13 +536,24 @@ class ProfileHandler:
just_field = args.get("field", None)
- response = {}
+ response: JsonDict = {}
try:
- if just_field is None or just_field == "displayname":
+ if just_field is None or just_field == ProfileFields.DISPLAYNAME:
response["displayname"] = await self.store.get_profile_displayname(user)
- if just_field is None or just_field == "avatar_url":
+ if just_field is None or just_field == ProfileFields.AVATAR_URL:
response["avatar_url"] = await self.store.get_profile_avatar_url(user)
+
+ if self.hs.config.experimental.msc4133_enabled:
+ if just_field is None:
+ response.update(await self.store.get_profile_fields(user))
+ elif just_field not in (
+ ProfileFields.DISPLAYNAME,
+ ProfileFields.AVATAR_URL,
+ ):
+ response[just_field] = await self.store.get_profile_field(
+ user, just_field
+ )
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
@@ -403,6 +564,12 @@ class ProfileHandler:
async def _update_join_states(
self, requester: Requester, target_user: UserID
) -> None:
+ """
+ Update the membership events of each room the user is joined to with the
+ new profile information.
+
+ Note that this stomps over any custom display name or avatar URL in member events.
+ """
if not self.hs.is_mine(target_user):
return
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c200e29569..8dd687c455 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -23,10 +23,9 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, TypedDict
from prometheus_client import Counter
-from typing_extensions import TypedDict
from synapse import types
from synapse.api.constants import (
@@ -44,7 +43,6 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.appservice import ApplicationService
-from synapse.config.server import is_threepid_reserved
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -109,13 +107,13 @@ class RegistrationHandler:
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
- self.identity_handler = self.hs.get_identity_handler()
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._account_validity_handler = hs.get_account_validity_handler()
self._user_consent_version = self.hs.config.consent.user_consent_version
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self._server_name = hs.hostname
+ self._user_types_config = hs.config.user_types
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
@@ -160,7 +158,10 @@ class RegistrationHandler:
if not localpart:
raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME)
- if localpart[0] == "_":
+ if (
+ localpart[0] == "_"
+ and not self.hs.config.registration.allow_underscore_prefixed_localpart
+ ):
raise SynapseError(
400, "User ID may not begin with _", Codes.INVALID_USERNAME
)
@@ -304,6 +305,9 @@ class RegistrationHandler:
elif default_display_name is None:
default_display_name = localpart
+ if user_type is None:
+ user_type = self._user_types_config.default_user_type
+
await self.register_with_store(
user_id=user_id,
password_hash=password_hash,
@@ -380,19 +384,6 @@ class RegistrationHandler:
user_id,
)
- # Bind any specified emails to this account
- current_time = self.hs.get_clock().time_msec()
- for email in bind_emails:
- # generate threepid dict
- threepid_dict = {
- "medium": "email",
- "address": email,
- "validated_at": current_time,
- }
-
- # Bind email to new account
- await self._register_email_threepid(user_id, threepid_dict, None)
-
return user_id
async def _create_and_join_rooms(self, user_id: str) -> None:
@@ -630,7 +621,9 @@ class RegistrationHandler:
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart: str, as_token: str) -> str:
+ async def appservice_register(
+ self, user_localpart: str, as_token: str
+ ) -> Tuple[str, ApplicationService]:
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -653,7 +646,7 @@ class RegistrationHandler:
appservice_id=service_id,
create_profile_with_displayname=user.localpart,
)
- return user_id
+ return (user_id, service)
def check_user_id_not_appservice_exclusive(
self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
@@ -941,21 +934,6 @@ class RegistrationHandler:
)
return
- if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
- threepid = auth_result[LoginType.EMAIL_IDENTITY]
- # Necessary due to auth checks prior to the threepid being
- # written to the db
- if is_threepid_reserved(
- self.hs.config.server.mau_limits_reserved_threepids, threepid
- ):
- await self.store.upsert_monthly_active_user(user_id)
-
- await self._register_email_threepid(user_id, threepid, access_token)
-
- if auth_result and LoginType.MSISDN in auth_result:
- threepid = auth_result[LoginType.MSISDN]
- await self._register_msisdn_threepid(user_id, threepid)
-
if auth_result and LoginType.TERMS in auth_result:
# The terms type should only exist if consent is enabled.
assert self._user_consent_version is not None
@@ -971,86 +949,3 @@ class RegistrationHandler:
logger.info("%s has consented to the privacy policy", user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
-
- async def _register_email_threepid(
- self, user_id: str, threepid: dict, token: Optional[str]
- ) -> None:
- """Add an email address as a 3pid identifier
-
- Also adds an email pusher for the email address, if configured in the
- HS config
-
- Must be called on master.
-
- Args:
- user_id: id of user
- threepid: m.login.email.identity auth response
- token: access_token for the user, or None if not logged in.
- """
- reqd = ("medium", "address", "validated_at")
- if any(x not in threepid for x in reqd):
- # This will only happen if the ID server returns a malformed response
- logger.info("Can't add incomplete 3pid")
- return
-
- await self._auth_handler.add_threepid(
- user_id,
- threepid["medium"],
- threepid["address"],
- threepid["validated_at"],
- )
-
- # And we add an email pusher for them by default, but only
- # if email notifications are enabled (so people don't start
- # getting mail spam where they weren't before if email
- # notifs are set up on a homeserver)
- if (
- self.hs.config.email.email_enable_notifs
- and self.hs.config.email.email_notif_for_new_users
- and token
- ):
- # Pull the ID of the access token back out of the db
- # It would really make more sense for this to be passed
- # up when the access token is saved, but that's quite an
- # invasive change I'd rather do separately.
- user_tuple = await self.store.get_user_by_access_token(token)
- # The token better still exist.
- assert user_tuple
- device_id = user_tuple.device_id
-
- await self.pusher_pool.add_or_update_pusher(
- user_id=user_id,
- device_id=device_id,
- kind="email",
- app_id="m.email",
- app_display_name="Email Notifications",
- device_display_name=threepid["address"],
- pushkey=threepid["address"],
- lang=None,
- data={},
- )
-
- async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
- """Add a phone number as a 3pid identifier
-
- Must be called on master.
-
- Args:
- user_id: id of user
- threepid: m.login.msisdn auth response
- """
- try:
- assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
- except SynapseError as ex:
- if ex.errcode == Codes.MISSING_PARAM:
- # This will only happen if the ID server returns a malformed response
- logger.info("Can't add incomplete 3pid")
- return None
- raise
-
- await self._auth_handler.add_threepid(
- user_id,
- threepid["medium"],
- threepid["address"],
- threepid["validated_at"],
- )
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index efe31e81f9..b1158ee77d 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -188,13 +188,13 @@ class RelationsHandler:
if include_original_event:
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
- return_value["original_event"] = (
- await self._event_serializer.serialize_event(
- event,
- now,
- bundle_aggregations=None,
- config=serialize_options,
- )
+ return_value[
+ "original_event"
+ ] = await self._event_serializer.serialize_event(
+ event,
+ now,
+ bundle_aggregations=None,
+ config=serialize_options,
)
if next_token:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2c6e672ede..1ccb6f7171 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,6 +20,7 @@
#
"""Contains functions for performing actions on rooms."""
+
import itertools
import logging
import math
@@ -467,17 +468,6 @@ class RoomCreationHandler:
"""
user_id = requester.user.to_string()
- spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
- user_id
- )
- if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
- raise SynapseError(
- 403,
- "You are not permitted to create rooms",
- errcode=spam_check[0],
- additional_fields=spam_check[1],
- )
-
creation_content: JsonDict = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
@@ -584,6 +574,24 @@ class RoomCreationHandler:
if current_power_level_int < needed_power_level:
user_power_levels[user_id] = needed_power_level
+ # We construct what the body of a call to /createRoom would look like for passing
+ # to the spam checker. We don't include a preset here, as we expect the
+ # initial state to contain everything we need.
+ spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
+ user_id,
+ {
+ "creation_content": creation_content,
+ "initial_state": list(initial_state.items()),
+ },
+ )
+ if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
+ raise SynapseError(
+ 403,
+ "You are not permitted to create rooms",
+ errcode=spam_check[0],
+ additional_fields=spam_check[1],
+ )
+
await self._send_events_for_new_room(
requester,
new_room_id,
@@ -785,7 +793,7 @@ class RoomCreationHandler:
if not is_requester_admin:
spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
- user_id
+ user_id, config
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise SynapseError(
@@ -900,11 +908,9 @@ class RoomCreationHandler:
)
# Check whether this visibility value is blocked by a third party module
- allowed_by_third_party_rules = (
- await (
- self._third_party_event_rules.check_visibility_can_be_modified(
- room_id, visibility
- )
+ allowed_by_third_party_rules = await (
+ self._third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
)
)
if not allowed_by_third_party_rules:
@@ -1754,7 +1760,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
)
events = list(room_events)
- events.extend(e for evs, _ in room_to_events.values() for e in evs)
+ events.extend(e for evs, _, _ in room_to_events.values() for e in evs)
# We know stream_ordering must be not None here, as its been
# persisted, but mypy doesn't know that
@@ -1807,7 +1813,7 @@ class RoomShutdownHandler:
] = None,
) -> Optional[ShutdownRoomResponse]:
"""
- Shuts down a room. Moves all local users and room aliases automatically
+ Shuts down a room. Moves all joined local users and room aliases automatically
to a new room if `new_room_user_id` is set. Otherwise local users only
leave the room without any information.
@@ -1950,16 +1956,17 @@ class RoomShutdownHandler:
# Join users to new room
if new_room_user_id:
- assert new_room_id is not None
- await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=new_room_id,
- action=Membership.JOIN,
- content={},
- ratelimit=False,
- require_consent=False,
- )
+ if membership == Membership.JOIN:
+ assert new_room_id is not None
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=new_room_id,
+ action=Membership.JOIN,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
result["kicked_users"].append(user_id)
if update_result_fct:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 51b9772329..a3a7326d94 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -53,6 +53,7 @@ from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.push import ReplicationCopyPusherRestServlet
from synapse.storage.databases.main.state_deltas import StateDelta
+from synapse.storage.invite_rule import InviteRule
from synapse.types import (
JsonDict,
Requester,
@@ -98,7 +99,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.federation_handler = hs.get_federation_handler()
self.directory_handler = hs.get_directory_handler()
- self.identity_handler = hs.get_identity_handler()
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -122,7 +122,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
hs.get_module_api_callbacks().third_party_event_rules
)
self._server_notices_mxid = self.config.servernotices.server_notices_mxid
- self._enable_lookup = hs.config.registration.enable_3pid_lookup
self.allow_per_room_profiles = self.config.server.allow_per_room_profiles
self._join_rate_limiter_local = Ratelimiter(
@@ -158,6 +157,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
store=self.store,
clock=self.clock,
cfg=hs.config.ratelimiting.rc_invites_per_room,
+ ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit,
)
# Ratelimiter for invites, keyed by recipient (across all rooms, all
@@ -166,6 +166,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
store=self.store,
clock=self.clock,
cfg=hs.config.ratelimiting.rc_invites_per_user,
+ ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit,
)
# Ratelimiter for invites, keyed by issuer (across all rooms, all
@@ -174,6 +175,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
store=self.store,
clock=self.clock,
cfg=hs.config.ratelimiting.rc_invites_per_issuer,
+ ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit,
)
self._third_party_invite_limiter = Ratelimiter(
@@ -912,6 +914,21 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
additional_fields=block_invite_result[1],
)
+ # check the invitee's configuration and apply rules. Admins on the server can bypass.
+ if not is_requester_admin:
+ invite_config = await self.store.get_invite_config_for_user(target_id)
+ rule = invite_config.get_invite_rule(requester.user.to_string())
+ if rule == InviteRule.BLOCK:
+ logger.info(
+ f"Automatically rejecting invite from {target_id} due to the the invite filtering rules of {requester.user}"
+ )
+ raise SynapseError(
+ 403,
+ "You are not permitted to invite this user.",
+ errcode=Codes.INVITE_BLOCKED,
+ )
+ # InviteRule.IGNORE is handled at the sync layer.
+
# An empty prev_events list is allowed as long as the auth_event_ids are present
if prev_event_ids is not None:
return await self._local_membership_update(
@@ -1190,6 +1207,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
origin_server_ts=origin_server_ts,
)
+ async def check_for_any_membership_in_room(
+ self, *, user_id: str, room_id: str
+ ) -> None:
+ """
+ Check if the user has any membership in the room and raise error if not.
+
+ Args:
+ user_id: The user to check.
+ room_id: The room to check.
+
+ Raises:
+ AuthError if the user doesn't have any membership in the room.
+ """
+ result = await self.store.get_local_current_membership_for_user_in_room(
+ user_id=user_id, room_id=room_id
+ )
+
+ if result is None or result == (None, None):
+ raise AuthError(403, f"User {user_id} has no membership in room {room_id}")
+
async def _should_perform_remote_join(
self,
user_id: str,
@@ -1302,11 +1339,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If this is going to be a local join, additional information must
# be included in the event content in order to efficiently validate
# the event.
- content[EventContentFields.AUTHORISING_USER] = (
- await self.event_auth_handler.get_user_which_could_invite(
- room_id,
- state_before_join,
- )
+ content[
+ EventContentFields.AUTHORISING_USER
+ ] = await self.event_auth_handler.get_user_which_could_invite(
+ room_id,
+ state_before_join,
)
return False, []
@@ -1415,9 +1452,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if requester is not None:
sender = UserID.from_string(event.sender)
- assert (
- sender == requester.user
- ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
+ assert sender == requester.user, (
+ "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
+ )
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = types.create_requester(target_user)
@@ -1572,230 +1609,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return UserID.from_string(invite.sender)
return None
- async def do_3pid_invite(
- self,
- room_id: str,
- inviter: UserID,
- medium: str,
- address: str,
- id_server: str,
- requester: Requester,
- txn_id: Optional[str],
- id_access_token: str,
- prev_event_ids: Optional[List[str]] = None,
- depth: Optional[int] = None,
- ) -> Tuple[str, int]:
- """Invite a 3PID to a room.
-
- Args:
- room_id: The room to invite the 3PID to.
- inviter: The user sending the invite.
- medium: The 3PID's medium.
- address: The 3PID's address.
- id_server: The identity server to use.
- requester: The user making the request.
- txn_id: The transaction ID this is part of, or None if this is not
- part of a transaction.
- id_access_token: Identity server access token.
- depth: Override the depth used to order the event in the DAG.
- prev_event_ids: The event IDs to use as the prev events
- Should normally be set to None, which will cause the depth to be calculated
- based on the prev_events.
-
- Returns:
- Tuple of event ID and stream ordering position
-
- Raises:
- ShadowBanError if the requester has been shadow-banned.
- """
- if self.config.server.block_non_admin_invites:
- is_requester_admin = await self.auth.is_server_admin(requester)
- if not is_requester_admin:
- raise SynapseError(
- 403, "Invites have been disabled on this server", Codes.FORBIDDEN
- )
-
- if requester.shadow_banned:
- # We randomly sleep a bit just to annoy the requester.
- await self.clock.sleep(random.randint(1, 10))
- raise ShadowBanError()
-
- # We need to rate limit *before* we send out any 3PID invites, so we
- # can't just rely on the standard ratelimiting of events.
- await self._third_party_invite_limiter.ratelimit(requester)
-
- can_invite = await self._third_party_event_rules.check_threepid_can_be_invited(
- medium, address, room_id
- )
- if not can_invite:
- raise SynapseError(
- 403,
- "This third-party identifier can not be invited in this room",
- Codes.FORBIDDEN,
- )
-
- if not self._enable_lookup:
- raise SynapseError(
- 403, "Looking up third-party identifiers is denied from this server"
- )
-
- invitee = await self.identity_handler.lookup_3pid(
- id_server, medium, address, id_access_token
- )
-
- if invitee:
- # Note that update_membership with an action of "invite" can raise
- # a ShadowBanError, but this was done above already.
- # We don't check the invite against the spamchecker(s) here (through
- # user_may_invite) because we'll do it further down the line anyway (in
- # update_membership_locked).
- event_id, stream_id = await self.update_membership(
- requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
- )
- else:
- # Check if the spamchecker(s) allow this invite to go through.
- spam_check = (
- await self._spam_checker_module_callbacks.user_may_send_3pid_invite(
- inviter_userid=requester.user.to_string(),
- medium=medium,
- address=address,
- room_id=room_id,
- )
- )
- if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
- raise SynapseError(
- 403,
- "Cannot send threepid invite",
- errcode=spam_check[0],
- additional_fields=spam_check[1],
- )
-
- event, stream_id = await self._make_and_store_3pid_invite(
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter,
- txn_id=txn_id,
- id_access_token=id_access_token,
- prev_event_ids=prev_event_ids,
- depth=depth,
- )
- event_id = event.event_id
-
- return event_id, stream_id
-
- async def _make_and_store_3pid_invite(
- self,
- requester: Requester,
- id_server: str,
- medium: str,
- address: str,
- room_id: str,
- user: UserID,
- txn_id: Optional[str],
- id_access_token: str,
- prev_event_ids: Optional[List[str]] = None,
- depth: Optional[int] = None,
- ) -> Tuple[EventBase, int]:
- room_state = await self._storage_controllers.state.get_current_state(
- room_id,
- StateFilter.from_types(
- [
- (EventTypes.Member, user.to_string()),
- (EventTypes.CanonicalAlias, ""),
- (EventTypes.Name, ""),
- (EventTypes.Create, ""),
- (EventTypes.JoinRules, ""),
- (EventTypes.RoomAvatar, ""),
- ]
- ),
- )
-
- inviter_display_name = ""
- inviter_avatar_url = ""
- member_event = room_state.get((EventTypes.Member, user.to_string()))
- if member_event:
- inviter_display_name = member_event.content.get("displayname", "")
- inviter_avatar_url = member_event.content.get("avatar_url", "")
-
- # if user has no display name, default to their MXID
- if not inviter_display_name:
- inviter_display_name = user.to_string()
-
- canonical_room_alias = ""
- canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
- if canonical_alias_event:
- canonical_room_alias = canonical_alias_event.content.get("alias", "")
-
- room_name = ""
- room_name_event = room_state.get((EventTypes.Name, ""))
- if room_name_event:
- room_name = room_name_event.content.get("name", "")
-
- room_type = None
- room_create_event = room_state.get((EventTypes.Create, ""))
- if room_create_event:
- room_type = room_create_event.content.get(EventContentFields.ROOM_TYPE)
-
- room_join_rules = ""
- join_rules_event = room_state.get((EventTypes.JoinRules, ""))
- if join_rules_event:
- room_join_rules = join_rules_event.content.get("join_rule", "")
-
- room_avatar_url = ""
- room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
- if room_avatar_event:
- room_avatar_url = room_avatar_event.content.get("url", "")
-
- (
- token,
- public_keys,
- fallback_public_key,
- display_name,
- ) = await self.identity_handler.ask_id_server_for_third_party_invite(
- requester=requester,
- id_server=id_server,
- medium=medium,
- address=address,
- room_id=room_id,
- inviter_user_id=user.to_string(),
- room_alias=canonical_room_alias,
- room_avatar_url=room_avatar_url,
- room_join_rules=room_join_rules,
- room_name=room_name,
- room_type=room_type,
- inviter_display_name=inviter_display_name,
- inviter_avatar_url=inviter_avatar_url,
- id_access_token=id_access_token,
- )
-
- (
- event,
- stream_id,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.ThirdPartyInvite,
- "content": {
- "display_name": display_name,
- "public_keys": public_keys,
- # For backwards compatibility:
- "key_validity_url": fallback_public_key["key_validity_url"],
- "public_key": fallback_public_key["public_key"],
- },
- "room_id": room_id,
- "sender": user.to_string(),
- "state_key": token,
- },
- ratelimit=False,
- txn_id=txn_id,
- prev_event_ids=prev_event_ids,
- depth=depth,
- )
- return event, stream_id
-
async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool:
"""Returns whether the homeserver is in the room based on its current state.
diff --git a/synapse/handlers/room_policy.py b/synapse/handlers/room_policy.py
new file mode 100644
index 0000000000..3a83c4d6ec
--- /dev/null
+++ b/synapse/handlers/room_policy.py
@@ -0,0 +1,96 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+#
+
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.events import EventBase
+from synapse.types.handlers.policy_server import RECOMMENDATION_OK
+from synapse.util.stringutils import parse_and_validate_server_name
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class RoomPolicyHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._hs = hs
+ self._store = hs.get_datastores().main
+ self._storage_controllers = hs.get_storage_controllers()
+ self._event_auth_handler = hs.get_event_auth_handler()
+ self._federation_client = hs.get_federation_client()
+
+ async def is_event_allowed(self, event: EventBase) -> bool:
+ """Check if the given event is allowed in the room by the policy server.
+
+ Note: This will *always* return True if the room's policy server is Synapse
+ itself. This is because Synapse can't be a policy server (currently).
+
+ If no policy server is configured in the room, this returns True. Similarly, if
+ the policy server is invalid in any way (not joined, not a server, etc), this
+ returns True.
+
+ If a valid and contactable policy server is configured in the room, this returns
+ True if that server suggests the event is not spammy, and False otherwise.
+
+ Args:
+ event: The event to check. This should be a fully-formed PDU.
+
+ Returns:
+ bool: True if the event is allowed in the room, False otherwise.
+ """
+ policy_event = await self._storage_controllers.state.get_current_state_event(
+ event.room_id, "org.matrix.msc4284.policy", ""
+ )
+ if not policy_event:
+ return True # no policy server == default allow
+
+ policy_server = policy_event.content.get("via", "")
+ if policy_server is None or not isinstance(policy_server, str):
+ return True # no policy server == default allow
+
+ if policy_server == self._hs.hostname:
+ return True # Synapse itself can't be a policy server (currently)
+
+ try:
+ parse_and_validate_server_name(policy_server)
+ except ValueError:
+ return True # invalid policy server == default allow
+
+ is_in_room = await self._event_auth_handler.is_host_in_room(
+ event.room_id, policy_server
+ )
+ if not is_in_room:
+ return True # policy server not in room == default allow
+
+ # At this point, the server appears valid and is in the room, so ask it to check
+ # the event.
+ recommendation = await self._federation_client.get_pdu_policy_recommendation(
+ policy_server, event
+ )
+ if recommendation != RECOMMENDATION_OK:
+ logger.info(
+ "[POLICY] Policy server %s recommended not to allow event %s in room %s: %s",
+ policy_server,
+ event.event_id,
+ event.room_id,
+ recommendation,
+ )
+ return False
+
+ return True # default allow
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 720459f1e7..1c39cfed1b 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -183,8 +183,13 @@ class RoomSummaryHandler:
) -> JsonDict:
"""See docstring for SpaceSummaryHandler.get_room_hierarchy."""
- # First of all, check that the room is accessible.
- if not await self._is_local_room_accessible(requested_room_id, requester):
+ # If the room is available locally, quickly check that the user can access it.
+ local_room = await self._store.is_host_joined(
+ requested_room_id, self._server_name
+ )
+ if local_room and not await self._is_local_room_accessible(
+ requested_room_id, requester
+ ):
raise UnstableSpecAuthError(
403,
"User %s not in room %s, and room previews are disabled"
@@ -192,6 +197,22 @@ class RoomSummaryHandler:
errcode=Codes.NOT_JOINED,
)
+ if not local_room:
+ room_hierarchy = await self._summarize_remote_room_hierarchy(
+ _RoomQueueEntry(requested_room_id, ()),
+ False,
+ )
+ root_room_entry = room_hierarchy[0]
+ if not root_room_entry or not await self._is_remote_room_accessible(
+ requester, requested_room_id, root_room_entry.room
+ ):
+ raise UnstableSpecAuthError(
+ 403,
+ "User %s not in room %s, and room previews are disabled"
+ % (requester, requested_room_id),
+ errcode=Codes.NOT_JOINED,
+ )
+
# If this is continuing a previous session, pull the persisted data.
if from_token:
try:
@@ -679,23 +700,55 @@ class RoomSummaryHandler:
"""
# The API doesn't return the room version so assume that a
# join rule of knock is valid.
+ join_rule = room.get("join_rule")
+ world_readable = room.get("world_readable")
+
+ logger.warning(
+ "[EMMA] Checking if room %s is accessible to %s: join_rule=%s, world_readable=%s",
+ room_id, requester, join_rule, world_readable
+ )
+
if (
- room.get("join_rule")
- in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED)
- or room.get("world_readable") is True
+ join_rule in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED)
+ or world_readable is True
):
return True
- elif not requester:
+ else:
+ logger.warning(
+ "[EMMA] Room %s is not accessible to %s: join_rule=%s, world_readable=%s, join_rule result=%s, world_readable result=%s",
+ room_id, requester, join_rule, world_readable,
+ join_rule in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED),
+ world_readable is True
+ )
+
+ if not requester:
+ logger.warning(
+ "[EMMA] No requester, so room %s is not accessible",
+ room_id
+ )
return False
+
# Check if the user is a member of any of the allowed rooms from the response.
allowed_rooms = room.get("allowed_room_ids")
+ logger.warning(
+ "[EMMA] Checking if room %s is in allowed rooms for %s: join_rule=%s, allowed_rooms=%s",
+ requester,
+ room_id,
+ join_rule,
+ allowed_rooms
+ )
if allowed_rooms and isinstance(allowed_rooms, list):
if await self._event_auth_handler.is_user_in_rooms(
allowed_rooms, requester
):
return True
+ logger.warning(
+ "[EMMA] Checking if room %s is accessble to %s via local state",
+ room_id,
+ requester
+ )
# Finally, check locally if we can access the room. The user might
# already be in the room (if it was a child room), or there might be a
# pending invite, etc.
@@ -863,6 +916,10 @@ class RoomSummaryHandler:
if not room_entry or not await self._is_remote_room_accessible(
requester, room_entry.room_id, room_entry.room
):
+ logger.warning(
+ "[Emma] Room entry contents: %s",
+ room_entry.room if room_entry else None
+ )
raise NotFoundError("Room not found or is not accessible")
room = dict(room_entry.room)
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
deleted file mode 100644
index 8ebd3d4ff9..0000000000
--- a/synapse/handlers/saml.py
+++ /dev/null
@@ -1,524 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import logging
-import re
-from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
-
-import attr
-import saml2
-import saml2.response
-from saml2.client import Saml2Client
-
-from synapse.api.errors import SynapseError
-from synapse.config import ConfigError
-from synapse.handlers.sso import MappingException, UserAttributes
-from synapse.http.servlet import parse_string
-from synapse.http.site import SynapseRequest
-from synapse.module_api import ModuleApi
-from synapse.types import (
- MXID_LOCALPART_ALLOWED_CHARACTERS,
- UserID,
- map_username_to_mxid_localpart,
-)
-from synapse.util.iterutils import chunk_seq
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-@attr.s(slots=True, auto_attribs=True)
-class Saml2SessionData:
- """Data we track about SAML2 sessions"""
-
- # time the session was created, in milliseconds
- creation_time: int
- # The user interactive authentication session ID associated with this SAML
- # session (or None if this SAML session is for an initial login).
- ui_auth_session_id: Optional[str] = None
-
-
-class SamlHandler:
- def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastores().main
- self.clock = hs.get_clock()
- self.server_name = hs.hostname
- self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config)
- self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid
-
- self._saml2_session_lifetime = hs.config.saml2.saml2_session_lifetime
- self._grandfathered_mxid_source_attribute = (
- hs.config.saml2.saml2_grandfathered_mxid_source_attribute
- )
- self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
-
- # plugin to do custom mapping from saml response to mxid
- self._user_mapping_provider = hs.config.saml2.saml2_user_mapping_provider_class(
- hs.config.saml2.saml2_user_mapping_provider_config,
- ModuleApi(hs, hs.get_auth_handler()),
- )
-
- # identifier for the external_ids table
- self.idp_id = "saml"
-
- # user-facing name of this auth provider
- self.idp_name = hs.config.saml2.idp_name
-
- # MXC URI for icon for this auth provider
- self.idp_icon = hs.config.saml2.idp_icon
-
- # optional brand identifier for this auth provider
- self.idp_brand = hs.config.saml2.idp_brand
-
- # a map from saml session id to Saml2SessionData object
- self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
-
- self._sso_handler = hs.get_sso_handler()
- self._sso_handler.register_identity_provider(self)
-
- async def handle_redirect_request(
- self,
- request: SynapseRequest,
- client_redirect_url: Optional[bytes],
- ui_auth_session_id: Optional[str] = None,
- ) -> str:
- """Handle an incoming request to /login/sso/redirect
-
- Args:
- request: the incoming HTTP request
- client_redirect_url: the URL that we should redirect the
- client to after login (or None for UI Auth).
- ui_auth_session_id: The session ID of the ongoing UI Auth (or
- None if this is a login).
-
- Returns:
- URL to redirect to
- """
- if not client_redirect_url:
- # Some SAML identity providers (e.g. Google) require a
- # RelayState parameter on requests, so pass in a dummy redirect URL
- # (which will never get used).
- client_redirect_url = b"unused"
-
- reqid, info = self._saml_client.prepare_for_authenticate(
- entityid=self._saml_idp_entityid, relay_state=client_redirect_url
- )
-
- # Since SAML sessions timeout it is useful to log when they were created.
- logger.info("Initiating a new SAML session: %s" % (reqid,))
-
- now = self.clock.time_msec()
- self._outstanding_requests_dict[reqid] = Saml2SessionData(
- creation_time=now,
- ui_auth_session_id=ui_auth_session_id,
- )
-
- for key, value in info["headers"]:
- if key == "Location":
- return value
-
- # this shouldn't happen!
- raise Exception("prepare_for_authenticate didn't return a Location header")
-
- async def handle_saml_response(self, request: SynapseRequest) -> None:
- """Handle an incoming request to /_synapse/client/saml2/authn_response
-
- Args:
- request: the incoming request from the browser. We'll
- respond to it with a redirect.
-
- Returns:
- Completes once we have handled the request.
- """
- resp_bytes = parse_string(request, "SAMLResponse", required=True)
- relay_state = parse_string(request, "RelayState", required=True)
-
- # expire outstanding sessions before parse_authn_request_response checks
- # the dict.
- self.expire_sessions()
-
- try:
- saml2_auth = self._saml_client.parse_authn_request_response(
- resp_bytes,
- saml2.BINDING_HTTP_POST,
- outstanding=self._outstanding_requests_dict,
- )
- except saml2.response.UnsolicitedResponse as e:
- # the pysaml2 library helpfully logs an ERROR here, but neglects to log
- # the session ID. I don't really want to put the full text of the exception
- # in the (user-visible) exception message, so let's log the exception here
- # so we can track down the session IDs later.
- logger.warning(str(e))
- self._sso_handler.render_error(
- request, "unsolicited_response", "Unexpected SAML2 login."
- )
- return
- except Exception as e:
- self._sso_handler.render_error(
- request,
- "invalid_response",
- "Unable to parse SAML2 response: %s." % (e,),
- )
- return
-
- if saml2_auth.not_signed:
- self._sso_handler.render_error(
- request, "unsigned_respond", "SAML2 response was not signed."
- )
- return
-
- logger.debug("SAML2 response: %s", saml2_auth.origxml)
-
- await self._handle_authn_response(request, saml2_auth, relay_state)
-
- async def _handle_authn_response(
- self,
- request: SynapseRequest,
- saml2_auth: saml2.response.AuthnResponse,
- relay_state: str,
- ) -> None:
- """Handle an AuthnResponse, having parsed it from the request params
-
- Assumes that the signature on the response object has been checked. Maps
- the user onto an MXID, registering them if necessary, and returns a response
- to the browser.
-
- Args:
- request: the incoming request from the browser. We'll respond to it with an
- HTML page or a redirect
- saml2_auth: the parsed AuthnResponse object
- relay_state: the RelayState query param, which encodes the URI to rediret
- back to
- """
-
- for assertion in saml2_auth.assertions:
- # kibana limits the length of a log field, whereas this is all rather
- # useful, so split it up.
- count = 0
- for part in chunk_seq(str(assertion), 10000):
- logger.info(
- "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part
- )
- count += 1
-
- logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
-
- current_session = self._outstanding_requests_dict.pop(
- saml2_auth.in_response_to, None
- )
-
- # first check if we're doing a UIA
- if current_session and current_session.ui_auth_session_id:
- try:
- remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
- except MappingException as e:
- logger.exception("Failed to extract remote user id from SAML response")
- self._sso_handler.render_error(request, "mapping_error", str(e))
- return
-
- return await self._sso_handler.complete_sso_ui_auth_request(
- self.idp_id,
- remote_user_id,
- current_session.ui_auth_session_id,
- request,
- )
-
- # otherwise, we're handling a login request.
-
- # Ensure that the attributes of the logged in user meet the required
- # attributes.
- if not self._sso_handler.check_required_attributes(
- request, saml2_auth.ava, self._saml2_attribute_requirements
- ):
- return
-
- # Call the mapper to register/login the user
- try:
- await self._complete_saml_login(saml2_auth, request, relay_state)
- except MappingException as e:
- logger.exception("Could not map user")
- self._sso_handler.render_error(request, "mapping_error", str(e))
-
- async def _complete_saml_login(
- self,
- saml2_auth: saml2.response.AuthnResponse,
- request: SynapseRequest,
- client_redirect_url: str,
- ) -> None:
- """
- Given a SAML response, complete the login flow
-
- Retrieves the remote user ID, registers the user if necessary, and serves
- a redirect back to the client with a login-token.
-
- Args:
- saml2_auth: The parsed SAML2 response.
- request: The request to respond to
- client_redirect_url: The redirect URL passed in by the client.
-
- Raises:
- MappingException if there was a problem mapping the response to a user.
- RedirectException: some mapping providers may raise this if they need
- to redirect to an interstitial page.
- """
- remote_user_id = self._remote_id_from_saml_response(
- saml2_auth, client_redirect_url
- )
-
- async def saml_response_to_remapped_user_attributes(
- failures: int,
- ) -> UserAttributes:
- """
- Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
-
- This is backwards compatibility for abstraction for the SSO handler.
- """
- # Call the mapping provider.
- result = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, failures, client_redirect_url
- )
- # Remap some of the results.
- return UserAttributes(
- localpart=result.get("mxid_localpart"),
- display_name=result.get("displayname"),
- emails=result.get("emails", []),
- )
-
- async def grandfather_existing_users() -> Optional[str]:
- # backwards-compatibility hack: see if there is an existing user with a
- # suitable mapping from the uid
- if (
- self._grandfathered_mxid_source_attribute
- and self._grandfathered_mxid_source_attribute in saml2_auth.ava
- ):
- attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
- user_id = UserID(
- map_username_to_mxid_localpart(attrval), self.server_name
- ).to_string()
-
- logger.debug(
- "Looking for existing account based on mapped %s %s",
- self._grandfathered_mxid_source_attribute,
- user_id,
- )
-
- users = await self.store.get_users_by_id_case_insensitive(user_id)
- if users:
- registered_user_id = list(users.keys())[0]
- logger.info("Grandfathering mapping to %s", registered_user_id)
- return registered_user_id
-
- return None
-
- await self._sso_handler.complete_sso_login_request(
- self.idp_id,
- remote_user_id,
- request,
- client_redirect_url,
- saml_response_to_remapped_user_attributes,
- grandfather_existing_users,
- )
-
- def _remote_id_from_saml_response(
- self,
- saml2_auth: saml2.response.AuthnResponse,
- client_redirect_url: Optional[str],
- ) -> str:
- """Extract the unique remote id from a SAML2 AuthnResponse
-
- Args:
- saml2_auth: The parsed SAML2 response.
- client_redirect_url: The redirect URL passed in by the client.
- Returns:
- remote user id
-
- Raises:
- MappingException if there was an error extracting the user id
- """
- # It's not obvious why we need to pass in the redirect URI to the mapping
- # provider, but we do :/
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
- saml2_auth, client_redirect_url
- )
-
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
- return remote_user_id
-
- def expire_sessions(self) -> None:
- expire_before = self.clock.time_msec() - self._saml2_session_lifetime
- to_expire = set()
- for reqid, data in self._outstanding_requests_dict.items():
- if data.creation_time < expire_before:
- to_expire.add(reqid)
- for reqid in to_expire:
- logger.debug("Expiring session id %s", reqid)
- del self._outstanding_requests_dict[reqid]
-
-
-DOT_REPLACE_PATTERN = re.compile(
- "[^%s]" % (re.escape("".join(MXID_LOCALPART_ALLOWED_CHARACTERS)),)
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
- """Replace any characters which are not allowed in Matrix IDs with a dot."""
- username = username.lower()
- username = DOT_REPLACE_PATTERN.sub(".", username)
-
- # regular mxids aren't allowed to start with an underscore either
- username = re.sub("^_", "", username)
- return username
-
-
-MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
- "hexencode": map_username_to_mxid_localpart,
- "dotreplace": dot_replace_for_mxid,
-}
-
-
-@attr.s(auto_attribs=True)
-class SamlConfig:
- mxid_source_attribute: str
- mxid_mapper: Callable[[str], str]
-
-
-class DefaultSamlMappingProvider:
- __version__ = "0.0.1"
-
- def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
- """The default SAML user mapping provider
-
- Args:
- parsed_config: Module configuration
- module_api: module api proxy
- """
- self._mxid_source_attribute = parsed_config.mxid_source_attribute
- self._mxid_mapper = parsed_config.mxid_mapper
-
- self._grandfathered_mxid_source_attribute = (
- module_api._hs.config.saml2.saml2_grandfathered_mxid_source_attribute
- )
-
- def get_remote_user_id(
- self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
- ) -> str:
- """Extracts the remote user id from the SAML response"""
- try:
- return saml_response.ava["uid"][0]
- except KeyError:
- logger.warning("SAML2 response lacks a 'uid' attestation")
- raise MappingException("'uid' not in SAML2 response")
-
- def saml_response_to_user_attributes(
- self,
- saml_response: saml2.response.AuthnResponse,
- failures: int,
- client_redirect_url: str,
- ) -> dict:
- """Maps some text from a SAML response to attributes of a new user
-
- Args:
- saml_response: A SAML auth response object
-
- failures: How many times a call to this function with this
- saml_response has resulted in a failure
-
- client_redirect_url: where the client wants to redirect to
-
- Returns:
- A dict containing new user attributes. Possible keys:
- * mxid_localpart (str): Required. The localpart of the user's mxid
- * displayname (str): The displayname of the user
- * emails (list[str]): Any emails for the user
- """
- try:
- mxid_source = saml_response.ava[self._mxid_source_attribute][0]
- except KeyError:
- logger.warning(
- "SAML2 response lacks a '%s' attestation",
- self._mxid_source_attribute,
- )
- raise SynapseError(
- 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
- )
-
- # Use the configured mapper for this mxid_source
- localpart = self._mxid_mapper(mxid_source)
-
- # Append suffix integer if last call to this function failed to produce
- # a usable mxid.
- localpart += str(failures) if failures else ""
-
- # Retrieve the display name from the saml response
- # If displayname is None, the mxid_localpart will be used instead
- displayname = saml_response.ava.get("displayName", [None])[0]
-
- # Retrieve any emails present in the saml response
- emails = saml_response.ava.get("email", [])
-
- return {
- "mxid_localpart": localpart,
- "displayname": displayname,
- "emails": emails,
- }
-
- @staticmethod
- def parse_config(config: dict) -> SamlConfig:
- """Parse the dict provided by the homeserver's config
- Args:
- config: A dictionary containing configuration options for this provider
- Returns:
- A custom config object for this module
- """
- # Parse config options and use defaults where necessary
- mxid_source_attribute = config.get("mxid_source_attribute", "uid")
- mapping_type = config.get("mxid_mapping", "hexencode")
-
- # Retrieve the associating mapping function
- try:
- mxid_mapper = MXID_MAPPER_MAP[mapping_type]
- except KeyError:
- raise ConfigError(
- "saml2_config.user_mapping_provider.config: '%s' is not a valid "
- "mxid_mapping value" % (mapping_type,)
- )
-
- return SamlConfig(mxid_source_attribute, mxid_mapper)
-
- @staticmethod
- def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
- """Returns the required attributes of a SAML
-
- Args:
- config: A SamlConfig object containing configuration params for this provider
-
- Returns:
- The first set equates to the saml auth response
- attributes that are required for the module to function, whereas the
- second set consists of those attributes which can be used if
- available, but are not necessary
- """
- return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index a7d52fa648..1a71135d5f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -423,9 +423,9 @@ class SearchHandler:
}
if search_result.room_groups and "room_id" in group_keys:
- rooms_cat_res.setdefault("groups", {})[
- "room_id"
- ] = search_result.room_groups
+ rooms_cat_res.setdefault("groups", {})["room_id"] = (
+ search_result.room_groups
+ )
if sender_group and "sender" in group_keys:
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
deleted file mode 100644
index 70cdb0721c..0000000000
--- a/synapse/handlers/send_email.py
+++ /dev/null
@@ -1,230 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2021 The Matrix.org C.I.C. Foundation
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-import email.utils
-import logging
-from email.mime.multipart import MIMEMultipart
-from email.mime.text import MIMEText
-from io import BytesIO
-from typing import TYPE_CHECKING, Any, Dict, Optional
-
-from pkg_resources import parse_version
-
-import twisted
-from twisted.internet.defer import Deferred
-from twisted.internet.endpoints import HostnameEndpoint
-from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory
-from twisted.internet.ssl import optionsForClientTLS
-from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
-from twisted.protocols.tls import TLSMemoryBIOFactory
-
-from synapse.logging.context import make_deferred_yieldable
-from synapse.types import ISynapseReactor
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-_is_old_twisted = parse_version(twisted.__version__) < parse_version("21")
-
-
-class _NoTLSESMTPSender(ESMTPSender):
- """Extend ESMTPSender to disable TLS
-
- Unfortunately, before Twisted 21.2, ESMTPSender doesn't give an easy way to disable
- TLS, so we override its internal method which it uses to generate a context factory.
- """
-
- def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]:
- return None
-
-
-async def _sendmail(
- reactor: ISynapseReactor,
- smtphost: str,
- smtpport: int,
- from_addr: str,
- to_addr: str,
- msg_bytes: bytes,
- username: Optional[bytes] = None,
- password: Optional[bytes] = None,
- require_auth: bool = False,
- require_tls: bool = False,
- enable_tls: bool = True,
- force_tls: bool = False,
-) -> None:
- """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
-
- Params:
- reactor: reactor to use to make the outbound connection
- smtphost: hostname to connect to
- smtpport: port to connect to
- from_addr: "From" address for email
- to_addr: "To" address for email
- msg_bytes: Message content
- username: username to authenticate with, if auth is enabled
- password: password to give when authenticating
- require_auth: if auth is not offered, fail the request
- require_tls: if TLS is not offered, fail the reqest
- enable_tls: True to enable STARTTLS. If this is False and require_tls is True,
- the request will fail.
- force_tls: True to enable Implicit TLS.
- """
- msg = BytesIO(msg_bytes)
- d: "Deferred[object]" = Deferred()
-
- def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
- return ESMTPSenderFactory(
- username,
- password,
- from_addr,
- to_addr,
- msg,
- d,
- heloFallback=True,
- requireAuthentication=require_auth,
- requireTransportSecurity=require_tls,
- **kwargs,
- )
-
- factory: IProtocolFactory
- if _is_old_twisted:
- # before twisted 21.2, we have to override the ESMTPSender protocol to disable
- # TLS
- factory = build_sender_factory()
-
- if not enable_tls:
- factory.protocol = _NoTLSESMTPSender
- else:
- # for twisted 21.2 and later, there is a 'hostname' parameter which we should
- # set to enable TLS.
- factory = build_sender_factory(hostname=smtphost if enable_tls else None)
-
- if force_tls:
- factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory)
-
- endpoint = HostnameEndpoint(
- reactor, smtphost, smtpport, timeout=30, bindAddress=None
- )
-
- await make_deferred_yieldable(endpoint.connect(factory))
-
- await make_deferred_yieldable(d)
-
-
-class SendEmailHandler:
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
-
- self._reactor = hs.get_reactor()
-
- self._from = hs.config.email.email_notif_from
- self._smtp_host = hs.config.email.email_smtp_host
- self._smtp_port = hs.config.email.email_smtp_port
-
- user = hs.config.email.email_smtp_user
- self._smtp_user = user.encode("utf-8") if user is not None else None
- passwd = hs.config.email.email_smtp_pass
- self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None
- self._require_transport_security = hs.config.email.require_transport_security
- self._enable_tls = hs.config.email.enable_smtp_tls
- self._force_tls = hs.config.email.force_tls
-
- self._sendmail = _sendmail
-
- async def send_email(
- self,
- email_address: str,
- subject: str,
- app_name: str,
- html: str,
- text: str,
- additional_headers: Optional[Dict[str, str]] = None,
- ) -> None:
- """Send a multipart email with the given information.
-
- Args:
- email_address: The address to send the email to.
- subject: The email's subject.
- app_name: The app name to include in the From header.
- html: The HTML content to include in the email.
- text: The plain text content to include in the email.
- additional_headers: A map of additional headers to include.
- """
- try:
- from_string = self._from % {"app": app_name}
- except (KeyError, TypeError):
- from_string = self._from
-
- raw_from = email.utils.parseaddr(from_string)[1]
- raw_to = email.utils.parseaddr(email_address)[1]
-
- if raw_to == "":
- raise RuntimeError("Invalid 'to' address")
-
- html_part = MIMEText(html, "html", "utf-8")
- text_part = MIMEText(text, "plain", "utf-8")
-
- multipart_msg = MIMEMultipart("alternative")
- multipart_msg["Subject"] = subject
- multipart_msg["From"] = from_string
- multipart_msg["To"] = email_address
- multipart_msg["Date"] = email.utils.formatdate()
- multipart_msg["Message-ID"] = email.utils.make_msgid()
-
- # Discourage automatic responses to Synapse's emails.
- # Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted"
- # header is present with any value other than "no". See
- # https://www.rfc-editor.org/rfc/rfc3834.html#section-5.1
- multipart_msg["Auto-Submitted"] = "auto-generated"
- # Also include a Microsoft-Exchange specific header:
- # https://learn.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxcmail/ced68690-498a-4567-9d14-5c01f974d8b1
- # which suggests it can take the value "All" to "suppress all auto-replies",
- # or a comma separated list of auto-reply classes to suppress.
- # The following stack overflow question has a little more context:
- # https://stackoverflow.com/a/25324691/5252017
- # https://stackoverflow.com/a/61646381/5252017
- multipart_msg["X-Auto-Response-Suppress"] = "All"
-
- if additional_headers:
- for header, value in additional_headers.items():
- multipart_msg[header] = value
-
- multipart_msg.attach(text_part)
- multipart_msg.attach(html_part)
-
- logger.info("Sending email to %s" % email_address)
-
- await self._sendmail(
- self._reactor,
- self._smtp_host,
- self._smtp_port,
- raw_from,
- raw_to,
- multipart_msg.as_string().encode("utf8"),
- username=self._smtp_user,
- password=self._smtp_pass,
- require_auth=self._smtp_user is not None,
- require_tls=self._require_transport_security,
- enable_tls=self._enable_tls,
- force_tls=self._force_tls,
- )
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 29cc03d71d..94301add9e 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -36,10 +36,17 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
- # This can only be instantiated on the main process.
- device_handler = hs.get_device_handler()
- assert isinstance(device_handler, DeviceHandler)
- self._device_handler = device_handler
+
+ # We don't need the device handler if password changing is disabled.
+ # This allows us to instantiate the SetPasswordHandler on the workers
+ # that have admin APIs for MAS
+ if self._auth_handler.can_change_password():
+ # This can only be instantiated on the main process.
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler: Optional[DeviceHandler] = device_handler
+ else:
+ self._device_handler = None
async def set_password(
self,
@@ -51,6 +58,9 @@ class SetPasswordHandler:
if not self._auth_handler.can_change_password():
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
+ # We should have this available only if password changing is enabled.
+ assert self._device_handler is not None
+
try:
await self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
deleted file mode 100644
index 18a96843be..0000000000
--- a/synapse/handlers/sliding_sync.py
+++ /dev/null
@@ -1,3158 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import enum
-import logging
-from enum import Enum
-from itertools import chain
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Final,
- List,
- Literal,
- Mapping,
- Optional,
- Sequence,
- Set,
- Tuple,
- Union,
-)
-
-import attr
-from immutabledict import immutabledict
-from typing_extensions import assert_never
-
-from synapse.api.constants import (
- AccountDataTypes,
- Direction,
- EventContentFields,
- EventTypes,
- Membership,
-)
-from synapse.api.errors import SlidingSyncUnknownPosition
-from synapse.events import EventBase, StrippedStateEvent
-from synapse.events.utils import parse_stripped_state_event, strip_event
-from synapse.handlers.relations import BundledAggregations
-from synapse.logging.opentracing import (
- SynapseTags,
- log_kv,
- set_tag,
- start_active_span,
- tag_args,
- trace,
-)
-from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
-from synapse.storage.databases.main.state import (
- ROOM_UNKNOWN_SENTINEL,
- Sentinel as StateSentinel,
-)
-from synapse.storage.databases.main.stream import (
- CurrentStateDeltaMembership,
- PaginateFunction,
-)
-from synapse.storage.roommember import MemberSummary
-from synapse.types import (
- DeviceListUpdates,
- JsonDict,
- JsonMapping,
- MultiWriterStreamToken,
- MutableStateMap,
- PersistedEventPosition,
- Requester,
- RoomStreamToken,
- SlidingSyncStreamToken,
- StateMap,
- StrCollection,
- StreamKeyType,
- StreamToken,
- UserID,
-)
-from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult
-from synapse.types.state import StateFilter
-from synapse.util.async_helpers import concurrently_execute
-from synapse.visibility import filter_events_for_client
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class Sentinel(enum.Enum):
- # defining a sentinel in this way allows mypy to correctly handle the
- # type of a dictionary lookup and subsequent type narrowing.
- UNSET_SENTINEL = object()
-
-
-# The event types that clients should consider as new activity.
-DEFAULT_BUMP_EVENT_TYPES = {
- EventTypes.Create,
- EventTypes.Message,
- EventTypes.Encrypted,
- EventTypes.Sticker,
- EventTypes.CallInvite,
- EventTypes.PollStart,
- EventTypes.LiveLocationShareStart,
-}
-
-
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _RoomMembershipForUser:
- """
- Attributes:
- room_id: The room ID of the membership event
- event_id: The event ID of the membership event
- event_pos: The stream position of the membership event
- membership: The membership state of the user in the room
- sender: The person who sent the membership event
- newly_joined: Whether the user newly joined the room during the given token
- range and is still joined to the room at the end of this range.
- newly_left: Whether the user newly left (or kicked) the room during the given
- token range and is still "leave" at the end of this range.
- is_dm: Whether this user considers this room as a direct-message (DM) room
- """
-
- room_id: str
- # Optional because state resets can affect room membership without a corresponding event.
- event_id: Optional[str]
- # Even during a state reset which removes the user from the room, we expect this to
- # be set because `current_state_delta_stream` will note the position that the reset
- # happened.
- event_pos: PersistedEventPosition
- # Even during a state reset which removes the user from the room, we expect this to
- # be set to `LEAVE` because we can make that assumption based on the situaton (see
- # `get_current_state_delta_membership_changes_for_user(...)`)
- membership: str
- # Optional because state resets can affect room membership without a corresponding event.
- sender: Optional[str]
- newly_joined: bool
- newly_left: bool
- is_dm: bool
-
- def copy_and_replace(self, **kwds: Any) -> "_RoomMembershipForUser":
- return attr.evolve(self, **kwds)
-
-
-def filter_membership_for_sync(
- *, user_id: str, room_membership_for_user: _RoomMembershipForUser
-) -> bool:
- """
- Returns True if the membership event should be included in the sync response,
- otherwise False.
-
- Attributes:
- user_id: The user ID that the membership applies to
- room_membership_for_user: Membership information for the user in the room
- """
-
- membership = room_membership_for_user.membership
- sender = room_membership_for_user.sender
- newly_left = room_membership_for_user.newly_left
-
- # We want to allow everything except rooms the user has left unless `newly_left`
- # because we want everything that's *still* relevant to the user. We include
- # `newly_left` rooms because the last event that the user should see is their own
- # leave event.
- #
- # A leave != kick. This logic includes kicks (leave events where the sender is not
- # the same user).
- #
- # When `sender=None`, it means that a state reset happened that removed the user
- # from the room without a corresponding leave event. We can just remove the rooms
- # since they are no longer relevant to the user but will still appear if they are
- # `newly_left`.
- return (
- # Anything except leave events
- membership != Membership.LEAVE
- # Unless...
- or newly_left
- # Allow kicks
- or (membership == Membership.LEAVE and sender not in (user_id, None))
- )
-
-
-# We can't freeze this class because we want to update it in place with the
-# de-duplicated data.
-@attr.s(slots=True, auto_attribs=True)
-class RoomSyncConfig:
- """
- Holds the config for what data we should fetch for a room in the sync response.
-
- Attributes:
- timeline_limit: The maximum number of events to return in the timeline.
-
- required_state_map: Map from state event type to state_keys requested for the
- room. The values are close to `StateKey` but actually use a syntax where you
- can provide `*` wildcard and `$LAZY` for lazy-loading room members.
- """
-
- timeline_limit: int
- required_state_map: Dict[str, Set[str]]
-
- @classmethod
- def from_room_config(
- cls,
- room_params: SlidingSyncConfig.CommonRoomParameters,
- ) -> "RoomSyncConfig":
- """
- Create a `RoomSyncConfig` from a `SlidingSyncList`/`RoomSubscription` config.
-
- Args:
- room_params: `SlidingSyncConfig.SlidingSyncList` or `SlidingSyncConfig.RoomSubscription`
- """
- required_state_map: Dict[str, Set[str]] = {}
- for (
- state_type,
- state_key,
- ) in room_params.required_state:
- # If we already have a wildcard for this specific `state_key`, we don't need
- # to add it since the wildcard already covers it.
- if state_key in required_state_map.get(StateValues.WILDCARD, set()):
- continue
-
- # If we already have a wildcard `state_key` for this `state_type`, we don't need
- # to add anything else
- if StateValues.WILDCARD in required_state_map.get(state_type, set()):
- continue
-
- # If we're getting wildcards for the `state_type` and `state_key`, that's
- # all that matters so get rid of any other entries
- if state_type == StateValues.WILDCARD and state_key == StateValues.WILDCARD:
- required_state_map = {StateValues.WILDCARD: {StateValues.WILDCARD}}
- # We can break, since we don't need to add anything else
- break
-
- # If we're getting a wildcard for the `state_type`, get rid of any other
- # entries with the same `state_key`, since the wildcard will cover it already.
- elif state_type == StateValues.WILDCARD:
- # Get rid of any entries that match the `state_key`
- #
- # Make a copy so we don't run into an error: `dictionary changed size
- # during iteration`, when we remove items
- for (
- existing_state_type,
- existing_state_key_set,
- ) in list(required_state_map.items()):
- # Make a copy so we don't run into an error: `Set changed size during
- # iteration`, when we filter out and remove items
- for existing_state_key in existing_state_key_set.copy():
- if existing_state_key == state_key:
- existing_state_key_set.remove(state_key)
-
- # If we've the left the `set()` empty, remove it from the map
- if existing_state_key_set == set():
- required_state_map.pop(existing_state_type, None)
-
- # If we're getting a wildcard `state_key`, get rid of any other state_keys
- # for this `state_type` since the wildcard will cover it already.
- if state_key == StateValues.WILDCARD:
- required_state_map[state_type] = {state_key}
- # Otherwise, just add it to the set
- else:
- if required_state_map.get(state_type) is None:
- required_state_map[state_type] = {state_key}
- else:
- required_state_map[state_type].add(state_key)
-
- return cls(
- timeline_limit=room_params.timeline_limit,
- required_state_map=required_state_map,
- )
-
- def deep_copy(self) -> "RoomSyncConfig":
- required_state_map: Dict[str, Set[str]] = {
- state_type: state_key_set.copy()
- for state_type, state_key_set in self.required_state_map.items()
- }
-
- return RoomSyncConfig(
- timeline_limit=self.timeline_limit,
- required_state_map=required_state_map,
- )
-
- def combine_room_sync_config(
- self, other_room_sync_config: "RoomSyncConfig"
- ) -> None:
- """
- Combine this `RoomSyncConfig` with another `RoomSyncConfig` and take the
- superset union of the two.
- """
- # Take the highest timeline limit
- if self.timeline_limit < other_room_sync_config.timeline_limit:
- self.timeline_limit = other_room_sync_config.timeline_limit
-
- # Union the required state
- for (
- state_type,
- state_key_set,
- ) in other_room_sync_config.required_state_map.items():
- # If we already have a wildcard for everything, we don't need to add
- # anything else
- if StateValues.WILDCARD in self.required_state_map.get(
- StateValues.WILDCARD, set()
- ):
- break
-
- # If we already have a wildcard `state_key` for this `state_type`, we don't need
- # to add anything else
- if StateValues.WILDCARD in self.required_state_map.get(state_type, set()):
- continue
-
- # If we're getting wildcards for the `state_type` and `state_key`, that's
- # all that matters so get rid of any other entries
- if (
- state_type == StateValues.WILDCARD
- and StateValues.WILDCARD in state_key_set
- ):
- self.required_state_map = {state_type: {StateValues.WILDCARD}}
- # We can break, since we don't need to add anything else
- break
-
- for state_key in state_key_set:
- # If we already have a wildcard for this specific `state_key`, we don't need
- # to add it since the wildcard already covers it.
- if state_key in self.required_state_map.get(
- StateValues.WILDCARD, set()
- ):
- continue
-
- # If we're getting a wildcard for the `state_type`, get rid of any other
- # entries with the same `state_key`, since the wildcard will cover it already.
- if state_type == StateValues.WILDCARD:
- # Get rid of any entries that match the `state_key`
- #
- # Make a copy so we don't run into an error: `dictionary changed size
- # during iteration`, when we remove items
- for existing_state_type, existing_state_key_set in list(
- self.required_state_map.items()
- ):
- # Make a copy so we don't run into an error: `Set changed size during
- # iteration`, when we filter out and remove items
- for existing_state_key in existing_state_key_set.copy():
- if existing_state_key == state_key:
- existing_state_key_set.remove(state_key)
-
- # If we've the left the `set()` empty, remove it from the map
- if existing_state_key_set == set():
- self.required_state_map.pop(existing_state_type, None)
-
- # If we're getting a wildcard `state_key`, get rid of any other state_keys
- # for this `state_type` since the wildcard will cover it already.
- if state_key == StateValues.WILDCARD:
- self.required_state_map[state_type] = {state_key}
- break
- # Otherwise, just add it to the set
- else:
- if self.required_state_map.get(state_type) is None:
- self.required_state_map[state_type] = {state_key}
- else:
- self.required_state_map[state_type].add(state_key)
-
-
-class StateValues:
- """
- Understood values of the (type, state_key) tuple in `required_state`.
- """
-
- # Include all state events of the given type
- WILDCARD: Final = "*"
- # Lazy-load room membership events (include room membership events for any event
- # `sender` in the timeline). We only give special meaning to this value when it's a
- # `state_key`.
- LAZY: Final = "$LAZY"
- # Subsitute with the requester's user ID. Typically used by clients to get
- # the user's membership.
- ME: Final = "$ME"
-
-
-class SlidingSyncHandler:
- def __init__(self, hs: "HomeServer"):
- self.clock = hs.get_clock()
- self.store = hs.get_datastores().main
- self.storage_controllers = hs.get_storage_controllers()
- self.auth_blocking = hs.get_auth_blocking()
- self.notifier = hs.get_notifier()
- self.event_sources = hs.get_event_sources()
- self.relations_handler = hs.get_relations_handler()
- self.device_handler = hs.get_device_handler()
- self.push_rules_handler = hs.get_push_rules_handler()
- self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
-
- self.connection_store = SlidingSyncConnectionStore()
-
- async def wait_for_sync_for_user(
- self,
- requester: Requester,
- sync_config: SlidingSyncConfig,
- from_token: Optional[SlidingSyncStreamToken] = None,
- timeout_ms: int = 0,
- ) -> SlidingSyncResult:
- """
- Get the sync for a client if we have new data for it now. Otherwise
- wait for new data to arrive on the server. If the timeout expires, then
- return an empty sync result.
-
- Args:
- requester: The user making the request
- sync_config: Sync configuration
- from_token: The point in the stream to sync from. Token of the end of the
- previous batch. May be `None` if this is the initial sync request.
- timeout_ms: The time in milliseconds to wait for new data to arrive. If 0,
- we will immediately but there might not be any new data so we just return an
- empty response.
- """
- # If the user is not part of the mau group, then check that limits have
- # not been exceeded (if not part of the group by this point, almost certain
- # auth_blocking will occur)
- await self.auth_blocking.check_auth_blocking(requester=requester)
-
- # If we're working with a user-provided token, we need to make sure to wait for
- # this worker to catch up with the token so we don't skip past any incoming
- # events or future events if the user is nefariously, manually modifying the
- # token.
- if from_token is not None:
- # We need to make sure this worker has caught up with the token. If
- # this returns false, it means we timed out waiting, and we should
- # just return an empty response.
- before_wait_ts = self.clock.time_msec()
- if not await self.notifier.wait_for_stream_token(from_token.stream_token):
- logger.warning(
- "Timed out waiting for worker to catch up. Returning empty response"
- )
- return SlidingSyncResult.empty(from_token)
-
- # If we've spent significant time waiting to catch up, take it off
- # the timeout.
- after_wait_ts = self.clock.time_msec()
- if after_wait_ts - before_wait_ts > 1_000:
- timeout_ms -= after_wait_ts - before_wait_ts
- timeout_ms = max(timeout_ms, 0)
-
- # We're going to respond immediately if the timeout is 0 or if this is an
- # initial sync (without a `from_token`) so we can avoid calling
- # `notifier.wait_for_events()`.
- if timeout_ms == 0 or from_token is None:
- now_token = self.event_sources.get_current_token()
- result = await self.current_sync_for_user(
- sync_config,
- from_token=from_token,
- to_token=now_token,
- )
- else:
- # Otherwise, we wait for something to happen and report it to the user.
- async def current_sync_callback(
- before_token: StreamToken, after_token: StreamToken
- ) -> SlidingSyncResult:
- return await self.current_sync_for_user(
- sync_config,
- from_token=from_token,
- to_token=after_token,
- )
-
- result = await self.notifier.wait_for_events(
- sync_config.user.to_string(),
- timeout_ms,
- current_sync_callback,
- from_token=from_token.stream_token,
- )
-
- return result
-
- @trace
- async def current_sync_for_user(
- self,
- sync_config: SlidingSyncConfig,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken] = None,
- ) -> SlidingSyncResult:
- """
- Generates the response body of a Sliding Sync result, represented as a
- `SlidingSyncResult`.
-
- We fetch data according to the token range (> `from_token` and <= `to_token`).
-
- Args:
- sync_config: Sync configuration
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from. Token of the end of the
- previous batch. May be `None` if this is the initial sync request.
- """
- user_id = sync_config.user.to_string()
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service:
- # We no longer support AS users using /sync directly.
- # See https://github.com/matrix-org/matrix-doc/issues/1144
- raise NotImplementedError()
-
- if from_token:
- # Check that we recognize the connection position, if not tell the
- # clients that they need to start again.
- #
- # If we don't do this and the client asks for the full range of
- # rooms, we end up sending down all rooms and their state from
- # scratch (which can be very slow). By expiring the connection we
- # allow the client a chance to do an initial request with a smaller
- # range of rooms to get them some results sooner but will end up
- # taking the same amount of time (more with round-trips and
- # re-processing) in the end to get everything again.
- if not await self.connection_store.is_valid_token(
- sync_config, from_token.connection_position
- ):
- raise SlidingSyncUnknownPosition()
-
- await self.connection_store.mark_token_seen(
- sync_config=sync_config,
- from_token=from_token,
- )
-
- # Get all of the room IDs that the user should be able to see in the sync
- # response
- has_lists = sync_config.lists is not None and len(sync_config.lists) > 0
- has_room_subscriptions = (
- sync_config.room_subscriptions is not None
- and len(sync_config.room_subscriptions) > 0
- )
- if has_lists or has_room_subscriptions:
- room_membership_for_user_map = (
- await self.get_room_membership_for_user_at_to_token(
- user=sync_config.user,
- to_token=to_token,
- from_token=from_token.stream_token if from_token else None,
- )
- )
-
- # Assemble sliding window lists
- lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {}
- # Keep track of the rooms that we can display and need to fetch more info about
- relevant_room_map: Dict[str, RoomSyncConfig] = {}
- # The set of room IDs of all rooms that could appear in any list. These
- # include rooms that are outside the list ranges.
- all_rooms: Set[str] = set()
- if has_lists and sync_config.lists is not None:
- with start_active_span("assemble_sliding_window_lists"):
- sync_room_map = await self.filter_rooms_relevant_for_sync(
- user=sync_config.user,
- room_membership_for_user_map=room_membership_for_user_map,
- )
-
- for list_key, list_config in sync_config.lists.items():
- # Apply filters
- filtered_sync_room_map = sync_room_map
- if list_config.filters is not None:
- filtered_sync_room_map = await self.filter_rooms(
- sync_config.user,
- sync_room_map,
- list_config.filters,
- to_token,
- )
-
- # Find which rooms are partially stated and may need to be filtered out
- # depending on the `required_state` requested (see below).
- partial_state_room_map = (
- await self.store.is_partial_state_room_batched(
- filtered_sync_room_map.keys()
- )
- )
-
- # Since creating the `RoomSyncConfig` takes some work, let's just do it
- # once and make a copy whenever we need it.
- room_sync_config = RoomSyncConfig.from_room_config(list_config)
- membership_state_keys = room_sync_config.required_state_map.get(
- EventTypes.Member
- )
- # Also see `StateFilter.must_await_full_state(...)` for comparison
- lazy_loading = (
- membership_state_keys is not None
- and StateValues.LAZY in membership_state_keys
- )
-
- if not lazy_loading:
- # Exclude partially-stated rooms unless the `required_state`
- # only has `["m.room.member", "$LAZY"]` for membership
- # (lazy-loading room members).
- filtered_sync_room_map = {
- room_id: room
- for room_id, room in filtered_sync_room_map.items()
- if not partial_state_room_map.get(room_id)
- }
-
- all_rooms.update(filtered_sync_room_map)
-
- # Sort the list
- sorted_room_info = await self.sort_rooms(
- filtered_sync_room_map, to_token
- )
-
- ops: List[SlidingSyncResult.SlidingWindowList.Operation] = []
- if list_config.ranges:
- for range in list_config.ranges:
- room_ids_in_list: List[str] = []
-
- # We're going to loop through the sorted list of rooms starting
- # at the range start index and keep adding rooms until we fill
- # up the range or run out of rooms.
- #
- # Both sides of range are inclusive so we `+ 1`
- max_num_rooms = range[1] - range[0] + 1
- for room_membership in sorted_room_info[range[0] :]:
- room_id = room_membership.room_id
-
- if len(room_ids_in_list) >= max_num_rooms:
- break
-
- # Take the superset of the `RoomSyncConfig` for each room.
- #
- # Update our `relevant_room_map` with the room we're going
- # to display and need to fetch more info about.
- existing_room_sync_config = relevant_room_map.get(
- room_id
- )
- if existing_room_sync_config is not None:
- existing_room_sync_config.combine_room_sync_config(
- room_sync_config
- )
- else:
- # Make a copy so if we modify it later, it doesn't
- # affect all references.
- relevant_room_map[room_id] = (
- room_sync_config.deep_copy()
- )
-
- room_ids_in_list.append(room_id)
-
- ops.append(
- SlidingSyncResult.SlidingWindowList.Operation(
- op=OperationType.SYNC,
- range=range,
- room_ids=room_ids_in_list,
- )
- )
-
- lists[list_key] = SlidingSyncResult.SlidingWindowList(
- count=len(sorted_room_info),
- ops=ops,
- )
-
- # Handle room subscriptions
- if has_room_subscriptions and sync_config.room_subscriptions is not None:
- with start_active_span("assemble_room_subscriptions"):
- for (
- room_id,
- room_subscription,
- ) in sync_config.room_subscriptions.items():
- room_membership_for_user_at_to_token = (
- await self.check_room_subscription_allowed_for_user(
- room_id=room_id,
- room_membership_for_user_map=room_membership_for_user_map,
- to_token=to_token,
- )
- )
-
- # Skip this room if the user isn't allowed to see it
- if not room_membership_for_user_at_to_token:
- continue
-
- all_rooms.add(room_id)
-
- room_membership_for_user_map[room_id] = (
- room_membership_for_user_at_to_token
- )
-
- # Take the superset of the `RoomSyncConfig` for each room.
- #
- # Update our `relevant_room_map` with the room we're going to display
- # and need to fetch more info about.
- room_sync_config = RoomSyncConfig.from_room_config(
- room_subscription
- )
- existing_room_sync_config = relevant_room_map.get(room_id)
- if existing_room_sync_config is not None:
- existing_room_sync_config.combine_room_sync_config(
- room_sync_config
- )
- else:
- relevant_room_map[room_id] = room_sync_config
-
- # Fetch room data
- rooms: Dict[str, SlidingSyncResult.RoomResult] = {}
-
- # Filter out rooms that haven't received updates and we've sent down
- # previously.
- # Keep track of the rooms that we're going to display and need to fetch more info about
- relevant_rooms_to_send_map = relevant_room_map
- with start_active_span("filter_relevant_rooms_to_send"):
- if from_token:
- rooms_should_send = set()
-
- # First we check if there are rooms that match a list/room
- # subscription and have updates we need to send (i.e. either because
- # we haven't sent the room down, or we have but there are missing
- # updates).
- for room_id in relevant_room_map:
- status = await self.connection_store.have_sent_room(
- sync_config,
- from_token.connection_position,
- room_id,
- )
- if (
- # The room was never sent down before so the client needs to know
- # about it regardless of any updates.
- status.status == HaveSentRoomFlag.NEVER
- # `PREVIOUSLY` literally means the "room was sent down before *AND*
- # there are updates we haven't sent down" so we already know this
- # room has updates.
- or status.status == HaveSentRoomFlag.PREVIOUSLY
- ):
- rooms_should_send.add(room_id)
- elif status.status == HaveSentRoomFlag.LIVE:
- # We know that we've sent all updates up until `from_token`,
- # so we just need to check if there have been updates since
- # then.
- pass
- else:
- assert_never(status.status)
-
- # We only need to check for new events since any state changes
- # will also come down as new events.
- rooms_that_have_updates = self.store.get_rooms_that_might_have_updates(
- relevant_room_map.keys(), from_token.stream_token.room_key
- )
- rooms_should_send.update(rooms_that_have_updates)
- relevant_rooms_to_send_map = {
- room_id: room_sync_config
- for room_id, room_sync_config in relevant_room_map.items()
- if room_id in rooms_should_send
- }
-
- @trace
- @tag_args
- async def handle_room(room_id: str) -> None:
- room_sync_result = await self.get_room_sync_data(
- sync_config=sync_config,
- room_id=room_id,
- room_sync_config=relevant_rooms_to_send_map[room_id],
- room_membership_for_user_at_to_token=room_membership_for_user_map[
- room_id
- ],
- from_token=from_token,
- to_token=to_token,
- )
-
- # Filter out empty room results during incremental sync
- if room_sync_result or not from_token:
- rooms[room_id] = room_sync_result
-
- if relevant_rooms_to_send_map:
- with start_active_span("sliding_sync.generate_room_entries"):
- await concurrently_execute(handle_room, relevant_rooms_to_send_map, 10)
-
- extensions = await self.get_extensions_response(
- sync_config=sync_config,
- actual_lists=lists,
- # We're purposely using `relevant_room_map` instead of
- # `relevant_rooms_to_send_map` here. This needs to be all room_ids we could
- # send regardless of whether they have an event update or not. The
- # extensions care about more than just normal events in the rooms (like
- # account data, read receipts, typing indicators, to-device messages, etc).
- actual_room_ids=set(relevant_room_map.keys()),
- actual_room_response_map=rooms,
- from_token=from_token,
- to_token=to_token,
- )
-
- if has_lists or has_room_subscriptions:
- # We now calculate if any rooms outside the range have had updates,
- # which we are not sending down.
- #
- # We *must* record rooms that have had updates, but it is also fine
- # to record rooms as having updates even if there might not actually
- # be anything new for the user (e.g. due to event filters, events
- # having happened after the user left, etc).
- unsent_room_ids = []
- if from_token:
- # The set of rooms that the client (may) care about, but aren't
- # in any list range (or subscribed to).
- missing_rooms = all_rooms - relevant_room_map.keys()
-
- # We now just go and try fetching any events in the above rooms
- # to see if anything has happened since the `from_token`.
- #
- # TODO: Replace this with something faster. When we land the
- # sliding sync tables that record the most recent event
- # positions we can use that.
- missing_event_map_by_room = (
- await self.store.get_room_events_stream_for_rooms(
- room_ids=missing_rooms,
- from_key=to_token.room_key,
- to_key=from_token.stream_token.room_key,
- limit=1,
- )
- )
- unsent_room_ids = list(missing_event_map_by_room)
-
- connection_position = await self.connection_store.record_rooms(
- sync_config=sync_config,
- from_token=from_token,
- sent_room_ids=relevant_rooms_to_send_map.keys(),
- unsent_room_ids=unsent_room_ids,
- )
- elif from_token:
- connection_position = from_token.connection_position
- else:
- # Initial sync without a `from_token` starts at `0`
- connection_position = 0
-
- sliding_sync_result = SlidingSyncResult(
- next_pos=SlidingSyncStreamToken(to_token, connection_position),
- lists=lists,
- rooms=rooms,
- extensions=extensions,
- )
-
- # Make it easy to find traces for syncs that aren't empty
- set_tag(SynapseTags.RESULT_PREFIX + "result", bool(sliding_sync_result))
- set_tag(SynapseTags.FUNC_ARG_PREFIX + "sync_config.user", user_id)
-
- return sliding_sync_result
-
- @trace
- async def get_room_membership_for_user_at_to_token(
- self,
- user: UserID,
- to_token: StreamToken,
- from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
- """
- Fetch room IDs that the user has had membership in (the full room list including
- long-lost left rooms that will be filtered, sorted, and sliced).
-
- We're looking for rooms where the user has had any sort of membership in the
- token range (> `from_token` and <= `to_token`)
-
- In order for bans/kicks to not show up, you need to `/forget` those rooms. This
- doesn't modify the event itself though and only adds the `forgotten` flag to the
- `room_memberships` table in Synapse. There isn't a way to tell when a room was
- forgotten at the moment so we can't factor it into the token range.
-
- Args:
- user: User to fetch rooms for
- to_token: The token to fetch rooms up to.
- from_token: The point in the stream to sync from.
-
- Returns:
- A dictionary of room IDs that the user has had membership in along with
- membership information in that room at the time of `to_token`.
- """
- user_id = user.to_string()
-
- # First grab a current snapshot rooms for the user
- # (also handles forgotten rooms)
- room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is(
- user_id=user_id,
- # We want to fetch any kind of membership (joined and left rooms) in order
- # to get the `event_pos` of the latest room membership event for the
- # user.
- membership_list=Membership.LIST,
- excluded_rooms=self.rooms_to_exclude_globally,
- )
-
- # If the user has never joined any rooms before, we can just return an empty list
- if not room_for_user_list:
- return {}
-
- # Our working list of rooms that can show up in the sync response
- sync_room_id_set = {
- # Note: The `room_for_user` we're assigning here will need to be fixed up
- # (below) because they are potentially from the current snapshot time
- # instead from the time of the `to_token`.
- room_for_user.room_id: _RoomMembershipForUser(
- room_id=room_for_user.room_id,
- event_id=room_for_user.event_id,
- event_pos=room_for_user.event_pos,
- membership=room_for_user.membership,
- sender=room_for_user.sender,
- # We will update these fields below to be accurate
- newly_joined=False,
- newly_left=False,
- is_dm=False,
- )
- for room_for_user in room_for_user_list
- }
-
- # Get the `RoomStreamToken` that represents the spot we queried up to when we got
- # our membership snapshot from `get_rooms_for_local_user_where_membership_is()`.
- #
- # First, we need to get the max stream_ordering of each event persister instance
- # that we queried events from.
- instance_to_max_stream_ordering_map: Dict[str, int] = {}
- for room_for_user in room_for_user_list:
- instance_name = room_for_user.event_pos.instance_name
- stream_ordering = room_for_user.event_pos.stream
-
- current_instance_max_stream_ordering = (
- instance_to_max_stream_ordering_map.get(instance_name)
- )
- if (
- current_instance_max_stream_ordering is None
- or stream_ordering > current_instance_max_stream_ordering
- ):
- instance_to_max_stream_ordering_map[instance_name] = stream_ordering
-
- # Then assemble the `RoomStreamToken`
- min_stream_pos = min(instance_to_max_stream_ordering_map.values())
- membership_snapshot_token = RoomStreamToken(
- # Minimum position in the `instance_map`
- stream=min_stream_pos,
- instance_map=immutabledict(
- {
- instance_name: stream_pos
- for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
- if stream_pos > min_stream_pos
- }
- ),
- )
-
- # Since we fetched the users room list at some point in time after the from/to
- # tokens, we need to revert/rewind some membership changes to match the point in
- # time of the `to_token`. In particular, we need to make these fixups:
- #
- # - 1a) Remove rooms that the user joined after the `to_token`
- # - 1b) Add back rooms that the user left after the `to_token`
- # - 1c) Update room membership events to the point in time of the `to_token`
- # - 2) Figure out which rooms are `newly_left` rooms (> `from_token` and <= `to_token`)
- # - 3) Figure out which rooms are `newly_joined` (> `from_token` and <= `to_token`)
- # - 4) Figure out which rooms are DM's
-
- # 1) Fetch membership changes that fall in the range from `to_token` up to
- # `membership_snapshot_token`
- #
- # If our `to_token` is already the same or ahead of the latest room membership
- # for the user, we don't need to do any "2)" fix-ups and can just straight-up
- # use the room list from the snapshot as a base (nothing has changed)
- current_state_delta_membership_changes_after_to_token = []
- if not membership_snapshot_token.is_before_or_eq(to_token.room_key):
- current_state_delta_membership_changes_after_to_token = (
- await self.store.get_current_state_delta_membership_changes_for_user(
- user_id,
- from_key=to_token.room_key,
- to_key=membership_snapshot_token,
- excluded_room_ids=self.rooms_to_exclude_globally,
- )
- )
-
- # 1) Assemble a list of the first membership event after the `to_token` so we can
- # step backward to the previous membership that would apply to the from/to
- # range.
- first_membership_change_by_room_id_after_to_token: Dict[
- str, CurrentStateDeltaMembership
- ] = {}
- for membership_change in current_state_delta_membership_changes_after_to_token:
- # Only set if we haven't already set it
- first_membership_change_by_room_id_after_to_token.setdefault(
- membership_change.room_id, membership_change
- )
-
- # 1) Fixup
- #
- # Since we fetched a snapshot of the users room list at some point in time after
- # the from/to tokens, we need to revert/rewind some membership changes to match
- # the point in time of the `to_token`.
- for (
- room_id,
- first_membership_change_after_to_token,
- ) in first_membership_change_by_room_id_after_to_token.items():
- # 1a) Remove rooms that the user joined after the `to_token`
- if first_membership_change_after_to_token.prev_event_id is None:
- sync_room_id_set.pop(room_id, None)
- # 1b) 1c) From the first membership event after the `to_token`, step backward to the
- # previous membership that would apply to the from/to range.
- else:
- # We don't expect these fields to be `None` if we have a `prev_event_id`
- # but we're being defensive since it's possible that the prev event was
- # culled from the database.
- if (
- first_membership_change_after_to_token.prev_event_pos is not None
- and first_membership_change_after_to_token.prev_membership
- is not None
- ):
- sync_room_id_set[room_id] = _RoomMembershipForUser(
- room_id=room_id,
- event_id=first_membership_change_after_to_token.prev_event_id,
- event_pos=first_membership_change_after_to_token.prev_event_pos,
- membership=first_membership_change_after_to_token.prev_membership,
- sender=first_membership_change_after_to_token.prev_sender,
- # We will update these fields below to be accurate
- newly_joined=False,
- newly_left=False,
- is_dm=False,
- )
- else:
- # If we can't find the previous membership event, we shouldn't
- # include the room in the sync response since we can't determine the
- # exact membership state and shouldn't rely on the current snapshot.
- sync_room_id_set.pop(room_id, None)
-
- # 2) Fetch membership changes that fall in the range from `from_token` up to `to_token`
- current_state_delta_membership_changes_in_from_to_range = []
- if from_token:
- current_state_delta_membership_changes_in_from_to_range = (
- await self.store.get_current_state_delta_membership_changes_for_user(
- user_id,
- from_key=from_token.room_key,
- to_key=to_token.room_key,
- excluded_room_ids=self.rooms_to_exclude_globally,
- )
- )
-
- # 2) Assemble a list of the last membership events in some given ranges. Someone
- # could have left and joined multiple times during the given range but we only
- # care about end-result so we grab the last one.
- last_membership_change_by_room_id_in_from_to_range: Dict[
- str, CurrentStateDeltaMembership
- ] = {}
- # We also want to assemble a list of the first membership events during the token
- # range so we can step backward to the previous membership that would apply to
- # before the token range to see if we have `newly_joined` the room.
- first_membership_change_by_room_id_in_from_to_range: Dict[
- str, CurrentStateDeltaMembership
- ] = {}
- # Keep track if the room has a non-join event in the token range so we can later
- # tell if it was a `newly_joined` room. If the last membership event in the
- # token range is a join and there is also some non-join in the range, we know
- # they `newly_joined`.
- has_non_join_event_by_room_id_in_from_to_range: Dict[str, bool] = {}
- for (
- membership_change
- ) in current_state_delta_membership_changes_in_from_to_range:
- room_id = membership_change.room_id
-
- last_membership_change_by_room_id_in_from_to_range[room_id] = (
- membership_change
- )
- # Only set if we haven't already set it
- first_membership_change_by_room_id_in_from_to_range.setdefault(
- room_id, membership_change
- )
-
- if membership_change.membership != Membership.JOIN:
- has_non_join_event_by_room_id_in_from_to_range[room_id] = True
-
- # 2) Fixup
- #
- # 3) We also want to assemble a list of possibly newly joined rooms. Someone
- # could have left and joined multiple times during the given range but we only
- # care about whether they are joined at the end of the token range so we are
- # working with the last membership even in the token range.
- possibly_newly_joined_room_ids = set()
- for (
- last_membership_change_in_from_to_range
- ) in last_membership_change_by_room_id_in_from_to_range.values():
- room_id = last_membership_change_in_from_to_range.room_id
-
- # 3)
- if last_membership_change_in_from_to_range.membership == Membership.JOIN:
- possibly_newly_joined_room_ids.add(room_id)
-
- # 2) Figure out newly_left rooms (> `from_token` and <= `to_token`).
- if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
- # 2) Mark this room as `newly_left`
-
- # If we're seeing a membership change here, we should expect to already
- # have it in our snapshot but if a state reset happens, it wouldn't have
- # shown up in our snapshot but appear as a change here.
- existing_sync_entry = sync_room_id_set.get(room_id)
- if existing_sync_entry is not None:
- # Normal expected case
- sync_room_id_set[room_id] = existing_sync_entry.copy_and_replace(
- newly_left=True
- )
- else:
- # State reset!
- logger.warn(
- "State reset detected for room_id %s with %s who is no longer in the room",
- room_id,
- user_id,
- )
- # Even though a state reset happened which removed the person from
- # the room, we still add it the list so the user knows they left the
- # room. Downstream code can check for a state reset by looking for
- # `event_id=None and membership is not None`.
- sync_room_id_set[room_id] = _RoomMembershipForUser(
- room_id=room_id,
- event_id=last_membership_change_in_from_to_range.event_id,
- event_pos=last_membership_change_in_from_to_range.event_pos,
- membership=last_membership_change_in_from_to_range.membership,
- sender=last_membership_change_in_from_to_range.sender,
- newly_joined=False,
- newly_left=True,
- is_dm=False,
- )
-
- # 3) Figure out `newly_joined`
- for room_id in possibly_newly_joined_room_ids:
- has_non_join_in_from_to_range = (
- has_non_join_event_by_room_id_in_from_to_range.get(room_id, False)
- )
- # If the last membership event in the token range is a join and there is
- # also some non-join in the range, we know they `newly_joined`.
- if has_non_join_in_from_to_range:
- # We found a `newly_joined` room (we left and joined within the token range)
- sync_room_id_set[room_id] = sync_room_id_set[room_id].copy_and_replace(
- newly_joined=True
- )
- else:
- prev_event_id = first_membership_change_by_room_id_in_from_to_range[
- room_id
- ].prev_event_id
- prev_membership = first_membership_change_by_room_id_in_from_to_range[
- room_id
- ].prev_membership
-
- if prev_event_id is None:
- # We found a `newly_joined` room (we are joining the room for the
- # first time within the token range)
- sync_room_id_set[room_id] = sync_room_id_set[
- room_id
- ].copy_and_replace(newly_joined=True)
- # Last resort, we need to step back to the previous membership event
- # just before the token range to see if we're joined then or not.
- elif prev_membership != Membership.JOIN:
- # We found a `newly_joined` room (we left before the token range
- # and joined within the token range)
- sync_room_id_set[room_id] = sync_room_id_set[
- room_id
- ].copy_and_replace(newly_joined=True)
-
- # 4) Figure out which rooms the user considers to be direct-message (DM) rooms
- #
- # We're using global account data (`m.direct`) instead of checking for
- # `is_direct` on membership events because that property only appears for
- # the invitee membership event (doesn't show up for the inviter).
- #
- # We're unable to take `to_token` into account for global account data since
- # we only keep track of the latest account data for the user.
- dm_map = await self.store.get_global_account_data_by_type_for_user(
- user_id, AccountDataTypes.DIRECT
- )
-
- # Flatten out the map. Account data is set by the client so it needs to be
- # scrutinized.
- dm_room_id_set = set()
- if isinstance(dm_map, dict):
- for room_ids in dm_map.values():
- # Account data should be a list of room IDs. Ignore anything else
- if isinstance(room_ids, list):
- for room_id in room_ids:
- if isinstance(room_id, str):
- dm_room_id_set.add(room_id)
-
- # 4) Fixup
- for room_id in sync_room_id_set:
- sync_room_id_set[room_id] = sync_room_id_set[room_id].copy_and_replace(
- is_dm=room_id in dm_room_id_set
- )
-
- return sync_room_id_set
-
- @trace
- async def filter_rooms_relevant_for_sync(
- self,
- user: UserID,
- room_membership_for_user_map: Dict[str, _RoomMembershipForUser],
- ) -> Dict[str, _RoomMembershipForUser]:
- """
- Filter room IDs that should/can be listed for this user in the sync response (the
- full room list that will be further filtered, sorted, and sliced).
-
- We're looking for rooms where the user has the following state in the token
- range (> `from_token` and <= `to_token`):
-
- - `invite`, `join`, `knock`, `ban` membership events
- - Kicks (`leave` membership events where `sender` is different from the
- `user_id`/`state_key`)
- - `newly_left` (rooms that were left during the given token range)
- - In order for bans/kicks to not show up in sync, you need to `/forget` those
- rooms. This doesn't modify the event itself though and only adds the
- `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way
- to tell when a room was forgotten at the moment so we can't factor it into the
- from/to range.
-
- Args:
- user: User that is syncing
- room_membership_for_user_map: Room membership for the user
-
- Returns:
- A dictionary of room IDs that should be listed in the sync response along
- with membership information in that room at the time of `to_token`.
- """
- user_id = user.to_string()
-
- # Filter rooms to only what we're interested to sync with
- filtered_sync_room_map = {
- room_id: room_membership_for_user
- for room_id, room_membership_for_user in room_membership_for_user_map.items()
- if filter_membership_for_sync(
- user_id=user_id,
- room_membership_for_user=room_membership_for_user,
- )
- }
-
- return filtered_sync_room_map
-
- async def check_room_subscription_allowed_for_user(
- self,
- room_id: str,
- room_membership_for_user_map: Dict[str, _RoomMembershipForUser],
- to_token: StreamToken,
- ) -> Optional[_RoomMembershipForUser]:
- """
- Check whether the user is allowed to see the room based on whether they have
- ever had membership in the room or if the room is `world_readable`.
-
- Similar to `check_user_in_room_or_world_readable(...)`
-
- Args:
- room_id: Room to check
- room_membership_for_user_map: Room membership for the user at the time of
- the `to_token` (<= `to_token`).
- to_token: The token to fetch rooms up to.
-
- Returns:
- The room membership for the user if they are allowed to subscribe to the
- room else `None`.
- """
-
- # We can first check if they are already allowed to see the room based
- # on our previous work to assemble the `room_membership_for_user_map`.
- #
- # If they have had any membership in the room over time (up to the `to_token`),
- # let them subscribe and see what they can.
- existing_membership_for_user = room_membership_for_user_map.get(room_id)
- if existing_membership_for_user is not None:
- return existing_membership_for_user
-
- # TODO: Handle `world_readable` rooms
- return None
-
- # If the room is `world_readable`, it doesn't matter whether they can join,
- # everyone can see the room.
- # not_in_room_membership_for_user = _RoomMembershipForUser(
- # room_id=room_id,
- # event_id=None,
- # event_pos=None,
- # membership=None,
- # sender=None,
- # newly_joined=False,
- # newly_left=False,
- # is_dm=False,
- # )
- # room_state = await self.get_current_state_at(
- # room_id=room_id,
- # room_membership_for_user_at_to_token=not_in_room_membership_for_user,
- # state_filter=StateFilter.from_types(
- # [(EventTypes.RoomHistoryVisibility, "")]
- # ),
- # to_token=to_token,
- # )
-
- # visibility_event = room_state.get((EventTypes.RoomHistoryVisibility, ""))
- # if (
- # visibility_event is not None
- # and visibility_event.content.get("history_visibility")
- # == HistoryVisibility.WORLD_READABLE
- # ):
- # return not_in_room_membership_for_user
-
- # return None
-
- @trace
- async def _bulk_get_stripped_state_for_rooms_from_sync_room_map(
- self,
- room_ids: StrCollection,
- sync_room_map: Dict[str, _RoomMembershipForUser],
- ) -> Dict[str, Optional[StateMap[StrippedStateEvent]]]:
- """
- Fetch stripped state for a list of room IDs. Stripped state is only
- applicable to invite/knock rooms. Other rooms will have `None` as their
- stripped state.
-
- For invite rooms, we pull from `unsigned.invite_room_state`.
- For knock rooms, we pull from `unsigned.knock_room_state`.
-
- Args:
- room_ids: Room IDs to fetch stripped state for
- sync_room_map: Dictionary of room IDs to sort along with membership
- information in the room at the time of `to_token`.
-
- Returns:
- Mapping from room_id to mapping of (type, state_key) to stripped state
- event.
- """
- room_id_to_stripped_state_map: Dict[
- str, Optional[StateMap[StrippedStateEvent]]
- ] = {}
-
- # Fetch what we haven't before
- room_ids_to_fetch = [
- room_id
- for room_id in room_ids
- if room_id not in room_id_to_stripped_state_map
- ]
-
- # Gather a list of event IDs we can grab stripped state from
- invite_or_knock_event_ids: List[str] = []
- for room_id in room_ids_to_fetch:
- if sync_room_map[room_id].membership in (
- Membership.INVITE,
- Membership.KNOCK,
- ):
- event_id = sync_room_map[room_id].event_id
- # If this is an invite/knock then there should be an event_id
- assert event_id is not None
- invite_or_knock_event_ids.append(event_id)
- else:
- room_id_to_stripped_state_map[room_id] = None
-
- invite_or_knock_events = await self.store.get_events(invite_or_knock_event_ids)
- for invite_or_knock_event in invite_or_knock_events.values():
- room_id = invite_or_knock_event.room_id
- membership = invite_or_knock_event.membership
-
- raw_stripped_state_events = None
- if membership == Membership.INVITE:
- invite_room_state = invite_or_knock_event.unsigned.get(
- "invite_room_state"
- )
- raw_stripped_state_events = invite_room_state
- elif membership == Membership.KNOCK:
- knock_room_state = invite_or_knock_event.unsigned.get(
- "knock_room_state"
- )
- raw_stripped_state_events = knock_room_state
- else:
- raise AssertionError(
- f"Unexpected membership {membership} (this is a problem with Synapse itself)"
- )
-
- stripped_state_map: Optional[MutableStateMap[StrippedStateEvent]] = None
- # Scrutinize unsigned things. `raw_stripped_state_events` should be a list
- # of stripped events
- if raw_stripped_state_events is not None:
- stripped_state_map = {}
- if isinstance(raw_stripped_state_events, list):
- for raw_stripped_event in raw_stripped_state_events:
- stripped_state_event = parse_stripped_state_event(
- raw_stripped_event
- )
- if stripped_state_event is not None:
- stripped_state_map[
- (
- stripped_state_event.type,
- stripped_state_event.state_key,
- )
- ] = stripped_state_event
-
- room_id_to_stripped_state_map[room_id] = stripped_state_map
-
- return room_id_to_stripped_state_map
-
- @trace
- async def _bulk_get_partial_current_state_content_for_rooms(
- self,
- content_type: Literal[
- # `content.type` from `EventTypes.Create``
- "room_type",
- # `content.algorithm` from `EventTypes.RoomEncryption`
- "room_encryption",
- ],
- room_ids: Set[str],
- sync_room_map: Dict[str, _RoomMembershipForUser],
- to_token: StreamToken,
- room_id_to_stripped_state_map: Dict[
- str, Optional[StateMap[StrippedStateEvent]]
- ],
- ) -> Mapping[str, Union[Optional[str], StateSentinel]]:
- """
- Get the given state event content for a list of rooms. First we check the
- current state of the room, then fallback to stripped state if available, then
- historical state.
-
- Args:
- content_type: Which content to grab
- room_ids: Room IDs to fetch the given content field for.
- sync_room_map: Dictionary of room IDs to sort along with membership
- information in the room at the time of `to_token`.
- to_token: We filter based on the state of the room at this token
- room_id_to_stripped_state_map: This does not need to be filled in before
- calling this function. Mapping from room_id to mapping of (type, state_key)
- to stripped state event. Modified in place when we fetch new rooms so we can
- save work next time this function is called.
-
- Returns:
- A mapping from room ID to the state event content if the room has
- the given state event (event_type, ""), otherwise `None`. Rooms unknown to
- this server will return `ROOM_UNKNOWN_SENTINEL`.
- """
- room_id_to_content: Dict[str, Union[Optional[str], StateSentinel]] = {}
-
- # As a bulk shortcut, use the current state if the server is particpating in the
- # room (meaning we have current state). Ideally, for leave/ban rooms, we would
- # want the state at the time of the membership instead of current state to not
- # leak anything but we consider the create/encryption stripped state events to
- # not be a secret given they are often set at the start of the room and they are
- # normally handed out on invite/knock.
- #
- # Be mindful to only use this for non-sensitive details. For example, even
- # though the room name/avatar/topic are also stripped state, they seem a lot
- # more senstive to leak the current state value of.
- #
- # Since this function is cached, we need to make a mutable copy via
- # `dict(...)`.
- event_type = ""
- event_content_field = ""
- if content_type == "room_type":
- event_type = EventTypes.Create
- event_content_field = EventContentFields.ROOM_TYPE
- room_id_to_content = dict(await self.store.bulk_get_room_type(room_ids))
- elif content_type == "room_encryption":
- event_type = EventTypes.RoomEncryption
- event_content_field = EventContentFields.ENCRYPTION_ALGORITHM
- room_id_to_content = dict(
- await self.store.bulk_get_room_encryption(room_ids)
- )
- else:
- assert_never(content_type)
-
- room_ids_with_results = [
- room_id
- for room_id, content_field in room_id_to_content.items()
- if content_field is not ROOM_UNKNOWN_SENTINEL
- ]
-
- # We might not have current room state for remote invite/knocks if we are
- # the first person on our server to see the room. The best we can do is look
- # in the optional stripped state from the invite/knock event.
- room_ids_without_results = room_ids.difference(
- chain(
- room_ids_with_results,
- [
- room_id
- for room_id, stripped_state_map in room_id_to_stripped_state_map.items()
- if stripped_state_map is not None
- ],
- )
- )
- room_id_to_stripped_state_map.update(
- await self._bulk_get_stripped_state_for_rooms_from_sync_room_map(
- room_ids_without_results, sync_room_map
- )
- )
-
- # Update our `room_id_to_content` map based on the stripped state
- # (applies to invite/knock rooms)
- rooms_ids_without_stripped_state: Set[str] = set()
- for room_id in room_ids_without_results:
- stripped_state_map = room_id_to_stripped_state_map.get(
- room_id, Sentinel.UNSET_SENTINEL
- )
- assert stripped_state_map is not Sentinel.UNSET_SENTINEL, (
- f"Stripped state left unset for room {room_id}. "
- + "Make sure you're calling `_bulk_get_stripped_state_for_rooms_from_sync_room_map(...)` "
- + "with that room_id. (this is a problem with Synapse itself)"
- )
-
- # If there is some stripped state, we assume the remote server passed *all*
- # of the potential stripped state events for the room.
- if stripped_state_map is not None:
- create_stripped_event = stripped_state_map.get((EventTypes.Create, ""))
- stripped_event = stripped_state_map.get((event_type, ""))
- # Sanity check that we at-least have the create event
- if create_stripped_event is not None:
- if stripped_event is not None:
- room_id_to_content[room_id] = stripped_event.content.get(
- event_content_field
- )
- else:
- # Didn't see the state event we're looking for in the stripped
- # state so we can assume relevant content field is `None`.
- room_id_to_content[room_id] = None
- else:
- rooms_ids_without_stripped_state.add(room_id)
-
- # Last resort, we might not have current room state for rooms that the
- # server has left (no one local is in the room) but we can look at the
- # historical state.
- #
- # Update our `room_id_to_content` map based on the state at the time of
- # the membership event.
- for room_id in rooms_ids_without_stripped_state:
- # TODO: It would be nice to look this up in a bulk way (N+1 queries)
- #
- # TODO: `get_state_at(...)` doesn't take into account the "current state".
- room_state = await self.storage_controllers.state.get_state_at(
- room_id=room_id,
- stream_position=to_token.copy_and_replace(
- StreamKeyType.ROOM,
- sync_room_map[room_id].event_pos.to_room_stream_token(),
- ),
- state_filter=StateFilter.from_types(
- [
- (EventTypes.Create, ""),
- (event_type, ""),
- ]
- ),
- # Partially-stated rooms should have all state events except for
- # remote membership events so we don't need to wait at all because
- # we only want the create event and some non-member event.
- await_full_state=False,
- )
- # We can use the create event as a canary to tell whether the server has
- # seen the room before
- create_event = room_state.get((EventTypes.Create, ""))
- state_event = room_state.get((event_type, ""))
-
- if create_event is None:
- # Skip for unknown rooms
- continue
-
- if state_event is not None:
- room_id_to_content[room_id] = state_event.content.get(
- event_content_field
- )
- else:
- # Didn't see the state event we're looking for in the stripped
- # state so we can assume relevant content field is `None`.
- room_id_to_content[room_id] = None
-
- return room_id_to_content
-
- @trace
- async def filter_rooms(
- self,
- user: UserID,
- sync_room_map: Dict[str, _RoomMembershipForUser],
- filters: SlidingSyncConfig.SlidingSyncList.Filters,
- to_token: StreamToken,
- ) -> Dict[str, _RoomMembershipForUser]:
- """
- Filter rooms based on the sync request.
-
- Args:
- user: User to filter rooms for
- sync_room_map: Dictionary of room IDs to sort along with membership
- information in the room at the time of `to_token`.
- filters: Filters to apply
- to_token: We filter based on the state of the room at this token
-
- Returns:
- A filtered dictionary of room IDs along with membership information in the
- room at the time of `to_token`.
- """
- room_id_to_stripped_state_map: Dict[
- str, Optional[StateMap[StrippedStateEvent]]
- ] = {}
-
- filtered_room_id_set = set(sync_room_map.keys())
-
- # Filter for Direct-Message (DM) rooms
- if filters.is_dm is not None:
- with start_active_span("filters.is_dm"):
- if filters.is_dm:
- # Only DM rooms please
- filtered_room_id_set = {
- room_id
- for room_id in filtered_room_id_set
- if sync_room_map[room_id].is_dm
- }
- else:
- # Only non-DM rooms please
- filtered_room_id_set = {
- room_id
- for room_id in filtered_room_id_set
- if not sync_room_map[room_id].is_dm
- }
-
- if filters.spaces is not None:
- with start_active_span("filters.spaces"):
- raise NotImplementedError()
-
- # Filter for encrypted rooms
- if filters.is_encrypted is not None:
- with start_active_span("filters.is_encrypted"):
- room_id_to_encryption = (
- await self._bulk_get_partial_current_state_content_for_rooms(
- content_type="room_encryption",
- room_ids=filtered_room_id_set,
- to_token=to_token,
- sync_room_map=sync_room_map,
- room_id_to_stripped_state_map=room_id_to_stripped_state_map,
- )
- )
-
- # Make a copy so we don't run into an error: `Set changed size during
- # iteration`, when we filter out and remove items
- for room_id in filtered_room_id_set.copy():
- encryption = room_id_to_encryption.get(
- room_id, ROOM_UNKNOWN_SENTINEL
- )
-
- # Just remove rooms if we can't determine their encryption status
- if encryption is ROOM_UNKNOWN_SENTINEL:
- filtered_room_id_set.remove(room_id)
- continue
-
- # If we're looking for encrypted rooms, filter out rooms that are not
- # encrypted and vice versa
- is_encrypted = encryption is not None
- if (filters.is_encrypted and not is_encrypted) or (
- not filters.is_encrypted and is_encrypted
- ):
- filtered_room_id_set.remove(room_id)
-
- # Filter for rooms that the user has been invited to
- if filters.is_invite is not None:
- with start_active_span("filters.is_invite"):
- # Make a copy so we don't run into an error: `Set changed size during
- # iteration`, when we filter out and remove items
- for room_id in filtered_room_id_set.copy():
- room_for_user = sync_room_map[room_id]
- # If we're looking for invite rooms, filter out rooms that the user is
- # not invited to and vice versa
- if (
- filters.is_invite
- and room_for_user.membership != Membership.INVITE
- ) or (
- not filters.is_invite
- and room_for_user.membership == Membership.INVITE
- ):
- filtered_room_id_set.remove(room_id)
-
- # Filter by room type (space vs room, etc). A room must match one of the types
- # provided in the list. `None` is a valid type for rooms which do not have a
- # room type.
- if filters.room_types is not None or filters.not_room_types is not None:
- with start_active_span("filters.room_types"):
- room_id_to_type = (
- await self._bulk_get_partial_current_state_content_for_rooms(
- content_type="room_type",
- room_ids=filtered_room_id_set,
- to_token=to_token,
- sync_room_map=sync_room_map,
- room_id_to_stripped_state_map=room_id_to_stripped_state_map,
- )
- )
-
- # Make a copy so we don't run into an error: `Set changed size during
- # iteration`, when we filter out and remove items
- for room_id in filtered_room_id_set.copy():
- room_type = room_id_to_type.get(room_id, ROOM_UNKNOWN_SENTINEL)
-
- # Just remove rooms if we can't determine their type
- if room_type is ROOM_UNKNOWN_SENTINEL:
- filtered_room_id_set.remove(room_id)
- continue
-
- if (
- filters.room_types is not None
- and room_type not in filters.room_types
- ):
- filtered_room_id_set.remove(room_id)
-
- if (
- filters.not_room_types is not None
- and room_type in filters.not_room_types
- ):
- filtered_room_id_set.remove(room_id)
-
- if filters.room_name_like is not None:
- with start_active_span("filters.room_name_like"):
- # TODO: The room name is a bit more sensitive to leak than the
- # create/encryption event. Maybe we should consider a better way to fetch
- # historical state before implementing this.
- #
- # room_id_to_create_content = await self._bulk_get_partial_current_state_content_for_rooms(
- # content_type="room_name",
- # room_ids=filtered_room_id_set,
- # to_token=to_token,
- # sync_room_map=sync_room_map,
- # room_id_to_stripped_state_map=room_id_to_stripped_state_map,
- # )
- raise NotImplementedError()
-
- if filters.tags is not None or filters.not_tags is not None:
- with start_active_span("filters.tags"):
- raise NotImplementedError()
-
- # Assemble a new sync room map but only with the `filtered_room_id_set`
- return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set}
-
- @trace
- async def sort_rooms(
- self,
- sync_room_map: Dict[str, _RoomMembershipForUser],
- to_token: StreamToken,
- ) -> List[_RoomMembershipForUser]:
- """
- Sort by `stream_ordering` of the last event that the user should see in the
- room. `stream_ordering` is unique so we get a stable sort.
-
- Args:
- sync_room_map: Dictionary of room IDs to sort along with membership
- information in the room at the time of `to_token`.
- to_token: We sort based on the events in the room at this token (<= `to_token`)
-
- Returns:
- A sorted list of room IDs by `stream_ordering` along with membership information.
- """
-
- # Assemble a map of room ID to the `stream_ordering` of the last activity that the
- # user should see in the room (<= `to_token`)
- last_activity_in_room_map: Dict[str, int] = {}
-
- for room_id, room_for_user in sync_room_map.items():
- if room_for_user.membership != Membership.JOIN:
- # If the user has left/been invited/knocked/been banned from a
- # room, they shouldn't see anything past that point.
- #
- # FIXME: It's possible that people should see beyond this point
- # in invited/knocked cases if for example the room has
- # `invite`/`world_readable` history visibility, see
- # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
- last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
-
- # For fully-joined rooms, we find the latest activity at/before the
- # `to_token`.
- joined_room_positions = (
- await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering(
- [
- room_id
- for room_id, room_for_user in sync_room_map.items()
- if room_for_user.membership == Membership.JOIN
- ],
- to_token.room_key,
- )
- )
-
- last_activity_in_room_map.update(joined_room_positions)
-
- return sorted(
- sync_room_map.values(),
- # Sort by the last activity (stream_ordering) in the room
- key=lambda room_info: last_activity_in_room_map[room_info.room_id],
- # We want descending order
- reverse=True,
- )
-
- @trace
- async def get_current_state_ids_at(
- self,
- room_id: str,
- room_membership_for_user_at_to_token: _RoomMembershipForUser,
- state_filter: StateFilter,
- to_token: StreamToken,
- ) -> StateMap[str]:
- """
- Get current state IDs for the user in the room according to their membership. This
- will be the current state at the time of their LEAVE/BAN, otherwise will be the
- current state <= to_token.
-
- Args:
- room_id: The room ID to fetch data for
- room_membership_for_user_at_token: Membership information for the user
- in the room at the time of `to_token`.
- to_token: The point in the stream to sync up to.
- """
- state_ids: StateMap[str]
- # People shouldn't see past their leave/ban event
- if room_membership_for_user_at_to_token.membership in (
- Membership.LEAVE,
- Membership.BAN,
- ):
- # TODO: `get_state_ids_at(...)` doesn't take into account the "current
- # state". Maybe we need to use
- # `get_forward_extremities_for_room_at_stream_ordering(...)` to "Fetch the
- # current state at the time."
- state_ids = await self.storage_controllers.state.get_state_ids_at(
- room_id,
- stream_position=to_token.copy_and_replace(
- StreamKeyType.ROOM,
- room_membership_for_user_at_to_token.event_pos.to_room_stream_token(),
- ),
- state_filter=state_filter,
- # Partially-stated rooms should have all state events except for
- # remote membership events. Since we've already excluded
- # partially-stated rooms unless `required_state` only has
- # `["m.room.member", "$LAZY"]` for membership, we should be able to
- # retrieve everything requested. When we're lazy-loading, if there
- # are some remote senders in the timeline, we should also have their
- # membership event because we had to auth that timeline event. Plus
- # we don't want to block the whole sync waiting for this one room.
- await_full_state=False,
- )
- # Otherwise, we can get the latest current state in the room
- else:
- state_ids = await self.storage_controllers.state.get_current_state_ids(
- room_id,
- state_filter,
- # Partially-stated rooms should have all state events except for
- # remote membership events. Since we've already excluded
- # partially-stated rooms unless `required_state` only has
- # `["m.room.member", "$LAZY"]` for membership, we should be able to
- # retrieve everything requested. When we're lazy-loading, if there
- # are some remote senders in the timeline, we should also have their
- # membership event because we had to auth that timeline event. Plus
- # we don't want to block the whole sync waiting for this one room.
- await_full_state=False,
- )
- # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token`
-
- return state_ids
-
- @trace
- async def get_current_state_at(
- self,
- room_id: str,
- room_membership_for_user_at_to_token: _RoomMembershipForUser,
- state_filter: StateFilter,
- to_token: StreamToken,
- ) -> StateMap[EventBase]:
- """
- Get current state for the user in the room according to their membership. This
- will be the current state at the time of their LEAVE/BAN, otherwise will be the
- current state <= to_token.
-
- Args:
- room_id: The room ID to fetch data for
- room_membership_for_user_at_token: Membership information for the user
- in the room at the time of `to_token`.
- to_token: The point in the stream to sync up to.
- """
- state_ids = await self.get_current_state_ids_at(
- room_id=room_id,
- room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
- state_filter=state_filter,
- to_token=to_token,
- )
-
- event_map = await self.store.get_events(list(state_ids.values()))
-
- state_map = {}
- for key, event_id in state_ids.items():
- event = event_map.get(event_id)
- if event:
- state_map[key] = event
-
- return state_map
-
- async def get_room_sync_data(
- self,
- sync_config: SlidingSyncConfig,
- room_id: str,
- room_sync_config: RoomSyncConfig,
- room_membership_for_user_at_to_token: _RoomMembershipForUser,
- from_token: Optional[SlidingSyncStreamToken],
- to_token: StreamToken,
- ) -> SlidingSyncResult.RoomResult:
- """
- Fetch room data for the sync response.
-
- We fetch data according to the token range (> `from_token` and <= `to_token`).
-
- Args:
- user: User to fetch data for
- room_id: The room ID to fetch data for
- room_sync_config: Config for what data we should fetch for a room in the
- sync response.
- room_membership_for_user_at_to_token: Membership information for the user
- in the room at the time of `to_token`.
- from_token: The point in the stream to sync from.
- to_token: The point in the stream to sync up to.
- """
- user = sync_config.user
-
- set_tag(
- SynapseTags.FUNC_ARG_PREFIX + "membership",
- room_membership_for_user_at_to_token.membership,
- )
- set_tag(
- SynapseTags.FUNC_ARG_PREFIX + "timeline_limit",
- room_sync_config.timeline_limit,
- )
-
- # Determine whether we should limit the timeline to the token range.
- #
- # We should return historical messages (before token range) in the
- # following cases because we want clients to be able to show a basic
- # screen of information:
- #
- # - Initial sync (because no `from_token` to limit us anyway)
- # - When users `newly_joined`
- # - For an incremental sync where we haven't sent it down this
- # connection before
- #
- # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917
- from_bound = None
- initial = True
- if from_token and not room_membership_for_user_at_to_token.newly_joined:
- room_status = await self.connection_store.have_sent_room(
- sync_config=sync_config,
- connection_token=from_token.connection_position,
- room_id=room_id,
- )
- if room_status.status == HaveSentRoomFlag.LIVE:
- from_bound = from_token.stream_token.room_key
- initial = False
- elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
- assert room_status.last_token is not None
- from_bound = room_status.last_token
- initial = False
- elif room_status.status == HaveSentRoomFlag.NEVER:
- from_bound = None
- initial = True
- else:
- assert_never(room_status.status)
-
- log_kv({"sliding_sync.room_status": room_status})
-
- log_kv({"sliding_sync.from_bound": from_bound, "sliding_sync.initial": initial})
-
- # Assemble the list of timeline events
- #
- # FIXME: It would be nice to make the `rooms` response more uniform regardless of
- # membership. Currently, we have to make all of these optional because
- # `invite`/`knock` rooms only have `stripped_state`. See
- # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
- timeline_events: List[EventBase] = []
- bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
- limited: Optional[bool] = None
- prev_batch_token: Optional[StreamToken] = None
- num_live: Optional[int] = None
- if (
- room_sync_config.timeline_limit > 0
- # No timeline for invite/knock rooms (just `stripped_state`)
- and room_membership_for_user_at_to_token.membership
- not in (Membership.INVITE, Membership.KNOCK)
- ):
- limited = False
- # We want to start off using the `to_token` (vs `from_token`) because we look
- # backwards from the `to_token` up to the `timeline_limit` and we might not
- # reach the `from_token` before we hit the limit. We will update the room stream
- # position once we've fetched the events to point to the earliest event fetched.
- prev_batch_token = to_token
-
- # We're going to paginate backwards from the `to_token`
- to_bound = to_token.room_key
- # People shouldn't see past their leave/ban event
- if room_membership_for_user_at_to_token.membership in (
- Membership.LEAVE,
- Membership.BAN,
- ):
- to_bound = (
- room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
- )
-
- # For initial `/sync` (and other historical scenarios mentioned above), we
- # want to view a historical section of the timeline; to fetch events by
- # `topological_ordering` (best representation of the room DAG as others were
- # seeing it at the time). This also aligns with the order that `/messages`
- # returns events in.
- #
- # For incremental `/sync`, we want to get all updates for rooms since
- # the last `/sync` (regardless if those updates arrived late or happened
- # a while ago in the past); to fetch events by `stream_ordering` (in the
- # order they were received by the server).
- #
- # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917
- #
- # FIXME: Using workaround for mypy,
- # https://github.com/python/mypy/issues/10740#issuecomment-1997047277 and
- # https://github.com/python/mypy/issues/17479
- paginate_room_events_by_topological_ordering: PaginateFunction = (
- self.store.paginate_room_events_by_topological_ordering
- )
- paginate_room_events_by_stream_ordering: PaginateFunction = (
- self.store.paginate_room_events_by_stream_ordering
- )
- pagination_method: PaginateFunction = (
- # Use `topographical_ordering` for historical events
- paginate_room_events_by_topological_ordering
- if from_bound is None
- # Use `stream_ordering` for updates
- else paginate_room_events_by_stream_ordering
- )
- timeline_events, new_room_key = await pagination_method(
- room_id=room_id,
- # The bounds are reversed so we can paginate backwards
- # (from newer to older events) starting at to_bound.
- # This ensures we fill the `limit` with the newest events first,
- from_key=to_bound,
- to_key=from_bound,
- direction=Direction.BACKWARDS,
- # We add one so we can determine if there are enough events to saturate
- # the limit or not (see `limited`)
- limit=room_sync_config.timeline_limit + 1,
- )
-
- # We want to return the events in ascending order (the last event is the
- # most recent).
- timeline_events.reverse()
-
- # Determine our `limited` status based on the timeline. We do this before
- # filtering the events so we can accurately determine if there is more to
- # paginate even if we filter out some/all events.
- if len(timeline_events) > room_sync_config.timeline_limit:
- limited = True
- # Get rid of that extra "+ 1" event because we only used it to determine
- # if we hit the limit or not
- timeline_events = timeline_events[-room_sync_config.timeline_limit :]
- assert timeline_events[0].internal_metadata.stream_ordering
- new_room_key = RoomStreamToken(
- stream=timeline_events[0].internal_metadata.stream_ordering - 1
- )
-
- # Make sure we don't expose any events that the client shouldn't see
- timeline_events = await filter_events_for_client(
- self.storage_controllers,
- user.to_string(),
- timeline_events,
- is_peeking=room_membership_for_user_at_to_token.membership
- != Membership.JOIN,
- filter_send_to_client=True,
- )
- # TODO: Filter out `EventTypes.CallInvite` in public rooms,
- # see https://github.com/element-hq/synapse/issues/17359
-
- # TODO: Handle timeline gaps (`get_timeline_gaps()`)
-
- # Determine how many "live" events we have (events within the given token range).
- #
- # This is mostly useful to determine whether a given @mention event should
- # make a noise or not. Clients cannot rely solely on the absence of
- # `initial: true` to determine live events because if a room not in the
- # sliding window bumps into the window because of an @mention it will have
- # `initial: true` yet contain a single live event (with potentially other
- # old events in the timeline)
- num_live = 0
- if from_token is not None:
- for timeline_event in reversed(timeline_events):
- # This fields should be present for all persisted events
- assert timeline_event.internal_metadata.stream_ordering is not None
- assert timeline_event.internal_metadata.instance_name is not None
-
- persisted_position = PersistedEventPosition(
- instance_name=timeline_event.internal_metadata.instance_name,
- stream=timeline_event.internal_metadata.stream_ordering,
- )
- if persisted_position.persisted_after(
- from_token.stream_token.room_key
- ):
- num_live += 1
- else:
- # Since we're iterating over the timeline events in
- # reverse-chronological order, we can break once we hit an event
- # that's not live. In the future, we could potentially optimize
- # this more with a binary search (bisect).
- break
-
- # If the timeline is `limited=True`, the client does not have all events
- # necessary to calculate aggregations themselves.
- if limited:
- bundled_aggregations = (
- await self.relations_handler.get_bundled_aggregations(
- timeline_events, user.to_string()
- )
- )
-
- # Update the `prev_batch_token` to point to the position that allows us to
- # keep paginating backwards from the oldest event we return in the timeline.
- prev_batch_token = prev_batch_token.copy_and_replace(
- StreamKeyType.ROOM, new_room_key
- )
-
- # Figure out any stripped state events for invite/knocks. This allows the
- # potential joiner to identify the room.
- stripped_state: List[JsonDict] = []
- if room_membership_for_user_at_to_token.membership in (
- Membership.INVITE,
- Membership.KNOCK,
- ):
- # This should never happen. If someone is invited/knocked on room, then
- # there should be an event for it.
- assert room_membership_for_user_at_to_token.event_id is not None
-
- invite_or_knock_event = await self.store.get_event(
- room_membership_for_user_at_to_token.event_id
- )
-
- stripped_state = []
- if invite_or_knock_event.membership == Membership.INVITE:
- stripped_state.extend(
- invite_or_knock_event.unsigned.get("invite_room_state", [])
- )
- elif invite_or_knock_event.membership == Membership.KNOCK:
- stripped_state.extend(
- invite_or_knock_event.unsigned.get("knock_room_state", [])
- )
-
- stripped_state.append(strip_event(invite_or_knock_event))
-
- # TODO: Handle state resets. For example, if we see
- # `room_membership_for_user_at_to_token.event_id=None and
- # room_membership_for_user_at_to_token.membership is not None`, we should
- # indicate to the client that a state reset happened. Perhaps we should indicate
- # this by setting `initial: True` and empty `required_state`.
-
- # Check whether the room has a name set
- name_state_ids = await self.get_current_state_ids_at(
- room_id=room_id,
- room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
- state_filter=StateFilter.from_types([(EventTypes.Name, "")]),
- to_token=to_token,
- )
- name_event_id = name_state_ids.get((EventTypes.Name, ""))
-
- room_membership_summary: Mapping[str, MemberSummary]
- empty_membership_summary = MemberSummary([], 0)
- if room_membership_for_user_at_to_token.membership in (
- Membership.LEAVE,
- Membership.BAN,
- ):
- # TODO: Figure out how to get the membership summary for left/banned rooms
- room_membership_summary = {}
- else:
- room_membership_summary = await self.store.get_room_summary(room_id)
- # TODO: Reverse/rewind back to the `to_token`
-
- # `heroes` are required if the room name is not set.
- #
- # Note: When you're the first one on your server to be invited to a new room
- # over federation, we only have access to some stripped state in
- # `event.unsigned.invite_room_state` which currently doesn't include `heroes`,
- # see https://github.com/matrix-org/matrix-spec/issues/380. This means that
- # clients won't be able to calculate the room name when necessary and just a
- # pitfall we have to deal with until that spec issue is resolved.
- hero_user_ids: List[str] = []
- # TODO: Should we also check for `EventTypes.CanonicalAlias`
- # (`m.room.canonical_alias`) as a fallback for the room name? see
- # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153
- if name_event_id is None:
- hero_user_ids = extract_heroes_from_room_summary(
- room_membership_summary, me=user.to_string()
- )
-
- # Fetch the `required_state` for the room
- #
- # No `required_state` for invite/knock rooms (just `stripped_state`)
- #
- # FIXME: It would be nice to make the `rooms` response more uniform regardless
- # of membership. Currently, we have to make this optional because
- # `invite`/`knock` rooms only have `stripped_state`. See
- # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
- #
- # Calculate the `StateFilter` based on the `required_state` for the room
- required_state_filter = StateFilter.none()
- if room_membership_for_user_at_to_token.membership not in (
- Membership.INVITE,
- Membership.KNOCK,
- ):
- # If we have a double wildcard ("*", "*") in the `required_state`, we need
- # to fetch all state for the room
- #
- # Note: MSC3575 describes different behavior to how we're handling things
- # here but since it's not wrong to return more state than requested
- # (`required_state` is just the minimum requested), it doesn't matter if we
- # include more than client wanted. This complexity is also under scrutiny,
- # see
- # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1185109050
- #
- # > One unique exception is when you request all state events via ["*", "*"]. When used,
- # > all state events are returned by default, and additional entries FILTER OUT the returned set
- # > of state events. These additional entries cannot use '*' themselves.
- # > For example, ["*", "*"], ["m.room.member", "@alice:example.com"] will _exclude_ every m.room.member
- # > event _except_ for @alice:example.com, and include every other state event.
- # > In addition, ["*", "*"], ["m.space.child", "*"] is an error, the m.space.child filter is not
- # > required as it would have been returned anyway.
- # >
- # > -- MSC3575 (https://github.com/matrix-org/matrix-spec-proposals/pull/3575)
- if StateValues.WILDCARD in room_sync_config.required_state_map.get(
- StateValues.WILDCARD, set()
- ):
- set_tag(
- SynapseTags.FUNC_ARG_PREFIX + "required_state_wildcard",
- True,
- )
- required_state_filter = StateFilter.all()
- # TODO: `StateFilter` currently doesn't support wildcard event types. We're
- # currently working around this by returning all state to the client but it
- # would be nice to fetch less from the database and return just what the
- # client wanted.
- elif (
- room_sync_config.required_state_map.get(StateValues.WILDCARD)
- is not None
- ):
- set_tag(
- SynapseTags.FUNC_ARG_PREFIX + "required_state_wildcard_event_type",
- True,
- )
- required_state_filter = StateFilter.all()
- else:
- required_state_types: List[Tuple[str, Optional[str]]] = []
- for (
- state_type,
- state_key_set,
- ) in room_sync_config.required_state_map.items():
- num_wild_state_keys = 0
- lazy_load_room_members = False
- num_others = 0
- for state_key in state_key_set:
- if state_key == StateValues.WILDCARD:
- num_wild_state_keys += 1
- # `None` is a wildcard in the `StateFilter`
- required_state_types.append((state_type, None))
- # We need to fetch all relevant people when we're lazy-loading membership
- elif (
- state_type == EventTypes.Member
- and state_key == StateValues.LAZY
- ):
- lazy_load_room_members = True
- # Everyone in the timeline is relevant
- timeline_membership: Set[str] = set()
- if timeline_events is not None:
- for timeline_event in timeline_events:
- timeline_membership.add(timeline_event.sender)
-
- for user_id in timeline_membership:
- required_state_types.append(
- (EventTypes.Member, user_id)
- )
-
- # FIXME: We probably also care about invite, ban, kick, targets, etc
- # but the spec only mentions "senders".
- elif state_key == StateValues.ME:
- num_others += 1
- required_state_types.append((state_type, user.to_string()))
- else:
- num_others += 1
- required_state_types.append((state_type, state_key))
-
- set_tag(
- SynapseTags.FUNC_ARG_PREFIX
- + "required_state_wildcard_state_key_count",
- num_wild_state_keys,
- )
- set_tag(
- SynapseTags.FUNC_ARG_PREFIX + "required_state_lazy",
- lazy_load_room_members,
- )
- set_tag(
- SynapseTags.FUNC_ARG_PREFIX + "required_state_other_count",
- num_others,
- )
-
- required_state_filter = StateFilter.from_types(required_state_types)
-
- # We need this base set of info for the response so let's just fetch it along
- # with the `required_state` for the room
- meta_room_state = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")] + [
- (EventTypes.Member, hero_user_id) for hero_user_id in hero_user_ids
- ]
- state_filter = StateFilter.all()
- if required_state_filter != StateFilter.all():
- state_filter = StateFilter(
- types=StateFilter.from_types(
- chain(meta_room_state, required_state_filter.to_types())
- ).types,
- include_others=required_state_filter.include_others,
- )
-
- # We can return all of the state that was requested if this was the first
- # time we've sent the room down this connection.
- room_state: StateMap[EventBase] = {}
- if initial:
- room_state = await self.get_current_state_at(
- room_id=room_id,
- room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
- state_filter=state_filter,
- to_token=to_token,
- )
- else:
- assert from_bound is not None
-
- # TODO: Limit the number of state events we're about to send down
- # the room, if its too many we should change this to an
- # `initial=True`?
- deltas = await self.store.get_current_state_deltas_for_room(
- room_id=room_id,
- from_token=from_bound,
- to_token=to_token.room_key,
- )
- # TODO: Filter room state before fetching events
- # TODO: Handle state resets where event_id is None
- events = await self.store.get_events(
- [d.event_id for d in deltas if d.event_id]
- )
- room_state = {(s.type, s.state_key): s for s in events.values()}
-
- required_room_state: StateMap[EventBase] = {}
- if required_state_filter != StateFilter.none():
- required_room_state = required_state_filter.filter_state(room_state)
-
- # Find the room name and avatar from the state
- room_name: Optional[str] = None
- # TODO: Should we also check for `EventTypes.CanonicalAlias`
- # (`m.room.canonical_alias`) as a fallback for the room name? see
- # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153
- name_event = room_state.get((EventTypes.Name, ""))
- if name_event is not None:
- room_name = name_event.content.get("name")
-
- room_avatar: Optional[str] = None
- avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
- if avatar_event is not None:
- room_avatar = avatar_event.content.get("url")
-
- # Assemble heroes: extract the info from the state we just fetched
- heroes: List[SlidingSyncResult.RoomResult.StrippedHero] = []
- for hero_user_id in hero_user_ids:
- member_event = room_state.get((EventTypes.Member, hero_user_id))
- if member_event is not None:
- heroes.append(
- SlidingSyncResult.RoomResult.StrippedHero(
- user_id=hero_user_id,
- display_name=member_event.content.get("displayname"),
- avatar_url=member_event.content.get("avatar_url"),
- )
- )
-
- # Figure out the last bump event in the room
- last_bump_event_result = (
- await self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id, to_token.room_key, event_types=DEFAULT_BUMP_EVENT_TYPES
- )
- )
-
- # By default, just choose the membership event position
- bump_stamp = room_membership_for_user_at_to_token.event_pos.stream
- # But if we found a bump event, use that instead
- if last_bump_event_result is not None:
- _, new_bump_event_pos = last_bump_event_result
-
- # If we've just joined a remote room, then the last bump event may
- # have been backfilled (and so have a negative stream ordering).
- # These negative stream orderings can't sensibly be compared, so
- # instead we use the membership event position.
- if new_bump_event_pos.stream > 0:
- bump_stamp = new_bump_event_pos.stream
-
- set_tag(SynapseTags.RESULT_PREFIX + "initial", initial)
-
- return SlidingSyncResult.RoomResult(
- name=room_name,
- avatar=room_avatar,
- heroes=heroes,
- is_dm=room_membership_for_user_at_to_token.is_dm,
- initial=initial,
- required_state=list(required_room_state.values()),
- timeline_events=timeline_events,
- bundled_aggregations=bundled_aggregations,
- stripped_state=stripped_state,
- prev_batch=prev_batch_token,
- limited=limited,
- num_live=num_live,
- bump_stamp=bump_stamp,
- joined_count=room_membership_summary.get(
- Membership.JOIN, empty_membership_summary
- ).count,
- invited_count=room_membership_summary.get(
- Membership.INVITE, empty_membership_summary
- ).count,
- # TODO: These are just dummy values. We could potentially just remove these
- # since notifications can only really be done correctly on the client anyway
- # (encrypted rooms).
- notification_count=0,
- highlight_count=0,
- )
-
- @trace
- async def get_extensions_response(
- self,
- sync_config: SlidingSyncConfig,
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> SlidingSyncResult.Extensions:
- """Handle extension requests.
-
- Args:
- sync_config: Sync configuration
- actual_lists: Sliding window API. A map of list key to list results in the
- Sliding Sync response.
- actual_room_ids: The actual room IDs in the the Sliding Sync response.
- actual_room_response_map: A map of room ID to room results in the the
- Sliding Sync response.
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
-
- if sync_config.extensions is None:
- return SlidingSyncResult.Extensions()
-
- to_device_response = None
- if sync_config.extensions.to_device is not None:
- to_device_response = await self.get_to_device_extension_response(
- sync_config=sync_config,
- to_device_request=sync_config.extensions.to_device,
- to_token=to_token,
- )
-
- e2ee_response = None
- if sync_config.extensions.e2ee is not None:
- e2ee_response = await self.get_e2ee_extension_response(
- sync_config=sync_config,
- e2ee_request=sync_config.extensions.e2ee,
- to_token=to_token,
- from_token=from_token,
- )
-
- account_data_response = None
- if sync_config.extensions.account_data is not None:
- account_data_response = await self.get_account_data_extension_response(
- sync_config=sync_config,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- account_data_request=sync_config.extensions.account_data,
- to_token=to_token,
- from_token=from_token,
- )
-
- receipts_response = None
- if sync_config.extensions.receipts is not None:
- receipts_response = await self.get_receipts_extension_response(
- sync_config=sync_config,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- actual_room_response_map=actual_room_response_map,
- receipts_request=sync_config.extensions.receipts,
- to_token=to_token,
- from_token=from_token,
- )
-
- typing_response = None
- if sync_config.extensions.typing is not None:
- typing_response = await self.get_typing_extension_response(
- sync_config=sync_config,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- actual_room_response_map=actual_room_response_map,
- typing_request=sync_config.extensions.typing,
- to_token=to_token,
- from_token=from_token,
- )
-
- return SlidingSyncResult.Extensions(
- to_device=to_device_response,
- e2ee=e2ee_response,
- account_data=account_data_response,
- receipts=receipts_response,
- typing=typing_response,
- )
-
- def find_relevant_room_ids_for_extension(
- self,
- requested_lists: Optional[List[str]],
- requested_room_ids: Optional[List[str]],
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- ) -> Set[str]:
- """
- Handle the reserved `lists`/`rooms` keys for extensions. Extensions should only
- return results for rooms in the Sliding Sync response. This matches up the
- requested rooms/lists with the actual lists/rooms in the Sliding Sync response.
-
- {"lists": []} // Do not process any lists.
- {"lists": ["rooms", "dms"]} // Process only a subset of lists.
- {"lists": ["*"]} // Process all lists defined in the Sliding Window API. (This is the default.)
-
- {"rooms": []} // Do not process any specific rooms.
- {"rooms": ["!a:b", "!c:d"]} // Process only a subset of room subscriptions.
- {"rooms": ["*"]} // Process all room subscriptions defined in the Room Subscription API. (This is the default.)
-
- Args:
- requested_lists: The `lists` from the extension request.
- requested_room_ids: The `rooms` from the extension request.
- actual_lists: The actual lists from the Sliding Sync response.
- actual_room_ids: The actual room subscriptions from the Sliding Sync request.
- """
-
- # We only want to include account data for rooms that are already in the sliding
- # sync response AND that were requested in the account data request.
- relevant_room_ids: Set[str] = set()
-
- # See what rooms from the room subscriptions we should get account data for
- if requested_room_ids is not None:
- for room_id in requested_room_ids:
- # A wildcard means we process all rooms from the room subscriptions
- if room_id == "*":
- relevant_room_ids.update(actual_room_ids)
- break
-
- if room_id in actual_room_ids:
- relevant_room_ids.add(room_id)
-
- # See what rooms from the sliding window lists we should get account data for
- if requested_lists is not None:
- for list_key in requested_lists:
- # Just some typing because we share the variable name in multiple places
- actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None
-
- # A wildcard means we process rooms from all lists
- if list_key == "*":
- for actual_list in actual_lists.values():
- # We only expect a single SYNC operation for any list
- assert len(actual_list.ops) == 1
- sync_op = actual_list.ops[0]
- assert sync_op.op == OperationType.SYNC
-
- relevant_room_ids.update(sync_op.room_ids)
-
- break
-
- actual_list = actual_lists.get(list_key)
- if actual_list is not None:
- # We only expect a single SYNC operation for any list
- assert len(actual_list.ops) == 1
- sync_op = actual_list.ops[0]
- assert sync_op.op == OperationType.SYNC
-
- relevant_room_ids.update(sync_op.room_ids)
-
- return relevant_room_ids
-
- @trace
- async def get_to_device_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
- to_token: StreamToken,
- ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
- """Handle to-device extension (MSC3885)
-
- Args:
- sync_config: Sync configuration
- to_device_request: The to-device extension from the request
- to_token: The point in the stream to sync up to.
- """
- user_id = sync_config.user.to_string()
- device_id = sync_config.requester.device_id
-
- # Skip if the extension is not enabled
- if not to_device_request.enabled:
- return None
-
- # Check that this request has a valid device ID (not all requests have
- # to belong to a device, and so device_id is None)
- if device_id is None:
- return SlidingSyncResult.Extensions.ToDeviceExtension(
- next_batch=f"{to_token.to_device_key}",
- events=[],
- )
-
- since_stream_id = 0
- if to_device_request.since is not None:
- # We've already validated this is an int.
- since_stream_id = int(to_device_request.since)
-
- if to_token.to_device_key < since_stream_id:
- # The since token is ahead of our current token, so we return an
- # empty response.
- logger.warning(
- "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
- since_stream_id,
- to_token.to_device_key,
- )
- return SlidingSyncResult.Extensions.ToDeviceExtension(
- next_batch=to_device_request.since,
- events=[],
- )
-
- # Delete everything before the given since token, as we know the
- # device must have received them.
- deleted = await self.store.delete_messages_for_device(
- user_id=user_id,
- device_id=device_id,
- up_to_stream_id=since_stream_id,
- )
-
- logger.debug(
- "Deleted %d to-device messages up to %d for %s",
- deleted,
- since_stream_id,
- user_id,
- )
-
- messages, stream_id = await self.store.get_messages_for_device(
- user_id=user_id,
- device_id=device_id,
- from_stream_id=since_stream_id,
- to_stream_id=to_token.to_device_key,
- limit=min(to_device_request.limit, 100), # Limit to at most 100 events
- )
-
- return SlidingSyncResult.Extensions.ToDeviceExtension(
- next_batch=f"{stream_id}",
- events=messages,
- )
-
- @trace
- async def get_e2ee_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
- """Handle E2EE device extension (MSC3884)
-
- Args:
- sync_config: Sync configuration
- e2ee_request: The e2ee extension from the request
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
- user_id = sync_config.user.to_string()
- device_id = sync_config.requester.device_id
-
- # Skip if the extension is not enabled
- if not e2ee_request.enabled:
- return None
-
- device_list_updates: Optional[DeviceListUpdates] = None
- if from_token is not None:
- # TODO: This should take into account the `from_token` and `to_token`
- device_list_updates = await self.device_handler.get_user_ids_changed(
- user_id=user_id,
- from_token=from_token.stream_token,
- )
-
- device_one_time_keys_count: Mapping[str, int] = {}
- device_unused_fallback_key_types: Sequence[str] = []
- if device_id:
- # TODO: We should have a way to let clients differentiate between the states of:
- # * no change in OTK count since the provided since token
- # * the server has zero OTKs left for this device
- # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
- device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
- user_id, device_id
- )
- device_unused_fallback_key_types = (
- await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
- )
-
- return SlidingSyncResult.Extensions.E2eeExtension(
- device_list_updates=device_list_updates,
- device_one_time_keys_count=device_one_time_keys_count,
- device_unused_fallback_key_types=device_unused_fallback_key_types,
- )
-
- @trace
- async def get_account_data_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]:
- """Handle Account Data extension (MSC3959)
-
- Args:
- sync_config: Sync configuration
- actual_lists: Sliding window API. A map of list key to list results in the
- Sliding Sync response.
- actual_room_ids: The actual room IDs in the the Sliding Sync response.
- account_data_request: The account_data extension from the request
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
- user_id = sync_config.user.to_string()
-
- # Skip if the extension is not enabled
- if not account_data_request.enabled:
- return None
-
- global_account_data_map: Mapping[str, JsonMapping] = {}
- if from_token is not None:
- # TODO: This should take into account the `from_token` and `to_token`
- global_account_data_map = (
- await self.store.get_updated_global_account_data_for_user(
- user_id, from_token.stream_token.account_data_key
- )
- )
-
- have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
- user_id, from_token.stream_token.push_rules_key
- )
- if have_push_rules_changed:
- global_account_data_map = dict(global_account_data_map)
- # TODO: This should take into account the `from_token` and `to_token`
- global_account_data_map[AccountDataTypes.PUSH_RULES] = (
- await self.push_rules_handler.push_rules_for_user(sync_config.user)
- )
- else:
- # TODO: This should take into account the `to_token`
- all_global_account_data = await self.store.get_global_account_data_for_user(
- user_id
- )
-
- global_account_data_map = dict(all_global_account_data)
- # TODO: This should take into account the `to_token`
- global_account_data_map[AccountDataTypes.PUSH_RULES] = (
- await self.push_rules_handler.push_rules_for_user(sync_config.user)
- )
-
- # Fetch room account data
- account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
- relevant_room_ids = self.find_relevant_room_ids_for_extension(
- requested_lists=account_data_request.lists,
- requested_room_ids=account_data_request.rooms,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- )
- if len(relevant_room_ids) > 0:
- if from_token is not None:
- # TODO: This should take into account the `from_token` and `to_token`
- account_data_by_room_map = (
- await self.store.get_updated_room_account_data_for_user(
- user_id, from_token.stream_token.account_data_key
- )
- )
- else:
- # TODO: This should take into account the `to_token`
- account_data_by_room_map = (
- await self.store.get_room_account_data_for_user(user_id)
- )
-
- # Filter down to the relevant rooms
- account_data_by_room_map = {
- room_id: account_data_map
- for room_id, account_data_map in account_data_by_room_map.items()
- if room_id in relevant_room_ids
- }
-
- return SlidingSyncResult.Extensions.AccountDataExtension(
- global_account_data_map=global_account_data_map,
- account_data_by_room_map=account_data_by_room_map,
- )
-
- async def get_receipts_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
- receipts_request: SlidingSyncConfig.Extensions.ReceiptsExtension,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> Optional[SlidingSyncResult.Extensions.ReceiptsExtension]:
- """Handle Receipts extension (MSC3960)
-
- Args:
- sync_config: Sync configuration
- actual_lists: Sliding window API. A map of list key to list results in the
- Sliding Sync response.
- actual_room_ids: The actual room IDs in the the Sliding Sync response.
- actual_room_response_map: A map of room ID to room results in the the
- Sliding Sync response.
- account_data_request: The account_data extension from the request
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
- # Skip if the extension is not enabled
- if not receipts_request.enabled:
- return None
-
- relevant_room_ids = self.find_relevant_room_ids_for_extension(
- requested_lists=receipts_request.lists,
- requested_room_ids=receipts_request.rooms,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- )
-
- room_id_to_receipt_map: Dict[str, JsonMapping] = {}
- if len(relevant_room_ids) > 0:
- # TODO: Take connection tracking into account so that when a room comes back
- # into range we can send the receipts that were missed.
- receipt_source = self.event_sources.sources.receipt
- receipts, _ = await receipt_source.get_new_events(
- user=sync_config.user,
- from_key=(
- from_token.stream_token.receipt_key
- if from_token
- else MultiWriterStreamToken(stream=0)
- ),
- to_key=to_token.receipt_key,
- # This is a dummy value and isn't used in the function
- limit=0,
- room_ids=relevant_room_ids,
- is_guest=False,
- )
-
- for receipt in receipts:
- # These fields should exist for every receipt
- room_id = receipt["room_id"]
- type = receipt["type"]
- content = receipt["content"]
-
- # For `inital: True` rooms, we only want to include receipts for events
- # in the timeline.
- room_result = actual_room_response_map.get(room_id)
- if room_result is not None:
- if room_result.initial:
- # TODO: In the future, it would be good to fetch less receipts
- # out of the database in the first place but we would need to
- # add a new `event_id` index to `receipts_linearized`.
- relevant_event_ids = [
- event.event_id for event in room_result.timeline_events
- ]
-
- assert isinstance(content, dict)
- content = {
- event_id: content_value
- for event_id, content_value in content.items()
- if event_id in relevant_event_ids
- }
-
- room_id_to_receipt_map[room_id] = {"type": type, "content": content}
-
- return SlidingSyncResult.Extensions.ReceiptsExtension(
- room_id_to_receipt_map=room_id_to_receipt_map,
- )
-
- async def get_typing_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
- typing_request: SlidingSyncConfig.Extensions.TypingExtension,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> Optional[SlidingSyncResult.Extensions.TypingExtension]:
- """Handle Typing Notification extension (MSC3961)
-
- Args:
- sync_config: Sync configuration
- actual_lists: Sliding window API. A map of list key to list results in the
- Sliding Sync response.
- actual_room_ids: The actual room IDs in the the Sliding Sync response.
- actual_room_response_map: A map of room ID to room results in the the
- Sliding Sync response.
- account_data_request: The account_data extension from the request
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
- # Skip if the extension is not enabled
- if not typing_request.enabled:
- return None
-
- relevant_room_ids = self.find_relevant_room_ids_for_extension(
- requested_lists=typing_request.lists,
- requested_room_ids=typing_request.rooms,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- )
-
- room_id_to_typing_map: Dict[str, JsonMapping] = {}
- if len(relevant_room_ids) > 0:
- # Note: We don't need to take connection tracking into account for typing
- # notifications because they'll get anything still relevant and hasn't timed
- # out when the room comes into range. We consider the gap where the room
- # fell out of range, as long enough for any typing notifications to have
- # timed out (it's not worth the 30 seconds of data we may have missed).
- typing_source = self.event_sources.sources.typing
- typing_notifications, _ = await typing_source.get_new_events(
- user=sync_config.user,
- from_key=(from_token.stream_token.typing_key if from_token else 0),
- to_key=to_token.typing_key,
- # This is a dummy value and isn't used in the function
- limit=0,
- room_ids=relevant_room_ids,
- is_guest=False,
- )
-
- for typing_notification in typing_notifications:
- # These fields should exist for every typing notification
- room_id = typing_notification["room_id"]
- type = typing_notification["type"]
- content = typing_notification["content"]
-
- room_id_to_typing_map[room_id] = {"type": type, "content": content}
-
- return SlidingSyncResult.Extensions.TypingExtension(
- room_id_to_typing_map=room_id_to_typing_map,
- )
-
-
-class HaveSentRoomFlag(Enum):
- """Flag for whether we have sent the room down a sliding sync connection.
-
- The valid state changes here are:
- NEVER -> LIVE
- LIVE -> PREVIOUSLY
- PREVIOUSLY -> LIVE
- """
-
- # The room has never been sent down (or we have forgotten we have sent it
- # down).
- NEVER = 1
-
- # We have previously sent the room down, but there are updates that we
- # haven't sent down.
- PREVIOUSLY = 2
-
- # We have sent the room down and the client has received all updates.
- LIVE = 3
-
-
-@attr.s(auto_attribs=True, slots=True, frozen=True)
-class HaveSentRoom:
- """Whether we have sent the room down a sliding sync connection.
-
- Attributes:
- status: Flag of if we have or haven't sent down the room
- last_token: If the flag is `PREVIOUSLY` then this is non-null and
- contains the last stream token of the last updates we sent down
- the room, i.e. we still need to send everything since then to the
- client.
- """
-
- status: HaveSentRoomFlag
- last_token: Optional[RoomStreamToken]
-
- @staticmethod
- def previously(last_token: RoomStreamToken) -> "HaveSentRoom":
- """Constructor for `PREVIOUSLY` flag."""
- return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token)
-
-
-HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None)
-HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None)
-
-
-@attr.s(auto_attribs=True)
-class SlidingSyncConnectionStore:
- """In-memory store of per-connection state, including what rooms we have
- previously sent down a sliding sync connection.
-
- Note: This is NOT safe to run in a worker setup because connection positions will
- point to different sets of rooms on different workers. e.g. for the same connection,
- a connection position of 5 might have totally different states on worker A and
- worker B.
-
- One complication that we need to deal with here is needing to handle requests being
- resent, i.e. if we sent down a room in a response that the client received, we must
- consider the room *not* sent when we get the request again.
-
- This is handled by using an integer "token", which is returned to the client
- as part of the sync token. For each connection we store a mapping from
- tokens to the room states, and create a new entry when we send down new
- rooms.
-
- Note that for any given sliding sync connection we will only store a maximum
- of two different tokens: the previous token from the request and a new token
- sent in the response. When we receive a request with a given token, we then
- clear out all other entries with a different token.
-
- Attributes:
- _connections: Mapping from `(user_id, conn_id)` to mapping of `token`
- to mapping of room ID to `HaveSentRoom`.
- """
-
- # `(user_id, conn_id)` -> `token` -> `room_id` -> `HaveSentRoom`
- _connections: Dict[Tuple[str, str], Dict[int, Dict[str, HaveSentRoom]]] = (
- attr.Factory(dict)
- )
-
- async def is_valid_token(
- self, sync_config: SlidingSyncConfig, connection_token: int
- ) -> bool:
- """Return whether the connection token is valid/recognized"""
- if connection_token == 0:
- return True
-
- conn_key = self._get_connection_key(sync_config)
- return connection_token in self._connections.get(conn_key, {})
-
- async def have_sent_room(
- self, sync_config: SlidingSyncConfig, connection_token: int, room_id: str
- ) -> HaveSentRoom:
- """For the given user_id/conn_id/token, return whether we have
- previously sent the room down
- """
-
- conn_key = self._get_connection_key(sync_config)
- sync_statuses = self._connections.setdefault(conn_key, {})
- room_status = sync_statuses.get(connection_token, {}).get(
- room_id, HAVE_SENT_ROOM_NEVER
- )
-
- return room_status
-
- @trace
- async def record_rooms(
- self,
- sync_config: SlidingSyncConfig,
- from_token: Optional[SlidingSyncStreamToken],
- *,
- sent_room_ids: StrCollection,
- unsent_room_ids: StrCollection,
- ) -> int:
- """Record which rooms we have/haven't sent down in a new response
-
- Attributes:
- sync_config
- from_token: The since token from the request, if any
- sent_room_ids: The set of room IDs that we have sent down as
- part of this request (only needs to be ones we didn't
- previously sent down).
- unsent_room_ids: The set of room IDs that have had updates
- since the `from_token`, but which were not included in
- this request
- """
- prev_connection_token = 0
- if from_token is not None:
- prev_connection_token = from_token.connection_position
-
- # If there are no changes then this is a noop.
- if not sent_room_ids and not unsent_room_ids:
- return prev_connection_token
-
- conn_key = self._get_connection_key(sync_config)
- sync_statuses = self._connections.setdefault(conn_key, {})
-
- # Generate a new token, removing any existing entries in that token
- # (which can happen if requests get resent).
- new_store_token = prev_connection_token + 1
- sync_statuses.pop(new_store_token, None)
-
- # Copy over and update the room mappings.
- new_room_statuses = dict(sync_statuses.get(prev_connection_token, {}))
-
- # Whether we have updated the `new_room_statuses`, if we don't by the
- # end we can treat this as a noop.
- have_updated = False
- for room_id in sent_room_ids:
- new_room_statuses[room_id] = HAVE_SENT_ROOM_LIVE
- have_updated = True
-
- # Whether we add/update the entries for unsent rooms depends on the
- # existing entry:
- # - LIVE: We have previously sent down everything up to
- # `last_room_token, so we update the entry to be `PREVIOUSLY` with
- # `last_room_token`.
- # - PREVIOUSLY: We have previously sent down everything up to *a*
- # given token, so we don't need to update the entry.
- # - NEVER: We have never previously sent down the room, and we haven't
- # sent anything down this time either so we leave it as NEVER.
-
- # Work out the new state for unsent rooms that were `LIVE`.
- if from_token:
- new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key)
- else:
- new_unsent_state = HAVE_SENT_ROOM_NEVER
-
- for room_id in unsent_room_ids:
- prev_state = new_room_statuses.get(room_id)
- if prev_state is not None and prev_state.status == HaveSentRoomFlag.LIVE:
- new_room_statuses[room_id] = new_unsent_state
- have_updated = True
-
- if not have_updated:
- return prev_connection_token
-
- sync_statuses[new_store_token] = new_room_statuses
-
- return new_store_token
-
- @trace
- async def mark_token_seen(
- self,
- sync_config: SlidingSyncConfig,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> None:
- """We have received a request with the given token, so we can clear out
- any other tokens associated with the connection.
-
- If there is no from token then we have started afresh, and so we delete
- all tokens associated with the device.
- """
- # Clear out any tokens for the connection that doesn't match the one
- # from the request.
-
- conn_key = self._get_connection_key(sync_config)
- sync_statuses = self._connections.pop(conn_key, {})
- if from_token is None:
- return
-
- sync_statuses = {
- connection_token: room_statuses
- for connection_token, room_statuses in sync_statuses.items()
- if connection_token == from_token.connection_position
- }
- if sync_statuses:
- self._connections[conn_key] = sync_statuses
-
- @staticmethod
- def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
- """Return a unique identifier for this connection.
-
- The first part is simply the user ID.
-
- The second part is generally a combination of device ID and conn_id.
- However, both these two are optional (e.g. puppet access tokens don't
- have device IDs), so this handles those edge cases.
-
- We use this over the raw `conn_id` to avoid clashes between different
- clients that use the same `conn_id`. Imagine a user uses a web client
- that uses `conn_id: main_sync_loop` and an Android client that also has
- a `conn_id: main_sync_loop`.
- """
-
- user_id = sync_config.user.to_string()
-
- # Only one sliding sync connection is allowed per given conn_id (empty
- # or not).
- conn_id = sync_config.conn_id or ""
-
- if sync_config.requester.device_id:
- return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")
-
- if sync_config.requester.access_token_id:
- # If we don't have a device, then the access token ID should be a
- # stable ID.
- return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")
-
- # If we have neither then its likely an AS or some weird token. Either
- # way we can just fail here.
- raise Exception("Cannot use sliding sync with access token type")
diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py
new file mode 100644
index 0000000000..cb56eb53fc
--- /dev/null
+++ b/synapse/handlers/sliding_sync/__init__.py
@@ -0,0 +1,1691 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+import itertools
+import logging
+from itertools import chain
+from typing import TYPE_CHECKING, AbstractSet, Dict, List, Mapping, Optional, Set, Tuple
+
+from prometheus_client import Histogram
+from typing_extensions import assert_never
+
+from synapse.api.constants import Direction, EventTypes, Membership
+from synapse.events import EventBase
+from synapse.events.utils import strip_event
+from synapse.handlers.relations import BundledAggregations
+from synapse.handlers.sliding_sync.extensions import SlidingSyncExtensionHandler
+from synapse.handlers.sliding_sync.room_lists import (
+ RoomsForUserType,
+ SlidingSyncRoomLists,
+)
+from synapse.handlers.sliding_sync.store import SlidingSyncConnectionStore
+from synapse.logging.opentracing import (
+ SynapseTags,
+ log_kv,
+ set_tag,
+ start_active_span,
+ tag_args,
+ trace,
+)
+from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
+from synapse.storage.databases.main.state_deltas import StateDelta
+from synapse.storage.databases.main.stream import PaginateFunction
+from synapse.storage.roommember import (
+ MemberSummary,
+)
+from synapse.types import (
+ JsonDict,
+ MutableStateMap,
+ PersistedEventPosition,
+ Requester,
+ RoomStreamToken,
+ SlidingSyncStreamToken,
+ StateMap,
+ StrCollection,
+ StreamKeyType,
+ StreamToken,
+)
+from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
+from synapse.types.handlers.sliding_sync import (
+ HaveSentRoomFlag,
+ MutablePerConnectionState,
+ PerConnectionState,
+ RoomSyncConfig,
+ SlidingSyncConfig,
+ SlidingSyncResult,
+ StateValues,
+)
+from synapse.types.state import StateFilter
+from synapse.util.async_helpers import concurrently_execute
+from synapse.visibility import filter_events_for_client
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+sync_processing_time = Histogram(
+ "synapse_sliding_sync_processing_time",
+ "Time taken to generate a sliding sync response, ignoring wait times.",
+ ["initial"],
+)
+
+# Limit the number of state_keys we should remember sending down the connection for each
+# (room_id, user_id). We don't want to store and pull out too much data in the database.
+#
+# 100 is an arbitrary but small-ish number. The idea is that we probably won't send down
+# too many redundant member state events (that the client already knows about) for a
+# given ongoing conversation if we keep 100 around. Most rooms don't have 100 members
+# anyway and it takes a while to cycle through 100 members.
+MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER = 100
+
+
+class SlidingSyncHandler:
+ def __init__(self, hs: "HomeServer"):
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
+ self.auth_blocking = hs.get_auth_blocking()
+ self.notifier = hs.get_notifier()
+ self.event_sources = hs.get_event_sources()
+ self.relations_handler = hs.get_relations_handler()
+ self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
+ self.is_mine_id = hs.is_mine_id
+
+ self.connection_store = SlidingSyncConnectionStore(self.store)
+ self.extensions = SlidingSyncExtensionHandler(hs)
+ self.room_lists = SlidingSyncRoomLists(hs)
+
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SlidingSyncConfig,
+ from_token: Optional[SlidingSyncStreamToken] = None,
+ timeout_ms: int = 0,
+ ) -> SlidingSyncResult:
+ """
+ Get the sync for a client if we have new data for it now. Otherwise
+ wait for new data to arrive on the server. If the timeout expires, then
+ return an empty sync result.
+
+ Args:
+ requester: The user making the request
+ sync_config: Sync configuration
+ from_token: The point in the stream to sync from. Token of the end of the
+ previous batch. May be `None` if this is the initial sync request.
+ timeout_ms: The time in milliseconds to wait for new data to arrive. If 0,
+ we will immediately but there might not be any new data so we just return an
+ empty response.
+ """
+ # If the user is not part of the mau group, then check that limits have
+ # not been exceeded (if not part of the group by this point, almost certain
+ # auth_blocking will occur)
+ await self.auth_blocking.check_auth_blocking(requester=requester)
+
+ # If we're working with a user-provided token, we need to make sure to wait for
+ # this worker to catch up with the token so we don't skip past any incoming
+ # events or future events if the user is nefariously, manually modifying the
+ # token.
+ if from_token is not None:
+ # We need to make sure this worker has caught up with the token. If
+ # this returns false, it means we timed out waiting, and we should
+ # just return an empty response.
+ before_wait_ts = self.clock.time_msec()
+ if not await self.notifier.wait_for_stream_token(from_token.stream_token):
+ logger.warning(
+ "Timed out waiting for worker to catch up. Returning empty response"
+ )
+ return SlidingSyncResult.empty(from_token)
+
+ # If we've spent significant time waiting to catch up, take it off
+ # the timeout.
+ after_wait_ts = self.clock.time_msec()
+ if after_wait_ts - before_wait_ts > 1_000:
+ timeout_ms -= after_wait_ts - before_wait_ts
+ timeout_ms = max(timeout_ms, 0)
+
+ # We're going to respond immediately if the timeout is 0 or if this is an
+ # initial sync (without a `from_token`) so we can avoid calling
+ # `notifier.wait_for_events()`.
+ if timeout_ms == 0 or from_token is None:
+ now_token = self.event_sources.get_current_token()
+ result = await self.current_sync_for_user(
+ sync_config,
+ from_token=from_token,
+ to_token=now_token,
+ )
+ else:
+ # Otherwise, we wait for something to happen and report it to the user.
+ async def current_sync_callback(
+ before_token: StreamToken, after_token: StreamToken
+ ) -> SlidingSyncResult:
+ return await self.current_sync_for_user(
+ sync_config,
+ from_token=from_token,
+ to_token=after_token,
+ )
+
+ result = await self.notifier.wait_for_events(
+ sync_config.user.to_string(),
+ timeout_ms,
+ current_sync_callback,
+ from_token=from_token.stream_token,
+ )
+
+ return result
+
+ @trace
+ async def current_sync_for_user(
+ self,
+ sync_config: SlidingSyncConfig,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken] = None,
+ ) -> SlidingSyncResult:
+ """
+ Generates the response body of a Sliding Sync result, represented as a
+ `SlidingSyncResult`.
+
+ We fetch data according to the token range (> `from_token` and <= `to_token`).
+
+ Args:
+ sync_config: Sync configuration
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from. Token of the end of the
+ previous batch. May be `None` if this is the initial sync request.
+ """
+ start_time_s = self.clock.time()
+
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+
+ # Get the per-connection state (if any).
+ #
+ # Raises an exception if there is a `connection_position` that we don't
+ # recognize. If we don't do this and the client asks for the full range
+ # of rooms, we end up sending down all rooms and their state from
+ # scratch (which can be very slow). By expiring the connection we allow
+ # the client a chance to do an initial request with a smaller range of
+ # rooms to get them some results sooner but will end up taking the same
+ # amount of time (more with round-trips and re-processing) in the end to
+ # get everything again.
+ previous_connection_state = (
+ await self.connection_store.get_and_clear_connection_positions(
+ sync_config, from_token
+ )
+ )
+
+ # Get all of the room IDs that the user should be able to see in the sync
+ # response
+ has_lists = sync_config.lists is not None and len(sync_config.lists) > 0
+ has_room_subscriptions = (
+ sync_config.room_subscriptions is not None
+ and len(sync_config.room_subscriptions) > 0
+ )
+
+ interested_rooms = await self.room_lists.compute_interested_rooms(
+ sync_config=sync_config,
+ previous_connection_state=previous_connection_state,
+ from_token=from_token.stream_token if from_token else None,
+ to_token=to_token,
+ )
+
+ lists = interested_rooms.lists
+ relevant_room_map = interested_rooms.relevant_room_map
+ all_rooms = interested_rooms.all_rooms
+ room_membership_for_user_map = interested_rooms.room_membership_for_user_map
+ relevant_rooms_to_send_map = interested_rooms.relevant_rooms_to_send_map
+
+ # Fetch room data
+ rooms: Dict[str, SlidingSyncResult.RoomResult] = {}
+
+ new_connection_state = previous_connection_state.get_mutable()
+
+ @trace
+ @tag_args
+ async def handle_room(room_id: str) -> None:
+ room_sync_result = await self.get_room_sync_data(
+ sync_config=sync_config,
+ previous_connection_state=previous_connection_state,
+ new_connection_state=new_connection_state,
+ room_id=room_id,
+ room_sync_config=relevant_rooms_to_send_map[room_id],
+ room_membership_for_user_at_to_token=room_membership_for_user_map[
+ room_id
+ ],
+ from_token=from_token,
+ to_token=to_token,
+ newly_joined=room_id in interested_rooms.newly_joined_rooms,
+ newly_left=room_id in interested_rooms.newly_left_rooms,
+ is_dm=room_id in interested_rooms.dm_room_ids,
+ )
+
+ # Filter out empty room results during incremental sync
+ if room_sync_result or not from_token:
+ rooms[room_id] = room_sync_result
+
+ if relevant_rooms_to_send_map:
+ with start_active_span("sliding_sync.generate_room_entries"):
+ await concurrently_execute(handle_room, relevant_rooms_to_send_map, 20)
+
+ extensions = await self.extensions.get_extensions_response(
+ sync_config=sync_config,
+ actual_lists=lists,
+ previous_connection_state=previous_connection_state,
+ new_connection_state=new_connection_state,
+ # We're purposely using `relevant_room_map` instead of
+ # `relevant_rooms_to_send_map` here. This needs to be all room_ids we could
+ # send regardless of whether they have an event update or not. The
+ # extensions care about more than just normal events in the rooms (like
+ # account data, read receipts, typing indicators, to-device messages, etc).
+ actual_room_ids=set(relevant_room_map.keys()),
+ actual_room_response_map=rooms,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ if has_lists or has_room_subscriptions:
+ # We now calculate if any rooms outside the range have had updates,
+ # which we are not sending down.
+ #
+ # We *must* record rooms that have had updates, but it is also fine
+ # to record rooms as having updates even if there might not actually
+ # be anything new for the user (e.g. due to event filters, events
+ # having happened after the user left, etc).
+ if from_token:
+ # The set of rooms that the client (may) care about, but aren't
+ # in any list range (or subscribed to).
+ missing_rooms = all_rooms - relevant_room_map.keys()
+
+ # We now just go and try fetching any events in the above rooms
+ # to see if anything has happened since the `from_token`.
+ #
+ # TODO: Replace this with something faster. When we land the
+ # sliding sync tables that record the most recent event
+ # positions we can use that.
+ unsent_room_ids: StrCollection
+ if await self.store.have_finished_sliding_sync_background_jobs():
+ unsent_room_ids = await (
+ self.store.get_rooms_that_have_updates_since_sliding_sync_table(
+ room_ids=missing_rooms,
+ from_key=from_token.stream_token.room_key,
+ )
+ )
+ else:
+ missing_event_map_by_room = (
+ await self.store.get_room_events_stream_for_rooms(
+ room_ids=missing_rooms,
+ from_key=to_token.room_key,
+ to_key=from_token.stream_token.room_key,
+ limit=1,
+ )
+ )
+ unsent_room_ids = list(missing_event_map_by_room)
+
+ new_connection_state.rooms.record_unsent_rooms(
+ unsent_room_ids, from_token.stream_token.room_key
+ )
+
+ new_connection_state.rooms.record_sent_rooms(
+ relevant_rooms_to_send_map.keys()
+ )
+
+ connection_position = await self.connection_store.record_new_state(
+ sync_config=sync_config,
+ from_token=from_token,
+ new_connection_state=new_connection_state,
+ )
+ elif from_token:
+ connection_position = from_token.connection_position
+ else:
+ # Initial sync without a `from_token` starts at `0`
+ connection_position = 0
+
+ sliding_sync_result = SlidingSyncResult(
+ next_pos=SlidingSyncStreamToken(to_token, connection_position),
+ lists=lists,
+ rooms=rooms,
+ extensions=extensions,
+ )
+
+ # Make it easy to find traces for syncs that aren't empty
+ set_tag(SynapseTags.RESULT_PREFIX + "result", bool(sliding_sync_result))
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + "sync_config.user", user_id)
+
+ end_time_s = self.clock.time()
+ sync_processing_time.labels(from_token is not None).observe(
+ end_time_s - start_time_s
+ )
+
+ return sliding_sync_result
+
+ @trace
+ async def get_current_state_ids_at(
+ self,
+ room_id: str,
+ room_membership_for_user_at_to_token: RoomsForUserType,
+ state_filter: StateFilter,
+ to_token: StreamToken,
+ ) -> StateMap[str]:
+ """
+ Get current state IDs for the user in the room according to their membership. This
+ will be the current state at the time of their LEAVE/BAN, otherwise will be the
+ current state <= to_token.
+
+ Args:
+ room_id: The room ID to fetch data for
+ room_membership_for_user_at_token: Membership information for the user
+ in the room at the time of `to_token`.
+ to_token: The point in the stream to sync up to.
+ """
+ state_ids: StateMap[str]
+ # People shouldn't see past their leave/ban event
+ if room_membership_for_user_at_to_token.membership in (
+ Membership.LEAVE,
+ Membership.BAN,
+ ):
+ # TODO: `get_state_ids_at(...)` doesn't take into account the "current
+ # state". Maybe we need to use
+ # `get_forward_extremities_for_room_at_stream_ordering(...)` to "Fetch the
+ # current state at the time."
+ state_ids = await self.storage_controllers.state.get_state_ids_at(
+ room_id,
+ stream_position=to_token.copy_and_replace(
+ StreamKeyType.ROOM,
+ room_membership_for_user_at_to_token.event_pos.to_room_stream_token(),
+ ),
+ state_filter=state_filter,
+ # Partially-stated rooms should have all state events except for
+ # remote membership events. Since we've already excluded
+ # partially-stated rooms unless `required_state` only has
+ # `["m.room.member", "$LAZY"]` for membership, we should be able to
+ # retrieve everything requested. When we're lazy-loading, if there
+ # are some remote senders in the timeline, we should also have their
+ # membership event because we had to auth that timeline event. Plus
+ # we don't want to block the whole sync waiting for this one room.
+ await_full_state=False,
+ )
+ # Otherwise, we can get the latest current state in the room
+ else:
+ state_ids = await self.storage_controllers.state.get_current_state_ids(
+ room_id,
+ state_filter,
+ # Partially-stated rooms should have all state events except for
+ # remote membership events. Since we've already excluded
+ # partially-stated rooms unless `required_state` only has
+ # `["m.room.member", "$LAZY"]` for membership, we should be able to
+ # retrieve everything requested. When we're lazy-loading, if there
+ # are some remote senders in the timeline, we should also have their
+ # membership event because we had to auth that timeline event. Plus
+ # we don't want to block the whole sync waiting for this one room.
+ await_full_state=False,
+ )
+ # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token`
+
+ return state_ids
+
+ @trace
+ async def get_current_state_at(
+ self,
+ room_id: str,
+ room_membership_for_user_at_to_token: RoomsForUserType,
+ state_filter: StateFilter,
+ to_token: StreamToken,
+ ) -> StateMap[EventBase]:
+ """
+ Get current state for the user in the room according to their membership. This
+ will be the current state at the time of their LEAVE/BAN, otherwise will be the
+ current state <= to_token.
+
+ Args:
+ room_id: The room ID to fetch data for
+ room_membership_for_user_at_token: Membership information for the user
+ in the room at the time of `to_token`.
+ to_token: The point in the stream to sync up to.
+ """
+ state_ids = await self.get_current_state_ids_at(
+ room_id=room_id,
+ room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
+ state_filter=state_filter,
+ to_token=to_token,
+ )
+
+ events = await self.store.get_events_as_list(list(state_ids.values()))
+
+ state_map = {}
+ for event in events:
+ state_map[(event.type, event.state_key)] = event
+
+ return state_map
+
+ @trace
+ async def get_current_state_deltas_for_room(
+ self,
+ room_id: str,
+ room_membership_for_user_at_to_token: RoomsForUserType,
+ from_token: RoomStreamToken,
+ to_token: RoomStreamToken,
+ ) -> List[StateDelta]:
+ """
+ Get the state deltas between two tokens taking into account the user's
+ membership. If the user is LEAVE/BAN, we will only get the state deltas up to
+ their LEAVE/BAN event (inclusive).
+
+ (> `from_token` and <= `to_token`)
+ """
+ membership = room_membership_for_user_at_to_token.membership
+ # We don't know how to handle `membership` values other than these. The
+ # code below would need to be updated.
+ assert membership in (
+ Membership.JOIN,
+ Membership.INVITE,
+ Membership.KNOCK,
+ Membership.LEAVE,
+ Membership.BAN,
+ )
+
+ # People shouldn't see past their leave/ban event
+ if membership in (
+ Membership.LEAVE,
+ Membership.BAN,
+ ):
+ to_bound = (
+ room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
+ )
+ # If we are participating in the room, we can get the latest current state in
+ # the room
+ elif membership == Membership.JOIN:
+ to_bound = to_token
+ # We can only rely on the stripped state included in the invite/knock event
+ # itself so there will never be any state deltas to send down.
+ elif membership in (Membership.INVITE, Membership.KNOCK):
+ return []
+ else:
+ # We don't know how to handle this type of membership yet
+ #
+ # FIXME: We should use `assert_never` here but for some reason
+ # the exhaustive matching doesn't recognize the `Never` here.
+ # assert_never(membership)
+ raise AssertionError(
+ f"Unexpected membership {membership} that we don't know how to handle yet"
+ )
+
+ return await self.store.get_current_state_deltas_for_room(
+ room_id=room_id,
+ from_token=from_token,
+ to_token=to_bound,
+ )
+
+ @trace
+ async def get_room_sync_data(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ new_connection_state: "MutablePerConnectionState",
+ room_id: str,
+ room_sync_config: RoomSyncConfig,
+ room_membership_for_user_at_to_token: RoomsForUserType,
+ from_token: Optional[SlidingSyncStreamToken],
+ to_token: StreamToken,
+ newly_joined: bool,
+ newly_left: bool,
+ is_dm: bool,
+ ) -> SlidingSyncResult.RoomResult:
+ """
+ Fetch room data for the sync response.
+
+ We fetch data according to the token range (> `from_token` and <= `to_token`).
+
+ Args:
+ user: User to fetch data for
+ room_id: The room ID to fetch data for
+ room_sync_config: Config for what data we should fetch for a room in the
+ sync response.
+ room_membership_for_user_at_to_token: Membership information for the user
+ in the room at the time of `to_token`.
+ from_token: The point in the stream to sync from.
+ to_token: The point in the stream to sync up to.
+ newly_joined: If the user has newly joined the room
+ newly_left: If the user has newly left the room
+ is_dm: Whether the room is a DM room
+ """
+ user = sync_config.user
+
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "membership",
+ room_membership_for_user_at_to_token.membership,
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "timeline_limit",
+ room_sync_config.timeline_limit,
+ )
+
+ # Handle state resets. For example, if we see
+ # `room_membership_for_user_at_to_token.event_id=None and
+ # room_membership_for_user_at_to_token.membership is not None`, we should
+ # indicate to the client that a state reset happened. Perhaps we should indicate
+ # this by setting `initial: True` and empty `required_state: []`.
+ state_reset_out_of_room = False
+ if (
+ room_membership_for_user_at_to_token.event_id is None
+ and room_membership_for_user_at_to_token.membership is not None
+ ):
+ # We only expect the `event_id` to be `None` if you've been state reset out
+ # of the room (meaning you're no longer in the room). We could put this as
+ # part of the if-statement above but we want to handle every case where
+ # `event_id` is `None`.
+ assert room_membership_for_user_at_to_token.membership is Membership.LEAVE
+
+ state_reset_out_of_room = True
+
+ prev_room_sync_config = previous_connection_state.room_configs.get(room_id)
+
+ # Determine whether we should limit the timeline to the token range.
+ #
+ # We should return historical messages (before token range) in the
+ # following cases because we want clients to be able to show a basic
+ # screen of information:
+ #
+ # - Initial sync (because no `from_token` to limit us anyway)
+ # - When users `newly_joined`
+ # - For an incremental sync where we haven't sent it down this
+ # connection before
+ #
+ # Relevant spec issue:
+ # https://github.com/matrix-org/matrix-spec/issues/1917
+ #
+ # XXX: Odd behavior - We also check if the `timeline_limit` has increased, if so
+ # we ignore the from bound for the timeline to send down a larger chunk of
+ # history and set `unstable_expanded_timeline` to true. This is only being added
+ # to match the behavior of the Sliding Sync proxy as we expect the ElementX
+ # client to feel a certain way and be able to trickle in a full page of timeline
+ # messages to fill up the screen. This is a bit different to the behavior of the
+ # Sliding Sync proxy (which sets initial=true, but then doesn't send down the
+ # full state again), but existing apps, e.g. ElementX, just need `limited` set.
+ # We don't explicitly set `limited` but this will be the case for any room that
+ # has more history than we're trying to pull out. Using
+ # `unstable_expanded_timeline` allows us to avoid contaminating what `initial`
+ # or `limited` mean for clients that interpret them correctly. In future this
+ # behavior is almost certainly going to change.
+ #
+ from_bound = None
+ initial = True
+ ignore_timeline_bound = False
+ if from_token and not newly_joined and not state_reset_out_of_room:
+ room_status = previous_connection_state.rooms.have_sent_room(room_id)
+ if room_status.status == HaveSentRoomFlag.LIVE:
+ from_bound = from_token.stream_token.room_key
+ initial = False
+ elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
+ assert room_status.last_token is not None
+ from_bound = room_status.last_token
+ initial = False
+ elif room_status.status == HaveSentRoomFlag.NEVER:
+ from_bound = None
+ initial = True
+ else:
+ assert_never(room_status.status)
+
+ log_kv({"sliding_sync.room_status": room_status})
+
+ if prev_room_sync_config is not None:
+ # Check if the timeline limit has increased, if so ignore the
+ # timeline bound and record the change (see "XXX: Odd behavior"
+ # above).
+ if (
+ prev_room_sync_config.timeline_limit
+ < room_sync_config.timeline_limit
+ ):
+ ignore_timeline_bound = True
+
+ log_kv(
+ {
+ "sliding_sync.from_bound": from_bound,
+ "sliding_sync.initial": initial,
+ "sliding_sync.ignore_timeline_bound": ignore_timeline_bound,
+ }
+ )
+
+ # Assemble the list of timeline events
+ #
+ # FIXME: It would be nice to make the `rooms` response more uniform regardless of
+ # membership. Currently, we have to make all of these optional because
+ # `invite`/`knock` rooms only have `stripped_state`. See
+ # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
+ timeline_events: List[EventBase] = []
+ bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
+ limited: Optional[bool] = None
+ prev_batch_token: Optional[StreamToken] = None
+ num_live: Optional[int] = None
+ if (
+ room_sync_config.timeline_limit > 0
+ # No timeline for invite/knock rooms (just `stripped_state`)
+ and room_membership_for_user_at_to_token.membership
+ not in (Membership.INVITE, Membership.KNOCK)
+ ):
+ limited = False
+ # We want to start off using the `to_token` (vs `from_token`) because we look
+ # backwards from the `to_token` up to the `timeline_limit` and we might not
+ # reach the `from_token` before we hit the limit. We will update the room stream
+ # position once we've fetched the events to point to the earliest event fetched.
+ prev_batch_token = to_token
+
+ # We're going to paginate backwards from the `to_token`
+ to_bound = to_token.room_key
+ # People shouldn't see past their leave/ban event
+ if room_membership_for_user_at_to_token.membership in (
+ Membership.LEAVE,
+ Membership.BAN,
+ ):
+ to_bound = room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
+
+ timeline_from_bound = from_bound
+ if ignore_timeline_bound:
+ timeline_from_bound = None
+
+ # For initial `/sync` (and other historical scenarios mentioned above), we
+ # want to view a historical section of the timeline; to fetch events by
+ # `topological_ordering` (best representation of the room DAG as others were
+ # seeing it at the time). This also aligns with the order that `/messages`
+ # returns events in.
+ #
+ # For incremental `/sync`, we want to get all updates for rooms since
+ # the last `/sync` (regardless if those updates arrived late or happened
+ # a while ago in the past); to fetch events by `stream_ordering` (in the
+ # order they were received by the server).
+ #
+ # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917
+ #
+ # FIXME: Using workaround for mypy,
+ # https://github.com/python/mypy/issues/10740#issuecomment-1997047277 and
+ # https://github.com/python/mypy/issues/17479
+ paginate_room_events_by_topological_ordering: PaginateFunction = (
+ self.store.paginate_room_events_by_topological_ordering
+ )
+ paginate_room_events_by_stream_ordering: PaginateFunction = (
+ self.store.paginate_room_events_by_stream_ordering
+ )
+ pagination_method: PaginateFunction = (
+ # Use `topographical_ordering` for historical events
+ paginate_room_events_by_topological_ordering
+ if timeline_from_bound is None
+ # Use `stream_ordering` for updates
+ else paginate_room_events_by_stream_ordering
+ )
+ timeline_events, new_room_key, limited = await pagination_method(
+ room_id=room_id,
+ # The bounds are reversed so we can paginate backwards
+ # (from newer to older events) starting at to_bound.
+ # This ensures we fill the `limit` with the newest events first,
+ from_key=to_bound,
+ to_key=timeline_from_bound,
+ direction=Direction.BACKWARDS,
+ limit=room_sync_config.timeline_limit,
+ )
+
+ # We want to return the events in ascending order (the last event is the
+ # most recent).
+ timeline_events.reverse()
+
+ # Make sure we don't expose any events that the client shouldn't see
+ timeline_events = await filter_events_for_client(
+ self.storage_controllers,
+ user.to_string(),
+ timeline_events,
+ is_peeking=room_membership_for_user_at_to_token.membership
+ != Membership.JOIN,
+ filter_send_to_client=True,
+ )
+ # TODO: Filter out `EventTypes.CallInvite` in public rooms,
+ # see https://github.com/element-hq/synapse/issues/17359
+
+ # TODO: Handle timeline gaps (`get_timeline_gaps()`)
+
+ # Determine how many "live" events we have (events within the given token range).
+ #
+ # This is mostly useful to determine whether a given @mention event should
+ # make a noise or not. Clients cannot rely solely on the absence of
+ # `initial: true` to determine live events because if a room not in the
+ # sliding window bumps into the window because of an @mention it will have
+ # `initial: true` yet contain a single live event (with potentially other
+ # old events in the timeline)
+ num_live = 0
+ if from_token is not None:
+ for timeline_event in reversed(timeline_events):
+ # This fields should be present for all persisted events
+ assert timeline_event.internal_metadata.stream_ordering is not None
+ assert timeline_event.internal_metadata.instance_name is not None
+
+ persisted_position = PersistedEventPosition(
+ instance_name=timeline_event.internal_metadata.instance_name,
+ stream=timeline_event.internal_metadata.stream_ordering,
+ )
+ if persisted_position.persisted_after(
+ from_token.stream_token.room_key
+ ):
+ num_live += 1
+ else:
+ # Since we're iterating over the timeline events in
+ # reverse-chronological order, we can break once we hit an event
+ # that's not live. In the future, we could potentially optimize
+ # this more with a binary search (bisect).
+ break
+
+ # If the timeline is `limited=True`, the client does not have all events
+ # necessary to calculate aggregations themselves.
+ if limited:
+ bundled_aggregations = (
+ await self.relations_handler.get_bundled_aggregations(
+ timeline_events, user.to_string()
+ )
+ )
+
+ # Update the `prev_batch_token` to point to the position that allows us to
+ # keep paginating backwards from the oldest event we return in the timeline.
+ prev_batch_token = prev_batch_token.copy_and_replace(
+ StreamKeyType.ROOM, new_room_key
+ )
+
+ # Figure out any stripped state events for invite/knocks. This allows the
+ # potential joiner to identify the room.
+ stripped_state: List[JsonDict] = []
+ if room_membership_for_user_at_to_token.membership in (
+ Membership.INVITE,
+ Membership.KNOCK,
+ ):
+ # This should never happen. If someone is invited/knocked on room, then
+ # there should be an event for it.
+ assert room_membership_for_user_at_to_token.event_id is not None
+
+ invite_or_knock_event = await self.store.get_event(
+ room_membership_for_user_at_to_token.event_id
+ )
+
+ stripped_state = []
+ if invite_or_knock_event.membership == Membership.INVITE:
+ invite_state = invite_or_knock_event.unsigned.get(
+ "invite_room_state", []
+ )
+ if not isinstance(invite_state, list):
+ invite_state = []
+
+ stripped_state.extend(invite_state)
+ elif invite_or_knock_event.membership == Membership.KNOCK:
+ knock_state = invite_or_knock_event.unsigned.get("knock_room_state", [])
+ if not isinstance(knock_state, list):
+ knock_state = []
+
+ stripped_state.extend(knock_state)
+
+ stripped_state.append(strip_event(invite_or_knock_event))
+
+ # Get the changes to current state in the token range from the
+ # `current_state_delta_stream` table.
+ #
+ # For incremental syncs, we can do this first to determine if something relevant
+ # has changed and strategically avoid fetching other costly things.
+ room_state_delta_id_map: MutableStateMap[str] = {}
+ name_event_id: Optional[str] = None
+ membership_changed = False
+ name_changed = False
+ avatar_changed = False
+ if initial:
+ # Check whether the room has a name set
+ name_state_ids = await self.get_current_state_ids_at(
+ room_id=room_id,
+ room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
+ state_filter=StateFilter.from_types([(EventTypes.Name, "")]),
+ to_token=to_token,
+ )
+ name_event_id = name_state_ids.get((EventTypes.Name, ""))
+ else:
+ assert from_bound is not None
+
+ # TODO: Limit the number of state events we're about to send down
+ # the room, if its too many we should change this to an
+ # `initial=True`?
+
+ # For the case of rejecting remote invites, the leave event won't be
+ # returned by `get_current_state_deltas_for_room`. This is due to the current
+ # state only being filled out for rooms the server is in, and so doesn't pick
+ # up out-of-band leaves (including locally rejected invites) as these events
+ # are outliers and not added to the `current_state_delta_stream`.
+ #
+ # We rely on being explicitly told that the room has been `newly_left` to
+ # ensure we extract the out-of-band leave.
+ if newly_left and room_membership_for_user_at_to_token.event_id is not None:
+ membership_changed = True
+ leave_event = await self.store.get_event(
+ room_membership_for_user_at_to_token.event_id
+ )
+ state_key = leave_event.get_state_key()
+ if state_key is not None:
+ room_state_delta_id_map[(leave_event.type, state_key)] = (
+ room_membership_for_user_at_to_token.event_id
+ )
+
+ deltas = await self.get_current_state_deltas_for_room(
+ room_id=room_id,
+ room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
+ from_token=from_bound,
+ to_token=to_token.room_key,
+ )
+ for delta in deltas:
+ # TODO: Handle state resets where event_id is None
+ if delta.event_id is not None:
+ room_state_delta_id_map[(delta.event_type, delta.state_key)] = (
+ delta.event_id
+ )
+
+ if delta.event_type == EventTypes.Member:
+ membership_changed = True
+ elif delta.event_type == EventTypes.Name and delta.state_key == "":
+ name_changed = True
+ elif (
+ delta.event_type == EventTypes.RoomAvatar and delta.state_key == ""
+ ):
+ avatar_changed = True
+
+ # We only need the room summary for calculating heroes, however if we do
+ # fetch it then we can use it to calculate `joined_count` and
+ # `invited_count`.
+ room_membership_summary: Optional[Mapping[str, MemberSummary]] = None
+
+ # `heroes` are required if the room name is not set.
+ #
+ # Note: When you're the first one on your server to be invited to a new room
+ # over federation, we only have access to some stripped state in
+ # `event.unsigned.invite_room_state` which currently doesn't include `heroes`,
+ # see https://github.com/matrix-org/matrix-spec/issues/380. This means that
+ # clients won't be able to calculate the room name when necessary and just a
+ # pitfall we have to deal with until that spec issue is resolved.
+ hero_user_ids: List[str] = []
+ # TODO: Should we also check for `EventTypes.CanonicalAlias`
+ # (`m.room.canonical_alias`) as a fallback for the room name? see
+ # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153
+ #
+ # We need to fetch the `heroes` if the room name is not set. But we only need to
+ # get them on initial syncs (or the first time we send down the room) or if the
+ # membership has changed which may change the heroes.
+ if name_event_id is None and (initial or (not initial and membership_changed)):
+ # We need the room summary to extract the heroes from
+ if room_membership_for_user_at_to_token.membership != Membership.JOIN:
+ # TODO: Figure out how to get the membership summary for left/banned rooms
+ # For invite/knock rooms we don't include the information.
+ room_membership_summary = {}
+ else:
+ room_membership_summary = await self.store.get_room_summary(room_id)
+ # TODO: Reverse/rewind back to the `to_token`
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me=user.to_string()
+ )
+
+ # Fetch the membership counts for rooms we're joined to.
+ #
+ # Similarly to other metadata, we only need to calculate the member
+ # counts if this is an initial sync or the memberships have changed.
+ joined_count: Optional[int] = None
+ invited_count: Optional[int] = None
+ if (
+ initial or membership_changed
+ ) and room_membership_for_user_at_to_token.membership == Membership.JOIN:
+ # If we have the room summary (because we calculated heroes above)
+ # then we can simply pull the counts from there.
+ if room_membership_summary is not None:
+ empty_membership_summary = MemberSummary([], 0)
+
+ joined_count = room_membership_summary.get(
+ Membership.JOIN, empty_membership_summary
+ ).count
+
+ invited_count = room_membership_summary.get(
+ Membership.INVITE, empty_membership_summary
+ ).count
+ else:
+ member_counts = await self.store.get_member_counts(room_id)
+ joined_count = member_counts.get(Membership.JOIN, 0)
+ invited_count = member_counts.get(Membership.INVITE, 0)
+
+ # Fetch the `required_state` for the room
+ #
+ # No `required_state` for invite/knock rooms (just `stripped_state`)
+ #
+ # FIXME: It would be nice to make the `rooms` response more uniform regardless
+ # of membership. Currently, we have to make this optional because
+ # `invite`/`knock` rooms only have `stripped_state`. See
+ # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
+ #
+ # Calculate the `StateFilter` based on the `required_state` for the room
+ required_state_filter = StateFilter.none()
+ # The requested `required_state_map` with the lazy membership expanded and
+ # `$ME` replaced with the user's ID. This allows us to see what membership we've
+ # sent down to the client in the next request.
+ #
+ # Make a copy so we can modify it. Still need to be careful to make a copy of
+ # the state key sets if we want to add/remove from them. We could make a deep
+ # copy but this saves us some work.
+ expanded_required_state_map = dict(room_sync_config.required_state_map)
+ if room_membership_for_user_at_to_token.membership not in (
+ Membership.INVITE,
+ Membership.KNOCK,
+ ):
+ # If we have a double wildcard ("*", "*") in the `required_state`, we need
+ # to fetch all state for the room
+ #
+ # Note: MSC3575 describes different behavior to how we're handling things
+ # here but since it's not wrong to return more state than requested
+ # (`required_state` is just the minimum requested), it doesn't matter if we
+ # include more than client wanted. This complexity is also under scrutiny,
+ # see
+ # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1185109050
+ #
+ # > One unique exception is when you request all state events via ["*", "*"]. When used,
+ # > all state events are returned by default, and additional entries FILTER OUT the returned set
+ # > of state events. These additional entries cannot use '*' themselves.
+ # > For example, ["*", "*"], ["m.room.member", "@alice:example.com"] will _exclude_ every m.room.member
+ # > event _except_ for @alice:example.com, and include every other state event.
+ # > In addition, ["*", "*"], ["m.space.child", "*"] is an error, the m.space.child filter is not
+ # > required as it would have been returned anyway.
+ # >
+ # > -- MSC3575 (https://github.com/matrix-org/matrix-spec-proposals/pull/3575)
+ if StateValues.WILDCARD in room_sync_config.required_state_map.get(
+ StateValues.WILDCARD, set()
+ ):
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "required_state_wildcard",
+ True,
+ )
+ required_state_filter = StateFilter.all()
+ # TODO: `StateFilter` currently doesn't support wildcard event types. We're
+ # currently working around this by returning all state to the client but it
+ # would be nice to fetch less from the database and return just what the
+ # client wanted.
+ elif (
+ room_sync_config.required_state_map.get(StateValues.WILDCARD)
+ is not None
+ ):
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "required_state_wildcard_event_type",
+ True,
+ )
+ required_state_filter = StateFilter.all()
+ else:
+ required_state_types: List[Tuple[str, Optional[str]]] = []
+ num_wild_state_keys = 0
+ lazy_load_room_members = False
+ num_others = 0
+ for (
+ state_type,
+ state_key_set,
+ ) in room_sync_config.required_state_map.items():
+ for state_key in state_key_set:
+ if state_key == StateValues.WILDCARD:
+ num_wild_state_keys += 1
+ # `None` is a wildcard in the `StateFilter`
+ required_state_types.append((state_type, None))
+ # We need to fetch all relevant people when we're lazy-loading membership
+ elif (
+ state_type == EventTypes.Member
+ and state_key == StateValues.LAZY
+ ):
+ lazy_load_room_members = True
+
+ # Everyone in the timeline is relevant
+ timeline_membership: Set[str] = set()
+ if timeline_events is not None:
+ for timeline_event in timeline_events:
+ # Anyone who sent a message is relevant
+ timeline_membership.add(timeline_event.sender)
+
+ # We also care about invite, ban, kick, targets,
+ # etc.
+ if timeline_event.type == EventTypes.Member:
+ timeline_membership.add(
+ timeline_event.state_key
+ )
+
+ # Update the required state filter so we pick up the new
+ # membership
+ for user_id in timeline_membership:
+ required_state_types.append(
+ (EventTypes.Member, user_id)
+ )
+
+ # Add an explicit entry for each user in the timeline
+ #
+ # Make a new set or copy of the state key set so we can
+ # modify it without affecting the original
+ # `required_state_map`
+ expanded_required_state_map[EventTypes.Member] = (
+ expanded_required_state_map.get(
+ EventTypes.Member, set()
+ )
+ | timeline_membership
+ )
+ elif state_key == StateValues.ME:
+ num_others += 1
+ required_state_types.append((state_type, user.to_string()))
+ # Replace `$ME` with the user's ID so we can deduplicate
+ # when someone requests the same state with `$ME` or with
+ # their user ID.
+ #
+ # Make a new set or copy of the state key set so we can
+ # modify it without affecting the original
+ # `required_state_map`
+ expanded_required_state_map[EventTypes.Member] = (
+ expanded_required_state_map.get(
+ EventTypes.Member, set()
+ )
+ | {user.to_string()}
+ )
+ else:
+ num_others += 1
+ required_state_types.append((state_type, state_key))
+
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX
+ + "required_state_wildcard_state_key_count",
+ num_wild_state_keys,
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "required_state_lazy",
+ lazy_load_room_members,
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "required_state_other_count",
+ num_others,
+ )
+
+ required_state_filter = StateFilter.from_types(required_state_types)
+
+ # We need this base set of info for the response so let's just fetch it along
+ # with the `required_state` for the room
+ hero_room_state = [
+ (EventTypes.Member, hero_user_id) for hero_user_id in hero_user_ids
+ ]
+ meta_room_state = list(hero_room_state)
+ if initial or name_changed:
+ meta_room_state.append((EventTypes.Name, ""))
+ if initial or avatar_changed:
+ meta_room_state.append((EventTypes.RoomAvatar, ""))
+
+ state_filter = StateFilter.all()
+ if required_state_filter != StateFilter.all():
+ state_filter = StateFilter(
+ types=StateFilter.from_types(
+ chain(meta_room_state, required_state_filter.to_types())
+ ).types,
+ include_others=required_state_filter.include_others,
+ )
+
+ # The required state map to store in the room sync config, if it has
+ # changed.
+ changed_required_state_map: Optional[Mapping[str, AbstractSet[str]]] = None
+
+ # We can return all of the state that was requested if this was the first
+ # time we've sent the room down this connection.
+ room_state: StateMap[EventBase] = {}
+ if initial:
+ room_state = await self.get_current_state_at(
+ room_id=room_id,
+ room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
+ state_filter=state_filter,
+ to_token=to_token,
+ )
+ else:
+ assert from_bound is not None
+
+ if prev_room_sync_config is not None:
+ # Check if there are any changes to the required state config
+ # that we need to handle.
+ changed_required_state_map, added_state_filter = (
+ _required_state_changes(
+ user.to_string(),
+ prev_required_state_map=prev_room_sync_config.required_state_map,
+ request_required_state_map=expanded_required_state_map,
+ state_deltas=room_state_delta_id_map,
+ )
+ )
+
+ if added_state_filter:
+ # Some state entries got added, so we pull out the current
+ # state for them. If we don't do this we'd only send down new deltas.
+ state_ids = await self.get_current_state_ids_at(
+ room_id=room_id,
+ room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
+ state_filter=added_state_filter,
+ to_token=to_token,
+ )
+ room_state_delta_id_map.update(state_ids)
+
+ events = await self.store.get_events(
+ state_filter.filter_state(room_state_delta_id_map).values()
+ )
+ room_state = {(s.type, s.state_key): s for s in events.values()}
+
+ # If the membership changed and we have to get heroes, get the remaining
+ # heroes from the state
+ if hero_user_ids:
+ hero_membership_state = await self.get_current_state_at(
+ room_id=room_id,
+ room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
+ state_filter=StateFilter.from_types(hero_room_state),
+ to_token=to_token,
+ )
+ room_state.update(hero_membership_state)
+
+ required_room_state: StateMap[EventBase] = {}
+ if required_state_filter != StateFilter.none():
+ required_room_state = required_state_filter.filter_state(room_state)
+
+ # Find the room name and avatar from the state
+ room_name: Optional[str] = None
+ # TODO: Should we also check for `EventTypes.CanonicalAlias`
+ # (`m.room.canonical_alias`) as a fallback for the room name? see
+ # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153
+ name_event = room_state.get((EventTypes.Name, ""))
+ if name_event is not None:
+ room_name = name_event.content.get("name")
+
+ room_avatar: Optional[str] = None
+ avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
+ if avatar_event is not None:
+ room_avatar = avatar_event.content.get("url")
+
+ # Assemble heroes: extract the info from the state we just fetched
+ heroes: List[SlidingSyncResult.RoomResult.StrippedHero] = []
+ for hero_user_id in hero_user_ids:
+ member_event = room_state.get((EventTypes.Member, hero_user_id))
+ if member_event is not None:
+ heroes.append(
+ SlidingSyncResult.RoomResult.StrippedHero(
+ user_id=hero_user_id,
+ display_name=member_event.content.get("displayname"),
+ avatar_url=member_event.content.get("avatar_url"),
+ )
+ )
+
+ # Figure out the last bump event in the room. If the bump stamp hasn't
+ # changed we omit it from the response.
+ bump_stamp = None
+
+ always_return_bump_stamp = (
+ # We use the membership event position for any non-join
+ room_membership_for_user_at_to_token.membership != Membership.JOIN
+ # We didn't fetch any timeline events but we should still check for
+ # a bump_stamp that might be somewhere
+ or limited is None
+ # There might be a bump event somewhere before the timeline events
+ # that we fetched, that we didn't previously send down
+ or limited is True
+ # Always give the client some frame of reference if this is the
+ # first time they are seeing the room down the connection
+ or initial
+ )
+
+ # If we're joined to the room, we need to find the last bump event before the
+ # `to_token`
+ if room_membership_for_user_at_to_token.membership == Membership.JOIN:
+ # Try and get a bump stamp
+ new_bump_stamp = await self._get_bump_stamp(
+ room_id,
+ to_token,
+ timeline_events,
+ check_outside_timeline=always_return_bump_stamp,
+ )
+ if new_bump_stamp is not None:
+ bump_stamp = new_bump_stamp
+
+ if bump_stamp is None and always_return_bump_stamp:
+ # By default, just choose the membership event position for any non-join membership
+ bump_stamp = room_membership_for_user_at_to_token.event_pos.stream
+
+ if bump_stamp is not None and bump_stamp < 0:
+ # We never want to send down negative stream orderings, as you can't
+ # sensibly compare positive and negative stream orderings (they have
+ # different meanings).
+ #
+ # A negative bump stamp here can only happen if the stream ordering
+ # of the membership event is negative (and there are no further bump
+ # stamps), which can happen if the server leaves and deletes a room,
+ # and then rejoins it.
+ #
+ # To deal with this, we just set the bump stamp to zero, which will
+ # shove this room to the bottom of the list. This is OK as the
+ # moment a new message happens in the room it will get put into a
+ # sensible order again.
+ bump_stamp = 0
+
+ room_sync_required_state_map_to_persist: Mapping[str, AbstractSet[str]] = (
+ expanded_required_state_map
+ )
+ if changed_required_state_map:
+ room_sync_required_state_map_to_persist = changed_required_state_map
+
+ # Record the `room_sync_config` if we're `ignore_timeline_bound` (which means
+ # that the `timeline_limit` has increased)
+ unstable_expanded_timeline = False
+ if ignore_timeline_bound:
+ # FIXME: We signal the fact that we're sending down more events to
+ # the client by setting `unstable_expanded_timeline` to true (see
+ # "XXX: Odd behavior" above).
+ unstable_expanded_timeline = True
+
+ new_connection_state.room_configs[room_id] = RoomSyncConfig(
+ timeline_limit=room_sync_config.timeline_limit,
+ required_state_map=room_sync_required_state_map_to_persist,
+ )
+ elif prev_room_sync_config is not None:
+ # If the result is `limited` then we need to record that the
+ # `timeline_limit` has been reduced, as when/if the client later requests
+ # more timeline then we have more data to send.
+ #
+ # Otherwise (when not `limited`) we don't need to record that the
+ # `timeline_limit` has been reduced, as the *effective* `timeline_limit`
+ # (i.e. the amount of timeline we have previously sent to the client) is at
+ # least the previous `timeline_limit`.
+ #
+ # This is to handle the case where the `timeline_limit` e.g. goes from 10 to
+ # 5 to 10 again (without any timeline gaps), where there's no point sending
+ # down the initial historical chunk events when the `timeline_limit` is
+ # increased as the client already has the 10 previous events. However, if
+ # client has a gap in the timeline (i.e. `limited` is True), then we *do*
+ # need to record the reduced timeline.
+ #
+ # TODO: Handle timeline gaps (`get_timeline_gaps()`) - This is separate from
+ # the gaps we might see on the client because a response was `limited` we're
+ # talking about above.
+ if (
+ limited
+ and prev_room_sync_config.timeline_limit
+ > room_sync_config.timeline_limit
+ ):
+ new_connection_state.room_configs[room_id] = RoomSyncConfig(
+ timeline_limit=room_sync_config.timeline_limit,
+ required_state_map=room_sync_required_state_map_to_persist,
+ )
+
+ elif changed_required_state_map is not None:
+ new_connection_state.room_configs[room_id] = RoomSyncConfig(
+ timeline_limit=room_sync_config.timeline_limit,
+ required_state_map=room_sync_required_state_map_to_persist,
+ )
+
+ else:
+ new_connection_state.room_configs[room_id] = RoomSyncConfig(
+ timeline_limit=room_sync_config.timeline_limit,
+ required_state_map=room_sync_required_state_map_to_persist,
+ )
+
+ set_tag(SynapseTags.RESULT_PREFIX + "initial", initial)
+
+ return SlidingSyncResult.RoomResult(
+ name=room_name,
+ avatar=room_avatar,
+ heroes=heroes,
+ is_dm=is_dm,
+ initial=initial,
+ required_state=list(required_room_state.values()),
+ timeline_events=timeline_events,
+ bundled_aggregations=bundled_aggregations,
+ stripped_state=stripped_state,
+ prev_batch=prev_batch_token,
+ limited=limited,
+ unstable_expanded_timeline=unstable_expanded_timeline,
+ num_live=num_live,
+ bump_stamp=bump_stamp,
+ joined_count=joined_count,
+ invited_count=invited_count,
+ # TODO: These are just dummy values. We could potentially just remove these
+ # since notifications can only really be done correctly on the client anyway
+ # (encrypted rooms).
+ notification_count=0,
+ highlight_count=0,
+ )
+
+ @trace
+ async def _get_bump_stamp(
+ self,
+ room_id: str,
+ to_token: StreamToken,
+ timeline: List[EventBase],
+ check_outside_timeline: bool,
+ ) -> Optional[int]:
+ """Get a bump stamp for the room, if we have a bump event and it has
+ changed.
+
+ Args:
+ room_id
+ to_token: The upper bound of token to return
+ timeline: The list of events we have fetched.
+ limited: If the timeline was limited.
+ check_outside_timeline: Whether we need to check for bump stamp for
+ events before the timeline if we didn't find a bump stamp in
+ the timeline events.
+ """
+
+ # First check the timeline events we're returning to see if one of
+ # those matches. We iterate backwards and take the stream ordering
+ # of the first event that matches the bump event types.
+ for timeline_event in reversed(timeline):
+ if timeline_event.type in SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES:
+ new_bump_stamp = timeline_event.internal_metadata.stream_ordering
+
+ # All persisted events have a stream ordering
+ assert new_bump_stamp is not None
+
+ # If we've just joined a remote room, then the last bump event may
+ # have been backfilled (and so have a negative stream ordering).
+ # These negative stream orderings can't sensibly be compared, so
+ # instead we use the membership event position.
+ if new_bump_stamp > 0:
+ return new_bump_stamp
+
+ if not check_outside_timeline:
+ # If we are not a limited sync, then we know the bump stamp can't
+ # have changed.
+ return None
+
+ # We can quickly query for the latest bump event in the room using the
+ # sliding sync tables.
+ latest_room_bump_stamp = await self.store.get_latest_bump_stamp_for_room(
+ room_id
+ )
+
+ min_to_token_position = to_token.room_key.stream
+
+ # If we can rely on the new sliding sync tables and the `bump_stamp` is
+ # `None`, just fallback to the membership event position. This can happen
+ # when we've just joined a remote room and all the events are backfilled.
+ if (
+ # FIXME: The background job check can be removed once we bump
+ # `SCHEMA_COMPAT_VERSION` and run the foreground update for
+ # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots`
+ # (tracked by https://github.com/element-hq/synapse/issues/17623)
+ latest_room_bump_stamp is None
+ and await self.store.have_finished_sliding_sync_background_jobs()
+ ):
+ return None
+
+ # The `bump_stamp` stored in the database might be ahead of our token. Since
+ # `bump_stamp` is only a `stream_ordering` position, we can't be 100% sure
+ # that's before the `to_token` in all scenarios. The only scenario we can be
+ # sure of is if the `bump_stamp` is totally before the minimum position from
+ # the token.
+ #
+ # We don't need to check if the background update has finished, as if the
+ # returned bump stamp is not None then it must be up to date.
+ elif (
+ latest_room_bump_stamp is not None
+ and latest_room_bump_stamp < min_to_token_position
+ ):
+ if latest_room_bump_stamp > 0:
+ return latest_room_bump_stamp
+ else:
+ return None
+
+ # Otherwise, if it's within or after the `to_token`, we need to find the
+ # last bump event before the `to_token`.
+ else:
+ last_bump_event_result = (
+ await self.store.get_last_event_pos_in_room_before_stream_ordering(
+ room_id,
+ to_token.room_key,
+ event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES,
+ )
+ )
+ if last_bump_event_result is not None:
+ _, new_bump_event_pos = last_bump_event_result
+
+ # If we've just joined a remote room, then the last bump event may
+ # have been backfilled (and so have a negative stream ordering).
+ # These negative stream orderings can't sensibly be compared, so
+ # instead we use the membership event position.
+ if new_bump_event_pos.stream > 0:
+ return new_bump_event_pos.stream
+
+ return None
+
+
+def _required_state_changes(
+ user_id: str,
+ *,
+ prev_required_state_map: Mapping[str, AbstractSet[str]],
+ request_required_state_map: Mapping[str, AbstractSet[str]],
+ state_deltas: StateMap[str],
+) -> Tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]:
+ """Calculates the changes between the required state room config from the
+ previous requests compared with the current request.
+
+ This does two things. First, it calculates if we need to update the room
+ config due to changes to required state. Secondly, it works out which state
+ entries we need to pull from current state and return due to the state entry
+ now appearing in the required state when it previously wasn't (on top of the
+ state deltas).
+
+ This function tries to ensure to handle the case where a state entry is
+ added, removed and then added again to the required state. In that case we
+ only want to re-send that entry down sync if it has changed.
+
+ Returns:
+ A 2-tuple of updated required state config (or None if there is no update)
+ and the state filter to use to fetch extra current state that we need to
+ return.
+ """
+ if prev_required_state_map == request_required_state_map:
+ # There has been no change. Return immediately.
+ return None, StateFilter.none()
+
+ prev_wildcard = prev_required_state_map.get(StateValues.WILDCARD, set())
+ request_wildcard = request_required_state_map.get(StateValues.WILDCARD, set())
+
+ # If we were previously fetching everything ("*", "*"), always update the effective
+ # room required state config to match the request. And since we we're previously
+ # already fetching everything, we don't have to fetch anything now that they've
+ # narrowed.
+ if StateValues.WILDCARD in prev_wildcard:
+ return request_required_state_map, StateFilter.none()
+
+ # If a event type wildcard has been added or removed we don't try and do
+ # anything fancy, and instead always update the effective room required
+ # state config to match the request.
+ if request_wildcard - prev_wildcard:
+ # Some keys were added, so we need to fetch everything
+ return request_required_state_map, StateFilter.all()
+ if prev_wildcard - request_wildcard:
+ # Keys were only removed, so we don't have to fetch everything.
+ return request_required_state_map, StateFilter.none()
+
+ # Contains updates to the required state map compared with the previous room
+ # config. This has the same format as `RoomSyncConfig.required_state`
+ changes: Dict[str, AbstractSet[str]] = {}
+
+ # The set of types/state keys that we need to fetch and return to the
+ # client. Passed to `StateFilter.from_types(...)`
+ added: List[Tuple[str, Optional[str]]] = []
+
+ # Convert the list of state deltas to map from type to state_keys that have
+ # changed.
+ changed_types_to_state_keys: Dict[str, Set[str]] = {}
+ for event_type, state_key in state_deltas:
+ changed_types_to_state_keys.setdefault(event_type, set()).add(state_key)
+
+ # First we calculate what, if anything, has been *added*.
+ for event_type in (
+ prev_required_state_map.keys() | request_required_state_map.keys()
+ ):
+ old_state_keys = prev_required_state_map.get(event_type, set())
+ request_state_keys = request_required_state_map.get(event_type, set())
+ changed_state_keys = changed_types_to_state_keys.get(event_type, set())
+
+ if old_state_keys == request_state_keys:
+ # No change to this type
+ continue
+
+ if not request_state_keys - old_state_keys:
+ # Nothing *added*, so we skip. Removals happen below.
+ continue
+
+ # We only remove state keys from the effective state if they've been
+ # removed from the request *and* the state has changed. This ensures
+ # that if a client removes and then re-adds a state key, we only send
+ # down the associated current state event if its changed (rather than
+ # sending down the same event twice).
+ invalidated_state_keys = (
+ old_state_keys - request_state_keys
+ ) & changed_state_keys
+
+ # Figure out which state keys we should remember sending down the connection
+ inheritable_previous_state_keys = (
+ # Retain the previous state_keys that we've sent down before.
+ # Wildcard and lazy state keys are not sticky from previous requests.
+ (old_state_keys - {StateValues.WILDCARD, StateValues.LAZY})
+ - invalidated_state_keys
+ )
+
+ # Always update changes to include the newly added keys (we've expanded the set
+ # of state keys), use the new requested set with whatever hasn't been
+ # invalidated from the previous set.
+ changes[event_type] = request_state_keys | inheritable_previous_state_keys
+ # Limit the number of state_keys we should remember sending down the connection
+ # for each (room_id, user_id). We don't want to store and pull out too much data
+ # in the database. This is a happy-medium between remembering nothing and
+ # everything. We can avoid sending redundant state down the connection most of
+ # the time given that most rooms don't have 100 members anyway and it takes a
+ # while to cycle through 100 members.
+ #
+ # Only remember up to (MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER)
+ if len(changes[event_type]) > MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER:
+ # Reset back to only the requested state keys
+ changes[event_type] = request_state_keys
+
+ # Skip if there isn't any room to fill in the rest with previous state keys
+ if len(request_state_keys) < MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER:
+ # Fill the rest with previous state_keys. Ideally, we could sort
+ # these by recency but it's just a set so just pick an arbitrary
+ # subset (good enough).
+ changes[event_type] = changes[event_type] | set(
+ itertools.islice(
+ inheritable_previous_state_keys,
+ # Just taking the difference isn't perfect as there could be
+ # overlap in the keys between the requested and previous but we
+ # will decide to just take the easy route for now and avoid
+ # additional set operations to figure it out.
+ MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER
+ - len(request_state_keys),
+ )
+ )
+
+ if StateValues.WILDCARD in old_state_keys:
+ # We were previously fetching everything for this type, so we don't need to
+ # fetch anything new.
+ continue
+
+ # Record the new state keys to fetch for this type.
+ if StateValues.WILDCARD in request_state_keys:
+ # If we have added a wildcard then we always just fetch everything.
+ added.append((event_type, None))
+ else:
+ for state_key in request_state_keys - old_state_keys:
+ if state_key == StateValues.ME:
+ added.append((event_type, user_id))
+ elif state_key == StateValues.LAZY:
+ # We handle lazy loading separately (outside this function),
+ # so don't need to explicitly add anything here.
+ #
+ # LAZY values should also be ignore for event types that are
+ # not membership.
+ pass
+ else:
+ added.append((event_type, state_key))
+
+ added_state_filter = StateFilter.from_types(added)
+
+ # Figure out what changes we need to apply to the effective required state
+ # config.
+ for event_type, changed_state_keys in changed_types_to_state_keys.items():
+ old_state_keys = prev_required_state_map.get(event_type, set())
+ request_state_keys = request_required_state_map.get(event_type, set())
+
+ if old_state_keys == request_state_keys:
+ # No change.
+ continue
+
+ # If we see the `user_id` as a state_key, also add "$ME" to the list of state
+ # that has changed to account for people requesting `required_state` with `$ME`
+ # or their user ID.
+ if user_id in changed_state_keys:
+ changed_state_keys.add(StateValues.ME)
+
+ # We only remove state keys from the effective state if they've been
+ # removed from the request *and* the state has changed. This ensures
+ # that if a client removes and then re-adds a state key, we only send
+ # down the associated current state event if its changed (rather than
+ # sending down the same event twice).
+ invalidated_state_keys = (
+ old_state_keys - request_state_keys
+ ) & changed_state_keys
+
+ # We've expanded the set of state keys, ... (already handled above)
+ if request_state_keys - old_state_keys:
+ continue
+
+ old_state_key_wildcard = StateValues.WILDCARD in old_state_keys
+ request_state_key_wildcard = StateValues.WILDCARD in request_state_keys
+
+ if old_state_key_wildcard != request_state_key_wildcard:
+ # If a state_key wildcard has been added or removed, we always update the
+ # effective room required state config to match the request.
+ changes[event_type] = request_state_keys
+ continue
+
+ if event_type == EventTypes.Member:
+ old_state_key_lazy = StateValues.LAZY in old_state_keys
+ request_state_key_lazy = StateValues.LAZY in request_state_keys
+
+ if old_state_key_lazy != request_state_key_lazy:
+ # If a "$LAZY" has been added or removed we always update the effective room
+ # required state config to match the request.
+ changes[event_type] = request_state_keys
+ continue
+
+ # At this point there are no wildcards and no additions to the set of
+ # state keys requested, only deletions.
+ #
+ # We only remove state keys from the effective state if they've been
+ # removed from the request *and* the state has changed. This ensures
+ # that if a client removes and then re-adds a state key, we only send
+ # down the associated current state event if its changed (rather than
+ # sending down the same event twice).
+ if invalidated_state_keys:
+ changes[event_type] = old_state_keys - invalidated_state_keys
+
+ if changes:
+ # Update the required state config based on the changes.
+ new_required_state_map = dict(prev_required_state_map)
+ for event_type, state_keys in changes.items():
+ if state_keys:
+ new_required_state_map[event_type] = state_keys
+ else:
+ # Remove entries with empty state keys.
+ new_required_state_map.pop(event_type, None)
+
+ return new_required_state_map, added_state_filter
+ else:
+ return None, added_state_filter
diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py
new file mode 100644
index 0000000000..077887ec32
--- /dev/null
+++ b/synapse/handlers/sliding_sync/extensions.py
@@ -0,0 +1,879 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+import itertools
+import logging
+from typing import (
+ TYPE_CHECKING,
+ AbstractSet,
+ ChainMap,
+ Dict,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Sequence,
+ Set,
+ cast,
+)
+
+from typing_extensions import assert_never
+
+from synapse.api.constants import AccountDataTypes, EduTypes
+from synapse.handlers.receipts import ReceiptEventSource
+from synapse.logging.opentracing import trace
+from synapse.storage.databases.main.receipts import ReceiptInRoom
+from synapse.types import (
+ DeviceListUpdates,
+ JsonMapping,
+ MultiWriterStreamToken,
+ SlidingSyncStreamToken,
+ StrCollection,
+ StreamToken,
+)
+from synapse.types.handlers.sliding_sync import (
+ HaveSentRoomFlag,
+ MutablePerConnectionState,
+ OperationType,
+ PerConnectionState,
+ SlidingSyncConfig,
+ SlidingSyncResult,
+)
+from synapse.util.async_helpers import (
+ concurrently_execute,
+ gather_optional_coroutines,
+)
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SlidingSyncExtensionHandler:
+ """Handles the extensions to sliding sync."""
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastores().main
+ self.event_sources = hs.get_event_sources()
+ self.device_handler = hs.get_device_handler()
+ self.push_rules_handler = hs.get_push_rules_handler()
+
+ @trace
+ async def get_extensions_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ new_connection_state: "MutablePerConnectionState",
+ actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> SlidingSyncResult.Extensions:
+ """Handle extension requests.
+
+ Args:
+ sync_config: Sync configuration
+ new_connection_state: Snapshot of the current per-connection state
+ new_per_connection_state: A mutable copy of the per-connection
+ state, used to record updates to the state during this request.
+ actual_lists: Sliding window API. A map of list key to list results in the
+ Sliding Sync response.
+ actual_room_ids: The actual room IDs in the the Sliding Sync response.
+ actual_room_response_map: A map of room ID to room results in the the
+ Sliding Sync response.
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+
+ if sync_config.extensions is None:
+ return SlidingSyncResult.Extensions()
+
+ to_device_coro = None
+ if sync_config.extensions.to_device is not None:
+ to_device_coro = self.get_to_device_extension_response(
+ sync_config=sync_config,
+ to_device_request=sync_config.extensions.to_device,
+ to_token=to_token,
+ )
+
+ e2ee_coro = None
+ if sync_config.extensions.e2ee is not None:
+ e2ee_coro = self.get_e2ee_extension_response(
+ sync_config=sync_config,
+ e2ee_request=sync_config.extensions.e2ee,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ account_data_coro = None
+ if sync_config.extensions.account_data is not None:
+ account_data_coro = self.get_account_data_extension_response(
+ sync_config=sync_config,
+ previous_connection_state=previous_connection_state,
+ new_connection_state=new_connection_state,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ account_data_request=sync_config.extensions.account_data,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ receipts_coro = None
+ if sync_config.extensions.receipts is not None:
+ receipts_coro = self.get_receipts_extension_response(
+ sync_config=sync_config,
+ previous_connection_state=previous_connection_state,
+ new_connection_state=new_connection_state,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ actual_room_response_map=actual_room_response_map,
+ receipts_request=sync_config.extensions.receipts,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ typing_coro = None
+ if sync_config.extensions.typing is not None:
+ typing_coro = self.get_typing_extension_response(
+ sync_config=sync_config,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ actual_room_response_map=actual_room_response_map,
+ typing_request=sync_config.extensions.typing,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ (
+ to_device_response,
+ e2ee_response,
+ account_data_response,
+ receipts_response,
+ typing_response,
+ ) = await gather_optional_coroutines(
+ to_device_coro,
+ e2ee_coro,
+ account_data_coro,
+ receipts_coro,
+ typing_coro,
+ )
+
+ return SlidingSyncResult.Extensions(
+ to_device=to_device_response,
+ e2ee=e2ee_response,
+ account_data=account_data_response,
+ receipts=receipts_response,
+ typing=typing_response,
+ )
+
+ def find_relevant_room_ids_for_extension(
+ self,
+ requested_lists: Optional[StrCollection],
+ requested_room_ids: Optional[StrCollection],
+ actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: AbstractSet[str],
+ ) -> Set[str]:
+ """
+ Handle the reserved `lists`/`rooms` keys for extensions. Extensions should only
+ return results for rooms in the Sliding Sync response. This matches up the
+ requested rooms/lists with the actual lists/rooms in the Sliding Sync response.
+
+ {"lists": []} // Do not process any lists.
+ {"lists": ["rooms", "dms"]} // Process only a subset of lists.
+ {"lists": ["*"]} // Process all lists defined in the Sliding Window API. (This is the default.)
+
+ {"rooms": []} // Do not process any specific rooms.
+ {"rooms": ["!a:b", "!c:d"]} // Process only a subset of room subscriptions.
+ {"rooms": ["*"]} // Process all room subscriptions defined in the Room Subscription API. (This is the default.)
+
+ Args:
+ requested_lists: The `lists` from the extension request.
+ requested_room_ids: The `rooms` from the extension request.
+ actual_lists: The actual lists from the Sliding Sync response.
+ actual_room_ids: The actual room subscriptions from the Sliding Sync request.
+ """
+
+ # We only want to include account data for rooms that are already in the sliding
+ # sync response AND that were requested in the account data request.
+ relevant_room_ids: Set[str] = set()
+
+ # See what rooms from the room subscriptions we should get account data for
+ if requested_room_ids is not None:
+ for room_id in requested_room_ids:
+ # A wildcard means we process all rooms from the room subscriptions
+ if room_id == "*":
+ relevant_room_ids.update(actual_room_ids)
+ break
+
+ if room_id in actual_room_ids:
+ relevant_room_ids.add(room_id)
+
+ # See what rooms from the sliding window lists we should get account data for
+ if requested_lists is not None:
+ for list_key in requested_lists:
+ # Just some typing because we share the variable name in multiple places
+ actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None
+
+ # A wildcard means we process rooms from all lists
+ if list_key == "*":
+ for actual_list in actual_lists.values():
+ # We only expect a single SYNC operation for any list
+ assert len(actual_list.ops) == 1
+ sync_op = actual_list.ops[0]
+ assert sync_op.op == OperationType.SYNC
+
+ relevant_room_ids.update(sync_op.room_ids)
+
+ break
+
+ actual_list = actual_lists.get(list_key)
+ if actual_list is not None:
+ # We only expect a single SYNC operation for any list
+ assert len(actual_list.ops) == 1
+ sync_op = actual_list.ops[0]
+ assert sync_op.op == OperationType.SYNC
+
+ relevant_room_ids.update(sync_op.room_ids)
+
+ return relevant_room_ids
+
+ @trace
+ async def get_to_device_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
+ to_token: StreamToken,
+ ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
+ """Handle to-device extension (MSC3885)
+
+ Args:
+ sync_config: Sync configuration
+ to_device_request: The to-device extension from the request
+ to_token: The point in the stream to sync up to.
+ """
+ user_id = sync_config.user.to_string()
+ device_id = sync_config.requester.device_id
+
+ # Skip if the extension is not enabled
+ if not to_device_request.enabled:
+ return None
+
+ # Check that this request has a valid device ID (not all requests have
+ # to belong to a device, and so device_id is None)
+ if device_id is None:
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=f"{to_token.to_device_key}",
+ events=[],
+ )
+
+ since_stream_id = 0
+ if to_device_request.since is not None:
+ # We've already validated this is an int.
+ since_stream_id = int(to_device_request.since)
+
+ if to_token.to_device_key < since_stream_id:
+ # The since token is ahead of our current token, so we return an
+ # empty response.
+ logger.warning(
+ "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
+ since_stream_id,
+ to_token.to_device_key,
+ )
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=to_device_request.since,
+ events=[],
+ )
+
+ # Delete everything before the given since token, as we know the
+ # device must have received them.
+ deleted = await self.store.delete_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ up_to_stream_id=since_stream_id,
+ )
+
+ logger.debug(
+ "Deleted %d to-device messages up to %d for %s",
+ deleted,
+ since_stream_id,
+ user_id,
+ )
+
+ messages, stream_id = await self.store.get_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ from_stream_id=since_stream_id,
+ to_stream_id=to_token.to_device_key,
+ limit=min(to_device_request.limit, 100), # Limit to at most 100 events
+ )
+
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=f"{stream_id}",
+ events=messages,
+ )
+
+ @trace
+ async def get_e2ee_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
+ """Handle E2EE device extension (MSC3884)
+
+ Args:
+ sync_config: Sync configuration
+ e2ee_request: The e2ee extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ user_id = sync_config.user.to_string()
+ device_id = sync_config.requester.device_id
+
+ # Skip if the extension is not enabled
+ if not e2ee_request.enabled:
+ return None
+
+ device_list_updates: Optional[DeviceListUpdates] = None
+ if from_token is not None:
+ # TODO: This should take into account the `from_token` and `to_token`
+ device_list_updates = await self.device_handler.get_user_ids_changed(
+ user_id=user_id,
+ from_token=from_token.stream_token,
+ )
+
+ device_one_time_keys_count: Mapping[str, int] = {}
+ device_unused_fallback_key_types: Sequence[str] = []
+ if device_id:
+ # TODO: We should have a way to let clients differentiate between the states of:
+ # * no change in OTK count since the provided since token
+ # * the server has zero OTKs left for this device
+ # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+ device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+ device_unused_fallback_key_types = (
+ await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+ )
+
+ return SlidingSyncResult.Extensions.E2eeExtension(
+ device_list_updates=device_list_updates,
+ device_one_time_keys_count=device_one_time_keys_count,
+ device_unused_fallback_key_types=device_unused_fallback_key_types,
+ )
+
+ @trace
+ async def get_account_data_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ new_connection_state: "MutablePerConnectionState",
+ actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]:
+ """Handle Account Data extension (MSC3959)
+
+ Args:
+ sync_config: Sync configuration
+ actual_lists: Sliding window API. A map of list key to list results in the
+ Sliding Sync response.
+ actual_room_ids: The actual room IDs in the the Sliding Sync response.
+ account_data_request: The account_data extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ user_id = sync_config.user.to_string()
+
+ # Skip if the extension is not enabled
+ if not account_data_request.enabled:
+ return None
+
+ global_account_data_map: Mapping[str, JsonMapping] = {}
+ if from_token is not None:
+ # TODO: This should take into account the `from_token` and `to_token`
+ global_account_data_map = (
+ await self.store.get_updated_global_account_data_for_user(
+ user_id, from_token.stream_token.account_data_key
+ )
+ )
+
+ # TODO: This should take into account the `from_token` and `to_token`
+ have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
+ user_id, from_token.stream_token.push_rules_key
+ )
+ if have_push_rules_changed:
+ # TODO: This should take into account the `from_token` and `to_token`
+ global_account_data_map[
+ AccountDataTypes.PUSH_RULES
+ ] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
+ else:
+ # TODO: This should take into account the `to_token`
+ immutable_global_account_data_map = (
+ await self.store.get_global_account_data_for_user(user_id)
+ )
+
+ # Use a `ChainMap` to avoid copying the immutable data from the cache
+ global_account_data_map = ChainMap(
+ {
+ # TODO: This should take into account the `to_token`
+ AccountDataTypes.PUSH_RULES: await self.push_rules_handler.push_rules_for_user(
+ sync_config.user
+ )
+ },
+ # Cast is safe because `ChainMap` only mutates the top-most map,
+ # see https://github.com/python/typeshed/issues/8430
+ cast(
+ MutableMapping[str, JsonMapping], immutable_global_account_data_map
+ ),
+ )
+
+ # Fetch room account data
+ #
+ account_data_by_room_map: MutableMapping[str, Mapping[str, JsonMapping]] = {}
+ relevant_room_ids = self.find_relevant_room_ids_for_extension(
+ requested_lists=account_data_request.lists,
+ requested_room_ids=account_data_request.rooms,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ )
+ if len(relevant_room_ids) > 0:
+ # We need to handle the different cases depending on if we have sent
+ # down account data previously or not, so we split the relevant
+ # rooms up into different collections based on status.
+ live_rooms = set()
+ previously_rooms: Dict[str, int] = {}
+ initial_rooms = set()
+
+ for room_id in relevant_room_ids:
+ if not from_token:
+ initial_rooms.add(room_id)
+ continue
+
+ room_status = previous_connection_state.account_data.have_sent_room(
+ room_id
+ )
+ if room_status.status == HaveSentRoomFlag.LIVE:
+ live_rooms.add(room_id)
+ elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
+ assert room_status.last_token is not None
+ previously_rooms[room_id] = room_status.last_token
+ elif room_status.status == HaveSentRoomFlag.NEVER:
+ initial_rooms.add(room_id)
+ else:
+ assert_never(room_status.status)
+
+ # We fetch all room account data since the from_token. This is so
+ # that we can record which rooms have updates that haven't been sent
+ # down.
+ #
+ # Mapping from room_id to mapping of `type` to `content` of room account
+ # data events.
+ all_updates_since_the_from_token: Mapping[
+ str, Mapping[str, JsonMapping]
+ ] = {}
+ if from_token is not None:
+ # TODO: This should take into account the `from_token` and `to_token`
+ all_updates_since_the_from_token = (
+ await self.store.get_updated_room_account_data_for_user(
+ user_id, from_token.stream_token.account_data_key
+ )
+ )
+
+ # Add room tags
+ #
+ # TODO: This should take into account the `from_token` and `to_token`
+ tags_by_room = await self.store.get_updated_tags(
+ user_id, from_token.stream_token.account_data_key
+ )
+ for room_id, tags in tags_by_room.items():
+ all_updates_since_the_from_token.setdefault(room_id, {})[
+ AccountDataTypes.TAG
+ ] = {"tags": tags}
+
+ # For live rooms we just get the updates from `all_updates_since_the_from_token`
+ if live_rooms:
+ for room_id in all_updates_since_the_from_token.keys() & live_rooms:
+ account_data_by_room_map[room_id] = (
+ all_updates_since_the_from_token[room_id]
+ )
+
+ # For previously and initial rooms we query each room individually.
+ if previously_rooms or initial_rooms:
+
+ async def handle_previously(room_id: str) -> None:
+ # Either get updates or all account data in the room
+ # depending on if the room state is PREVIOUSLY or NEVER.
+ previous_token = previously_rooms.get(room_id)
+ if previous_token is not None:
+ room_account_data = await (
+ self.store.get_updated_room_account_data_for_user_for_room(
+ user_id=user_id,
+ room_id=room_id,
+ from_stream_id=previous_token,
+ to_stream_id=to_token.account_data_key,
+ )
+ )
+
+ # Add room tags
+ changed = await self.store.has_tags_changed_for_room(
+ user_id=user_id,
+ room_id=room_id,
+ from_stream_id=previous_token,
+ to_stream_id=to_token.account_data_key,
+ )
+ if changed:
+ # XXX: Ideally, this should take into account the `to_token`
+ # and return the set of tags at that time but we don't track
+ # changes to tags so we just have to return all tags for the
+ # room.
+ immutable_tag_map = await self.store.get_tags_for_room(
+ user_id, room_id
+ )
+ room_account_data[AccountDataTypes.TAG] = {
+ "tags": immutable_tag_map
+ }
+
+ # Only add an entry if there were any updates.
+ if room_account_data:
+ account_data_by_room_map[room_id] = room_account_data
+ else:
+ # TODO: This should take into account the `to_token`
+ immutable_room_account_data = (
+ await self.store.get_account_data_for_room(user_id, room_id)
+ )
+
+ # Add room tags
+ #
+ # XXX: Ideally, this should take into account the `to_token`
+ # and return the set of tags at that time but we don't track
+ # changes to tags so we just have to return all tags for the
+ # room.
+ immutable_tag_map = await self.store.get_tags_for_room(
+ user_id, room_id
+ )
+
+ account_data_by_room_map[room_id] = ChainMap(
+ {AccountDataTypes.TAG: {"tags": immutable_tag_map}}
+ if immutable_tag_map
+ else {},
+ # Cast is safe because `ChainMap` only mutates the top-most map,
+ # see https://github.com/python/typeshed/issues/8430
+ cast(
+ MutableMapping[str, JsonMapping],
+ immutable_room_account_data,
+ ),
+ )
+
+ # We handle these rooms concurrently to speed it up.
+ await concurrently_execute(
+ handle_previously,
+ previously_rooms.keys() | initial_rooms,
+ limit=20,
+ )
+
+ # Now record which rooms are now up to data, and which rooms have
+ # pending updates to send.
+ new_connection_state.account_data.record_sent_rooms(previously_rooms.keys())
+ new_connection_state.account_data.record_sent_rooms(initial_rooms)
+ missing_updates = (
+ all_updates_since_the_from_token.keys() - relevant_room_ids
+ )
+ if missing_updates:
+ # If we have missing updates then we must have had a from_token.
+ assert from_token is not None
+
+ new_connection_state.account_data.record_unsent_rooms(
+ missing_updates, from_token.stream_token.account_data_key
+ )
+
+ return SlidingSyncResult.Extensions.AccountDataExtension(
+ global_account_data_map=global_account_data_map,
+ account_data_by_room_map=account_data_by_room_map,
+ )
+
+ @trace
+ async def get_receipts_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ new_connection_state: "MutablePerConnectionState",
+ actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
+ receipts_request: SlidingSyncConfig.Extensions.ReceiptsExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.ReceiptsExtension]:
+ """Handle Receipts extension (MSC3960)
+
+ Args:
+ sync_config: Sync configuration
+ previous_connection_state: The current per-connection state
+ new_connection_state: A mutable copy of the per-connection
+ state, used to record updates to the state.
+ actual_lists: Sliding window API. A map of list key to list results in the
+ Sliding Sync response.
+ actual_room_ids: The actual room IDs in the the Sliding Sync response.
+ actual_room_response_map: A map of room ID to room results in the the
+ Sliding Sync response.
+ account_data_request: The account_data extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ # Skip if the extension is not enabled
+ if not receipts_request.enabled:
+ return None
+
+ relevant_room_ids = self.find_relevant_room_ids_for_extension(
+ requested_lists=receipts_request.lists,
+ requested_room_ids=receipts_request.rooms,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ )
+
+ room_id_to_receipt_map: Dict[str, JsonMapping] = {}
+ if len(relevant_room_ids) > 0:
+ # We need to handle the different cases depending on if we have sent
+ # down receipts previously or not, so we split the relevant rooms
+ # up into different collections based on status.
+ live_rooms = set()
+ previously_rooms: Dict[str, MultiWriterStreamToken] = {}
+ initial_rooms = set()
+
+ for room_id in relevant_room_ids:
+ if not from_token:
+ initial_rooms.add(room_id)
+ continue
+
+ # If we're sending down the room from scratch again for some
+ # reason, we should always resend the receipts as well
+ # (regardless of if we've sent them down before). This is to
+ # mimic the behaviour of what happens on initial sync, where you
+ # get a chunk of timeline with all of the corresponding receipts
+ # for the events in the timeline.
+ #
+ # We also resend down receipts when we "expand" the timeline,
+ # (see the "XXX: Odd behavior" in
+ # `synapse.handlers.sliding_sync`).
+ room_result = actual_room_response_map.get(room_id)
+ if room_result is not None:
+ if room_result.initial or room_result.unstable_expanded_timeline:
+ initial_rooms.add(room_id)
+ continue
+
+ room_status = previous_connection_state.receipts.have_sent_room(room_id)
+ if room_status.status == HaveSentRoomFlag.LIVE:
+ live_rooms.add(room_id)
+ elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
+ assert room_status.last_token is not None
+ previously_rooms[room_id] = room_status.last_token
+ elif room_status.status == HaveSentRoomFlag.NEVER:
+ initial_rooms.add(room_id)
+ else:
+ assert_never(room_status.status)
+
+ # The set of receipts that we fetched. Private receipts need to be
+ # filtered out before returning.
+ fetched_receipts = []
+
+ # For live rooms we just fetch all receipts in those rooms since the
+ # `since` token.
+ if live_rooms:
+ assert from_token is not None
+ receipts = await self.store.get_linearized_receipts_for_rooms(
+ room_ids=live_rooms,
+ from_key=from_token.stream_token.receipt_key,
+ to_key=to_token.receipt_key,
+ )
+ fetched_receipts.extend(receipts)
+
+ # For rooms we've previously sent down, but aren't up to date, we
+ # need to use the from token from the room status.
+ if previously_rooms:
+ # Fetch any missing rooms concurrently.
+
+ async def handle_previously_room(room_id: str) -> None:
+ receipt_token = previously_rooms[room_id]
+ # TODO: Limit the number of receipts we're about to send down
+ # for the room, if its too many we should TODO
+ previously_receipts = (
+ await self.store.get_linearized_receipts_for_room(
+ room_id=room_id,
+ from_key=receipt_token,
+ to_key=to_token.receipt_key,
+ )
+ )
+ fetched_receipts.extend(previously_receipts)
+
+ await concurrently_execute(
+ handle_previously_room, previously_rooms.keys(), 20
+ )
+
+ if initial_rooms:
+ # We also always send down receipts for the current user.
+ user_receipts = (
+ await self.store.get_linearized_receipts_for_user_in_rooms(
+ user_id=sync_config.user.to_string(),
+ room_ids=initial_rooms,
+ to_key=to_token.receipt_key,
+ )
+ )
+
+ # For rooms we haven't previously sent down, we could send all receipts
+ # from that room but we only want to include receipts for events
+ # in the timeline to avoid bloating and blowing up the sync response
+ # as the number of users in the room increases. (this behavior is part of the spec)
+ initial_rooms_and_event_ids = [
+ (room_id, event.event_id)
+ for room_id in initial_rooms
+ if room_id in actual_room_response_map
+ for event in actual_room_response_map[room_id].timeline_events
+ ]
+ initial_receipts = await self.store.get_linearized_receipts_for_events(
+ room_and_event_ids=initial_rooms_and_event_ids,
+ )
+
+ # Combine the receipts for a room and add them to
+ # `fetched_receipts`
+ for room_id in initial_receipts.keys() | user_receipts.keys():
+ receipt_content = ReceiptInRoom.merge_to_content(
+ list(
+ itertools.chain(
+ initial_receipts.get(room_id, []),
+ user_receipts.get(room_id, []),
+ )
+ )
+ )
+
+ fetched_receipts.append(
+ {
+ "room_id": room_id,
+ "type": EduTypes.RECEIPT,
+ "content": receipt_content,
+ }
+ )
+
+ fetched_receipts = ReceiptEventSource.filter_out_private_receipts(
+ fetched_receipts, sync_config.user.to_string()
+ )
+
+ for receipt in fetched_receipts:
+ # These fields should exist for every receipt
+ room_id = receipt["room_id"]
+ type = receipt["type"]
+ content = receipt["content"]
+
+ room_id_to_receipt_map[room_id] = {"type": type, "content": content}
+
+ # Update the per-connection state to track which rooms we have sent
+ # all the receipts for.
+ new_connection_state.receipts.record_sent_rooms(previously_rooms.keys())
+ new_connection_state.receipts.record_sent_rooms(initial_rooms)
+
+ if from_token:
+ # Now find the set of rooms that may have receipts that we're not sending
+ # down. We only need to check rooms that we have previously returned
+ # receipts for (in `previous_connection_state`) because we only care about
+ # updating `LIVE` rooms to `PREVIOUSLY`. The `PREVIOUSLY` rooms will just
+ # stay pointing at their previous position so we don't need to waste time
+ # checking those and since we default to `NEVER`, rooms that were `NEVER`
+ # sent before don't need to be recorded as we'll handle them correctly when
+ # they come into range for the first time.
+ rooms_no_receipts = [
+ room_id
+ for room_id, room_status in previous_connection_state.receipts._statuses.items()
+ if room_status.status == HaveSentRoomFlag.LIVE
+ and room_id not in relevant_room_ids
+ ]
+ changed_rooms = await self.store.get_rooms_with_receipts_between(
+ rooms_no_receipts,
+ from_key=from_token.stream_token.receipt_key,
+ to_key=to_token.receipt_key,
+ )
+ new_connection_state.receipts.record_unsent_rooms(
+ changed_rooms, from_token.stream_token.receipt_key
+ )
+
+ return SlidingSyncResult.Extensions.ReceiptsExtension(
+ room_id_to_receipt_map=room_id_to_receipt_map,
+ )
+
+ async def get_typing_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
+ typing_request: SlidingSyncConfig.Extensions.TypingExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.TypingExtension]:
+ """Handle Typing Notification extension (MSC3961)
+
+ Args:
+ sync_config: Sync configuration
+ actual_lists: Sliding window API. A map of list key to list results in the
+ Sliding Sync response.
+ actual_room_ids: The actual room IDs in the the Sliding Sync response.
+ actual_room_response_map: A map of room ID to room results in the the
+ Sliding Sync response.
+ account_data_request: The account_data extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ # Skip if the extension is not enabled
+ if not typing_request.enabled:
+ return None
+
+ relevant_room_ids = self.find_relevant_room_ids_for_extension(
+ requested_lists=typing_request.lists,
+ requested_room_ids=typing_request.rooms,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ )
+
+ room_id_to_typing_map: Dict[str, JsonMapping] = {}
+ if len(relevant_room_ids) > 0:
+ # Note: We don't need to take connection tracking into account for typing
+ # notifications because they'll get anything still relevant and hasn't timed
+ # out when the room comes into range. We consider the gap where the room
+ # fell out of range, as long enough for any typing notifications to have
+ # timed out (it's not worth the 30 seconds of data we may have missed).
+ typing_source = self.event_sources.sources.typing
+ typing_notifications, _ = await typing_source.get_new_events(
+ user=sync_config.user,
+ from_key=(from_token.stream_token.typing_key if from_token else 0),
+ to_key=to_token.typing_key,
+ # This is a dummy value and isn't used in the function
+ limit=0,
+ room_ids=relevant_room_ids,
+ is_guest=False,
+ )
+
+ for typing_notification in typing_notifications:
+ # These fields should exist for every typing notification
+ room_id = typing_notification["room_id"]
+ type = typing_notification["type"]
+ content = typing_notification["content"]
+
+ room_id_to_typing_map[room_id] = {"type": type, "content": content}
+
+ return SlidingSyncResult.Extensions.TypingExtension(
+ room_id_to_typing_map=room_id_to_typing_map,
+ )
diff --git a/synapse/handlers/sliding_sync/room_lists.py b/synapse/handlers/sliding_sync/room_lists.py
new file mode 100644
index 0000000000..13e69f18a0
--- /dev/null
+++ b/synapse/handlers/sliding_sync/room_lists.py
@@ -0,0 +1,2304 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+
+import enum
+import logging
+from itertools import chain
+from typing import (
+ TYPE_CHECKING,
+ AbstractSet,
+ Dict,
+ List,
+ Literal,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+ cast,
+)
+
+import attr
+from immutabledict import immutabledict
+from typing_extensions import assert_never
+
+from synapse.api.constants import (
+ AccountDataTypes,
+ EventContentFields,
+ EventTypes,
+ Membership,
+)
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import StrippedStateEvent
+from synapse.events.utils import parse_stripped_state_event
+from synapse.logging.opentracing import start_active_span, trace
+from synapse.storage.databases.main.state import (
+ ROOM_UNKNOWN_SENTINEL,
+ Sentinel as StateSentinel,
+)
+from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
+from synapse.storage.invite_rule import InviteRule
+from synapse.storage.roommember import (
+ RoomsForUser,
+ RoomsForUserSlidingSync,
+ RoomsForUserStateReset,
+)
+from synapse.types import (
+ MutableStateMap,
+ RoomStreamToken,
+ StateMap,
+ StrCollection,
+ StreamKeyType,
+ StreamToken,
+ UserID,
+)
+from synapse.types.handlers.sliding_sync import (
+ HaveSentRoomFlag,
+ OperationType,
+ PerConnectionState,
+ RoomSyncConfig,
+ SlidingSyncConfig,
+ SlidingSyncResult,
+)
+from synapse.types.state import StateFilter
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class Sentinel(enum.Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup and subsequent type narrowing.
+ UNSET_SENTINEL = object()
+
+
+# Helper definition for the types that we might return. We do this to avoid
+# copying data between types (which can be expensive for many rooms).
+RoomsForUserType = Union[RoomsForUserStateReset, RoomsForUser, RoomsForUserSlidingSync]
+
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class SlidingSyncInterestedRooms:
+ """The set of rooms and metadata a client is interested in based on their
+ sliding sync request.
+
+ Returned by `compute_interested_rooms`.
+
+ Attributes:
+ lists: A mapping from list name to the list result for the response
+ relevant_room_map: A map from rooms that match the sync request to
+ their room sync config.
+ relevant_rooms_to_send_map: Subset of `relevant_room_map` that
+ includes the rooms that *may* have relevant updates. Rooms not
+ in this map will definitely not have room updates (though
+ extensions may have updates in these rooms).
+ newly_joined_rooms: The set of rooms that were joined in the token range
+ and the user is still joined to at the end of this range.
+ newly_left_rooms: The set of rooms that we left in the token range
+ and are still "leave" at the end of this range.
+ dm_room_ids: The set of rooms the user consider as direct-message (DM) rooms
+ """
+
+ lists: Mapping[str, SlidingSyncResult.SlidingWindowList]
+ relevant_room_map: Mapping[str, RoomSyncConfig]
+ relevant_rooms_to_send_map: Mapping[str, RoomSyncConfig]
+ all_rooms: Set[str]
+ room_membership_for_user_map: Mapping[str, RoomsForUserType]
+
+ newly_joined_rooms: AbstractSet[str]
+ newly_left_rooms: AbstractSet[str]
+ dm_room_ids: AbstractSet[str]
+
+ @staticmethod
+ def empty() -> "SlidingSyncInterestedRooms":
+ return SlidingSyncInterestedRooms(
+ lists={},
+ relevant_room_map={},
+ relevant_rooms_to_send_map={},
+ all_rooms=set(),
+ room_membership_for_user_map={},
+ newly_joined_rooms=set(),
+ newly_left_rooms=set(),
+ dm_room_ids=set(),
+ )
+
+
+def filter_membership_for_sync(
+ *,
+ user_id: str,
+ room_membership_for_user: RoomsForUserType,
+ newly_left: bool,
+) -> bool:
+ """
+ Returns True if the membership event should be included in the sync response,
+ otherwise False.
+
+ Attributes:
+ user_id: The user ID that the membership applies to
+ room_membership_for_user: Membership information for the user in the room
+ """
+
+ membership = room_membership_for_user.membership
+ sender = room_membership_for_user.sender
+
+ # We want to allow everything except rooms the user has left unless `newly_left`
+ # because we want everything that's *still* relevant to the user. We include
+ # `newly_left` rooms because the last event that the user should see is their own
+ # leave event.
+ #
+ # A leave != kick. This logic includes kicks (leave events where the sender is not
+ # the same user).
+ #
+ # When `sender=None`, it means that a state reset happened that removed the user
+ # from the room without a corresponding leave event. We can just remove the rooms
+ # since they are no longer relevant to the user but will still appear if they are
+ # `newly_left`.
+ return (
+ # Anything except leave events
+ membership != Membership.LEAVE
+ # Unless...
+ or newly_left
+ # Allow kicks
+ or (membership == Membership.LEAVE and sender not in (user_id, None))
+ )
+
+
+class SlidingSyncRoomLists:
+ """Handles calculating the room lists from sliding sync requests"""
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
+ self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
+ self.is_mine_id = hs.is_mine_id
+
+ async def compute_interested_rooms(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> SlidingSyncInterestedRooms:
+ """Fetch the set of rooms that match the request"""
+ has_lists = sync_config.lists is not None and len(sync_config.lists) > 0
+ has_room_subscriptions = (
+ sync_config.room_subscriptions is not None
+ and len(sync_config.room_subscriptions) > 0
+ )
+
+ if not has_lists and not has_room_subscriptions:
+ return SlidingSyncInterestedRooms.empty()
+
+ if await self.store.have_finished_sliding_sync_background_jobs():
+ return await self._compute_interested_rooms_new_tables(
+ sync_config=sync_config,
+ previous_connection_state=previous_connection_state,
+ to_token=to_token,
+ from_token=from_token,
+ )
+ else:
+ # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+ # foreground update for
+ # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+ # https://github.com/element-hq/synapse/issues/17623)
+ return await self._compute_interested_rooms_fallback(
+ sync_config=sync_config,
+ previous_connection_state=previous_connection_state,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ @trace
+ async def _compute_interested_rooms_new_tables(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> SlidingSyncInterestedRooms:
+ """Implementation of `compute_interested_rooms` using new sliding sync db tables."""
+ user_id = sync_config.user.to_string()
+
+ # Assemble sliding window lists
+ lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {}
+ # Keep track of the rooms that we can display and need to fetch more info about
+ relevant_room_map: Dict[str, RoomSyncConfig] = {}
+ # The set of room IDs of all rooms that could appear in any list. These
+ # include rooms that are outside the list ranges.
+ all_rooms: Set[str] = set()
+
+ # Note: this won't include rooms the user has left themselves. We add back
+ # `newly_left` rooms below. This is more efficient than fetching all rooms and
+ # then filtering out the old left rooms.
+ room_membership_for_user_map = (
+ await self.store.get_sliding_sync_rooms_for_user_from_membership_snapshots(
+ user_id
+ )
+ )
+ # To play nice with the rewind logic below, we need to go fetch the rooms the
+ # user has left themselves but only if it changed after the `to_token`.
+ #
+ # If a leave happens *after* the token range, we may have still been joined (or
+ # any non-self-leave which is relevant to sync) to the room before so we need to
+ # include it in the list of potentially relevant rooms and apply our rewind
+ # logic (outside of this function) to see if it's actually relevant.
+ #
+ # We do this separately from
+ # `get_sliding_sync_rooms_for_user_from_membership_snapshots` as those results
+ # are cached and the `to_token` isn't very cache friendly (people are constantly
+ # requesting with new tokens) so we separate it out here.
+ self_leave_room_membership_for_user_map = (
+ await self.store.get_sliding_sync_self_leave_rooms_after_to_token(
+ user_id, to_token
+ )
+ )
+ if self_leave_room_membership_for_user_map:
+ # FIXME: It would be nice to avoid this copy but since
+ # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it
+ # can't return a mutable value like a `dict`. We make the copy to get a
+ # mutable dict that we can change. We try to only make a copy when necessary
+ # (if we actually need to change something) as in most cases, the logic
+ # doesn't need to run.
+ room_membership_for_user_map = dict(room_membership_for_user_map)
+ room_membership_for_user_map.update(self_leave_room_membership_for_user_map)
+
+ # Remove invites from ignored users
+ ignored_users = await self.store.ignored_users(user_id)
+ invite_config = await self.store.get_invite_config_for_user(user_id)
+ if ignored_users:
+ # FIXME: It would be nice to avoid this copy but since
+ # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it
+ # can't return a mutable value like a `dict`. We make the copy to get a
+ # mutable dict that we can change. We try to only make a copy when necessary
+ # (if we actually need to change something) as in most cases, the logic
+ # doesn't need to run.
+ room_membership_for_user_map = dict(room_membership_for_user_map)
+ # Make a copy so we don't run into an error: `dictionary changed size during
+ # iteration`, when we remove items
+ for room_id in list(room_membership_for_user_map.keys()):
+ room_for_user_sliding_sync = room_membership_for_user_map[room_id]
+ if (
+ room_for_user_sliding_sync.membership == Membership.INVITE
+ and room_for_user_sliding_sync.sender
+ and (
+ room_for_user_sliding_sync.sender in ignored_users
+ or invite_config.get_invite_rule(
+ room_for_user_sliding_sync.sender
+ )
+ == InviteRule.IGNORE
+ )
+ ):
+ room_membership_for_user_map.pop(room_id, None)
+
+ (
+ newly_joined_room_ids,
+ newly_left_room_map,
+ ) = await self._get_newly_joined_and_left_rooms(
+ user_id, from_token=from_token, to_token=to_token
+ )
+
+ changes = await self._get_rewind_changes_to_current_membership_to_token(
+ sync_config.user, room_membership_for_user_map, to_token=to_token
+ )
+ if changes:
+ # FIXME: It would be nice to avoid this copy but since
+ # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it
+ # can't return a mutable value like a `dict`. We make the copy to get a
+ # mutable dict that we can change. We try to only make a copy when necessary
+ # (if we actually need to change something) as in most cases, the logic
+ # doesn't need to run.
+ room_membership_for_user_map = dict(room_membership_for_user_map)
+ for room_id, change in changes.items():
+ if change is None:
+ # Remove rooms that the user joined after the `to_token`
+ room_membership_for_user_map.pop(room_id, None)
+ continue
+
+ existing_room = room_membership_for_user_map.get(room_id)
+ if existing_room is not None:
+ # Update room membership events to the point in time of the `to_token`
+ room_for_user = RoomsForUserSlidingSync(
+ room_id=room_id,
+ sender=change.sender,
+ membership=change.membership,
+ event_id=change.event_id,
+ event_pos=change.event_pos,
+ room_version_id=change.room_version_id,
+ # We keep the state of the room though
+ has_known_state=existing_room.has_known_state,
+ room_type=existing_room.room_type,
+ is_encrypted=existing_room.is_encrypted,
+ )
+ if filter_membership_for_sync(
+ user_id=user_id,
+ room_membership_for_user=room_for_user,
+ newly_left=room_id in newly_left_room_map,
+ ):
+ room_membership_for_user_map[room_id] = room_for_user
+ else:
+ room_membership_for_user_map.pop(room_id, None)
+
+ # Add back `newly_left` rooms (rooms left in the from -> to token range).
+ #
+ # We do this because `get_sliding_sync_rooms_for_user_from_membership_snapshots(...)` doesn't include
+ # rooms that the user left themselves as it's more efficient to add them back
+ # here than to fetch all rooms and then filter out the old left rooms. The user
+ # only leaves a room once in a blue moon so this barely needs to run.
+ #
+ missing_newly_left_rooms = (
+ newly_left_room_map.keys() - room_membership_for_user_map.keys()
+ )
+ if missing_newly_left_rooms:
+ # FIXME: It would be nice to avoid this copy but since
+ # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it
+ # can't return a mutable value like a `dict`. We make the copy to get a
+ # mutable dict that we can change. We try to only make a copy when necessary
+ # (if we actually need to change something) as in most cases, the logic
+ # doesn't need to run.
+ room_membership_for_user_map = dict(room_membership_for_user_map)
+ for room_id in missing_newly_left_rooms:
+ newly_left_room_for_user = newly_left_room_map[room_id]
+ # This should be a given
+ assert newly_left_room_for_user.membership == Membership.LEAVE
+
+ # Add back `newly_left` rooms
+ #
+ # Check for membership and state in the Sliding Sync tables as it's just
+ # another membership
+ newly_left_room_for_user_sliding_sync = (
+ await self.store.get_sliding_sync_room_for_user(user_id, room_id)
+ )
+ # If the membership exists, it's just a normal user left the room on
+ # their own
+ if newly_left_room_for_user_sliding_sync is not None:
+ if filter_membership_for_sync(
+ user_id=user_id,
+ room_membership_for_user=newly_left_room_for_user_sliding_sync,
+ newly_left=room_id in newly_left_room_map,
+ ):
+ room_membership_for_user_map[room_id] = (
+ newly_left_room_for_user_sliding_sync
+ )
+ else:
+ room_membership_for_user_map.pop(room_id, None)
+
+ change = changes.get(room_id)
+ if change is not None:
+ # Update room membership events to the point in time of the `to_token`
+ room_for_user = RoomsForUserSlidingSync(
+ room_id=room_id,
+ sender=change.sender,
+ membership=change.membership,
+ event_id=change.event_id,
+ event_pos=change.event_pos,
+ room_version_id=change.room_version_id,
+ # We keep the state of the room though
+ has_known_state=newly_left_room_for_user_sliding_sync.has_known_state,
+ room_type=newly_left_room_for_user_sliding_sync.room_type,
+ is_encrypted=newly_left_room_for_user_sliding_sync.is_encrypted,
+ )
+ if filter_membership_for_sync(
+ user_id=user_id,
+ room_membership_for_user=room_for_user,
+ newly_left=room_id in newly_left_room_map,
+ ):
+ room_membership_for_user_map[room_id] = room_for_user
+ else:
+ room_membership_for_user_map.pop(room_id, None)
+
+ # If we are `newly_left` from the room but can't find any membership,
+ # then we have been "state reset" out of the room
+ else:
+ # Get the state at the time. We can't read from the Sliding Sync
+ # tables because the user has no membership in the room according to
+ # the state (thanks to the state reset).
+ #
+ # Note: `room_type` never changes, so we can just get current room
+ # type
+ room_type = await self.store.get_room_type(room_id)
+ has_known_state = room_type is not ROOM_UNKNOWN_SENTINEL
+ if isinstance(room_type, StateSentinel):
+ room_type = None
+
+ # Get the encryption status at the time of the token
+ is_encrypted = await self.get_is_encrypted_for_room_at_token(
+ room_id,
+ newly_left_room_for_user.event_pos.to_room_stream_token(),
+ )
+
+ room_for_user = RoomsForUserSlidingSync(
+ room_id=room_id,
+ sender=newly_left_room_for_user.sender,
+ membership=newly_left_room_for_user.membership,
+ event_id=newly_left_room_for_user.event_id,
+ event_pos=newly_left_room_for_user.event_pos,
+ room_version_id=newly_left_room_for_user.room_version_id,
+ has_known_state=has_known_state,
+ room_type=room_type,
+ is_encrypted=is_encrypted,
+ )
+ if filter_membership_for_sync(
+ user_id=user_id,
+ room_membership_for_user=room_for_user,
+ newly_left=room_id in newly_left_room_map,
+ ):
+ room_membership_for_user_map[room_id] = room_for_user
+ else:
+ room_membership_for_user_map.pop(room_id, None)
+
+ dm_room_ids = await self._get_dm_rooms_for_user(user_id)
+
+ if sync_config.lists:
+ sync_room_map = room_membership_for_user_map
+ with start_active_span("assemble_sliding_window_lists"):
+ for list_key, list_config in sync_config.lists.items():
+ # Apply filters
+ filtered_sync_room_map = sync_room_map
+ if list_config.filters is not None:
+ filtered_sync_room_map = await self.filter_rooms_using_tables(
+ user_id,
+ sync_room_map,
+ previous_connection_state,
+ list_config.filters,
+ to_token,
+ dm_room_ids,
+ )
+
+ # Find which rooms are partially stated and may need to be filtered out
+ # depending on the `required_state` requested (see below).
+ partial_state_rooms = await self.store.get_partial_rooms()
+
+ # Since creating the `RoomSyncConfig` takes some work, let's just do it
+ # once.
+ room_sync_config = RoomSyncConfig.from_room_config(list_config)
+
+ # Exclude partially-stated rooms if we must wait for the room to be
+ # fully-stated
+ if room_sync_config.must_await_full_state(self.is_mine_id):
+ filtered_sync_room_map = {
+ room_id: room
+ for room_id, room in filtered_sync_room_map.items()
+ if room_id not in partial_state_rooms
+ }
+
+ all_rooms.update(filtered_sync_room_map)
+
+ ops: List[SlidingSyncResult.SlidingWindowList.Operation] = []
+
+ if list_config.ranges:
+ # Optimization: If we are asking for the full range, we don't
+ # need to sort the list.
+ if (
+ # We're looking for a single range that covers the entire list
+ len(list_config.ranges) == 1
+ # Range starts at 0
+ and list_config.ranges[0][0] == 0
+ # And the range extends to the end of the list or more. Each
+ # side is inclusive.
+ and list_config.ranges[0][1]
+ >= len(filtered_sync_room_map) - 1
+ ):
+ sorted_room_info: List[RoomsForUserType] = list(
+ filtered_sync_room_map.values()
+ )
+ else:
+ # Sort the list
+ sorted_room_info = await self.sort_rooms(
+ # Cast is safe because RoomsForUserSlidingSync is part
+ # of the `RoomsForUserType` union. Why can't it detect this?
+ cast(
+ Dict[str, RoomsForUserType], filtered_sync_room_map
+ ),
+ to_token,
+ # We only need to sort the rooms up to the end
+ # of the largest range. Both sides of range are
+ # inclusive so we `+ 1`.
+ limit=max(range[1] + 1 for range in list_config.ranges),
+ )
+
+ for range in list_config.ranges:
+ room_ids_in_list: List[str] = []
+
+ # We're going to loop through the sorted list of rooms starting
+ # at the range start index and keep adding rooms until we fill
+ # up the range or run out of rooms.
+ #
+ # Both sides of range are inclusive so we `+ 1`
+ max_num_rooms = range[1] - range[0] + 1
+ for room_membership in sorted_room_info[range[0] :]:
+ room_id = room_membership.room_id
+
+ if len(room_ids_in_list) >= max_num_rooms:
+ break
+
+ # Take the superset of the `RoomSyncConfig` for each room.
+ #
+ # Update our `relevant_room_map` with the room we're going
+ # to display and need to fetch more info about.
+ existing_room_sync_config = relevant_room_map.get(
+ room_id
+ )
+ if existing_room_sync_config is not None:
+ room_sync_config = existing_room_sync_config.combine_room_sync_config(
+ room_sync_config
+ )
+
+ relevant_room_map[room_id] = room_sync_config
+
+ room_ids_in_list.append(room_id)
+
+ ops.append(
+ SlidingSyncResult.SlidingWindowList.Operation(
+ op=OperationType.SYNC,
+ range=range,
+ room_ids=room_ids_in_list,
+ )
+ )
+
+ lists[list_key] = SlidingSyncResult.SlidingWindowList(
+ count=len(filtered_sync_room_map),
+ ops=ops,
+ )
+
+ if sync_config.room_subscriptions:
+ with start_active_span("assemble_room_subscriptions"):
+ # FIXME: It would be nice to avoid this copy but since
+ # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it
+ # can't return a mutable value like a `dict`. We make the copy to get a
+ # mutable dict that we can change. We try to only make a copy when necessary
+ # (if we actually need to change something) as in most cases, the logic
+ # doesn't need to run.
+ room_membership_for_user_map = dict(room_membership_for_user_map)
+
+ # Find which rooms are partially stated and may need to be filtered out
+ # depending on the `required_state` requested (see below).
+ partial_state_rooms = await self.store.get_partial_rooms()
+
+ # Fetch any rooms that we have not already fetched from the database.
+ subscription_sliding_sync_rooms = (
+ await self.store.get_sliding_sync_room_for_user_batch(
+ user_id,
+ sync_config.room_subscriptions.keys()
+ - room_membership_for_user_map.keys(),
+ )
+ )
+ room_membership_for_user_map.update(subscription_sliding_sync_rooms)
+
+ for (
+ room_id,
+ room_subscription,
+ ) in sync_config.room_subscriptions.items():
+ # Check if we have a membership for the room, but didn't pull it out
+ # above. This could be e.g. a leave that we don't pull out by
+ # default.
+ current_room_entry = room_membership_for_user_map.get(room_id)
+ if not current_room_entry:
+ # TODO: Handle rooms the user isn't in.
+ continue
+
+ all_rooms.add(room_id)
+
+ # Take the superset of the `RoomSyncConfig` for each room.
+ room_sync_config = RoomSyncConfig.from_room_config(
+ room_subscription
+ )
+
+ # Exclude partially-stated rooms if we must wait for the room to be
+ # fully-stated
+ if room_sync_config.must_await_full_state(self.is_mine_id):
+ if room_id in partial_state_rooms:
+ continue
+
+ # Update our `relevant_room_map` with the room we're going to display
+ # and need to fetch more info about.
+ existing_room_sync_config = relevant_room_map.get(room_id)
+ if existing_room_sync_config is not None:
+ room_sync_config = (
+ existing_room_sync_config.combine_room_sync_config(
+ room_sync_config
+ )
+ )
+
+ relevant_room_map[room_id] = room_sync_config
+
+ # Filtered subset of `relevant_room_map` for rooms that may have updates
+ # (in the event stream)
+ relevant_rooms_to_send_map = await self._filter_relevant_rooms_to_send(
+ previous_connection_state, from_token, relevant_room_map
+ )
+
+ return SlidingSyncInterestedRooms(
+ lists=lists,
+ relevant_room_map=relevant_room_map,
+ relevant_rooms_to_send_map=relevant_rooms_to_send_map,
+ all_rooms=all_rooms,
+ room_membership_for_user_map=room_membership_for_user_map,
+ newly_joined_rooms=newly_joined_room_ids,
+ newly_left_rooms=set(newly_left_room_map),
+ dm_room_ids=dm_room_ids,
+ )
+
+ async def _compute_interested_rooms_fallback(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> SlidingSyncInterestedRooms:
+ """Fallback code when the database background updates haven't completed yet."""
+
+ (
+ room_membership_for_user_map,
+ newly_joined_room_ids,
+ newly_left_room_ids,
+ ) = await self.get_room_membership_for_user_at_to_token(
+ sync_config.user, to_token, from_token
+ )
+
+ dm_room_ids = await self._get_dm_rooms_for_user(sync_config.user.to_string())
+
+ # Assemble sliding window lists
+ lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {}
+ # Keep track of the rooms that we can display and need to fetch more info about
+ relevant_room_map: Dict[str, RoomSyncConfig] = {}
+ # The set of room IDs of all rooms that could appear in any list. These
+ # include rooms that are outside the list ranges.
+ all_rooms: Set[str] = set()
+
+ if sync_config.lists:
+ with start_active_span("assemble_sliding_window_lists"):
+ sync_room_map = await self.filter_rooms_relevant_for_sync(
+ user=sync_config.user,
+ room_membership_for_user_map=room_membership_for_user_map,
+ newly_left_room_ids=newly_left_room_ids,
+ )
+
+ for list_key, list_config in sync_config.lists.items():
+ # Apply filters
+ filtered_sync_room_map = sync_room_map
+ if list_config.filters is not None:
+ filtered_sync_room_map = await self.filter_rooms(
+ sync_config.user,
+ sync_room_map,
+ previous_connection_state,
+ list_config.filters,
+ to_token,
+ dm_room_ids,
+ )
+
+ # Find which rooms are partially stated and may need to be filtered out
+ # depending on the `required_state` requested (see below).
+ partial_state_rooms = await self.store.get_partial_rooms()
+
+ # Since creating the `RoomSyncConfig` takes some work, let's just do it
+ # once.
+ room_sync_config = RoomSyncConfig.from_room_config(list_config)
+
+ # Exclude partially-stated rooms if we must wait for the room to be
+ # fully-stated
+ if room_sync_config.must_await_full_state(self.is_mine_id):
+ filtered_sync_room_map = {
+ room_id: room
+ for room_id, room in filtered_sync_room_map.items()
+ if room_id not in partial_state_rooms
+ }
+
+ all_rooms.update(filtered_sync_room_map)
+
+ # Sort the list
+ sorted_room_info = await self.sort_rooms(
+ filtered_sync_room_map, to_token
+ )
+
+ ops: List[SlidingSyncResult.SlidingWindowList.Operation] = []
+ if list_config.ranges:
+ for range in list_config.ranges:
+ room_ids_in_list: List[str] = []
+
+ # We're going to loop through the sorted list of rooms starting
+ # at the range start index and keep adding rooms until we fill
+ # up the range or run out of rooms.
+ #
+ # Both sides of range are inclusive so we `+ 1`
+ max_num_rooms = range[1] - range[0] + 1
+ for room_membership in sorted_room_info[range[0] :]:
+ room_id = room_membership.room_id
+
+ if len(room_ids_in_list) >= max_num_rooms:
+ break
+
+ # Take the superset of the `RoomSyncConfig` for each room.
+ #
+ # Update our `relevant_room_map` with the room we're going
+ # to display and need to fetch more info about.
+ existing_room_sync_config = relevant_room_map.get(
+ room_id
+ )
+ if existing_room_sync_config is not None:
+ room_sync_config = existing_room_sync_config.combine_room_sync_config(
+ room_sync_config
+ )
+
+ relevant_room_map[room_id] = room_sync_config
+
+ room_ids_in_list.append(room_id)
+
+ ops.append(
+ SlidingSyncResult.SlidingWindowList.Operation(
+ op=OperationType.SYNC,
+ range=range,
+ room_ids=room_ids_in_list,
+ )
+ )
+
+ lists[list_key] = SlidingSyncResult.SlidingWindowList(
+ count=len(sorted_room_info),
+ ops=ops,
+ )
+
+ if sync_config.room_subscriptions:
+ with start_active_span("assemble_room_subscriptions"):
+ # Find which rooms are partially stated and may need to be filtered out
+ # depending on the `required_state` requested (see below).
+ partial_state_rooms = await self.store.get_partial_rooms()
+
+ for (
+ room_id,
+ room_subscription,
+ ) in sync_config.room_subscriptions.items():
+ room_membership_for_user_at_to_token = (
+ await self.check_room_subscription_allowed_for_user(
+ room_id=room_id,
+ room_membership_for_user_map=room_membership_for_user_map,
+ to_token=to_token,
+ )
+ )
+
+ # Skip this room if the user isn't allowed to see it
+ if not room_membership_for_user_at_to_token:
+ continue
+
+ all_rooms.add(room_id)
+
+ room_membership_for_user_map[room_id] = (
+ room_membership_for_user_at_to_token
+ )
+
+ # Take the superset of the `RoomSyncConfig` for each room.
+ room_sync_config = RoomSyncConfig.from_room_config(
+ room_subscription
+ )
+
+ # Exclude partially-stated rooms if we must wait for the room to be
+ # fully-stated
+ if room_sync_config.must_await_full_state(self.is_mine_id):
+ if room_id in partial_state_rooms:
+ continue
+
+ all_rooms.add(room_id)
+
+ # Update our `relevant_room_map` with the room we're going to display
+ # and need to fetch more info about.
+ existing_room_sync_config = relevant_room_map.get(room_id)
+ if existing_room_sync_config is not None:
+ room_sync_config = (
+ existing_room_sync_config.combine_room_sync_config(
+ room_sync_config
+ )
+ )
+
+ relevant_room_map[room_id] = room_sync_config
+
+ # Filtered subset of `relevant_room_map` for rooms that may have updates
+ # (in the event stream)
+ relevant_rooms_to_send_map = await self._filter_relevant_rooms_to_send(
+ previous_connection_state, from_token, relevant_room_map
+ )
+
+ return SlidingSyncInterestedRooms(
+ lists=lists,
+ relevant_room_map=relevant_room_map,
+ relevant_rooms_to_send_map=relevant_rooms_to_send_map,
+ all_rooms=all_rooms,
+ room_membership_for_user_map=room_membership_for_user_map,
+ newly_joined_rooms=newly_joined_room_ids,
+ newly_left_rooms=newly_left_room_ids,
+ dm_room_ids=dm_room_ids,
+ )
+
+ async def _filter_relevant_rooms_to_send(
+ self,
+ previous_connection_state: PerConnectionState,
+ from_token: Optional[StreamToken],
+ relevant_room_map: Dict[str, RoomSyncConfig],
+ ) -> Dict[str, RoomSyncConfig]:
+ """Filters the `relevant_room_map` down to those rooms that may have
+ updates we need to fetch and return."""
+
+ # Filtered subset of `relevant_room_map` for rooms that may have updates
+ # (in the event stream)
+ relevant_rooms_to_send_map: Dict[str, RoomSyncConfig] = relevant_room_map
+ if relevant_room_map:
+ with start_active_span("filter_relevant_rooms_to_send"):
+ if from_token:
+ rooms_should_send = set()
+
+ # First we check if there are rooms that match a list/room
+ # subscription and have updates we need to send (i.e. either because
+ # we haven't sent the room down, or we have but there are missing
+ # updates).
+ for room_id, room_config in relevant_room_map.items():
+ prev_room_sync_config = (
+ previous_connection_state.room_configs.get(room_id)
+ )
+ if prev_room_sync_config is not None:
+ # Always include rooms whose timeline limit has increased.
+ # (see the "XXX: Odd behavior" described below)
+ if (
+ prev_room_sync_config.timeline_limit
+ < room_config.timeline_limit
+ ):
+ rooms_should_send.add(room_id)
+ continue
+
+ status = previous_connection_state.rooms.have_sent_room(room_id)
+ if (
+ # The room was never sent down before so the client needs to know
+ # about it regardless of any updates.
+ status.status == HaveSentRoomFlag.NEVER
+ # `PREVIOUSLY` literally means the "room was sent down before *AND*
+ # there are updates we haven't sent down" so we already know this
+ # room has updates.
+ or status.status == HaveSentRoomFlag.PREVIOUSLY
+ ):
+ rooms_should_send.add(room_id)
+ elif status.status == HaveSentRoomFlag.LIVE:
+ # We know that we've sent all updates up until `from_token`,
+ # so we just need to check if there have been updates since
+ # then.
+ pass
+ else:
+ assert_never(status.status)
+
+ # We only need to check for new events since any state changes
+ # will also come down as new events.
+ rooms_that_have_updates = (
+ self.store.get_rooms_that_might_have_updates(
+ relevant_room_map.keys(), from_token.room_key
+ )
+ )
+ rooms_should_send.update(rooms_that_have_updates)
+ relevant_rooms_to_send_map = {
+ room_id: room_sync_config
+ for room_id, room_sync_config in relevant_room_map.items()
+ if room_id in rooms_should_send
+ }
+
+ return relevant_rooms_to_send_map
+
+ @trace
+ async def _get_rewind_changes_to_current_membership_to_token(
+ self,
+ user: UserID,
+ rooms_for_user: Mapping[str, RoomsForUserType],
+ to_token: StreamToken,
+ ) -> Mapping[str, Optional[RoomsForUser]]:
+ """
+ Takes the current set of rooms for a user (retrieved after the given
+ token), and returns the changes needed to "rewind" it to match the set of
+ memberships *at that token* (<= `to_token`).
+
+ Args:
+ user: User to fetch rooms for
+ rooms_for_user: The set of rooms for the user after the `to_token`.
+ to_token: The token to rewind to
+
+ Returns:
+ The changes to apply to rewind the the current memberships.
+ """
+ # If the user has never joined any rooms before, we can just return an empty list
+ if not rooms_for_user:
+ return {}
+
+ user_id = user.to_string()
+
+ # Get the `RoomStreamToken` that represents the spot we queried up to when we got
+ # our membership snapshot from `get_rooms_for_local_user_where_membership_is()`.
+ #
+ # First, we need to get the max stream_ordering of each event persister instance
+ # that we queried events from.
+ instance_to_max_stream_ordering_map: Dict[str, int] = {}
+ for room_for_user in rooms_for_user.values():
+ instance_name = room_for_user.event_pos.instance_name
+ stream_ordering = room_for_user.event_pos.stream
+
+ current_instance_max_stream_ordering = (
+ instance_to_max_stream_ordering_map.get(instance_name)
+ )
+ if (
+ current_instance_max_stream_ordering is None
+ or stream_ordering > current_instance_max_stream_ordering
+ ):
+ instance_to_max_stream_ordering_map[instance_name] = stream_ordering
+
+ # Then assemble the `RoomStreamToken`
+ min_stream_pos = min(instance_to_max_stream_ordering_map.values())
+ membership_snapshot_token = RoomStreamToken(
+ # Minimum position in the `instance_map`
+ stream=min_stream_pos,
+ instance_map=immutabledict(
+ {
+ instance_name: stream_pos
+ for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
+ if stream_pos > min_stream_pos
+ }
+ ),
+ )
+
+ # Since we fetched the users room list at some point in time after the
+ # tokens, we need to revert/rewind some membership changes to match the point in
+ # time of the `to_token`. In particular, we need to make these fixups:
+ #
+ # - a) Remove rooms that the user joined after the `to_token`
+ # - b) Update room membership events to the point in time of the `to_token`
+
+ # Fetch membership changes that fall in the range from `to_token` up to
+ # `membership_snapshot_token`
+ #
+ # If our `to_token` is already the same or ahead of the latest room membership
+ # for the user, we don't need to do any "2)" fix-ups and can just straight-up
+ # use the room list from the snapshot as a base (nothing has changed)
+ current_state_delta_membership_changes_after_to_token = []
+ if not membership_snapshot_token.is_before_or_eq(to_token.room_key):
+ current_state_delta_membership_changes_after_to_token = (
+ await self.store.get_current_state_delta_membership_changes_for_user(
+ user_id,
+ from_key=to_token.room_key,
+ to_key=membership_snapshot_token,
+ excluded_room_ids=self.rooms_to_exclude_globally,
+ )
+ )
+
+ if not current_state_delta_membership_changes_after_to_token:
+ # There have been no membership changes, so we can early return.
+ return {}
+
+ # Otherwise we're about to make changes to `rooms_for_user`, so we turn
+ # it into a mutable dict.
+ changes: Dict[str, Optional[RoomsForUser]] = {}
+
+ # Assemble a list of the first membership event after the `to_token` so we can
+ # step backward to the previous membership that would apply to the from/to
+ # range.
+ first_membership_change_by_room_id_after_to_token: Dict[
+ str, CurrentStateDeltaMembership
+ ] = {}
+ for membership_change in current_state_delta_membership_changes_after_to_token:
+ # Only set if we haven't already set it
+ first_membership_change_by_room_id_after_to_token.setdefault(
+ membership_change.room_id, membership_change
+ )
+
+ # Since we fetched a snapshot of the users room list at some point in time after
+ # the from/to tokens, we need to revert/rewind some membership changes to match
+ # the point in time of the `to_token`.
+ for (
+ room_id,
+ first_membership_change_after_to_token,
+ ) in first_membership_change_by_room_id_after_to_token.items():
+ # 1a) Remove rooms that the user joined after the `to_token`
+ if first_membership_change_after_to_token.prev_event_id is None:
+ changes[room_id] = None
+ # 1b) 1c) From the first membership event after the `to_token`, step backward to the
+ # previous membership that would apply to the from/to range.
+ else:
+ # We don't expect these fields to be `None` if we have a `prev_event_id`
+ # but we're being defensive since it's possible that the prev event was
+ # culled from the database.
+ if (
+ first_membership_change_after_to_token.prev_event_pos is not None
+ and first_membership_change_after_to_token.prev_membership
+ is not None
+ and first_membership_change_after_to_token.prev_sender is not None
+ ):
+ # We need to know the room version ID, which we normally we
+ # can get from the current membership, but if we don't have
+ # that then we need to query the DB.
+ current_membership = rooms_for_user.get(room_id)
+ if current_membership is not None:
+ room_version_id = current_membership.room_version_id
+ else:
+ room_version_id = await self.store.get_room_version_id(room_id)
+
+ changes[room_id] = RoomsForUser(
+ room_id=room_id,
+ event_id=first_membership_change_after_to_token.prev_event_id,
+ event_pos=first_membership_change_after_to_token.prev_event_pos,
+ membership=first_membership_change_after_to_token.prev_membership,
+ sender=first_membership_change_after_to_token.prev_sender,
+ room_version_id=room_version_id,
+ )
+ else:
+ # If we can't find the previous membership event, we shouldn't
+ # include the room in the sync response since we can't determine the
+ # exact membership state and shouldn't rely on the current snapshot.
+ changes[room_id] = None
+
+ return changes
+
+ @trace
+ async def get_room_membership_for_user_at_to_token(
+ self,
+ user: UserID,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> Tuple[Dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]:
+ """
+ Fetch room IDs that the user has had membership in (the full room list including
+ long-lost left rooms that will be filtered, sorted, and sliced).
+
+ We're looking for rooms where the user has had any sort of membership in the
+ token range (> `from_token` and <= `to_token`)
+
+ In order for bans/kicks to not show up, you need to `/forget` those rooms. This
+ doesn't modify the event itself though and only adds the `forgotten` flag to the
+ `room_memberships` table in Synapse. There isn't a way to tell when a room was
+ forgotten at the moment so we can't factor it into the token range.
+
+ Args:
+ user: User to fetch rooms for
+ to_token: The token to fetch rooms up to.
+ from_token: The point in the stream to sync from.
+
+ Returns:
+ A 3-tuple of:
+ - A dictionary of room IDs that the user has had membership in along with
+ membership information in that room at the time of `to_token`.
+ - Set of newly joined rooms
+ - Set of newly left rooms
+ """
+ user_id = user.to_string()
+
+ # First grab a current snapshot rooms for the user
+ # (also handles forgotten rooms)
+ room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is(
+ user_id=user_id,
+ # We want to fetch any kind of membership (joined and left rooms) in order
+ # to get the `event_pos` of the latest room membership event for the
+ # user.
+ membership_list=Membership.LIST,
+ excluded_rooms=self.rooms_to_exclude_globally,
+ )
+
+ # We filter out unknown room versions before we try and load any
+ # metadata about the room. They shouldn't go down sync anyway, and their
+ # metadata may be in a broken state.
+ room_for_user_list = [
+ room_for_user
+ for room_for_user in room_for_user_list
+ if room_for_user.room_version_id in KNOWN_ROOM_VERSIONS
+ ]
+
+ # Remove invites from ignored users
+ ignored_users = await self.store.ignored_users(user_id)
+ if ignored_users:
+ room_for_user_list = [
+ room_for_user
+ for room_for_user in room_for_user_list
+ if not (
+ room_for_user.membership == Membership.INVITE
+ and room_for_user.sender in ignored_users
+ )
+ ]
+
+ (
+ newly_joined_room_ids,
+ newly_left_room_map,
+ ) = await self._get_newly_joined_and_left_rooms_fallback(
+ user_id, to_token=to_token, from_token=from_token
+ )
+
+ # If the user has never joined any rooms before, we can just return an empty
+ # list. We also have to check the `newly_left_room_map` in case someone was
+ # state reset out of all of the rooms they were in.
+ if not room_for_user_list and not newly_left_room_map:
+ return {}, set(), set()
+
+ # Since we fetched the users room list at some point in time after the
+ # tokens, we need to revert/rewind some membership changes to match the point in
+ # time of the `to_token`.
+ rooms_for_user: Dict[str, RoomsForUserType] = {
+ room.room_id: room for room in room_for_user_list
+ }
+ changes = await self._get_rewind_changes_to_current_membership_to_token(
+ user, rooms_for_user, to_token
+ )
+ for room_id, change_room_for_user in changes.items():
+ if change_room_for_user is None:
+ rooms_for_user.pop(room_id, None)
+ else:
+ rooms_for_user[room_id] = change_room_for_user
+
+ # Ensure we have entries for rooms that the user has been "state reset"
+ # out of. These are rooms appear in the `newly_left_rooms` map but
+ # aren't in the `rooms_for_user` map.
+ for room_id, newly_left_room_for_user in newly_left_room_map.items():
+ # If we already know about the room, it's not a state reset
+ if room_id in rooms_for_user:
+ continue
+
+ # This should be true if it's a state reset
+ assert newly_left_room_for_user.membership is Membership.LEAVE
+ assert newly_left_room_for_user.event_id is None
+ assert newly_left_room_for_user.sender is None
+
+ rooms_for_user[room_id] = newly_left_room_for_user
+
+ return rooms_for_user, newly_joined_room_ids, set(newly_left_room_map)
+
+ @trace
+ async def _get_newly_joined_and_left_rooms(
+ self,
+ user_id: str,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> Tuple[AbstractSet[str], Mapping[str, RoomsForUserStateReset]]:
+ """Fetch the sets of rooms that the user newly joined or left in the
+ given token range.
+
+ Note: there may be rooms in the newly left rooms where the user was
+ "state reset" out of the room, and so that room would not be part of the
+ "current memberships" of the user.
+
+ Returns:
+ A 2-tuple of newly joined room IDs and a map of newly_left room
+ IDs to the `RoomsForUserStateReset` entry.
+
+ We're using `RoomsForUserStateReset` but that doesn't necessarily mean the
+ user was state reset of the rooms. It's just that the `event_id`/`sender`
+ are optional and we can't tell the difference between the server leaving the
+ room when the user was the last person participating in the room and left or
+ was state reset out of the room. To actually check for a state reset, you
+ need to check if a membership still exists in the room.
+ """
+
+ newly_joined_room_ids: Set[str] = set()
+ newly_left_room_map: Dict[str, RoomsForUserStateReset] = {}
+
+ if not from_token:
+ return newly_joined_room_ids, newly_left_room_map
+
+ changes = await self.store.get_sliding_sync_membership_changes(
+ user_id,
+ from_key=from_token.room_key,
+ to_key=to_token.room_key,
+ excluded_room_ids=set(self.rooms_to_exclude_globally),
+ )
+
+ for room_id, entry in changes.items():
+ if entry.membership == Membership.JOIN:
+ newly_joined_room_ids.add(room_id)
+ elif entry.membership == Membership.LEAVE:
+ newly_left_room_map[room_id] = entry
+
+ return newly_joined_room_ids, newly_left_room_map
+
+ @trace
+ async def _get_newly_joined_and_left_rooms_fallback(
+ self,
+ user_id: str,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> Tuple[AbstractSet[str], Mapping[str, RoomsForUserStateReset]]:
+ """Fetch the sets of rooms that the user newly joined or left in the
+ given token range.
+
+ Note: there may be rooms in the newly left rooms where the user was
+ "state reset" out of the room, and so that room would not be part of the
+ "current memberships" of the user.
+
+ Returns:
+ A 2-tuple of newly joined room IDs and a map of newly_left room
+ IDs to the `RoomsForUserStateReset` entry.
+
+ We're using `RoomsForUserStateReset` but that doesn't necessarily mean the
+ user was state reset of the rooms. It's just that the `event_id`/`sender`
+ are optional and we can't tell the difference between the server leaving the
+ room when the user was the last person participating in the room and left or
+ was state reset out of the room. To actually check for a state reset, you
+ need to check if a membership still exists in the room.
+ """
+ newly_joined_room_ids: Set[str] = set()
+ newly_left_room_map: Dict[str, RoomsForUserStateReset] = {}
+
+ # We need to figure out the
+ #
+ # - 1) Figure out which rooms are `newly_left` rooms (> `from_token` and <= `to_token`)
+ # - 2) Figure out which rooms are `newly_joined` (> `from_token` and <= `to_token`)
+
+ # 1) Fetch membership changes that fall in the range from `from_token` up to `to_token`
+ current_state_delta_membership_changes_in_from_to_range = []
+ if from_token:
+ current_state_delta_membership_changes_in_from_to_range = (
+ await self.store.get_current_state_delta_membership_changes_for_user(
+ user_id,
+ from_key=from_token.room_key,
+ to_key=to_token.room_key,
+ excluded_room_ids=self.rooms_to_exclude_globally,
+ )
+ )
+
+ # 1) Assemble a list of the last membership events in some given ranges. Someone
+ # could have left and joined multiple times during the given range but we only
+ # care about end-result so we grab the last one.
+ last_membership_change_by_room_id_in_from_to_range: Dict[
+ str, CurrentStateDeltaMembership
+ ] = {}
+ # We also want to assemble a list of the first membership events during the token
+ # range so we can step backward to the previous membership that would apply to
+ # before the token range to see if we have `newly_joined` the room.
+ first_membership_change_by_room_id_in_from_to_range: Dict[
+ str, CurrentStateDeltaMembership
+ ] = {}
+ # Keep track if the room has a non-join event in the token range so we can later
+ # tell if it was a `newly_joined` room. If the last membership event in the
+ # token range is a join and there is also some non-join in the range, we know
+ # they `newly_joined`.
+ has_non_join_event_by_room_id_in_from_to_range: Dict[str, bool] = {}
+ for (
+ membership_change
+ ) in current_state_delta_membership_changes_in_from_to_range:
+ room_id = membership_change.room_id
+
+ last_membership_change_by_room_id_in_from_to_range[room_id] = (
+ membership_change
+ )
+ # Only set if we haven't already set it
+ first_membership_change_by_room_id_in_from_to_range.setdefault(
+ room_id, membership_change
+ )
+
+ if membership_change.membership != Membership.JOIN:
+ has_non_join_event_by_room_id_in_from_to_range[room_id] = True
+
+ # 1) Fixup
+ #
+ # 2) We also want to assemble a list of possibly newly joined rooms. Someone
+ # could have left and joined multiple times during the given range but we only
+ # care about whether they are joined at the end of the token range so we are
+ # working with the last membership even in the token range.
+ possibly_newly_joined_room_ids = set()
+ for (
+ last_membership_change_in_from_to_range
+ ) in last_membership_change_by_room_id_in_from_to_range.values():
+ room_id = last_membership_change_in_from_to_range.room_id
+
+ # 2)
+ if last_membership_change_in_from_to_range.membership == Membership.JOIN:
+ possibly_newly_joined_room_ids.add(room_id)
+
+ # 1) Figure out newly_left rooms (> `from_token` and <= `to_token`).
+ if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
+ # 1) Mark this room as `newly_left`
+ newly_left_room_map[room_id] = RoomsForUserStateReset(
+ room_id=room_id,
+ sender=last_membership_change_in_from_to_range.sender,
+ membership=Membership.LEAVE,
+ event_id=last_membership_change_in_from_to_range.event_id,
+ event_pos=last_membership_change_in_from_to_range.event_pos,
+ room_version_id=await self.store.get_room_version_id(room_id),
+ )
+
+ # 2) Figure out `newly_joined`
+ for room_id in possibly_newly_joined_room_ids:
+ has_non_join_in_from_to_range = (
+ has_non_join_event_by_room_id_in_from_to_range.get(room_id, False)
+ )
+ # If the last membership event in the token range is a join and there is
+ # also some non-join in the range, we know they `newly_joined`.
+ if has_non_join_in_from_to_range:
+ # We found a `newly_joined` room (we left and joined within the token range)
+ newly_joined_room_ids.add(room_id)
+ else:
+ prev_event_id = first_membership_change_by_room_id_in_from_to_range[
+ room_id
+ ].prev_event_id
+ prev_membership = first_membership_change_by_room_id_in_from_to_range[
+ room_id
+ ].prev_membership
+
+ if prev_event_id is None:
+ # We found a `newly_joined` room (we are joining the room for the
+ # first time within the token range)
+ newly_joined_room_ids.add(room_id)
+ # Last resort, we need to step back to the previous membership event
+ # just before the token range to see if we're joined then or not.
+ elif prev_membership != Membership.JOIN:
+ # We found a `newly_joined` room (we left before the token range
+ # and joined within the token range)
+ newly_joined_room_ids.add(room_id)
+
+ return newly_joined_room_ids, newly_left_room_map
+
+ @trace
+ async def _get_dm_rooms_for_user(
+ self,
+ user_id: str,
+ ) -> AbstractSet[str]:
+ """Get the set of DM rooms for the user."""
+
+ # We're using global account data (`m.direct`) instead of checking for
+ # `is_direct` on membership events because that property only appears for
+ # the invitee membership event (doesn't show up for the inviter).
+ #
+ # We're unable to take `to_token` into account for global account data since
+ # we only keep track of the latest account data for the user.
+ dm_map = await self.store.get_global_account_data_by_type_for_user(
+ user_id, AccountDataTypes.DIRECT
+ )
+
+ # Flatten out the map. Account data is set by the client so it needs to be
+ # scrutinized.
+ dm_room_id_set = set()
+ if isinstance(dm_map, dict):
+ for room_ids in dm_map.values():
+ # Account data should be a list of room IDs. Ignore anything else
+ if isinstance(room_ids, list):
+ for room_id in room_ids:
+ if isinstance(room_id, str):
+ dm_room_id_set.add(room_id)
+
+ return dm_room_id_set
+
+ @trace
+ async def filter_rooms_relevant_for_sync(
+ self,
+ user: UserID,
+ room_membership_for_user_map: Dict[str, RoomsForUserType],
+ newly_left_room_ids: AbstractSet[str],
+ ) -> Dict[str, RoomsForUserType]:
+ """
+ Filter room IDs that should/can be listed for this user in the sync response (the
+ full room list that will be further filtered, sorted, and sliced).
+
+ We're looking for rooms where the user has the following state in the token
+ range (> `from_token` and <= `to_token`):
+
+ - `invite`, `join`, `knock`, `ban` membership events
+ - Kicks (`leave` membership events where `sender` is different from the
+ `user_id`/`state_key`)
+ - `newly_left` (rooms that were left during the given token range)
+ - In order for bans/kicks to not show up in sync, you need to `/forget` those
+ rooms. This doesn't modify the event itself though and only adds the
+ `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way
+ to tell when a room was forgotten at the moment so we can't factor it into the
+ from/to range.
+
+ Args:
+ user: User that is syncing
+ room_membership_for_user_map: Room membership for the user
+ newly_left_room_ids: The set of room IDs we have newly left
+
+ Returns:
+ A dictionary of room IDs that should be listed in the sync response along
+ with membership information in that room at the time of `to_token`.
+ """
+ user_id = user.to_string()
+
+ # Filter rooms to only what we're interested to sync with
+ filtered_sync_room_map = {
+ room_id: room_membership_for_user
+ for room_id, room_membership_for_user in room_membership_for_user_map.items()
+ if filter_membership_for_sync(
+ user_id=user_id,
+ room_membership_for_user=room_membership_for_user,
+ newly_left=room_id in newly_left_room_ids,
+ )
+ }
+
+ return filtered_sync_room_map
+
+ async def check_room_subscription_allowed_for_user(
+ self,
+ room_id: str,
+ room_membership_for_user_map: Dict[str, RoomsForUserType],
+ to_token: StreamToken,
+ ) -> Optional[RoomsForUserType]:
+ """
+ Check whether the user is allowed to see the room based on whether they have
+ ever had membership in the room or if the room is `world_readable`.
+
+ Similar to `check_user_in_room_or_world_readable(...)`
+
+ Args:
+ room_id: Room to check
+ room_membership_for_user_map: Room membership for the user at the time of
+ the `to_token` (<= `to_token`).
+ to_token: The token to fetch rooms up to.
+
+ Returns:
+ The room membership for the user if they are allowed to subscribe to the
+ room else `None`.
+ """
+
+ # We can first check if they are already allowed to see the room based
+ # on our previous work to assemble the `room_membership_for_user_map`.
+ #
+ # If they have had any membership in the room over time (up to the `to_token`),
+ # let them subscribe and see what they can.
+ existing_membership_for_user = room_membership_for_user_map.get(room_id)
+ if existing_membership_for_user is not None:
+ return existing_membership_for_user
+
+ # TODO: Handle `world_readable` rooms
+ return None
+
+ # If the room is `world_readable`, it doesn't matter whether they can join,
+ # everyone can see the room.
+ # not_in_room_membership_for_user = _RoomMembershipForUser(
+ # room_id=room_id,
+ # event_id=None,
+ # event_pos=None,
+ # membership=None,
+ # sender=None,
+ # newly_joined=False,
+ # newly_left=False,
+ # is_dm=False,
+ # )
+ # room_state = await self.get_current_state_at(
+ # room_id=room_id,
+ # room_membership_for_user_at_to_token=not_in_room_membership_for_user,
+ # state_filter=StateFilter.from_types(
+ # [(EventTypes.RoomHistoryVisibility, "")]
+ # ),
+ # to_token=to_token,
+ # )
+
+ # visibility_event = room_state.get((EventTypes.RoomHistoryVisibility, ""))
+ # if (
+ # visibility_event is not None
+ # and visibility_event.content.get("history_visibility")
+ # == HistoryVisibility.WORLD_READABLE
+ # ):
+ # return not_in_room_membership_for_user
+
+ # return None
+
+ @trace
+ async def _bulk_get_stripped_state_for_rooms_from_sync_room_map(
+ self,
+ room_ids: StrCollection,
+ sync_room_map: Dict[str, RoomsForUserType],
+ ) -> Dict[str, Optional[StateMap[StrippedStateEvent]]]:
+ """
+ Fetch stripped state for a list of room IDs. Stripped state is only
+ applicable to invite/knock rooms. Other rooms will have `None` as their
+ stripped state.
+
+ For invite rooms, we pull from `unsigned.invite_room_state`.
+ For knock rooms, we pull from `unsigned.knock_room_state`.
+
+ Args:
+ room_ids: Room IDs to fetch stripped state for
+ sync_room_map: Dictionary of room IDs to sort along with membership
+ information in the room at the time of `to_token`.
+
+ Returns:
+ Mapping from room_id to mapping of (type, state_key) to stripped state
+ event.
+ """
+ room_id_to_stripped_state_map: Dict[
+ str, Optional[StateMap[StrippedStateEvent]]
+ ] = {}
+
+ # Fetch what we haven't before
+ room_ids_to_fetch = [
+ room_id
+ for room_id in room_ids
+ if room_id not in room_id_to_stripped_state_map
+ ]
+
+ # Gather a list of event IDs we can grab stripped state from
+ invite_or_knock_event_ids: List[str] = []
+ for room_id in room_ids_to_fetch:
+ if sync_room_map[room_id].membership in (
+ Membership.INVITE,
+ Membership.KNOCK,
+ ):
+ event_id = sync_room_map[room_id].event_id
+ # If this is an invite/knock then there should be an event_id
+ assert event_id is not None
+ invite_or_knock_event_ids.append(event_id)
+ else:
+ room_id_to_stripped_state_map[room_id] = None
+
+ invite_or_knock_events = await self.store.get_events(invite_or_knock_event_ids)
+ for invite_or_knock_event in invite_or_knock_events.values():
+ room_id = invite_or_knock_event.room_id
+ membership = invite_or_knock_event.membership
+
+ raw_stripped_state_events = None
+ if membership == Membership.INVITE:
+ invite_room_state = invite_or_knock_event.unsigned.get(
+ "invite_room_state"
+ )
+ raw_stripped_state_events = invite_room_state
+ elif membership == Membership.KNOCK:
+ knock_room_state = invite_or_knock_event.unsigned.get(
+ "knock_room_state"
+ )
+ raw_stripped_state_events = knock_room_state
+ else:
+ raise AssertionError(
+ f"Unexpected membership {membership} (this is a problem with Synapse itself)"
+ )
+
+ stripped_state_map: Optional[MutableStateMap[StrippedStateEvent]] = None
+ # Scrutinize unsigned things. `raw_stripped_state_events` should be a list
+ # of stripped events
+ if raw_stripped_state_events is not None:
+ stripped_state_map = {}
+ if isinstance(raw_stripped_state_events, list):
+ for raw_stripped_event in raw_stripped_state_events:
+ stripped_state_event = parse_stripped_state_event(
+ raw_stripped_event
+ )
+ if stripped_state_event is not None:
+ stripped_state_map[
+ (
+ stripped_state_event.type,
+ stripped_state_event.state_key,
+ )
+ ] = stripped_state_event
+
+ room_id_to_stripped_state_map[room_id] = stripped_state_map
+
+ return room_id_to_stripped_state_map
+
+ @trace
+ async def _bulk_get_partial_current_state_content_for_rooms(
+ self,
+ content_type: Literal[
+ # `content.type` from `EventTypes.Create``
+ "room_type",
+ # `content.algorithm` from `EventTypes.RoomEncryption`
+ "room_encryption",
+ ],
+ room_ids: Set[str],
+ sync_room_map: Dict[str, RoomsForUserType],
+ to_token: StreamToken,
+ room_id_to_stripped_state_map: Dict[
+ str, Optional[StateMap[StrippedStateEvent]]
+ ],
+ ) -> Mapping[str, Union[Optional[str], StateSentinel]]:
+ """
+ Get the given state event content for a list of rooms. First we check the
+ current state of the room, then fallback to stripped state if available, then
+ historical state.
+
+ Args:
+ content_type: Which content to grab
+ room_ids: Room IDs to fetch the given content field for.
+ sync_room_map: Dictionary of room IDs to sort along with membership
+ information in the room at the time of `to_token`.
+ to_token: We filter based on the state of the room at this token
+ room_id_to_stripped_state_map: This does not need to be filled in before
+ calling this function. Mapping from room_id to mapping of (type, state_key)
+ to stripped state event. Modified in place when we fetch new rooms so we can
+ save work next time this function is called.
+
+ Returns:
+ A mapping from room ID to the state event content if the room has
+ the given state event (event_type, ""), otherwise `None`. Rooms unknown to
+ this server will return `ROOM_UNKNOWN_SENTINEL`.
+ """
+ room_id_to_content: Dict[str, Union[Optional[str], StateSentinel]] = {}
+
+ # As a bulk shortcut, use the current state if the server is particpating in the
+ # room (meaning we have current state). Ideally, for leave/ban rooms, we would
+ # want the state at the time of the membership instead of current state to not
+ # leak anything but we consider the create/encryption stripped state events to
+ # not be a secret given they are often set at the start of the room and they are
+ # normally handed out on invite/knock.
+ #
+ # Be mindful to only use this for non-sensitive details. For example, even
+ # though the room name/avatar/topic are also stripped state, they seem a lot
+ # more senstive to leak the current state value of.
+ #
+ # Since this function is cached, we need to make a mutable copy via
+ # `dict(...)`.
+ event_type = ""
+ event_content_field = ""
+ if content_type == "room_type":
+ event_type = EventTypes.Create
+ event_content_field = EventContentFields.ROOM_TYPE
+ room_id_to_content = dict(await self.store.bulk_get_room_type(room_ids))
+ elif content_type == "room_encryption":
+ event_type = EventTypes.RoomEncryption
+ event_content_field = EventContentFields.ENCRYPTION_ALGORITHM
+ room_id_to_content = dict(
+ await self.store.bulk_get_room_encryption(room_ids)
+ )
+ else:
+ assert_never(content_type)
+
+ room_ids_with_results = [
+ room_id
+ for room_id, content_field in room_id_to_content.items()
+ if content_field is not ROOM_UNKNOWN_SENTINEL
+ ]
+
+ # We might not have current room state for remote invite/knocks if we are
+ # the first person on our server to see the room. The best we can do is look
+ # in the optional stripped state from the invite/knock event.
+ room_ids_without_results = room_ids.difference(
+ chain(
+ room_ids_with_results,
+ [
+ room_id
+ for room_id, stripped_state_map in room_id_to_stripped_state_map.items()
+ if stripped_state_map is not None
+ ],
+ )
+ )
+ room_id_to_stripped_state_map.update(
+ await self._bulk_get_stripped_state_for_rooms_from_sync_room_map(
+ room_ids_without_results, sync_room_map
+ )
+ )
+
+ # Update our `room_id_to_content` map based on the stripped state
+ # (applies to invite/knock rooms)
+ rooms_ids_without_stripped_state: Set[str] = set()
+ for room_id in room_ids_without_results:
+ stripped_state_map = room_id_to_stripped_state_map.get(
+ room_id, Sentinel.UNSET_SENTINEL
+ )
+ assert stripped_state_map is not Sentinel.UNSET_SENTINEL, (
+ f"Stripped state left unset for room {room_id}. "
+ + "Make sure you're calling `_bulk_get_stripped_state_for_rooms_from_sync_room_map(...)` "
+ + "with that room_id. (this is a problem with Synapse itself)"
+ )
+
+ # If there is some stripped state, we assume the remote server passed *all*
+ # of the potential stripped state events for the room.
+ if stripped_state_map is not None:
+ create_stripped_event = stripped_state_map.get((EventTypes.Create, ""))
+ stripped_event = stripped_state_map.get((event_type, ""))
+ # Sanity check that we at-least have the create event
+ if create_stripped_event is not None:
+ if stripped_event is not None:
+ room_id_to_content[room_id] = stripped_event.content.get(
+ event_content_field
+ )
+ else:
+ # Didn't see the state event we're looking for in the stripped
+ # state so we can assume relevant content field is `None`.
+ room_id_to_content[room_id] = None
+ else:
+ rooms_ids_without_stripped_state.add(room_id)
+
+ # Last resort, we might not have current room state for rooms that the
+ # server has left (no one local is in the room) but we can look at the
+ # historical state.
+ #
+ # Update our `room_id_to_content` map based on the state at the time of
+ # the membership event.
+ for room_id in rooms_ids_without_stripped_state:
+ # TODO: It would be nice to look this up in a bulk way (N+1 queries)
+ #
+ # TODO: `get_state_at(...)` doesn't take into account the "current state".
+ room_state = await self.storage_controllers.state.get_state_at(
+ room_id=room_id,
+ stream_position=to_token.copy_and_replace(
+ StreamKeyType.ROOM,
+ sync_room_map[room_id].event_pos.to_room_stream_token(),
+ ),
+ state_filter=StateFilter.from_types(
+ [
+ (EventTypes.Create, ""),
+ (event_type, ""),
+ ]
+ ),
+ # Partially-stated rooms should have all state events except for
+ # remote membership events so we don't need to wait at all because
+ # we only want the create event and some non-member event.
+ await_full_state=False,
+ )
+ # We can use the create event as a canary to tell whether the server has
+ # seen the room before
+ create_event = room_state.get((EventTypes.Create, ""))
+ state_event = room_state.get((event_type, ""))
+
+ if create_event is None:
+ # Skip for unknown rooms
+ continue
+
+ if state_event is not None:
+ room_id_to_content[room_id] = state_event.content.get(
+ event_content_field
+ )
+ else:
+ # Didn't see the state event we're looking for in the stripped
+ # state so we can assume relevant content field is `None`.
+ room_id_to_content[room_id] = None
+
+ return room_id_to_content
+
+ @trace
+ async def filter_rooms(
+ self,
+ user: UserID,
+ sync_room_map: Dict[str, RoomsForUserType],
+ previous_connection_state: PerConnectionState,
+ filters: SlidingSyncConfig.SlidingSyncList.Filters,
+ to_token: StreamToken,
+ dm_room_ids: AbstractSet[str],
+ ) -> Dict[str, RoomsForUserType]:
+ """
+ Filter rooms based on the sync request.
+
+ Args:
+ user: User to filter rooms for
+ sync_room_map: Dictionary of room IDs to sort along with membership
+ information in the room at the time of `to_token`.
+ filters: Filters to apply
+ to_token: We filter based on the state of the room at this token
+ dm_room_ids: Set of room IDs that are DMs for the user
+
+ Returns:
+ A filtered dictionary of room IDs along with membership information in the
+ room at the time of `to_token`.
+ """
+ user_id = user.to_string()
+
+ room_id_to_stripped_state_map: Dict[
+ str, Optional[StateMap[StrippedStateEvent]]
+ ] = {}
+
+ filtered_room_id_set = set(sync_room_map.keys())
+
+ # Filter for Direct-Message (DM) rooms
+ if filters.is_dm is not None:
+ with start_active_span("filters.is_dm"):
+ if filters.is_dm:
+ # Only DM rooms please
+ filtered_room_id_set = {
+ room_id
+ for room_id in filtered_room_id_set
+ if room_id in dm_room_ids
+ }
+ else:
+ # Only non-DM rooms please
+ filtered_room_id_set = {
+ room_id
+ for room_id in filtered_room_id_set
+ if room_id not in dm_room_ids
+ }
+
+ if filters.spaces is not None:
+ with start_active_span("filters.spaces"):
+ raise NotImplementedError()
+
+ # Filter for encrypted rooms
+ if filters.is_encrypted is not None:
+ with start_active_span("filters.is_encrypted"):
+ room_id_to_encryption = (
+ await self._bulk_get_partial_current_state_content_for_rooms(
+ content_type="room_encryption",
+ room_ids=filtered_room_id_set,
+ to_token=to_token,
+ sync_room_map=sync_room_map,
+ room_id_to_stripped_state_map=room_id_to_stripped_state_map,
+ )
+ )
+
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in filtered_room_id_set.copy():
+ encryption = room_id_to_encryption.get(
+ room_id, ROOM_UNKNOWN_SENTINEL
+ )
+
+ # Just remove rooms if we can't determine their encryption status
+ if encryption is ROOM_UNKNOWN_SENTINEL:
+ filtered_room_id_set.remove(room_id)
+ continue
+
+ # If we're looking for encrypted rooms, filter out rooms that are not
+ # encrypted and vice versa
+ is_encrypted = encryption is not None
+ if (filters.is_encrypted and not is_encrypted) or (
+ not filters.is_encrypted and is_encrypted
+ ):
+ filtered_room_id_set.remove(room_id)
+
+ # Filter for rooms that the user has been invited to
+ if filters.is_invite is not None:
+ with start_active_span("filters.is_invite"):
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in filtered_room_id_set.copy():
+ room_for_user = sync_room_map[room_id]
+ # If we're looking for invite rooms, filter out rooms that the user is
+ # not invited to and vice versa
+ if (
+ filters.is_invite
+ and room_for_user.membership != Membership.INVITE
+ ) or (
+ not filters.is_invite
+ and room_for_user.membership == Membership.INVITE
+ ):
+ filtered_room_id_set.remove(room_id)
+
+ # Filter by room type (space vs room, etc). A room must match one of the types
+ # provided in the list. `None` is a valid type for rooms which do not have a
+ # room type.
+ if filters.room_types is not None or filters.not_room_types is not None:
+ with start_active_span("filters.room_types"):
+ room_id_to_type = (
+ await self._bulk_get_partial_current_state_content_for_rooms(
+ content_type="room_type",
+ room_ids=filtered_room_id_set,
+ to_token=to_token,
+ sync_room_map=sync_room_map,
+ room_id_to_stripped_state_map=room_id_to_stripped_state_map,
+ )
+ )
+
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in filtered_room_id_set.copy():
+ room_type = room_id_to_type.get(room_id, ROOM_UNKNOWN_SENTINEL)
+
+ # Just remove rooms if we can't determine their type
+ if room_type is ROOM_UNKNOWN_SENTINEL:
+ filtered_room_id_set.remove(room_id)
+ continue
+
+ if (
+ filters.room_types is not None
+ and room_type not in filters.room_types
+ ):
+ filtered_room_id_set.remove(room_id)
+ continue
+
+ if (
+ filters.not_room_types is not None
+ and room_type in filters.not_room_types
+ ):
+ filtered_room_id_set.remove(room_id)
+ continue
+
+ if filters.room_name_like is not None:
+ with start_active_span("filters.room_name_like"):
+ # TODO: The room name is a bit more sensitive to leak than the
+ # create/encryption event. Maybe we should consider a better way to fetch
+ # historical state before implementing this.
+ #
+ # room_id_to_create_content = await self._bulk_get_partial_current_state_content_for_rooms(
+ # content_type="room_name",
+ # room_ids=filtered_room_id_set,
+ # to_token=to_token,
+ # sync_room_map=sync_room_map,
+ # room_id_to_stripped_state_map=room_id_to_stripped_state_map,
+ # )
+ raise NotImplementedError()
+
+ # Filter by room tags according to the users account data
+ if filters.tags is not None or filters.not_tags is not None:
+ with start_active_span("filters.tags"):
+ # Fetch the user tags for their rooms
+ room_tags = await self.store.get_tags_for_user(user_id)
+ room_id_to_tag_name_set: Dict[str, Set[str]] = {
+ room_id: set(tags.keys()) for room_id, tags in room_tags.items()
+ }
+
+ if filters.tags is not None:
+ tags_set = set(filters.tags)
+ filtered_room_id_set = {
+ room_id
+ for room_id in filtered_room_id_set
+ # Remove rooms that don't have one of the tags in the filter
+ if room_id_to_tag_name_set.get(room_id, set()).intersection(
+ tags_set
+ )
+ }
+
+ if filters.not_tags is not None:
+ not_tags_set = set(filters.not_tags)
+ filtered_room_id_set = {
+ room_id
+ for room_id in filtered_room_id_set
+ # Remove rooms if they have any of the tags in the filter
+ if not room_id_to_tag_name_set.get(room_id, set()).intersection(
+ not_tags_set
+ )
+ }
+
+ # Keep rooms if the user has been state reset out of it but we previously sent
+ # down the connection before. We want to make sure that we send these down to
+ # the client regardless of filters so they find out about the state reset.
+ #
+ # We don't always have access to the state in a room after being state reset if
+ # no one else locally on the server is participating in the room so we patch
+ # these back in manually.
+ state_reset_out_of_room_id_set = {
+ room_id
+ for room_id in sync_room_map.keys()
+ if sync_room_map[room_id].event_id is None
+ and previous_connection_state.rooms.have_sent_room(room_id).status
+ != HaveSentRoomFlag.NEVER
+ }
+
+ # Assemble a new sync room map but only with the `filtered_room_id_set`
+ return {
+ room_id: sync_room_map[room_id]
+ for room_id in filtered_room_id_set | state_reset_out_of_room_id_set
+ }
+
+ @trace
+ async def filter_rooms_using_tables(
+ self,
+ user_id: str,
+ sync_room_map: Mapping[str, RoomsForUserSlidingSync],
+ previous_connection_state: PerConnectionState,
+ filters: SlidingSyncConfig.SlidingSyncList.Filters,
+ to_token: StreamToken,
+ dm_room_ids: AbstractSet[str],
+ ) -> Dict[str, RoomsForUserSlidingSync]:
+ """
+ Filter rooms based on the sync request.
+
+ Args:
+ user: User to filter rooms for
+ sync_room_map: Dictionary of room IDs to sort along with membership
+ information in the room at the time of `to_token`.
+ filters: Filters to apply
+ to_token: We filter based on the state of the room at this token
+ dm_room_ids: Set of room IDs which are DMs
+ room_tags: Mapping of room ID to tags
+
+ Returns:
+ A filtered dictionary of room IDs along with membership information in the
+ room at the time of `to_token`.
+ """
+
+ filtered_room_id_set = set(sync_room_map.keys())
+
+ # Filter for Direct-Message (DM) rooms
+ if filters.is_dm is not None:
+ with start_active_span("filters.is_dm"):
+ if filters.is_dm:
+ # Intersect with the DM room set
+ filtered_room_id_set &= dm_room_ids
+ else:
+ # Remove DMs
+ filtered_room_id_set -= dm_room_ids
+
+ if filters.spaces is not None:
+ with start_active_span("filters.spaces"):
+ raise NotImplementedError()
+
+ # Filter for encrypted rooms
+ if filters.is_encrypted is not None:
+ filtered_room_id_set = {
+ room_id
+ for room_id in filtered_room_id_set
+ # Remove rooms if we can't figure out what the encryption status is
+ if sync_room_map[room_id].has_known_state
+ # Or remove if it doesn't match the filter
+ and sync_room_map[room_id].is_encrypted == filters.is_encrypted
+ }
+
+ # Filter for rooms that the user has been invited to
+ if filters.is_invite is not None:
+ with start_active_span("filters.is_invite"):
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in filtered_room_id_set.copy():
+ room_for_user = sync_room_map[room_id]
+ # If we're looking for invite rooms, filter out rooms that the user is
+ # not invited to and vice versa
+ if (
+ filters.is_invite
+ and room_for_user.membership != Membership.INVITE
+ ) or (
+ not filters.is_invite
+ and room_for_user.membership == Membership.INVITE
+ ):
+ filtered_room_id_set.remove(room_id)
+
+ # Filter by room type (space vs room, etc). A room must match one of the types
+ # provided in the list. `None` is a valid type for rooms which do not have a
+ # room type.
+ if filters.room_types is not None or filters.not_room_types is not None:
+ with start_active_span("filters.room_types"):
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in filtered_room_id_set.copy():
+ # Remove rooms if we can't figure out what room type it is
+ if not sync_room_map[room_id].has_known_state:
+ filtered_room_id_set.remove(room_id)
+ continue
+
+ room_type = sync_room_map[room_id].room_type
+
+ if (
+ filters.room_types is not None
+ and room_type not in filters.room_types
+ ):
+ filtered_room_id_set.remove(room_id)
+ continue
+
+ if (
+ filters.not_room_types is not None
+ and room_type in filters.not_room_types
+ ):
+ filtered_room_id_set.remove(room_id)
+ continue
+
+ if filters.room_name_like is not None:
+ with start_active_span("filters.room_name_like"):
+ # TODO: The room name is a bit more sensitive to leak than the
+ # create/encryption event. Maybe we should consider a better way to fetch
+ # historical state before implementing this.
+ #
+ # room_id_to_create_content = await self._bulk_get_partial_current_state_content_for_rooms(
+ # content_type="room_name",
+ # room_ids=filtered_room_id_set,
+ # to_token=to_token,
+ # sync_room_map=sync_room_map,
+ # room_id_to_stripped_state_map=room_id_to_stripped_state_map,
+ # )
+ raise NotImplementedError()
+
+ # Filter by room tags according to the users account data
+ if filters.tags is not None or filters.not_tags is not None:
+ with start_active_span("filters.tags"):
+ # Fetch the user tags for their rooms
+ room_tags = await self.store.get_tags_for_user(user_id)
+ room_id_to_tag_name_set: Dict[str, Set[str]] = {
+ room_id: set(tags.keys()) for room_id, tags in room_tags.items()
+ }
+
+ if filters.tags is not None:
+ tags_set = set(filters.tags)
+ filtered_room_id_set = {
+ room_id
+ for room_id in filtered_room_id_set
+ # Remove rooms that don't have one of the tags in the filter
+ if room_id_to_tag_name_set.get(room_id, set()).intersection(
+ tags_set
+ )
+ }
+
+ if filters.not_tags is not None:
+ not_tags_set = set(filters.not_tags)
+ filtered_room_id_set = {
+ room_id
+ for room_id in filtered_room_id_set
+ # Remove rooms if they have any of the tags in the filter
+ if not room_id_to_tag_name_set.get(room_id, set()).intersection(
+ not_tags_set
+ )
+ }
+
+ # Keep rooms if the user has been state reset out of it but we previously sent
+ # down the connection before. We want to make sure that we send these down to
+ # the client regardless of filters so they find out about the state reset.
+ #
+ # We don't always have access to the state in a room after being state reset if
+ # no one else locally on the server is participating in the room so we patch
+ # these back in manually.
+ state_reset_out_of_room_id_set = {
+ room_id
+ for room_id in sync_room_map.keys()
+ if sync_room_map[room_id].event_id is None
+ and previous_connection_state.rooms.have_sent_room(room_id).status
+ != HaveSentRoomFlag.NEVER
+ }
+
+ # Assemble a new sync room map but only with the `filtered_room_id_set`
+ return {
+ room_id: sync_room_map[room_id]
+ for room_id in filtered_room_id_set | state_reset_out_of_room_id_set
+ }
+
+ @trace
+ async def sort_rooms(
+ self,
+ sync_room_map: Dict[str, RoomsForUserType],
+ to_token: StreamToken,
+ limit: Optional[int] = None,
+ ) -> List[RoomsForUserType]:
+ """
+ Sort by `stream_ordering` of the last event that the user should see in the
+ room. `stream_ordering` is unique so we get a stable sort.
+
+ If `limit` is specified then sort may return fewer entries, but will
+ always return at least the top N rooms. This is useful as we don't always
+ need to sort the full list, but are just interested in the top N.
+
+ Args:
+ sync_room_map: Dictionary of room IDs to sort along with membership
+ information in the room at the time of `to_token`.
+ to_token: We sort based on the events in the room at this token (<= `to_token`)
+ limit: The number of rooms that we need to return from the top of the list.
+
+ Returns:
+ A sorted list of room IDs by `stream_ordering` along with membership information.
+ """
+
+ # Assemble a map of room ID to the `stream_ordering` of the last activity that the
+ # user should see in the room (<= `to_token`)
+ last_activity_in_room_map: Dict[str, int] = {}
+
+ # Same as above, except for positions that we know are in the event
+ # stream cache.
+ cached_positions: Dict[str, int] = {}
+
+ earliest_cache_position = (
+ self.store._events_stream_cache.get_earliest_known_position()
+ )
+
+ for room_id, room_for_user in sync_room_map.items():
+ if room_for_user.membership == Membership.JOIN:
+ # For joined rooms check the stream change cache.
+ cached_position = (
+ self.store._events_stream_cache.get_max_pos_of_last_change(room_id)
+ )
+ if cached_position is not None:
+ cached_positions[room_id] = cached_position
+ else:
+ # If the user has left/been invited/knocked/been banned from a
+ # room, they shouldn't see anything past that point.
+ #
+ # FIXME: It's possible that people should see beyond this point
+ # in invited/knocked cases if for example the room has
+ # `invite`/`world_readable` history visibility, see
+ # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
+ last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
+
+ # If the stream position is in range of the stream change cache
+ # we can include it.
+ if room_for_user.event_pos.stream > earliest_cache_position:
+ cached_positions[room_id] = room_for_user.event_pos.stream
+
+ # If we are only asked for the top N rooms, and we have enough from
+ # looking in the stream change cache, then we can return early. This
+ # is because the cache must include all entries above
+ # `.get_earliest_known_position()`.
+ if limit is not None and len(cached_positions) >= limit:
+ # ... but first we need to handle the case where the cached max
+ # position is greater than the to_token, in which case we do
+ # actually query the DB. This should happen rarely, so can do it in
+ # a loop.
+ for room_id, position in list(cached_positions.items()):
+ if position > to_token.room_key.stream:
+ result = await self.store.get_last_event_pos_in_room_before_stream_ordering(
+ room_id, to_token.room_key
+ )
+ if (
+ result is not None
+ and result[1].stream > earliest_cache_position
+ ):
+ # We have a stream position in the cached range.
+ cached_positions[room_id] = result[1].stream
+ else:
+ # No position in the range, so we remove the entry.
+ cached_positions.pop(room_id)
+
+ if limit is not None and len(cached_positions) >= limit:
+ return sorted(
+ (
+ room
+ for room in sync_room_map.values()
+ if room.room_id in cached_positions
+ ),
+ # Sort by the last activity (stream_ordering) in the room
+ key=lambda room_info: cached_positions[room_info.room_id],
+ # We want descending order
+ reverse=True,
+ )
+
+ # For fully-joined rooms, we find the latest activity at/before the
+ # `to_token`.
+ joined_room_positions = (
+ await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering(
+ [
+ room_id
+ for room_id, room_for_user in sync_room_map.items()
+ if room_for_user.membership == Membership.JOIN
+ ],
+ to_token.room_key,
+ )
+ )
+
+ last_activity_in_room_map.update(joined_room_positions)
+
+ return sorted(
+ sync_room_map.values(),
+ # Sort by the last activity (stream_ordering) in the room
+ key=lambda room_info: last_activity_in_room_map[room_info.room_id],
+ # We want descending order
+ reverse=True,
+ )
+
+ async def get_is_encrypted_for_room_at_token(
+ self, room_id: str, to_token: RoomStreamToken
+ ) -> bool:
+ """Get if the room is encrypted at the time."""
+
+ # Fetch the current encryption state
+ state_ids = await self.store.get_partial_filtered_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.RoomEncryption, "")])
+ )
+ encryption_event_id = state_ids.get((EventTypes.RoomEncryption, ""))
+
+ # Now roll back the state by looking at the state deltas between
+ # to_token and now.
+ deltas = await self.store.get_current_state_deltas_for_room(
+ room_id,
+ from_token=to_token,
+ to_token=self.store.get_room_max_token(),
+ )
+
+ for delta in deltas:
+ if delta.event_type != EventTypes.RoomEncryption:
+ continue
+
+ # Found the first change, we look at the previous event ID to get
+ # the state at the to token.
+
+ if delta.prev_event_id is None:
+ # There is no prev event, so no encryption state event, so room is not encrypted
+ return False
+
+ encryption_event_id = delta.prev_event_id
+ break
+
+ # We didn't find an encryption state, room isn't encrypted
+ if encryption_event_id is None:
+ return False
+
+ # We found encryption state, check if content has a non-null algorithm
+ encrypted_event = await self.store.get_event(encryption_event_id)
+ algorithm = encrypted_event.content.get(EventContentFields.ENCRYPTION_ALGORITHM)
+
+ return algorithm is not None
diff --git a/synapse/handlers/sliding_sync/store.py b/synapse/handlers/sliding_sync/store.py
new file mode 100644
index 0000000000..d24fccf76f
--- /dev/null
+++ b/synapse/handlers/sliding_sync/store.py
@@ -0,0 +1,128 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+import logging
+from typing import TYPE_CHECKING, Optional
+
+import attr
+
+from synapse.logging.opentracing import trace
+from synapse.storage.databases.main import DataStore
+from synapse.types import SlidingSyncStreamToken
+from synapse.types.handlers.sliding_sync import (
+ MutablePerConnectionState,
+ PerConnectionState,
+ SlidingSyncConfig,
+)
+
+if TYPE_CHECKING:
+ pass
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(auto_attribs=True)
+class SlidingSyncConnectionStore:
+ """In-memory store of per-connection state, including what rooms we have
+ previously sent down a sliding sync connection.
+
+ Note: This is NOT safe to run in a worker setup because connection positions will
+ point to different sets of rooms on different workers. e.g. for the same connection,
+ a connection position of 5 might have totally different states on worker A and
+ worker B.
+
+ One complication that we need to deal with here is needing to handle requests being
+ resent, i.e. if we sent down a room in a response that the client received, we must
+ consider the room *not* sent when we get the request again.
+
+ This is handled by using an integer "token", which is returned to the client
+ as part of the sync token. For each connection we store a mapping from
+ tokens to the room states, and create a new entry when we send down new
+ rooms.
+
+ Note that for any given sliding sync connection we will only store a maximum
+ of two different tokens: the previous token from the request and a new token
+ sent in the response. When we receive a request with a given token, we then
+ clear out all other entries with a different token.
+
+ Attributes:
+ _connections: Mapping from `(user_id, conn_id)` to mapping of `token`
+ to mapping of room ID to `HaveSentRoom`.
+ """
+
+ store: "DataStore"
+
+ async def get_and_clear_connection_positions(
+ self,
+ sync_config: SlidingSyncConfig,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> PerConnectionState:
+ """Fetch the per-connection state for the token.
+
+ Raises:
+ SlidingSyncUnknownPosition if the connection_token is unknown
+ """
+ # If this is our first request, there is no previous connection state to fetch out of the database
+ if from_token is None or from_token.connection_position == 0:
+ return PerConnectionState()
+
+ conn_id = sync_config.conn_id or ""
+
+ device_id = sync_config.requester.device_id
+ assert device_id is not None
+
+ return await self.store.get_and_clear_connection_positions(
+ sync_config.user.to_string(),
+ device_id,
+ conn_id,
+ from_token.connection_position,
+ )
+
+ @trace
+ async def record_new_state(
+ self,
+ sync_config: SlidingSyncConfig,
+ from_token: Optional[SlidingSyncStreamToken],
+ new_connection_state: MutablePerConnectionState,
+ ) -> int:
+ """Record updated per-connection state, returning the connection
+ position associated with the new state.
+ If there are no changes to the state this may return the same token as
+ the existing per-connection state.
+ """
+ if not new_connection_state.has_updates():
+ if from_token is not None:
+ return from_token.connection_position
+ else:
+ return 0
+
+ # A from token with a zero connection position means there was no
+ # previously stored connection state, so we treat a zero the same as
+ # there being no previous position.
+ previous_connection_position = None
+ if from_token is not None and from_token.connection_position != 0:
+ previous_connection_position = from_token.connection_position
+
+ conn_id = sync_config.conn_id or ""
+
+ device_id = sync_config.requester.device_id
+ assert device_id is not None
+
+ return await self.store.persist_per_connection_state(
+ sync_config.user.to_string(),
+ device_id,
+ conn_id,
+ previous_connection_position,
+ new_connection_state,
+ )
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index ee74289b6c..2795b282e5 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -33,17 +33,17 @@ from typing import (
Mapping,
NoReturn,
Optional,
+ Protocol,
Set,
)
from urllib.parse import urlencode
import attr
-from typing_extensions import Protocol
from twisted.web.iweb import IRequest
from twisted.web.server import Request
-from synapse.api.constants import LoginType
+from synapse.api.constants import LoginType, ProfileFields
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.device import DeviceHandler
@@ -81,8 +81,7 @@ class SsoIdentityProvider(Protocol):
An Identity Provider, or IdP, is an external HTTP service which authenticates a user
to say whether they should be allowed to log in, or perform a given action.
- Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
- and CAS.
+ Synapse supports various implementations of IdPs, including OpenID Connect.
The main entry point is `handle_redirect_request`, which should return a URI to
redirect the user's browser to the IdP's authentication page.
@@ -97,7 +96,7 @@ class SsoIdentityProvider(Protocol):
def idp_id(self) -> str:
"""A unique identifier for this SSO provider
- Eg, "saml", "cas", "github"
+ Eg. "github"
"""
@property
@@ -157,7 +156,7 @@ class UserAttributes:
class UsernameMappingSession:
"""Data we track about SSO sessions"""
- # A unique identifier for this SSO provider, e.g. "oidc" or "saml".
+ # A unique identifier for this SSO provider, e.g. "oidc".
auth_provider_id: str
# An optional session ID from the IdP.
@@ -351,7 +350,7 @@ class SsoHandler:
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
- "oidc" or "saml".
+ "oidc".
remote_user_id: The user ID according to the remote IdP. This might
be an e-mail address, a GUID, or some other form. It must be
unique and immutable.
@@ -418,7 +417,7 @@ class SsoHandler:
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
- "oidc" or "saml".
+ "oidc".
remote_user_id: The unique identifier from the SSO provider.
@@ -634,7 +633,7 @@ class SsoHandler:
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
- "oidc" or "saml".
+ "oidc".
remote_user_id: The unique identifier from the SSO provider.
@@ -704,7 +703,7 @@ class SsoHandler:
including a non-empty localpart.
auth_provider_id: A unique identifier for this SSO provider, e.g.
- "oidc" or "saml".
+ "oidc".
remote_user_id: The unique identifier from the SSO provider.
@@ -813,9 +812,10 @@ class SsoHandler:
# bail if user already has the same avatar
profile = await self._profile_handler.get_profile(user_id)
- if profile["avatar_url"] is not None:
- server_name = profile["avatar_url"].split("/")[-2]
- media_id = profile["avatar_url"].split("/")[-1]
+ if ProfileFields.AVATAR_URL in profile:
+ avatar_url_parts = profile[ProfileFields.AVATAR_URL].split("/")
+ server_name = avatar_url_parts[-2]
+ media_id = avatar_url_parts[-1]
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id) # type: ignore[has-type]
if media is not None and upload_name == media.upload_name:
@@ -855,12 +855,12 @@ class SsoHandler:
Given an SSO ID, retrieve the user ID for it and complete UIA.
Note that this requires that the user is mapped in the "user_external_ids"
- table. This will be the case if they have ever logged in via SAML or OIDC in
+ table. This will be the case if they have ever logged in via OIDC in
recentish synapse versions, but may not be for older users.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
- "oidc" or "saml".
+ "oidc".
remote_user_id: The unique identifier from the SSO provider.
ui_auth_session_id: The ID of the user-interactive auth session.
request: The request to complete.
@@ -1184,16 +1184,16 @@ class SsoHandler:
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
- "oidc" or "saml".
+ "oidc".
auth_provider_session_id: The session ID from the provider to logout
expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error.
"""
# It is expected that this is the main process.
- assert isinstance(
- self._device_handler, DeviceHandler
- ), "revoking SSO sessions can only be called on the main process"
+ assert isinstance(self._device_handler, DeviceHandler), (
+ "revoking SSO sessions can only be called on the main process"
+ )
# Invalidate any running user-mapping sessions
to_delete = []
@@ -1276,12 +1276,16 @@ def _check_attribute_requirement(
return False
# If the requirement is None, the attribute existing is enough.
- if req.value is None:
+ if req.value is None and req.one_of is None:
return True
values = attributes[req.attribute]
if req.value in values:
return True
+ if req.one_of:
+ for value in req.one_of:
+ if value in values:
+ return True
logger.info(
"SSO attribute %s did not match required value '%s' (was '%s')",
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6af2eeb75f..c6f2c38d8d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -66,6 +66,7 @@ from synapse.logging.opentracing import (
from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.databases.main.stream import PaginateFunction
+from synapse.storage.invite_rule import InviteRule
from synapse.storage.roommember import MemberSummary
from synapse.types import (
DeviceListUpdates,
@@ -86,7 +87,7 @@ from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
-from synapse.util.metrics import Measure, measure_func
+from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -143,6 +144,7 @@ class SyncConfig:
filter_collection: FilterCollection
is_guest: bool
device_id: Optional[str]
+ use_state_after: bool
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -183,10 +185,7 @@ class JoinedSyncResult:
to tell if room needs to be part of the sync result.
"""
return bool(
- self.timeline
- or self.state
- or self.ephemeral
- or self.account_data
+ self.timeline or self.state or self.ephemeral or self.account_data
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
@@ -575,10 +574,10 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
- result: Union[SyncResult, E2eeSyncResult] = (
- await self.current_sync_for_user(
- sync_config, sync_version, since_token, full_state=full_state
- )
+ result: Union[
+ SyncResult, E2eeSyncResult
+ ] = await self.current_sync_for_user(
+ sync_config, sync_version, since_token, full_state=full_state
)
else:
# Otherwise, we wait for something to happen and report it to the user.
@@ -673,10 +672,10 @@ class SyncHandler:
# Go through the `/sync` v2 path
if sync_version == SyncVersion.SYNC_V2:
- sync_result: Union[SyncResult, E2eeSyncResult] = (
- await self.generate_sync_result(
- sync_config, since_token, full_state
- )
+ sync_result: Union[
+ SyncResult, E2eeSyncResult
+ ] = await self.generate_sync_result(
+ sync_config, since_token, full_state
)
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
elif sync_version == SyncVersion.E2EE_SYNC:
@@ -909,7 +908,7 @@ class SyncHandler:
# Use `stream_ordering` for updates
else paginate_room_events_by_stream_ordering
)
- events, end_key = await pagination_method(
+ events, end_key, limited = await pagination_method(
room_id=room_id,
# The bounds are reversed so we can paginate backwards
# (from newer to older events) starting at to_bound.
@@ -917,9 +916,7 @@ class SyncHandler:
from_key=end_key,
to_key=since_key,
direction=Direction.BACKWARDS,
- # We add one so we can determine if there are enough events to saturate
- # the limit or not (see `limited`)
- limit=load_limit + 1,
+ limit=load_limit,
)
# We want to return the events in ascending order (the last event is the
# most recent).
@@ -974,9 +971,6 @@ class SyncHandler:
loaded_recents.extend(recents)
recents = loaded_recents
- if len(events) <= load_limit:
- limited = False
- break
max_repeat -= 1
if len(recents) > timeline_limit:
@@ -1149,6 +1143,7 @@ class SyncHandler:
since_token: Optional[StreamToken],
end_token: StreamToken,
full_state: bool,
+ joined: bool,
) -> MutableStateMap[EventBase]:
"""Works out the difference in state between the end of the previous sync and
the start of the timeline.
@@ -1163,6 +1158,7 @@ class SyncHandler:
the point just after their leave event.
full_state: Whether to force returning the full state.
`lazy_load_members` still applies when `full_state` is `True`.
+ joined: whether the user is currently joined to the room
Returns:
The state to return in the sync response for the room.
@@ -1238,11 +1234,12 @@ class SyncHandler:
if full_state:
state_ids = await self._compute_state_delta_for_full_sync(
room_id,
- sync_config.user,
+ sync_config,
batch,
end_token,
members_to_fetch,
timeline_state,
+ joined,
)
else:
# If this is an initial sync then full_state should be set, and
@@ -1252,6 +1249,7 @@ class SyncHandler:
state_ids = await self._compute_state_delta_for_incremental_sync(
room_id,
+ sync_config,
batch,
since_token,
end_token,
@@ -1324,20 +1322,24 @@ class SyncHandler:
async def _compute_state_delta_for_full_sync(
self,
room_id: str,
- syncing_user: UserID,
+ sync_config: SyncConfig,
batch: TimelineBatch,
end_token: StreamToken,
members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str],
+ joined: bool,
) -> StateMap[str]:
"""Calculate the state events to be included in a full sync response.
As with `_compute_state_delta_for_incremental_sync`, the result will include
the membership events for the senders of each event in `members_to_fetch`.
+ Note that whether this returns the state at the start or the end of the
+ batch depends on `sync_config.use_state_after` (c.f. MSC4222).
+
Args:
room_id: The room we are calculating for.
- syncing_user: The user that is calling `/sync`.
+ sync_confg: The user that is calling `/sync`.
batch: The timeline batch for the room that will be sent to the user.
end_token: Token of the end of the current batch. Normally this will be
the same as the global "now_token", but if the user has left the room,
@@ -1346,10 +1348,11 @@ class SyncHandler:
events in the timeline.
timeline_state: The contribution to the room state from state events in
`batch`. Only contains the last event for any given state key.
+ joined: whether the user is currently joined to the room
Returns:
A map from (type, state_key) to event_id, for each event that we believe
- should be included in the `state` part of the sync response.
+ should be included in the `state` or `state_after` part of the sync response.
"""
if members_to_fetch is not None:
# Lazy-loading of membership events is enabled.
@@ -1367,7 +1370,7 @@ class SyncHandler:
# is no guarantee that our membership will be in the auth events of
# timeline events when the room is partial stated.
state_filter = StateFilter.from_lazy_load_member_list(
- members_to_fetch.union((syncing_user.to_string(),))
+ members_to_fetch.union((sync_config.user.to_string(),))
)
# We are happy to use partial state to compute the `/sync` response.
@@ -1381,6 +1384,61 @@ class SyncHandler:
await_full_state = True
lazy_load_members = False
+ # Check if we are wanting to return the state at the start or end of the
+ # timeline. If at the end we can just use the current state.
+ if sync_config.use_state_after:
+ # If we're getting the state at the end of the timeline, we can just
+ # use the current state of the room (and roll back any changes
+ # between when we fetched the current state and `end_token`).
+ #
+ # For rooms we're not joined to, there might be a very large number
+ # of deltas between `end_token` and "now", and so instead we fetch
+ # the state at the end of the timeline.
+ if joined:
+ state_ids = await self._state_storage_controller.get_current_state_ids(
+ room_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
+ )
+
+ # Now roll back the state by looking at the state deltas between
+ # end_token and now.
+ deltas = await self.store.get_current_state_deltas_for_room(
+ room_id,
+ from_token=end_token.room_key,
+ to_token=self.store.get_room_max_token(),
+ )
+ if deltas:
+ mutable_state_ids = dict(state_ids)
+
+ # We iterate over the deltas backwards so that if there are
+ # multiple changes of the same type/state_key we'll
+ # correctly pick the earliest delta.
+ for delta in reversed(deltas):
+ if delta.prev_event_id:
+ mutable_state_ids[(delta.event_type, delta.state_key)] = (
+ delta.prev_event_id
+ )
+ elif (delta.event_type, delta.state_key) in mutable_state_ids:
+ mutable_state_ids.pop((delta.event_type, delta.state_key))
+
+ state_ids = mutable_state_ids
+
+ return state_ids
+
+ else:
+ # Just use state groups to get the state at the end of the
+ # timeline, i.e. the state at the leave/etc event.
+ state_at_timeline_end = (
+ await self._state_storage_controller.get_state_ids_at(
+ room_id,
+ stream_position=end_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
+ )
+ )
+ return state_at_timeline_end
+
state_at_timeline_end = await self._state_storage_controller.get_state_ids_at(
room_id,
stream_position=end_token,
@@ -1413,6 +1471,7 @@ class SyncHandler:
async def _compute_state_delta_for_incremental_sync(
self,
room_id: str,
+ sync_config: SyncConfig,
batch: TimelineBatch,
since_token: StreamToken,
end_token: StreamToken,
@@ -1427,8 +1486,12 @@ class SyncHandler:
(`compute_state_delta`) is responsible for keeping track of which membership
events we have already sent to the client, and hence ripping them out.
+ Note that whether this returns the state at the start or the end of the
+ batch depends on `sync_config.use_state_after` (c.f. MSC4222).
+
Args:
room_id: The room we are calculating for.
+ sync_config
batch: The timeline batch for the room that will be sent to the user.
since_token: Token of the end of the previous batch.
end_token: Token of the end of the current batch. Normally this will be
@@ -1441,7 +1504,7 @@ class SyncHandler:
Returns:
A map from (type, state_key) to event_id, for each event that we believe
- should be included in the `state` part of the sync response.
+ should be included in the `state` or `state_after` part of the sync response.
"""
if members_to_fetch is not None:
# Lazy-loading is enabled. Only return the state that is needed.
@@ -1453,6 +1516,51 @@ class SyncHandler:
await_full_state = True
lazy_load_members = False
+ # Check if we are wanting to return the state at the start or end of the
+ # timeline. If at the end we can just use the current state delta stream.
+ if sync_config.use_state_after:
+ delta_state_ids: MutableStateMap[str] = {}
+
+ if members_to_fetch:
+ # We're lazy-loading, so the client might need some more member
+ # events to understand the events in this timeline. So we always
+ # fish out all the member events corresponding to the timeline
+ # here. The caller will then dedupe any redundant ones.
+ member_ids = await self._state_storage_controller.get_current_state_ids(
+ room_id=room_id,
+ state_filter=StateFilter.from_types(
+ (EventTypes.Member, member) for member in members_to_fetch
+ ),
+ await_full_state=await_full_state,
+ )
+ delta_state_ids.update(member_ids)
+
+ # We don't do LL filtering for incremental syncs - see
+ # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346
+ # N.B. this slows down incr syncs as we are now processing way more
+ # state in the server than if we were LLing.
+ #
+ # i.e. we return all state deltas, including membership changes that
+ # we'd normally exclude due to LL.
+ deltas = await self.store.get_current_state_deltas_for_room(
+ room_id=room_id,
+ from_token=since_token.room_key,
+ to_token=end_token.room_key,
+ )
+ for delta in deltas:
+ if delta.event_id is None:
+ # There was a state reset and this state entry is no longer
+ # present, but we have no way of informing the client about
+ # this, so we just skip it for now.
+ continue
+
+ # Note that deltas are in stream ordering, so if there are
+ # multiple deltas for a given type/state_key we'll always pick
+ # the latest one.
+ delta_state_ids[(delta.event_type, delta.state_key)] = delta.event_id
+
+ return delta_state_ids
+
# For a non-gappy sync if the events in the timeline are simply a linear
# chain (i.e. no merging/branching of the graph), then we know the state
# delta between the end of the previous sync and start of the new one is
@@ -1488,13 +1596,16 @@ class SyncHandler:
# timeline here. The caller will then dedupe any redundant
# ones.
- state_ids = await self._state_storage_controller.get_state_ids_for_event(
- batch.events[0].event_id,
- # we only want members!
- state_filter=StateFilter.from_types(
- (EventTypes.Member, member) for member in members_to_fetch
- ),
- await_full_state=False,
+ state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[0].event_id,
+ # we only want members!
+ state_filter=StateFilter.from_types(
+ (EventTypes.Member, member)
+ for member in members_to_fetch
+ ),
+ await_full_state=False,
+ )
)
return state_ids
@@ -1779,8 +1890,15 @@ class SyncHandler:
)
if include_device_list_updates:
- device_lists = await self._generate_sync_entry_for_device_list(
- sync_result_builder,
+ # include_device_list_updates can only be True if we have a
+ # since token.
+ assert since_token is not None
+
+ device_lists = await self._device_handler.generate_sync_entry_for_device_list(
+ user_id=user_id,
+ since_token=since_token,
+ now_token=sync_result_builder.now_token,
+ joined_room_ids=sync_result_builder.joined_room_ids,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
@@ -1892,8 +2010,14 @@ class SyncHandler:
newly_left_users,
) = sync_result_builder.calculate_user_changes()
- device_lists = await self._generate_sync_entry_for_device_list(
- sync_result_builder,
+ # include_device_list_updates can only be True if we have a
+ # since token.
+ assert since_token is not None
+ device_lists = await self._device_handler.generate_sync_entry_for_device_list(
+ user_id=user_id,
+ since_token=since_token,
+ now_token=sync_result_builder.now_token,
+ joined_room_ids=sync_result_builder.joined_room_ids,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
@@ -2070,94 +2194,6 @@ class SyncHandler:
return sync_result_builder
- @measure_func("_generate_sync_entry_for_device_list")
- async def _generate_sync_entry_for_device_list(
- self,
- sync_result_builder: "SyncResultBuilder",
- newly_joined_rooms: AbstractSet[str],
- newly_joined_or_invited_or_knocked_users: AbstractSet[str],
- newly_left_rooms: AbstractSet[str],
- newly_left_users: AbstractSet[str],
- ) -> DeviceListUpdates:
- """Generate the DeviceListUpdates section of sync
-
- Args:
- sync_result_builder
- newly_joined_rooms: Set of rooms user has joined since previous sync
- newly_joined_or_invited_or_knocked_users: Set of users that have joined,
- been invited to a room or are knocking on a room since
- previous sync.
- newly_left_rooms: Set of rooms user has left since previous sync
- newly_left_users: Set of users that have left a room we're in since
- previous sync
- """
-
- user_id = sync_result_builder.sync_config.user.to_string()
- since_token = sync_result_builder.since_token
- assert since_token is not None
-
- # Take a copy since these fields will be mutated later.
- newly_joined_or_invited_or_knocked_users = set(
- newly_joined_or_invited_or_knocked_users
- )
- newly_left_users = set(newly_left_users)
-
- # We want to figure out what user IDs the client should refetch
- # device keys for, and which users we aren't going to track changes
- # for anymore.
- #
- # For the first step we check:
- # a. if any users we share a room with have updated their devices,
- # and
- # b. we also check if we've joined any new rooms, or if a user has
- # joined a room we're in.
- #
- # For the second step we just find any users we no longer share a
- # room with by looking at all users that have left a room plus users
- # that were in a room we've left.
-
- users_that_have_changed = set()
-
- joined_room_ids = sync_result_builder.joined_room_ids
-
- # Step 1a, check for changes in devices of users we share a room
- # with
- users_that_have_changed = (
- await self._device_handler.get_device_changes_in_shared_rooms(
- user_id,
- joined_room_ids,
- from_token=since_token,
- now_token=sync_result_builder.now_token,
- )
- )
-
- # Step 1b, check for newly joined rooms
- for room_id in newly_joined_rooms:
- joined_users = await self.store.get_users_in_room(room_id)
- newly_joined_or_invited_or_knocked_users.update(joined_users)
-
- # TODO: Check that these users are actually new, i.e. either they
- # weren't in the previous sync *or* they left and rejoined.
- users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
-
- user_signatures_changed = await self.store.get_users_whose_signatures_changed(
- user_id, since_token.device_list_key
- )
- users_that_have_changed.update(user_signatures_changed)
-
- # Now find users that we no longer track
- for room_id in newly_left_rooms:
- left_users = await self.store.get_users_in_room(room_id)
- newly_left_users.update(left_users)
-
- # Remove any users that we still share a room with.
- left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
- for user_id, entries in left_users_rooms.items():
- if any(rid in joined_room_ids for rid in entries):
- newly_left_users.discard(user_id)
-
- return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users)
-
@trace
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"
@@ -2241,18 +2277,18 @@ class SyncHandler:
if push_rules_changed:
global_account_data = dict(global_account_data)
- global_account_data[AccountDataTypes.PUSH_RULES] = (
- await self._push_rules_handler.push_rules_for_user(sync_config.user)
- )
+ global_account_data[
+ AccountDataTypes.PUSH_RULES
+ ] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
else:
all_global_account_data = await self.store.get_global_account_data_for_user(
user_id
)
global_account_data = dict(all_global_account_data)
- global_account_data[AccountDataTypes.PUSH_RULES] = (
- await self._push_rules_handler.push_rules_for_user(sync_config.user)
- )
+ global_account_data[
+ AccountDataTypes.PUSH_RULES
+ ] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
account_data_for_user = (
await sync_config.filter_collection.filter_global_account_data(
@@ -2514,6 +2550,7 @@ class SyncHandler:
room_entries: List[RoomSyncResultBuilder] = []
invited: List[InvitedSyncResult] = []
knocked: List[KnockedSyncResult] = []
+ invite_config = await self.store.get_invite_config_for_user(user_id)
for room_id, events in mem_change_events_by_room_id.items():
# The body of this loop will add this room to at least one of the five lists
# above. Things get messy if you've e.g. joined, left, joined then left the
@@ -2596,7 +2633,11 @@ class SyncHandler:
# Only bother if we're still currently invited
should_invite = last_non_join.membership == Membership.INVITE
if should_invite:
- if last_non_join.sender not in ignored_users:
+ if (
+ last_non_join.sender not in ignored_users
+ and invite_config.get_invite_rule(last_non_join.sender)
+ != InviteRule.IGNORE
+ ):
invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join)
if invite_room_sync:
invited.append(invite_room_sync)
@@ -2683,7 +2724,7 @@ class SyncHandler:
newly_joined = room_id in newly_joined_rooms
if room_entry:
- events, start_key = room_entry
+ events, start_key, _ = room_entry
# We want to return the events in ascending order (the last event is the
# most recent).
events.reverse()
@@ -2751,6 +2792,7 @@ class SyncHandler:
membership_list=Membership.LIST,
excluded_rooms=sync_result_builder.excluded_room_ids,
)
+ invite_config = await self.store.get_invite_config_for_user(user_id)
room_entries = []
invited = []
@@ -2776,6 +2818,8 @@ class SyncHandler:
elif event.membership == Membership.INVITE:
if event.sender in ignored_users:
continue
+ if invite_config.get_invite_rule(event.sender) == InviteRule.IGNORE:
+ continue
invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
elif event.membership == Membership.KNOCK:
@@ -2947,6 +2991,7 @@ class SyncHandler:
since_token,
room_builder.end_token,
full_state=full_state,
+ joined=room_builder.rtype == "joined",
)
else:
# An out of band room won't have any state changes.
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 32dca8c43b..477961d78c 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -157,104 +157,6 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
)
-class _BaseThreepidAuthChecker:
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.store = hs.get_datastores().main
-
- async def _check_threepid(self, medium: str, authdict: dict) -> dict:
- if "threepid_creds" not in authdict:
- raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
-
- threepid_creds = authdict["threepid_creds"]
-
- identity_handler = self.hs.get_identity_handler()
-
- logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
-
- # msisdns are currently always verified via the IS
- if medium == "msisdn":
- if not self.hs.config.registration.account_threepid_delegate_msisdn:
- raise SynapseError(
- 400, "Phone number verification is not enabled on this homeserver"
- )
- threepid = await identity_handler.threepid_from_creds(
- self.hs.config.registration.account_threepid_delegate_msisdn,
- threepid_creds,
- )
- elif medium == "email":
- if self.hs.config.email.can_verify_email:
- threepid = None
- row = await self.store.get_threepid_validation_session(
- medium,
- threepid_creds["client_secret"],
- sid=threepid_creds["sid"],
- validated=True,
- )
-
- if row:
- threepid = {
- "medium": row.medium,
- "address": row.address,
- "validated_at": row.validated_at,
- }
-
- # Valid threepid returned, delete from the db
- await self.store.delete_threepid_session(threepid_creds["sid"])
- else:
- raise SynapseError(
- 400, "Email address verification is not enabled on this homeserver"
- )
- else:
- # this can't happen!
- raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
-
- if not threepid:
- raise LoginError(
- 401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED
- )
-
- if threepid["medium"] != medium:
- raise LoginError(
- 401,
- "Expecting threepid of type '%s', got '%s'"
- % (medium, threepid["medium"]),
- errcode=Codes.UNAUTHORIZED,
- )
-
- threepid["threepid_creds"] = authdict["threepid_creds"]
-
- return threepid
-
-
-class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
- AUTH_TYPE = LoginType.EMAIL_IDENTITY
-
- def __init__(self, hs: "HomeServer"):
- UserInteractiveAuthChecker.__init__(self, hs)
- _BaseThreepidAuthChecker.__init__(self, hs)
-
- def is_enabled(self) -> bool:
- return self.hs.config.email.can_verify_email
-
- async def check_auth(self, authdict: dict, clientip: str) -> Any:
- return await self._check_threepid("email", authdict)
-
-
-class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
- AUTH_TYPE = LoginType.MSISDN
-
- def __init__(self, hs: "HomeServer"):
- UserInteractiveAuthChecker.__init__(self, hs)
- _BaseThreepidAuthChecker.__init__(self, hs)
-
- def is_enabled(self) -> bool:
- return bool(self.hs.config.registration.account_threepid_delegate_msisdn)
-
- async def check_auth(self, authdict: dict, clientip: str) -> Any:
- return await self._check_threepid("msisdn", authdict)
-
-
class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.REGISTRATION_TOKEN
@@ -263,7 +165,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
self.hs = hs
self._enabled = bool(
hs.config.registration.registration_requires_token
- ) or bool(hs.config.registration.enable_registration_token_3pid_bypass)
+ )
self.store = hs.get_datastores().main
def is_enabled(self) -> bool:
@@ -325,8 +227,6 @@ INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
- EmailIdentityAuthChecker,
- MsisdnAuthChecker,
RegistrationTokenAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index a343637b82..33edef5f14 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -26,7 +26,13 @@ from typing import TYPE_CHECKING, List, Optional, Set, Tuple
from twisted.internet.interfaces import IDelayedCall
import synapse.metrics
-from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
+from synapse.api.constants import (
+ EventTypes,
+ HistoryVisibility,
+ JoinRules,
+ Membership,
+ ProfileFields,
+)
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -102,6 +108,9 @@ class UserDirectoryHandler(StateDeltasHandler):
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.worker.should_update_user_directory
self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
+ self.exclude_remote_users = (
+ hs.config.userdirectory.user_directory_exclude_remote_users
+ )
self.show_locked_users = hs.config.userdirectory.show_locked_users
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self._hs = hs
@@ -161,7 +170,7 @@ class UserDirectoryHandler(StateDeltasHandler):
non_spammy_users = []
for user in results["results"]:
if not await self._spam_checker_module_callbacks.check_username_for_spam(
- user
+ user, user_id
):
non_spammy_users.append(user)
results["results"] = non_spammy_users
@@ -756,6 +765,10 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_profile_in_user_dir(
user_id,
- display_name=non_null_str_or_none(profile.get("displayname")),
- avatar_url=non_null_str_or_none(profile.get("avatar_url")),
+ display_name=non_null_str_or_none(
+ profile.get(ProfileFields.DISPLAYNAME)
+ ),
+ avatar_url=non_null_str_or_none(
+ profile.get(ProfileFields.AVATAR_URL)
+ ),
)
diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
index 7e578cf462..e58a416026 100644
--- a/synapse/handlers/worker_lock.py
+++ b/synapse/handlers/worker_lock.py
@@ -19,6 +19,7 @@
#
#
+import logging
import random
from types import TracebackType
from typing import (
@@ -183,7 +184,7 @@ class WorkerLocksHandler:
return
def _wake_all_locks(
- locks: Collection[Union[WaitingLock, WaitingMultiLock]]
+ locks: Collection[Union[WaitingLock, WaitingMultiLock]],
) -> None:
for lock in locks:
deferred = lock.deferred
@@ -269,6 +270,10 @@ class WaitingLock:
def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
+ if self._retry_interval > 5 * 2 ^ 7: # ~10 minutes
+ logging.warning(
+ f"Lock timeout is getting excessive: {self._retry_interval}s. There may be a deadlock."
+ )
return next * random.uniform(0.9, 1.1)
@@ -344,4 +349,8 @@ class WaitingMultiLock:
def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
+ if self._retry_interval > 5 * 2 ^ 7: # ~10 minutes
+ logging.warning(
+ f"Lock timeout is getting excessive: {self._retry_interval}s. There may be a deadlock."
+ )
return next * random.uniform(0.9, 1.1)
|