diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 6e6eaf3805..9a1aea083f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -26,13 +26,18 @@ from synapse.api.errors import (
Codes,
InvalidClientTokenError,
MissingClientTokenError,
+ UnstableSpecAuthError,
)
from synapse.appservice import ApplicationService
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import active_span, force_tracing, start_active_span
-from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import Requester, UserID, create_requester
+from synapse.logging.opentracing import (
+ active_span,
+ force_tracing,
+ start_active_span,
+ trace,
+)
+from synapse.types import Requester, create_requester
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -64,14 +69,14 @@ class Auth:
async def check_user_in_room(
self,
room_id: str,
- user_id: str,
+ requester: Requester,
allow_departed_users: bool = False,
) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
- user_id: The user to check.
+ requester: The user making the request, according to the access token.
current_state: Optional map of the current state of the room.
If provided then that map is used to check whether they are a
@@ -88,6 +93,7 @@ class Auth:
membership event ID of the user.
"""
+ user_id = requester.user.to_string()
(
membership,
member_event_id,
@@ -106,8 +112,11 @@ class Auth:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return membership, member_event_id
-
- raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
+ raise UnstableSpecAuthError(
+ 403,
+ "User %s not in room %s" % (user_id, room_id),
+ errcode=Codes.NOT_JOINED,
+ )
async def get_user_by_req(
self,
@@ -173,96 +182,69 @@ class Auth:
access_token = self.get_access_token_from_request(request)
- (
- user_id,
- device_id,
- app_service,
- ) = await self._get_appservice_user_id_and_device_id(request)
- if user_id and app_service:
- if ip_addr and self._track_appservice_user_ips:
- await self.store.insert_client_ip(
- user_id=user_id,
- access_token=access_token,
- ip=ip_addr,
- user_agent=user_agent,
- device_id="dummy-device"
- if device_id is None
- else device_id, # stubbed
- )
-
- requester = create_requester(
- user_id, app_service=app_service, device_id=device_id
+ # First check if it could be a request from an appservice
+ requester = await self._get_appservice_user(request)
+ if not requester:
+ # If not, it should be from a regular user
+ requester = await self.get_user_by_access_token(
+ access_token, allow_expired=allow_expired
)
- request.requester = user_id
- return requester
-
- user_info = await self.get_user_by_access_token(
- access_token, allow_expired=allow_expired
- )
- token_id = user_info.token_id
- is_guest = user_info.is_guest
- shadow_banned = user_info.shadow_banned
-
- # Deny the request if the user account has expired.
- if not allow_expired:
- if await self._account_validity_handler.is_user_expired(
- user_info.user_id
- ):
- # Raise the error if either an account validity module has determined
- # the account has expired, or the legacy account validity
- # implementation is enabled and determined the account has expired
- raise AuthError(
- 403,
- "User account has expired",
- errcode=Codes.EXPIRED_ACCOUNT,
- )
-
- device_id = user_info.device_id
-
- if access_token and ip_addr:
+ # Deny the request if the user account has expired.
+ # This check is only done for regular users, not appservice ones.
+ if not allow_expired:
+ if await self._account_validity_handler.is_user_expired(
+ requester.user.to_string()
+ ):
+ # Raise the error if either an account validity module has determined
+ # the account has expired, or the legacy account validity
+ # implementation is enabled and determined the account has expired
+ raise AuthError(
+ 403,
+ "User account has expired",
+ errcode=Codes.EXPIRED_ACCOUNT,
+ )
+
+ if ip_addr and (
+ not requester.app_service or self._track_appservice_user_ips
+ ):
+ # XXX(quenting): I'm 95% confident that we could skip setting the
+ # device_id to "dummy-device" for appservices, and that the only impact
+ # would be some rows which whould not deduplicate in the 'user_ips'
+ # table during the transition
+ recorded_device_id = (
+ "dummy-device"
+ if requester.device_id is None and requester.app_service is not None
+ else requester.device_id
+ )
await self.store.insert_client_ip(
- user_id=user_info.token_owner,
+ user_id=requester.authenticated_entity,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
- device_id=device_id,
+ device_id=recorded_device_id,
)
+
# Track also the puppeted user client IP if enabled and the user is puppeting
if (
- user_info.user_id != user_info.token_owner
+ requester.user.to_string() != requester.authenticated_entity
and self._track_puppeted_user_ips
):
await self.store.insert_client_ip(
- user_id=user_info.user_id,
+ user_id=requester.user.to_string(),
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
- device_id=device_id,
+ device_id=requester.device_id,
)
- if is_guest and not allow_guest:
+ if requester.is_guest and not allow_guest:
raise AuthError(
403,
"Guest access not allowed",
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
- # Mark the token as used. This is used to invalidate old refresh
- # tokens after some time.
- if not user_info.token_used and token_id is not None:
- await self.store.mark_access_token_as_used(token_id)
-
- requester = create_requester(
- user_info.user_id,
- token_id,
- is_guest,
- shadow_banned,
- device_id,
- app_service=app_service,
- authenticated_entity=user_info.token_owner,
- )
-
request.requester = requester
return requester
except KeyError:
@@ -299,9 +281,7 @@ class Auth:
403, "Application service has not registered this user (%s)" % user_id
)
- async def _get_appservice_user_id_and_device_id(
- self, request: Request
- ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
+ async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
"""
Given a request, reads the request parameters to determine:
- whether it's an application service that's making this request
@@ -316,15 +296,13 @@ class Auth:
Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
Returns:
- 3-tuple of
- (user ID?, device ID?, application service?)
+ the application service `Requester` of that request
Postconditions:
- - If an application service is returned, so is a user ID
- - A user ID is never returned without an application service
- - A device ID is never returned without a user ID or an application service
- - The returned application service, if present, is permitted to control the
- returned user ID.
+ - The `app_service` field in the returned `Requester` is set
+ - The `user_id` field in the returned `Requester` is either the application
+ service sender or the controlled user set by the `user_id` URI parameter
+ - The returned application service is permitted to control the returned user ID.
- The returned device ID, if present, has been checked to be a valid device ID
for the returned user ID.
"""
@@ -334,12 +312,12 @@ class Auth:
self.get_access_token_from_request(request)
)
if app_service is None:
- return None, None, None
+ return None
if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientAddress().host)
if ip_address not in app_service.ip_range_whitelist:
- return None, None, None
+ return None
# This will always be set by the time Twisted calls us.
assert request.args is not None
@@ -373,13 +351,15 @@ class Auth:
Codes.EXCLUSIVE,
)
- return effective_user_id, effective_device_id, app_service
+ return create_requester(
+ effective_user_id, app_service=app_service, device_id=effective_device_id
+ )
async def get_user_by_access_token(
self,
token: str,
allow_expired: bool = False,
- ) -> TokenLookupResult:
+ ) -> Requester:
"""Validate access token and get user_id from it
Args:
@@ -396,9 +376,9 @@ class Auth:
# First look in the database to see if the access token is present
# as an opaque token.
- r = await self.store.get_user_by_access_token(token)
- if r:
- valid_until_ms = r.valid_until_ms
+ user_info = await self.store.get_user_by_access_token(token)
+ if user_info:
+ valid_until_ms = user_info.valid_until_ms
if (
not allow_expired
and valid_until_ms is not None
@@ -410,7 +390,20 @@ class Auth:
msg="Access token has expired", soft_logout=True
)
- return r
+ # Mark the token as used. This is used to invalidate old refresh
+ # tokens after some time.
+ await self.store.mark_access_token_as_used(user_info.token_id)
+
+ requester = create_requester(
+ user_id=user_info.user_id,
+ access_token_id=user_info.token_id,
+ is_guest=user_info.is_guest,
+ shadow_banned=user_info.shadow_banned,
+ device_id=user_info.device_id,
+ authenticated_entity=user_info.token_owner,
+ )
+
+ return requester
# If the token isn't found in the database, then it could still be a
# macaroon for a guest, so we check that here.
@@ -436,11 +429,12 @@ class Auth:
"Guest access token used for regular user"
)
- return TokenLookupResult(
+ return create_requester(
user_id=user_id,
is_guest=True,
# all guests get the same device id
device_id=GUEST_DEVICE_ID,
+ authenticated_entity=user_id,
)
except (
pymacaroons.exceptions.MacaroonException,
@@ -463,32 +457,33 @@ class Auth:
request.requester = create_requester(service.sender, app_service=service)
return service
- async def is_server_admin(self, user: UserID) -> bool:
+ async def is_server_admin(self, requester: Requester) -> bool:
"""Check if the given user is a local server admin.
Args:
- user: user to check
+ requester: The user making the request, according to the access token.
Returns:
True if the user is an admin
"""
- return await self.store.is_server_admin(user)
+ return await self.store.is_server_admin(requester.user)
- async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
+ async def check_can_change_room_list(
+ self, room_id: str, requester: Requester
+ ) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
Args:
- room_id
- user
+ room_id: The room to check.
+ requester: The user making the request, according to the access token.
"""
- is_admin = await self.is_server_admin(user)
+ is_admin = await self.is_server_admin(requester)
if is_admin:
return True
- user_id = user.to_string()
- await self.check_user_in_room(room_id, user_id)
+ await self.check_user_in_room(room_id, requester)
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
@@ -507,7 +502,9 @@ class Auth:
send_level = event_auth.get_send_level(
EventTypes.CanonicalAlias, "", power_level_event
)
- user_level = event_auth.get_user_power_level(user_id, auth_events)
+ user_level = event_auth.get_user_power_level(
+ requester.user.to_string(), auth_events
+ )
return user_level >= send_level
@@ -563,17 +560,18 @@ class Auth:
return query_params[0].decode("ascii")
+ @trace
async def check_user_in_room_or_world_readable(
- self, room_id: str, user_id: str, allow_departed_users: bool = False
+ self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
Args:
- room_id: room to check
- user_id: user to check
- allow_departed_users: if True, accept users that were previously
- members but have now departed
+ room_id: The room to check.
+ requester: The user making the request, according to the access token.
+ allow_departed_users: If True, accept users that were previously
+ members but have now departed.
Returns:
Resolves to the current membership of the user in the room and the
@@ -588,7 +586,7 @@ class Auth:
# * The user is a guest user, and has joined the room
# else it will throw.
return await self.check_user_in_room(
- room_id, user_id, allow_departed_users=allow_departed_users
+ room_id, requester, allow_departed_users=allow_departed_users
)
except AuthError:
visibility = await self._storage_controllers.state.get_current_state_event(
@@ -600,8 +598,9 @@ class Auth:
== HistoryVisibility.WORLD_READABLE
):
return Membership.JOIN, None
- raise AuthError(
+ raise UnstableSpecAuthError(
403,
"User %s not in room %s, and room previews are disabled"
- % (user_id, room_id),
+ % (requester.user, room_id),
+ errcode=Codes.NOT_JOINED,
)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 2653764119..c73aea622a 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -216,11 +216,11 @@ class EventContentFields:
MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
# For "insertion" events to indicate what the next batch ID should be in
# order to connect to it
- MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id"
+ MSC2716_NEXT_BATCH_ID: Final = "next_batch_id"
# Used on "batch" events to indicate which insertion event it connects to
- MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id"
+ MSC2716_BATCH_ID: Final = "batch_id"
# For "marker" events
- MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion"
+ MSC2716_INSERTION_EVENT_REFERENCE: Final = "insertion_event_reference"
# The authorising user for joining a restricted room.
AUTHORISING_USER: Final = "join_authorised_via_users_server"
@@ -257,7 +257,8 @@ class GuestAccess:
class ReceiptTypes:
READ: Final = "m.read"
- READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
+ READ_PRIVATE: Final = "m.read.private"
+ UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
FULLY_READ: Final = "m.fully_read"
@@ -268,4 +269,4 @@ class PublicRoomsFilterFields:
"""
GENERIC_SEARCH_TERM: Final = "generic_search_term"
- ROOM_TYPES: Final = "org.matrix.msc3827.room_types"
+ ROOM_TYPES: Final = "room_types"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 1c74e131f2..e6dea89c6d 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -26,6 +26,7 @@ from twisted.web import http
from synapse.util import json_decoder
if typing.TYPE_CHECKING:
+ from synapse.config.homeserver import HomeServerConfig
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -80,6 +81,12 @@ class Codes(str, Enum):
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
+ # Part of MSC3848
+ # https://github.com/matrix-org/matrix-spec-proposals/pull/3848
+ ALREADY_JOINED = "ORG.MATRIX.MSC3848.ALREADY_JOINED"
+ NOT_JOINED = "ORG.MATRIX.MSC3848.NOT_JOINED"
+ INSUFFICIENT_POWER = "ORG.MATRIX.MSC3848.INSUFFICIENT_POWER"
+
# The account has been suspended on the server.
# By opposition to `USER_DEACTIVATED`, this is a reversible measure
# that can possibly be appealed and reverted.
@@ -167,7 +174,7 @@ class SynapseError(CodeMessageException):
else:
self._additional_fields = dict(additional_fields)
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields)
@@ -213,7 +220,7 @@ class ConsentNotGivenError(SynapseError):
)
self._consent_uri = consent_uri
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
@@ -307,6 +314,37 @@ class AuthError(SynapseError):
super().__init__(code, msg, errcode, additional_fields)
+class UnstableSpecAuthError(AuthError):
+ """An error raised when a new error code is being proposed to replace a previous one.
+ This error will return a "org.matrix.unstable.errcode" property with the new error code,
+ with the previous error code still being defined in the "errcode" property.
+
+ This error will include `org.matrix.msc3848.unstable.errcode` in the C-S error body.
+ """
+
+ def __init__(
+ self,
+ code: int,
+ msg: str,
+ errcode: str,
+ previous_errcode: str = Codes.FORBIDDEN,
+ additional_fields: Optional[dict] = None,
+ ):
+ self.previous_errcode = previous_errcode
+ super().__init__(code, msg, errcode, additional_fields)
+
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
+ fields = {}
+ if config is not None and config.experimental.msc3848_enabled:
+ fields["org.matrix.msc3848.unstable.errcode"] = self.errcode
+ return cs_error(
+ self.msg,
+ self.previous_errcode,
+ **fields,
+ **self._additional_fields,
+ )
+
+
class InvalidClientCredentialsError(SynapseError):
"""An error raised when there was a problem with the authorisation credentials
in a client request.
@@ -338,8 +376,8 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout
- def error_dict(self) -> "JsonDict":
- d = super().error_dict()
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
+ d = super().error_dict(config)
d["soft_logout"] = self._soft_logout
return d
@@ -362,7 +400,7 @@ class ResourceLimitError(SynapseError):
self.limit_type = limit_type
super().__init__(code, msg, errcode=errcode)
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(
self.msg,
self.errcode,
@@ -397,7 +435,7 @@ class InvalidCaptchaError(SynapseError):
super().__init__(code, msg, errcode)
self.error_url = error_url
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, error_url=self.error_url)
@@ -414,7 +452,7 @@ class LimitExceededError(SynapseError):
super().__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@@ -429,7 +467,7 @@ class RoomKeysVersionError(SynapseError):
super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION)
self.current_version = current_version
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, current_version=self.current_version)
@@ -469,7 +507,7 @@ class IncompatibleRoomVersionError(SynapseError):
self._room_version = room_version
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, room_version=self._room_version)
@@ -515,7 +553,7 @@ class UnredactedContentDeletedError(SynapseError):
)
self.content_keep_ms = content_keep_ms
- def error_dict(self) -> "JsonDict":
+ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
extra = {}
if self.content_keep_ms is not None:
extra = {"fi.mau.msc2815.content_keep_ms": self.content_keep_ms}
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index f43965c1c8..044c7d4926 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,7 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
-from synapse.config.ratelimiting import RateLimitConfig
+from synapse.config.ratelimiting import RatelimitSettings
from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
@@ -314,8 +314,8 @@ class RequestRatelimiter:
self,
store: DataStore,
clock: Clock,
- rc_message: RateLimitConfig,
- rc_admin_redaction: Optional[RateLimitConfig],
+ rc_message: RatelimitSettings,
+ rc_admin_redaction: Optional[RatelimitSettings],
):
self.store = store
self.clock = clock
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 00e81b3afc..a0e4ab6db6 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -269,24 +269,6 @@ class RoomVersions:
msc3787_knock_restricted_join_rule=False,
msc3667_int_only_power_levels=False,
)
- MSC2716v3 = RoomVersion(
- "org.matrix.msc2716v3",
- RoomDisposition.UNSTABLE,
- EventFormatVersions.V3,
- StateResolutionVersions.V2,
- enforce_key_validity=True,
- special_case_aliases_auth=False,
- strict_canonicaljson=True,
- limit_notifications_power_levels=True,
- msc2176_redaction_rules=False,
- msc3083_join_rules=False,
- msc3375_redaction_rules=False,
- msc2403_knocking=True,
- msc2716_historical=True,
- msc2716_redactions=True,
- msc3787_knock_restricted_join_rule=False,
- msc3667_int_only_power_levels=False,
- )
MSC3787 = RoomVersion(
"org.matrix.msc3787",
RoomDisposition.UNSTABLE,
@@ -323,6 +305,24 @@ class RoomVersions:
msc3787_knock_restricted_join_rule=True,
msc3667_int_only_power_levels=True,
)
+ MSC2716v4 = RoomVersion(
+ "org.matrix.msc2716v4",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
+ msc2403_knocking=True,
+ msc2716_historical=True,
+ msc2716_redactions=True,
+ msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
+ )
KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
@@ -338,9 +338,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V7,
RoomVersions.V8,
RoomVersions.V9,
- RoomVersions.MSC2716v3,
RoomVersions.MSC3787,
RoomVersions.V10,
+ RoomVersions.MSC2716v4,
)
}
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 87f82bd9a5..8a583d3ec6 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -28,19 +28,22 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.events import EventBase
from synapse.handlers.admin import ExfiltrationWriter
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.appservice import (
+ ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceWorkerStore,
+)
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.storage.databases.main.room import RoomWorkerStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
from synapse.types import StateMap
from synapse.util import SYNAPSE_VERSION
from synapse.util.logcontext import LoggingContext
@@ -49,16 +52,17 @@ logger = logging.getLogger("synapse.app.admin_cmd")
class AdminCmdSlavedStore(
- SlavedReceiptsStore,
- SlavedAccountDataStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
SlavedFilteringStore,
- SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedPushRuleStore,
SlavedEventStore,
- BaseSlavedStore,
+ TagsWorkerStore,
+ DeviceInboxWorkerStore,
+ AccountDataWorkerStore,
+ ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceWorkerStore,
+ RegistrationWorkerStore,
+ ReceiptsWorkerStore,
RoomWorkerStore,
):
def __init__(
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 4a987fb759..30e21d9707 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -48,20 +48,12 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client import (
account_data,
@@ -100,8 +92,15 @@ from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.appservice import (
+ ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceWorkerStore,
+)
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.databases.main.directory import DirectoryWorkerStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
from synapse.storage.databases.main.lock import LockStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
@@ -110,11 +109,15 @@ from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.databases.main.presence import PresenceStore
+from synapse.storage.databases.main.profile import ProfileWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.storage.databases.main.room_batch import RoomBatchStore
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
@@ -227,11 +230,11 @@ class GenericWorkerSlavedStore(
UIAuthWorkerStore,
EndToEndRoomKeyStore,
PresenceStore,
- SlavedDeviceInboxStore,
+ DeviceInboxWorkerStore,
SlavedDeviceStore,
- SlavedReceiptsStore,
SlavedPushRuleStore,
- SlavedAccountDataStore,
+ TagsWorkerStore,
+ AccountDataWorkerStore,
SlavedPusherStore,
CensorEventsStore,
ClientIpWorkerStore,
@@ -239,19 +242,20 @@ class GenericWorkerSlavedStore(
SlavedKeyStore,
RoomWorkerStore,
RoomBatchStore,
- DirectoryStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedProfileStore,
+ DirectoryWorkerStore,
+ ApplicationServiceTransactionWorkerStore,
+ ApplicationServiceWorkerStore,
+ ProfileWorkerStore,
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
ServerMetricsStore,
+ ReceiptsWorkerStore,
+ RegistrationWorkerStore,
SearchStore,
TransactionWorkerStore,
LockStore,
SessionStore,
- BaseSlavedStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.
@@ -437,6 +441,13 @@ def start(config_options: List[str]) -> None:
"synapse.app.user_dir",
)
+ if config.experimental.faster_joins_enabled:
+ raise ConfigError(
+ "You have enabled the experimental `faster_joins` config option, but it is "
+ "not compatible with worker deployments yet. Please disable `faster_joins` "
+ "or run Synapse as a single process deployment instead."
+ )
+
synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6bafa7d3f3..68993d91a9 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -219,7 +219,10 @@ class SynapseHomeServer(HomeServer):
resources.update({"/_matrix/consent": consent_resource})
if name == "federation":
- resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
+ federation_resource: Resource = TransportLayerServer(self)
+ if compress:
+ federation_resource = gz_wrap(federation_resource)
+ resources.update({FEDERATION_PREFIX: federation_resource})
if name == "openid":
resources.update(
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index d1335e77cd..b3972ede96 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -23,7 +23,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
This server's configuration file is using the deprecated 'template_dir' setting in the
'account_validity' section. Support for this setting has been deprecated and will be
removed in a future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 3ead80d985..a3af35b7c4 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -52,7 +52,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
This server's configuration file is using the deprecated 'template_dir' setting in the
'email' section. Support for this setting has been deprecated and will be removed in a
future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------"""
@@ -85,14 +85,19 @@ class EmailConfig(Config):
if email_config is None:
email_config = {}
+ self.force_tls = email_config.get("force_tls", False)
self.email_smtp_host = email_config.get("smtp_host", "localhost")
- self.email_smtp_port = email_config.get("smtp_port", 25)
+ self.email_smtp_port = email_config.get(
+ "smtp_port", 465 if self.force_tls else 25
+ )
self.email_smtp_user = email_config.get("smtp_user", None)
self.email_smtp_pass = email_config.get("smtp_pass", None)
self.require_transport_security = email_config.get(
"require_transport_security", False
)
self.enable_smtp_tls = email_config.get("enable_tls", True)
+ if self.force_tls and not self.enable_smtp_tls:
+ raise ConfigError("email.force_tls requires email.enable_tls to be true")
if self.require_transport_security and not self.enable_smtp_tls:
raise ConfigError(
"email.require_transport_security requires email.enable_tls to be true"
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index ee443cea00..c1ff417539 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -32,7 +32,7 @@ class ExperimentalConfig(Config):
# MSC2716 (importing historical messages)
self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
- # MSC2285 (private read receipts)
+ # MSC2285 (unstable private read receipts)
self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)
# MSC3244 (room version capabilities)
@@ -88,5 +88,8 @@ class ExperimentalConfig(Config):
# MSC3715: dir param on /relations.
self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
- # MSC3827: Filtering of /publicRooms by room type
- self.msc3827_enabled: bool = experimental.get("msc3827_enabled", False)
+ # MSC3848: Introduce errcodes for specific event sending failures
+ self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
+
+ # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
+ self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 5a91917b4a..1ed001e105 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -21,7 +21,7 @@ from synapse.types import JsonDict
from ._base import Config
-class RateLimitConfig:
+class RatelimitSettings:
def __init__(
self,
config: Dict[str, float],
@@ -34,7 +34,7 @@ class RateLimitConfig:
@attr.s(auto_attribs=True)
-class FederationRateLimitConfig:
+class FederationRatelimitSettings:
window_size: int = 1000
sleep_limit: int = 10
sleep_delay: int = 500
@@ -50,11 +50,11 @@ class RatelimitConfig(Config):
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
if "rc_message" in config:
- self.rc_message = RateLimitConfig(
+ self.rc_message = RatelimitSettings(
config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0}
)
else:
- self.rc_message = RateLimitConfig(
+ self.rc_message = RatelimitSettings(
{
"per_second": config.get("rc_messages_per_second", 0.2),
"burst_count": config.get("rc_message_burst_count", 10.0),
@@ -64,9 +64,9 @@ class RatelimitConfig(Config):
# Load the new-style federation config, if it exists. Otherwise, fall
# back to the old method.
if "rc_federation" in config:
- self.rc_federation = FederationRateLimitConfig(**config["rc_federation"])
+ self.rc_federation = FederationRatelimitSettings(**config["rc_federation"])
else:
- self.rc_federation = FederationRateLimitConfig(
+ self.rc_federation = FederationRatelimitSettings(
**{
k: v
for k, v in {
@@ -80,17 +80,17 @@ class RatelimitConfig(Config):
}
)
- self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_registration = RatelimitSettings(config.get("rc_registration", {}))
- self.rc_registration_token_validity = RateLimitConfig(
+ self.rc_registration_token_validity = RatelimitSettings(
config.get("rc_registration_token_validity", {}),
defaults={"per_second": 0.1, "burst_count": 5},
)
rc_login_config = config.get("rc_login", {})
- self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
- self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
- self.rc_login_failed_attempts = RateLimitConfig(
+ self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {}))
+ self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {}))
+ self.rc_login_failed_attempts = RatelimitSettings(
rc_login_config.get("failed_attempts", {})
)
@@ -101,20 +101,20 @@ class RatelimitConfig(Config):
rc_admin_redaction = config.get("rc_admin_redaction")
self.rc_admin_redaction = None
if rc_admin_redaction:
- self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
+ self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction)
- self.rc_joins_local = RateLimitConfig(
+ self.rc_joins_local = RatelimitSettings(
config.get("rc_joins", {}).get("local", {}),
defaults={"per_second": 0.1, "burst_count": 10},
)
- self.rc_joins_remote = RateLimitConfig(
+ self.rc_joins_remote = RatelimitSettings(
config.get("rc_joins", {}).get("remote", {}),
defaults={"per_second": 0.01, "burst_count": 10},
)
# Track the rate of joins to a given room. If there are too many, temporarily
# prevent local joins and remote joins via this server.
- self.rc_joins_per_room = RateLimitConfig(
+ self.rc_joins_per_room = RatelimitSettings(
config.get("rc_joins_per_room", {}),
defaults={"per_second": 1, "burst_count": 10},
)
@@ -124,31 +124,31 @@ class RatelimitConfig(Config):
# * For requests received over federation this is keyed by the origin.
#
# Note that this isn't exposed in the configuration as it is obscure.
- self.rc_key_requests = RateLimitConfig(
+ self.rc_key_requests = RatelimitSettings(
config.get("rc_key_requests", {}),
defaults={"per_second": 20, "burst_count": 100},
)
- self.rc_3pid_validation = RateLimitConfig(
+ self.rc_3pid_validation = RatelimitSettings(
config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5},
)
- self.rc_invites_per_room = RateLimitConfig(
+ self.rc_invites_per_room = RatelimitSettings(
config.get("rc_invites", {}).get("per_room", {}),
defaults={"per_second": 0.3, "burst_count": 10},
)
- self.rc_invites_per_user = RateLimitConfig(
+ self.rc_invites_per_user = RatelimitSettings(
config.get("rc_invites", {}).get("per_user", {}),
defaults={"per_second": 0.003, "burst_count": 5},
)
- self.rc_invites_per_issuer = RateLimitConfig(
+ self.rc_invites_per_issuer = RatelimitSettings(
config.get("rc_invites", {}).get("per_issuer", {}),
defaults={"per_second": 0.3, "burst_count": 10},
)
- self.rc_third_party_invite = RateLimitConfig(
+ self.rc_third_party_invite = RatelimitSettings(
config.get("rc_third_party_invite", {}),
defaults={
"per_second": self.rc_message.per_second,
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 685a0423c5..a888d976f2 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -60,7 +60,6 @@ class RegistrationConfig(Config):
account_threepid_delegates = config.get("account_threepid_delegates") or {}
if "email" in account_threepid_delegates:
raise ConfigError(NO_EMAIL_DELEGATE_ERROR)
- # self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 2178cbf983..a452cc3a49 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -26,7 +26,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
This server's configuration file is using the deprecated 'template_dir' setting in the
'sso' section. Support for this setting has been deprecated and will be removed in a
future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 7520647d1e..23b799ac32 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.utils import prune_event, prune_event_dict
+from synapse.logging.opentracing import trace
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ logger = logging.getLogger(__name__)
Hasher = Callable[[bytes], "hashlib._Hash"]
+@trace
def check_event_content_hash(
event: EventBase, hash_algorithm: Hasher = hashlib.sha256
) -> bool:
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 965cb265da..389b0c5d53 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -30,7 +30,13 @@ from synapse.api.constants import (
JoinRules,
Membership,
)
-from synapse.api.errors import AuthError, EventSizeError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ EventSizeError,
+ SynapseError,
+ UnstableSpecAuthError,
+)
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@@ -291,7 +297,11 @@ def check_state_dependent_auth_rules(
invite_level = get_named_level(auth_dict, "invite", 0)
if user_level < invite_level:
- raise AuthError(403, "You don't have permission to invite users")
+ raise UnstableSpecAuthError(
+ 403,
+ "You don't have permission to invite users",
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
else:
logger.debug("Allowing! %s", event)
return
@@ -474,7 +484,11 @@ def _is_membership_change_allowed(
return
if not caller_in_room: # caller isn't joined
- raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id))
+ raise UnstableSpecAuthError(
+ 403,
+ "%s not in room %s." % (event.user_id, event.room_id),
+ errcode=Codes.NOT_JOINED,
+ )
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
@@ -484,10 +498,18 @@ def _is_membership_change_allowed(
if target_banned:
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
elif target_in_room: # the target is already in the room.
- raise AuthError(403, "%s is already in the room." % target_user_id)
+ raise UnstableSpecAuthError(
+ 403,
+ "%s is already in the room." % target_user_id,
+ errcode=Codes.ALREADY_JOINED,
+ )
else:
if user_level < invite_level:
- raise AuthError(403, "You don't have permission to invite users")
+ raise UnstableSpecAuthError(
+ 403,
+ "You don't have permission to invite users",
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and:
# * They are not banned.
@@ -549,15 +571,27 @@ def _is_membership_change_allowed(
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
- raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
+ raise UnstableSpecAuthError(
+ 403,
+ "You cannot unban user %s." % (target_user_id,),
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
elif target_user_id != event.user_id:
kick_level = get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
- raise AuthError(403, "You cannot kick user %s." % target_user_id)
+ raise UnstableSpecAuthError(
+ 403,
+ "You cannot kick user %s." % target_user_id,
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
- raise AuthError(403, "You don't have permission to ban")
+ raise UnstableSpecAuthError(
+ 403,
+ "You don't have permission to ban",
+ errcode=Codes.INSUFFICIENT_POWER,
+ )
elif room_version.msc2403_knocking and Membership.KNOCK == membership:
if join_rule != JoinRules.KNOCK and (
not room_version.msc3787_knock_restricted_join_rule
@@ -567,7 +601,11 @@ def _is_membership_change_allowed(
elif target_user_id != event.user_id:
raise AuthError(403, "You cannot knock for other users")
elif target_in_room:
- raise AuthError(403, "You cannot knock on a room you are already in")
+ raise UnstableSpecAuthError(
+ 403,
+ "You cannot knock on a room you are already in",
+ errcode=Codes.ALREADY_JOINED,
+ )
elif caller_invited:
raise AuthError(403, "You are already invited to this room")
elif target_banned:
@@ -638,10 +676,11 @@ def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> b
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
- raise AuthError(
+ raise UnstableSpecAuthError(
403,
"You don't have permission to post that to the room. "
+ "user_level (%d) < send_level (%d)" % (user_level, send_level),
+ errcode=Codes.INSUFFICIENT_POWER,
)
# Check state_key
@@ -716,9 +755,10 @@ def check_historical(
historical_level = get_named_level(auth_events, "historical", 100)
if user_level < historical_level:
- raise AuthError(
+ raise UnstableSpecAuthError(
403,
'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
+ errcode=Codes.INSUFFICIENT_POWER,
)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index b700cbbfa1..d3c8083e4a 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
import attr
from frozendict import frozendict
-from typing_extensions import Literal
from synapse.appservice import ApplicationService
from synapse.events import EventBase
@@ -33,7 +32,7 @@ class EventContext:
Holds information relevant to persisting an event
Attributes:
- rejected: A rejection reason if the event was rejected, else False
+ rejected: A rejection reason if the event was rejected, else None
_state_group: The ID of the state group for this event. Note that state events
are persisted with a state group which includes the new event, so this is
@@ -85,7 +84,7 @@ class EventContext:
"""
_storage: "StorageControllers"
- rejected: Union[Literal[False], str] = False
+ rejected: Optional[str] = None
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
_state_delta_due_to_event: Optional[StateMap[str]] = None
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 4a3bfb38f1..623a2c71ea 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -32,6 +32,7 @@ from typing_extensions import Literal
import synapse
from synapse.api.errors import Codes
+from synapse.logging.opentracing import trace
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
@@ -378,6 +379,7 @@ class SpamChecker:
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
+ @trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
) -> Union[Tuple[Codes, JsonDict], str]:
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ac91c5eb57..71853caad8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -161,7 +161,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
add_fields(EventContentFields.MSC2716_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
- add_fields(EventContentFields.MSC2716_MARKER_INSERTION)
+ add_fields(EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE)
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2522bf78fc..4269a98db2 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -23,6 +23,7 @@ from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event, validate_canonicaljson
from synapse.http.servlet import assert_params_in_dict
+from synapse.logging.opentracing import log_kv, trace
from synapse.types import JsonDict, get_domain_from_id
if TYPE_CHECKING:
@@ -55,6 +56,7 @@ class FederationBase:
self._clock = hs.get_clock()
self._storage_controllers = hs.get_storage_controllers()
+ @trace
async def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase
) -> EventBase:
@@ -97,17 +99,36 @@ class FederationBase:
"Event %s seems to have been redacted; using our redacted copy",
pdu.event_id,
)
+ log_kv(
+ {
+ "message": "Event seems to have been redacted; using our redacted copy",
+ "event_id": pdu.event_id,
+ }
+ )
else:
logger.warning(
"Event %s content has been tampered, redacting",
pdu.event_id,
)
+ log_kv(
+ {
+ "message": "Event content has been tampered, redacting",
+ "event_id": pdu.event_id,
+ }
+ )
return redacted_event
spam_check = await self.spam_checker.check_event_for_spam(pdu)
if spam_check != self.spam_checker.NOT_SPAM:
logger.warning("Event contains spam, soft-failing %s", pdu.event_id)
+ log_kv(
+ {
+ "message": "Event contains spam, redacting (to save disk space) "
+ "as well as soft-failing (to stop using the event in prev_events)",
+ "event_id": pdu.event_id,
+ }
+ )
# we redact (to save disk space) as well as soft-failing (to stop
# using the event in prev_events).
redacted_event = prune_event(pdu)
@@ -117,6 +138,7 @@ class FederationBase:
return pdu
+@trace
async def _check_sigs_on_pdu(
keyring: Keyring, room_version: RoomVersion, pdu: EventBase
) -> None:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 842f5327c2..7ee2974bb1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
)
from synapse.federation.transport.client import SendJoinResponse
from synapse.http.types import QueryParams
+from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@@ -233,6 +234,8 @@ class FederationClient(FederationBase):
destination, content, timeout
)
+ @trace
+ @tag_args
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> Optional[List[EventBase]]:
@@ -335,6 +338,8 @@ class FederationClient(FederationBase):
return None
+ @trace
+ @tag_args
async def get_pdu(
self,
destinations: Iterable[str],
@@ -403,9 +408,9 @@ class FederationClient(FederationBase):
# Prime the cache
self._get_pdu_cache[event.event_id] = event
- # FIXME: We should add a `break` here to avoid calling every
- # destination after we already found a PDU (will follow-up
- # in a separate PR)
+ # Now that we have an event, we can break out of this
+ # loop and stop asking other destinations.
+ break
except SynapseError as e:
logger.info(
@@ -446,6 +451,8 @@ class FederationClient(FederationBase):
return event_copy
+ @trace
+ @tag_args
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> Tuple[List[str], List[str]]:
@@ -465,6 +472,23 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "state_event_ids",
+ str(state_event_ids),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "state_event_ids.length",
+ str(len(state_event_ids)),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "auth_event_ids",
+ str(auth_event_ids),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "auth_event_ids.length",
+ str(len(auth_event_ids)),
+ )
+
if not isinstance(state_event_ids, list) or not isinstance(
auth_event_ids, list
):
@@ -472,6 +496,8 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids
+ @trace
+ @tag_args
async def get_room_state(
self,
destination: str,
@@ -531,6 +557,7 @@ class FederationClient(FederationBase):
return valid_state_events, valid_auth_events
+ @trace
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
@@ -560,11 +587,15 @@ class FederationClient(FederationBase):
Returns:
A list of PDUs that have valid signatures and hashes.
"""
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "pdus.length",
+ str(len(pdus)),
+ )
# We limit how many PDUs we check at once, as if we try to do hundreds
# of thousands of PDUs at once we see large memory spikes.
- valid_pdus = []
+ valid_pdus: List[EventBase] = []
async def _execute(pdu: EventBase) -> None:
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
@@ -580,6 +611,8 @@ class FederationClient(FederationBase):
return valid_pdus
+ @trace
+ @tag_args
async def _check_sigs_and_hash_and_fetch_one(
self,
pdu: EventBase,
@@ -612,16 +645,27 @@ class FederationClient(FederationBase):
except InvalidEventSignatureError as e:
logger.warning(
"Signature on retrieved event %s was invalid (%s). "
- "Checking local store/orgin server",
+ "Checking local store/origin server",
pdu.event_id,
e,
)
+ log_kv(
+ {
+ "message": "Signature on retrieved event was invalid. "
+ "Checking local store/origin server",
+ "event_id": pdu.event_id,
+ "InvalidEventSignatureError": e,
+ }
+ )
# Check local db.
res = await self.store.get_event(
pdu.event_id, allow_rejected=True, allow_none=True
)
+ # If the PDU fails its signature check and we don't have it in our
+ # database, we then request it from sender's server (if that is not the
+ # same as `origin`).
pdu_origin = get_domain_from_id(pdu.sender)
if not res and pdu_origin != origin:
try:
@@ -725,6 +769,12 @@ class FederationClient(FederationBase):
if failover_errcodes is None:
failover_errcodes = ()
+ if not destinations:
+ # Give a bit of a clearer message if no servers were specified at all.
+ raise SynapseError(
+ 502, f"Failed to {description} via any server: No servers specified."
+ )
+
for destination in destinations:
if destination == self.server_name:
continue
@@ -774,7 +824,7 @@ class FederationClient(FederationBase):
"Failed to %s via %s", description, destination, exc_info=True
)
- raise SynapseError(502, "Failed to %s via any server" % (description,))
+ raise SynapseError(502, f"Failed to {description} via any server")
async def make_membership_event(
self,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ae550d3f4d..75fbc6073d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -61,7 +61,12 @@ from synapse.logging.context import (
nested_logging_context,
run_in_background,
)
-from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
+from synapse.logging.opentracing import (
+ log_kv,
+ start_active_span_from_edu,
+ tag_args,
+ trace,
+)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
@@ -469,7 +474,7 @@ class FederationServer(FederationBase):
)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
- pdu_results[event_id] = e.error_dict()
+ pdu_results[event_id] = e.error_dict(self.hs.config)
return
for pdu in pdus_by_room[room_id]:
@@ -547,6 +552,8 @@ class FederationServer(FederationBase):
return 200, resp
+ @trace
+ @tag_args
async def on_state_ids_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
@@ -569,6 +576,8 @@ class FederationServer(FederationBase):
return 200, resp
+ @trace
+ @tag_args
async def _on_state_ids_request_compute(
self, room_id: str, event_id: str
) -> JsonDict:
@@ -843,8 +852,25 @@ class FederationServer(FederationBase):
Codes.BAD_JSON,
)
+ # Note that get_room_version throws if the room does not exist here.
room_version = await self.store.get_room_version(room_id)
+ if await self.store.is_partial_state_room(room_id):
+ # If our server is still only partially joined, we can't give a complete
+ # response to /send_join, /send_knock or /send_leave.
+ # This is because we will not be able to provide the server list (for partial
+ # joins) or the full state (for full joins).
+ # Return a 404 as we would if we weren't in the room at all.
+ logger.info(
+ f"Rejecting /send_{membership_type} to %s because it's a partial state room",
+ room_id,
+ )
+ raise SynapseError(
+ 404,
+ f"Unable to handle /send_{membership_type} right now; this server is not fully joined.",
+ errcode=Codes.NOT_FOUND,
+ )
+
if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
raise SynapseError(
403,
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 84100a5a52..bb0f8d6b7b 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -309,7 +309,7 @@ class BaseFederationServlet:
raise
# update the active opentracing span with the authenticated entity
- set_tag("authenticated_entity", origin)
+ set_tag("authenticated_entity", str(origin))
# if the origin is authenticated and whitelisted, use its span context
# as the parent.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3d83236b0c..0327fc57a4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -280,7 +280,7 @@ class AuthHandler:
that it isn't stolen by re-authenticating them.
Args:
- requester: The user, as given by the access token
+ requester: The user making the request, according to the access token.
request: The request sent by the client.
@@ -565,7 +565,7 @@ class AuthHandler:
except LoginError as e:
# this step failed. Merge the error dict into the response
# so that the client can have another go.
- errordict = e.error_dict()
+ errordict = e.error_dict(self.hs.config)
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
@@ -1435,20 +1435,25 @@ class AuthHandler:
access_token: access token to be deleted
"""
- user_info = await self.auth.get_user_by_access_token(access_token)
+ token = await self.store.get_user_by_access_token(access_token)
+ if not token:
+ # At this point, the token should already have been fetched once by
+ # the caller, so this should not happen, unless of a race condition
+ # between two delete requests
+ raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token")
await self.store.delete_access_token(access_token)
# see if any modules want to know about this
await self.password_auth_provider.on_logged_out(
- user_id=user_info.user_id,
- device_id=user_info.device_id,
+ user_id=token.user_id,
+ device_id=token.device_id,
access_token=access_token,
)
# delete pushers associated with this access token
- if user_info.token_id is not None:
+ if token.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token(
- user_info.user_id, (user_info.token_id,)
+ token.user_id, (token.token_id,)
)
async def delete_access_tokens_for_user(
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index c05a170c55..f5c586f657 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -74,6 +74,7 @@ class DeviceWorkerHandler:
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
+ self._msc3852_enabled = hs.config.experimental.msc3852_enabled
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@@ -118,8 +119,8 @@ class DeviceWorkerHandler:
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
- set_tag("device", device)
- set_tag("ips", ips)
+ set_tag("device", str(device))
+ set_tag("ips", str(ips))
return device
@@ -170,7 +171,7 @@ class DeviceWorkerHandler:
"""
set_tag("user_id", user_id)
- set_tag("from_token", from_token)
+ set_tag("from_token", str(from_token))
now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -747,7 +748,13 @@ def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
+ device.update(
+ {
+ "last_seen_user_agent": ip.get("user_agent"),
+ "last_seen_ts": ip.get("last_seen"),
+ "last_seen_ip": ip.get("ip"),
+ }
+ )
class DeviceListUpdater:
@@ -795,7 +802,7 @@ class DeviceListUpdater:
"""
set_tag("origin", origin)
- set_tag("edu_content", edu_content)
+ set_tag("edu_content", str(edu_content))
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 09a7a4b238..948f66a94d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -30,7 +30,7 @@ from synapse.api.errors import (
from synapse.appservice import ApplicationService
from synapse.module_api import NOT_SPAM
from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -133,7 +133,7 @@ class DirectoryHandler:
else:
# Server admins are not subject to the same constraints as normal
# users when creating an alias (e.g. being in the room).
- is_admin = await self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester)
if (self.require_membership and check_membership) and not is_admin:
rooms_for_user = await self.store.get_rooms_for_user(user_id)
@@ -197,7 +197,7 @@ class DirectoryHandler:
user_id = requester.user.to_string()
try:
- can_delete = await self._user_can_delete_alias(room_alias, user_id)
+ can_delete = await self._user_can_delete_alias(room_alias, requester)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown room alias")
@@ -400,7 +400,9 @@ class DirectoryHandler:
# either no interested services, or no service with an exclusive lock
return True
- async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
+ async def _user_can_delete_alias(
+ self, alias: RoomAlias, requester: Requester
+ ) -> bool:
"""Determine whether a user can delete an alias.
One of the following must be true:
@@ -413,7 +415,7 @@ class DirectoryHandler:
"""
creator = await self.store.get_room_alias_creator(alias.to_string())
- if creator == user_id:
+ if creator == requester.user.to_string():
return True
# Resolve the alias to the corresponding room.
@@ -422,9 +424,7 @@ class DirectoryHandler:
if not room_id:
return False
- return await self.auth.check_can_change_room_list(
- room_id, UserID.from_string(user_id)
- )
+ return await self.auth.check_can_change_room_list(room_id, requester)
async def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str
@@ -463,7 +463,7 @@ class DirectoryHandler:
raise SynapseError(400, "Unknown room")
can_change_room_list = await self.auth.check_can_change_room_list(
- room_id, requester.user
+ room_id, requester
)
if not can_change_room_list:
raise AuthError(
@@ -528,10 +528,8 @@ class DirectoryHandler:
Get a list of the aliases that currently point to this room on this server
"""
# allow access to server admins and current members of the room
- is_admin = await self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester)
if not is_admin:
- await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string()
- )
+ await self.auth.check_user_in_room_or_world_readable(room_id, requester)
return await self.store.get_aliases_for_room(room_id)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 84c28c480e..c938339ddd 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -138,8 +138,8 @@ class E2eKeysHandler:
else:
remote_queries[user_id] = device_ids
- set_tag("local_key_query", local_query)
- set_tag("remote_key_query", remote_queries)
+ set_tag("local_key_query", str(local_query))
+ set_tag("remote_key_query", str(remote_queries))
# First get local devices.
# A map of destination -> failure response.
@@ -343,7 +343,7 @@ class E2eKeysHandler:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
- set_tag("reason", failure)
+ set_tag("reason", str(failure))
return
@@ -405,7 +405,7 @@ class E2eKeysHandler:
Returns:
A map from user_id -> device_id -> device details
"""
- set_tag("local_query", query)
+ set_tag("local_query", str(query))
local_query: List[Tuple[str, Optional[str]]] = []
result_dict: Dict[str, Dict[str, dict]] = {}
@@ -477,8 +477,8 @@ class E2eKeysHandler:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
- set_tag("local_key_query", local_query)
- set_tag("remote_key_query", remote_queries)
+ set_tag("local_key_query", str(local_query))
+ set_tag("remote_key_query", str(remote_queries))
results = await self.store.claim_e2e_one_time_keys(local_query)
@@ -508,7 +508,7 @@ class E2eKeysHandler:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
- set_tag("reason", failure)
+ set_tag("reason", str(failure))
await make_deferred_yieldable(
defer.gatherResults(
@@ -611,7 +611,7 @@ class E2eKeysHandler:
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
- set_tag("one_time_key_counts", result)
+ set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 446f509bdc..28dc08c22a 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Optional, cast
from typing_extensions import Literal
@@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
user_id, version, room_id, session_id
)
- log_kv(results)
+ log_kv(cast(JsonDict, results))
return results
@trace
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 11a005f0bf..4bb4d09d4a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -32,6 +32,7 @@ from typing import (
)
import attr
+from prometheus_client import Histogram
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -59,6 +60,7 @@ from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import NOT_SPAM
from synapse.replication.http.federation import (
@@ -78,6 +80,29 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Added to debug performance and track progress on optimizations
+backfill_processing_before_timer = Histogram(
+ "synapse_federation_backfill_processing_before_time_seconds",
+ "sec",
+ [],
+ buckets=(
+ 0.1,
+ 0.5,
+ 1.0,
+ 2.5,
+ 5.0,
+ 7.5,
+ 10.0,
+ 15.0,
+ 20.0,
+ 30.0,
+ 40.0,
+ 60.0,
+ 80.0,
+ "+Inf",
+ ),
+)
+
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state
@@ -137,6 +162,7 @@ class FederationHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
+ self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -180,6 +206,7 @@ class FederationHandler:
"resume_sync_partial_state_room", self._resume_sync_partial_state_room
)
+ @trace
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
) -> bool:
@@ -195,12 +222,39 @@ class FederationHandler:
return. This is used as part of the heuristic to decide if we
should back paginate.
"""
+ # Starting the processing time here so we can include the room backfill
+ # linearizer lock queue in the timing
+ processing_start_time = self.clock.time_msec()
+
async with self._room_backfill.queue(room_id):
- return await self._maybe_backfill_inner(room_id, current_depth, limit)
+ return await self._maybe_backfill_inner(
+ room_id,
+ current_depth,
+ limit,
+ processing_start_time=processing_start_time,
+ )
async def _maybe_backfill_inner(
- self, room_id: str, current_depth: int, limit: int
+ self,
+ room_id: str,
+ current_depth: int,
+ limit: int,
+ *,
+ processing_start_time: int,
) -> bool:
+ """
+ Checks whether the `current_depth` is at or approaching any backfill
+ points in the room and if so, will backfill. We only care about
+ checking backfill points that happened before the `current_depth`
+ (meaning less than or equal to the `current_depth`).
+
+ Args:
+ room_id: The room to backfill in.
+ current_depth: The depth to check at for any upcoming backfill points.
+ limit: The max number of events to request from the remote federated server.
+ processing_start_time: The time when `maybe_backfill` started
+ processing. Only used for timing.
+ """
backwards_extremities = [
_BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room(
@@ -368,6 +422,14 @@ class FederationHandler:
logger.debug(
"_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
)
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "extremities_to_request",
+ str(extremities_to_request),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "extremities_to_request.length",
+ str(len(extremities_to_request)),
+ )
# Now we need to decide which hosts to hit first.
@@ -423,6 +485,11 @@ class FederationHandler:
return False
+ processing_end_time = self.clock.time_msec()
+ backfill_processing_before_timer.observe(
+ (processing_end_time - processing_start_time) / 1000
+ )
+
success = await try_backfill(likely_domains)
if success:
return True
@@ -546,9 +613,9 @@ class FederationHandler:
)
if ret.partial_state:
- # TODO(faster_joins): roll this back if we don't manage to start the
- # background resync (eg process_remote_join fails)
- # https://github.com/matrix-org/synapse/issues/12998
+ # Mark the room as having partial state.
+ # The background process is responsible for unmarking this flag,
+ # even if the join fails.
await self.store.store_partial_state_room(room_id, ret.servers_in_room)
try:
@@ -574,17 +641,21 @@ class FederationHandler:
room_id,
)
raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
-
- if ret.partial_state:
- # Kick off the process of asynchronously fetching the state for this
- # room.
- run_as_background_process(
- desc="sync_partial_state_room",
- func=self._sync_partial_state_room,
- initial_destination=origin,
- other_destinations=ret.servers_in_room,
- room_id=room_id,
- )
+ finally:
+ # Always kick off the background process that asynchronously fetches
+ # state for the room.
+ # If the join failed, the background process is responsible for
+ # cleaning up — including unmarking the room as a partial state room.
+ if ret.partial_state:
+ # Kick off the process of asynchronously fetching the state for this
+ # room.
+ run_as_background_process(
+ desc="sync_partial_state_room",
+ func=self._sync_partial_state_room,
+ initial_destination=origin,
+ other_destinations=ret.servers_in_room,
+ room_id=room_id,
+ )
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
@@ -748,6 +819,23 @@ class FederationHandler:
# (and return a 404 otherwise)
room_version = await self.store.get_room_version(room_id)
+ if await self.store.is_partial_state_room(room_id):
+ # If our server is still only partially joined, we can't give a complete
+ # response to /make_join, so return a 404 as we would if we weren't in the
+ # room at all.
+ # The main reason we can't respond properly is that we need to know about
+ # the auth events for the join event that we would return.
+ # We also should not bother entertaining the /make_join since we cannot
+ # handle the /send_join.
+ logger.info(
+ "Rejecting /make_join to %s because it's a partial state room", room_id
+ )
+ raise SynapseError(
+ 404,
+ "Unable to handle /make_join right now; this server is not fully joined.",
+ errcode=Codes.NOT_FOUND,
+ )
+
# now check that we are *still* in the room
is_in_room = await self._event_auth_handler.check_host_in_room(
room_id, self.server_name
@@ -1071,6 +1159,8 @@ class FederationHandler:
return event
+ @trace
+ @tag_args
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
@@ -1552,15 +1642,16 @@ class FederationHandler:
# Make an infinite iterator of destinations to try. Once we find a working
# destination, we'll stick with it until it flakes.
+ destinations: Collection[str]
if initial_destination is not None:
# Move `initial_destination` to the front of the list.
destinations = list(other_destinations)
if initial_destination in destinations:
destinations.remove(initial_destination)
destinations = [initial_destination] + destinations
- destination_iter = itertools.cycle(destinations)
else:
- destination_iter = itertools.cycle(other_destinations)
+ destinations = other_destinations
+ destination_iter = itertools.cycle(destinations)
# `destination` is the current remote homeserver we're pulling from.
destination = next(destination_iter)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index a5f4ce7c8a..048c4111f6 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -29,7 +29,7 @@ from typing import (
Tuple,
)
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
from synapse import event_auth
from synapse.api.constants import (
@@ -59,6 +59,13 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import (
+ SynapseTags,
+ set_tag,
+ start_active_span,
+ tag_args,
+ trace,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
@@ -91,6 +98,36 @@ soft_failed_event_counter = Counter(
"Events received over federation that we marked as soft_failed",
)
+# Added to debug performance and track progress on optimizations
+backfill_processing_after_timer = Histogram(
+ "synapse_federation_backfill_processing_after_time_seconds",
+ "sec",
+ [],
+ buckets=(
+ 0.1,
+ 0.25,
+ 0.5,
+ 1.0,
+ 2.5,
+ 5.0,
+ 7.5,
+ 10.0,
+ 15.0,
+ 20.0,
+ 25.0,
+ 30.0,
+ 40.0,
+ 50.0,
+ 60.0,
+ 80.0,
+ 100.0,
+ 120.0,
+ 150.0,
+ 180.0,
+ "+Inf",
+ ),
+)
+
class FederationEventHandler:
"""Handles events that originated from federation.
@@ -278,7 +315,8 @@ class FederationEventHandler:
)
try:
- await self._process_received_pdu(origin, pdu, state_ids=None)
+ context = await self._state_handler.compute_event_context(pdu)
+ await self._process_received_pdu(origin, pdu, context)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time.
@@ -286,7 +324,8 @@ class FederationEventHandler:
"Room %s was un-partial stated while processing the PDU, trying again.",
room_id,
)
- await self._process_received_pdu(origin, pdu, state_ids=None)
+ context = await self._state_handler.compute_event_context(pdu)
+ await self._process_received_pdu(origin, pdu, context)
async def on_send_membership_event(
self, origin: str, event: EventBase
@@ -316,6 +355,7 @@ class FederationEventHandler:
The event and context of the event after inserting it into the room graph.
Raises:
+ RuntimeError if any prev_events are missing
SynapseError if the event is not accepted into the room
PartialStateConflictError if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should
@@ -376,7 +416,7 @@ class FederationEventHandler:
# need to.
await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
- await self._check_for_soft_fail(event, None, origin=origin)
+ await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context)
return event, context
@@ -406,6 +446,7 @@ class FederationEventHandler:
prev_member_event,
)
+ @trace
async def process_remote_join(
self,
origin: str,
@@ -534,32 +575,36 @@ class FederationEventHandler:
#
# This is the same operation as we do when we receive a regular event
# over federation.
- state_ids = await self._resolve_state_at_missing_prevs(destination, event)
-
- # build a new state group for it if need be
- context = await self._state_handler.compute_event_context(
- event,
- state_ids_before_event=state_ids,
+ context = await self._compute_event_context_with_maybe_missing_prevs(
+ destination, event
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
- # partial state - ie, an event has an earlier stream_ordering than one
- # or more of its prev_events, so we de-partial-state it before its
- # prev_events.
+ # partial state. We were careful to only pick events from the db without
+ # partial-state prev events, so that implies that a prev event has
+ # been persisted (with partial state) since we did the query.
#
- # TODO(faster_joins): we probably need to be more intelligent, and
- # exclude partial-state prev_events from consideration
- # https://github.com/matrix-org/synapse/issues/13001
+ # So, let's just ignore `event` for now; when we re-run the db query
+ # we should instead get its partial-state prev event, which we will
+ # de-partial-state, and then come back to event.
logger.warning(
- "%s still has partial state: can't de-partial-state it yet",
+ "%s still has prev_events with partial state: can't de-partial-state it yet",
event.event_id,
)
return
+
+ # since the state at this event has changed, we should now re-evaluate
+ # whether it should have been rejected. We must already have all of the
+ # auth events (from last time we went round this path), so there is no
+ # need to pass the origin.
+ await self._check_event_auth(None, event, context)
+
await self._store.update_state_for_partial_state_event(event, context)
self._state_storage_controller.notify_event_un_partial_stated(
event.event_id
)
+ @trace
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> None:
@@ -589,21 +634,23 @@ class FederationEventHandler:
if not events:
return
- # if there are any events in the wrong room, the remote server is buggy and
- # should not be trusted.
- for ev in events:
- if ev.room_id != room_id:
- raise InvalidResponseError(
- f"Remote server {dest} returned event {ev.event_id} which is in "
- f"room {ev.room_id}, when we were backfilling in {room_id}"
- )
+ with backfill_processing_after_timer.time():
+ # if there are any events in the wrong room, the remote server is buggy and
+ # should not be trusted.
+ for ev in events:
+ if ev.room_id != room_id:
+ raise InvalidResponseError(
+ f"Remote server {dest} returned event {ev.event_id} which is in "
+ f"room {ev.room_id}, when we were backfilling in {room_id}"
+ )
- await self._process_pulled_events(
- dest,
- events,
- backfilled=True,
- )
+ await self._process_pulled_events(
+ dest,
+ events,
+ backfilled=True,
+ )
+ @trace
async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
) -> None:
@@ -704,8 +751,9 @@ class FederationEventHandler:
logger.info("Got %d prev_events", len(missing_events))
await self._process_pulled_events(origin, missing_events, backfilled=False)
+ @trace
async def _process_pulled_events(
- self, origin: str, events: Iterable[EventBase], backfilled: bool
+ self, origin: str, events: Collection[EventBase], backfilled: bool
) -> None:
"""Process a batch of events we have pulled from a remote server
@@ -720,6 +768,15 @@ class FederationEventHandler:
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
"""
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+ str([event.event_id for event in events]),
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+ str(len(events)),
+ )
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
logger.debug(
"processing pulled backfilled=%s events=%s",
backfilled,
@@ -742,6 +799,8 @@ class FederationEventHandler:
with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev, backfilled=backfilled)
+ @trace
+ @tag_args
async def _process_pulled_event(
self, origin: str, event: EventBase, backfilled: bool
) -> None:
@@ -793,7 +852,7 @@ class FederationEventHandler:
if existing:
if not existing.internal_metadata.is_outlier():
logger.info(
- "Ignoring received event %s which we have already seen",
+ "_process_pulled_event: Ignoring received event %s which we have already seen",
event_id,
)
return
@@ -806,29 +865,56 @@ class FederationEventHandler:
return
try:
- state_ids = await self._resolve_state_at_missing_prevs(origin, event)
- # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
- # not return partial state
- # https://github.com/matrix-org/synapse/issues/13002
+ try:
+ context = await self._compute_event_context_with_maybe_missing_prevs(
+ origin, event
+ )
+ await self._process_received_pdu(
+ origin,
+ event,
+ context,
+ backfilled=backfilled,
+ )
+ except PartialStateConflictError:
+ # The room was un-partial stated while we were processing the event.
+ # Try once more, with full state this time.
+ context = await self._compute_event_context_with_maybe_missing_prevs(
+ origin, event
+ )
- await self._process_received_pdu(
- origin, event, state_ids=state_ids, backfilled=backfilled
- )
+ # We ought to have full state now, barring some unlikely race where we left and
+ # rejoned the room in the background.
+ if context.partial_state:
+ raise AssertionError(
+ f"Event {event.event_id} still has a partial resolved state "
+ f"after room {event.room_id} was un-partial stated"
+ )
+
+ await self._process_received_pdu(
+ origin,
+ event,
+ context,
+ backfilled=backfilled,
+ )
except FederationError as e:
if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id)
else:
raise
- async def _resolve_state_at_missing_prevs(
+ @trace
+ async def _compute_event_context_with_maybe_missing_prevs(
self, dest: str, event: EventBase
- ) -> Optional[StateMap[str]]:
- """Calculate the state at an event with missing prev_events.
+ ) -> EventContext:
+ """Build an EventContext structure for a non-outlier event whose prev_events may
+ be missing.
- This is used when we have pulled a batch of events from a remote server, and
- still don't have all the prev_events.
+ This is used when we have pulled a batch of events from a remote server, and may
+ not have all the prev_events.
- If we already have all the prev_events for `event`, this method does nothing.
+ To build an EventContext, we need to calculate the state before the event. If we
+ already have all the prev_events for `event`, we can simply use the state after
+ the prev_events to calculate the state before `event`.
Otherwise, the missing prevs become new backwards extremities, and we fall back
to asking the remote server for the state after each missing `prev_event`,
@@ -849,8 +935,7 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
- if we already had all the prev events, `None`. Otherwise, returns
- the event ids of the state at `event`.
+ The event context.
Raises:
FederationError if we fail to get the state from the remote server after any
@@ -864,7 +949,7 @@ class FederationEventHandler:
missing_prevs = prevs - seen
if not missing_prevs:
- return None
+ return await self._state_handler.compute_event_context(event)
logger.info(
"Event %s is missing prev_events %s: calculating state for a "
@@ -876,9 +961,15 @@ class FederationEventHandler:
# resolve them to find the correct state at the current event.
try:
+ # Determine whether we may be about to retrieve partial state
+ # Events may be un-partial stated right after we compute the partial state
+ # flag, but that's okay, as long as the flag errs on the conservative side.
+ partial_state_flags = await self._store.get_partial_state_events(seen)
+ partial_state = any(partial_state_flags.values())
+
# Get the state of the events we know about
ours = await self._state_storage_controller.get_state_groups_ids(
- room_id, seen
+ room_id, seen, await_full_state=False
)
# state_maps is a list of mappings from (type, state_key) to event_id
@@ -924,8 +1015,12 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
- return state_map
+ return await self._state_handler.compute_event_context(
+ event, state_ids_before_event=state_map, partial_state=partial_state
+ )
+ @trace
+ @tag_args
async def _get_state_ids_after_missing_prev_event(
self,
destination: str,
@@ -965,10 +1060,10 @@ class FederationEventHandler:
logger.debug("Fetching %i events from cache/store", len(desired_events))
have_events = await self._store.have_seen_events(room_id, desired_events)
- missing_desired_events = desired_events - have_events
+ missing_desired_event_ids = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
- len(missing_desired_events),
+ len(missing_desired_event_ids),
len(have_events),
)
@@ -980,13 +1075,30 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
- missing_auth_events = set(auth_event_ids) - have_events
- missing_auth_events.difference_update(
- await self._store.have_seen_events(room_id, missing_auth_events)
+ missing_auth_event_ids = set(auth_event_ids) - have_events
+ missing_auth_event_ids.difference_update(
+ await self._store.have_seen_events(room_id, missing_auth_event_ids)
)
- logger.debug("We are also missing %i auth events", len(missing_auth_events))
+ logger.debug("We are also missing %i auth events", len(missing_auth_event_ids))
+
+ missing_event_ids = missing_desired_event_ids | missing_auth_event_ids
- missing_events = missing_desired_events | missing_auth_events
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "missing_auth_event_ids",
+ str(missing_auth_event_ids),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "missing_auth_event_ids.length",
+ str(len(missing_auth_event_ids)),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "missing_desired_event_ids",
+ str(missing_desired_event_ids),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "missing_desired_event_ids.length",
+ str(len(missing_desired_event_ids)),
+ )
# Making an individual request for each of 1000s of events has a lot of
# overhead. On the other hand, we don't really want to fetch all of the events
@@ -997,13 +1109,13 @@ class FederationEventHandler:
#
# TODO: might it be better to have an API which lets us do an aggregate event
# request
- if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+ if (len(missing_event_ids) * 10) >= len(auth_event_ids) + len(state_event_ids):
logger.debug("Requesting complete state from remote")
await self._get_state_and_persist(destination, room_id, event_id)
else:
- logger.debug("Fetching %i events from remote", len(missing_events))
+ logger.debug("Fetching %i events from remote", len(missing_event_ids))
await self._get_events_and_persist(
- destination=destination, room_id=room_id, event_ids=missing_events
+ destination=destination, room_id=room_id, event_ids=missing_event_ids
)
# We now need to fill out the state map, which involves fetching the
@@ -1060,6 +1172,14 @@ class FederationEventHandler:
event_id,
failed_to_fetch,
)
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "failed_to_fetch",
+ str(failed_to_fetch),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "failed_to_fetch.length",
+ str(len(failed_to_fetch)),
+ )
if remote_event.is_state() and remote_event.rejected_reason is None:
state_map[
@@ -1068,6 +1188,8 @@ class FederationEventHandler:
return state_map
+ @trace
+ @tag_args
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
) -> None:
@@ -1089,11 +1211,12 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=(event_id,)
)
+ @trace
async def _process_received_pdu(
self,
origin: str,
event: EventBase,
- state_ids: Optional[StateMap[str]],
+ context: EventContext,
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@@ -1115,24 +1238,18 @@ class FederationEventHandler:
event: event to be persisted
- state_ids: Normally None, but if we are handling a gap in the graph
- (ie, we are missing one or more prev_events), the resolved state at the
- event. Must not be partial state.
+ context: The `EventContext` to persist the event with.
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
PartialStateConflictError: if the room was un-partial stated in between
- computing the state at the event and persisting it. The caller should retry
- exactly once in this case. Will never be raised if `state_ids` is provided.
+ computing the state at the event and persisting it. The caller should
+ recompute `context` and retry exactly once when this happens.
"""
logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier
- context = await self._state_handler.compute_event_context(
- event,
- state_ids_before_event=state_ids,
- )
try:
await self._check_event_auth(origin, event, context)
except AuthError as e:
@@ -1144,7 +1261,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- await self._check_for_soft_fail(event, state_ids, origin=origin)
+ await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@@ -1245,6 +1362,7 @@ class FederationEventHandler:
except Exception:
logger.exception("Failed to resync device for %s", sender)
+ @trace
async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
"""Handles backfilling the insertion event when we receive a marker
event that points to one.
@@ -1276,7 +1394,7 @@ class FederationEventHandler:
logger.debug("_handle_marker_event: received %s", marker_event)
insertion_event_id = marker_event.content.get(
- EventContentFields.MSC2716_MARKER_INSERTION
+ EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE
)
if insertion_event_id is None:
@@ -1329,6 +1447,55 @@ class FederationEventHandler:
marker_event,
)
+ async def backfill_event_id(
+ self, destination: str, room_id: str, event_id: str
+ ) -> EventBase:
+ """Backfill a single event and persist it as a non-outlier which means
+ we also pull in all of the state and auth events necessary for it.
+
+ Args:
+ destination: The homeserver to pull the given event_id from.
+ room_id: The room where the event is from.
+ event_id: The event ID to backfill.
+
+ Raises:
+ FederationError if we are unable to find the event from the destination
+ """
+ logger.info(
+ "backfill_event_id: event_id=%s from destination=%s", event_id, destination
+ )
+
+ room_version = await self._store.get_room_version(room_id)
+
+ event_from_response = await self._federation_client.get_pdu(
+ [destination],
+ event_id,
+ room_version,
+ )
+
+ if not event_from_response:
+ raise FederationError(
+ "ERROR",
+ 404,
+ "Unable to find event_id=%s from destination=%s to backfill."
+ % (event_id, destination),
+ affected=event_id,
+ )
+
+ # Persist the event we just fetched, including pulling all of the state
+ # and auth events to de-outlier it. This also sets up the necessary
+ # `state_groups` for the event.
+ await self._process_pulled_events(
+ destination,
+ [event_from_response],
+ # Prevent notifications going to clients
+ backfilled=True,
+ )
+
+ return event_from_response
+
+ @trace
+ @tag_args
async def _get_events_and_persist(
self, destination: str, room_id: str, event_ids: Collection[str]
) -> None:
@@ -1374,6 +1541,7 @@ class FederationEventHandler:
logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
await self._auth_and_persist_outliers(room_id, events)
+ @trace
async def _auth_and_persist_outliers(
self, room_id: str, events: Iterable[EventBase]
) -> None:
@@ -1392,6 +1560,16 @@ class FederationEventHandler:
"""
event_map = {event.event_id: event for event in events}
+ event_ids = event_map.keys()
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+ str(event_ids),
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+ str(len(event_ids)),
+ )
+
# filter out any events we have already seen. This might happen because
# the events were eagerly pushed to us (eg, during a room join), or because
# another thread has raced against us since we decided to request the event.
@@ -1508,14 +1686,17 @@ class FederationEventHandler:
backfilled=True,
)
+ @trace
async def _check_event_auth(
- self, origin: str, event: EventBase, context: EventContext
+ self, origin: Optional[str], event: EventBase, context: EventContext
) -> None:
"""
Checks whether an event should be rejected (for failing auth checks).
Args:
- origin: The host the event originates from.
+ origin: The host the event originates from. This is used to fetch
+ any missing auth events. It can be set to None, but only if we are
+ sure that we already have all the auth events.
event: The event itself.
context:
The event context.
@@ -1544,6 +1725,14 @@ class FederationEventHandler:
claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
origin, event
)
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "claimed_auth_events",
+ str([ev.event_id for ev in claimed_auth_events]),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "claimed_auth_events.length",
+ str(len(claimed_auth_events)),
+ )
# ... and check that the event passes auth at those auth events.
# https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
@@ -1641,6 +1830,7 @@ class FederationEventHandler:
)
context.rejected = RejectedReason.AUTH_ERROR
+ @trace
async def _maybe_kick_guest_users(self, event: EventBase) -> None:
if event.type != EventTypes.GuestAccess:
return
@@ -1658,17 +1848,27 @@ class FederationEventHandler:
async def _check_for_soft_fail(
self,
event: EventBase,
- state_ids: Optional[StateMap[str]],
+ context: EventContext,
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
such.
+ Does nothing for events in rooms with partial state, since we may not have an
+ accurate membership event for the sender in the current state.
+
Args:
event
- state_ids: The state at the event if we don't have all the event's prev events
+ context: The `EventContext` which we are about to persist the event with.
origin: The host the event originates from.
"""
+ if await self._store.is_partial_state_room(event.room_id):
+ # We might not know the sender's membership in the current state, so don't
+ # soft fail anything. Even if we do have a membership for the sender in the
+ # current state, it may have been derived from state resolution between
+ # partial and full state and may not be accurate.
+ return
+
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
@@ -1685,11 +1885,15 @@ class FederationEventHandler:
auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state".
- if state_ids is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
+ seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
+ has_missing_prevs = bool(prev_event_ids - seen_event_ids)
+ if has_missing_prevs:
+ # We don't have all the prev_events of this event, which means we have a
+ # gap in the graph, and the new event is going to become a new backwards
+ # extremity.
+ #
+ # In this case we want to be a little careful as we might have been
+ # down for a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to
# maliciously manufacture.
#
@@ -1702,6 +1906,7 @@ class FederationEventHandler:
event.room_id, extrem_ids
)
state_sets: List[StateMap[str]] = list(state_sets_d.values())
+ state_ids = await context.get_prev_state_ids()
state_sets.append(state_ids)
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
@@ -1751,7 +1956,7 @@ class FederationEventHandler:
event.internal_metadata.soft_failed = True
async def _load_or_fetch_auth_events_for_event(
- self, destination: str, event: EventBase
+ self, destination: Optional[str], event: EventBase
) -> Collection[EventBase]:
"""Fetch this event's auth_events, from database or remote
@@ -1767,12 +1972,19 @@ class FederationEventHandler:
Args:
destination: where to send the /event_auth request. Typically the server
that sent us `event` in the first place.
+
+ If this is None, no attempt is made to load any missing auth events:
+ rather, an AssertionError is raised if there are any missing events.
+
event: the event whose auth_events we want
Returns:
all of the events listed in `event.auth_events_ids`, after deduplication
Raises:
+ AssertionError if some auth events were missing and no `destination` was
+ supplied.
+
AuthError if we were unable to fetch the auth_events for any reason.
"""
event_auth_event_ids = set(event.auth_event_ids())
@@ -1784,6 +1996,13 @@ class FederationEventHandler:
)
if not missing_auth_event_ids:
return event_auth_events.values()
+ if destination is None:
+ # this shouldn't happen: destination must be set unless we know we have already
+ # persisted the auth events.
+ raise AssertionError(
+ "_load_or_fetch_auth_events_for_event() called with no destination for "
+ "an event with missing auth_events"
+ )
logger.info(
"Event %s refers to unknown auth events %s: fetching auth chain",
@@ -1819,6 +2038,8 @@ class FederationEventHandler:
# instead we raise an AuthError, which will make the caller ignore it.
raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
+ @trace
+ @tag_args
async def _get_remote_auth_chain_for_event(
self, destination: str, room_id: str, event_id: str
) -> None:
@@ -1847,6 +2068,7 @@ class FederationEventHandler:
await self._auth_and_persist_outliers(room_id, remote_auth_events)
+ @trace
async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> None:
@@ -1955,8 +2177,17 @@ class FederationEventHandler:
self._message_handler.maybe_schedule_expiry(event)
if not backfilled: # Never notify for backfilled events
- for event in events:
- await self._notify_persisted_event(event, max_stream_token)
+ with start_active_span("notify_persisted_events"):
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "event_ids",
+ str([ev.event_id for ev in events]),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "event_ids.length",
+ str(len(events)),
+ )
+ for event in events:
+ await self._notify_persisted_event(event, max_stream_token)
return max_stream_token.stream
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 85b472f250..860c82c110 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -143,8 +143,8 @@ class InitialSyncHandler:
joined_rooms,
to_key=int(now_token.receipt_key),
)
- if self.hs.config.experimental.msc2285_enabled:
- receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
+
+ receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -309,18 +309,18 @@ class InitialSyncHandler:
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
- user_id = requester.user.to_string()
-
(
membership,
member_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
room_id,
- user_id,
+ requester,
allow_departed_users=True,
)
is_peeking = member_event_id is None
+ user_id = requester.user.to_string()
+
if membership == Membership.JOIN:
result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
@@ -456,11 +456,8 @@ class InitialSyncHandler:
)
if not receipts:
return []
- if self.hs.config.experimental.msc2285_enabled:
- receipts = ReceiptEventSource.filter_out_private_receipts(
- receipts, user_id
- )
- return receipts
+
+ return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
presence, receipts, (messages, token) = await make_deferred_yieldable(
gather_results(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index bd7baef051..acd3de06f6 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -41,6 +41,7 @@ from synapse.api.errors import (
NotFoundError,
ShadowBanError,
SynapseError,
+ UnstableSpecAuthError,
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -51,6 +52,7 @@ from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
+from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@@ -102,7 +104,7 @@ class MessageHandler:
async def get_room_data(
self,
- user_id: str,
+ requester: Requester,
room_id: str,
event_type: str,
state_key: str,
@@ -110,7 +112,7 @@ class MessageHandler:
"""Get data from a room.
Args:
- user_id
+ requester: The user who did the request.
room_id
event_type
state_key
@@ -123,7 +125,7 @@ class MessageHandler:
membership,
membership_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
if membership == Membership.JOIN:
@@ -149,17 +151,20 @@ class MessageHandler:
"Attempted to retrieve data from a room for a user that has never been in it. "
"This should not have happened."
)
- raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
+ raise UnstableSpecAuthError(
+ 403,
+ "User not in room",
+ errcode=Codes.NOT_JOINED,
+ )
return data
async def get_state_events(
self,
- user_id: str,
+ requester: Requester,
room_id: str,
state_filter: Optional[StateFilter] = None,
at_token: Optional[StreamToken] = None,
- is_guest: bool = False,
) -> List[dict]:
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
@@ -168,14 +173,13 @@ class MessageHandler:
visible.
Args:
- user_id: The user requesting state events.
+ requester: The user requesting state events.
room_id: The room ID to get all state events from.
state_filter: The state filter used to fetch state from the database.
at_token: the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table.
- is_guest: whether this user is a guest
Returns:
A list of dicts representing state events. [{}, {}, {}]
Raises:
@@ -185,6 +189,7 @@ class MessageHandler:
members of this room.
"""
state_filter = state_filter or StateFilter.all()
+ user_id = requester.user.to_string()
if at_token:
last_event_id = (
@@ -217,7 +222,7 @@ class MessageHandler:
membership,
membership_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
if membership == Membership.JOIN:
@@ -311,30 +316,42 @@ class MessageHandler:
Returns:
A dict of user_id to profile info
"""
- user_id = requester.user.to_string()
if not requester.app_service:
# We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway.
membership, _ = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
if membership != Membership.JOIN:
- raise NotImplementedError(
- "Getting joined members after leaving is not implemented"
+ raise SynapseError(
+ code=403,
+ errcode=Codes.FORBIDDEN,
+ msg="Getting joined members while not being a current member of the room is forbidden.",
)
- users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
+ users_with_profile = (
+ await self._state_storage_controller.get_users_in_room_with_profiles(
+ room_id
+ )
+ )
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there
# is a user in the room that the AS is "interested in"
- if requester.app_service and user_id not in users_with_profile:
+ if (
+ requester.app_service
+ and requester.user.to_string() not in users_with_profile
+ ):
for uid in users_with_profile:
if requester.app_service.is_interested_in_user(uid):
break
else:
# Loop fell through, AS has no interested users in room
- raise AuthError(403, "Appservice not in room")
+ raise UnstableSpecAuthError(
+ 403,
+ "Appservice not in room",
+ errcode=Codes.NOT_JOINED,
+ )
return {
user_id: {
@@ -1135,6 +1152,10 @@ class EventCreationHandler:
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map_for_event,
+ # TODO(faster_joins): check how MSC2716 works and whether we can have
+ # partial state here
+ # https://github.com/matrix-org/synapse/issues/13003
+ partial_state=False,
)
else:
context = await self.state.compute_event_context(event)
@@ -1359,9 +1380,10 @@ class EventCreationHandler:
# and `state_groups` because they have `prev_events` that aren't persisted yet
# (historical messages persisted in reverse-chronological order).
if not event.internal_metadata.is_historical():
- await self._bulk_push_rule_evaluator.action_for_event_by_user(
- event, context
- )
+ with opentracing.start_active_span("calculate_push_actions"):
+ await self._bulk_push_rule_evaluator.action_for_event_by_user(
+ event, context
+ )
try:
# If we're a worker we need to hit out to the master.
@@ -1448,9 +1470,10 @@ class EventCreationHandler:
state = await state_entry.get_state(
self._storage_controllers.state, StateFilter.all()
)
- joined_hosts = await self.store.get_joined_hosts(
- event.room_id, state, state_entry
- )
+ with opentracing.start_active_span("get_joined_hosts"):
+ joined_hosts = await self.store.get_joined_hosts(
+ event.room_id, state, state_entry
+ )
# Note that the expiry times must be larger than the expiry time in
# _external_cache_joined_hosts_updates.
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6262a35822..74e944bce7 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.room import ShutdownRoomResponse
+from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
@@ -416,6 +417,7 @@ class PaginationHandler:
await self._storage_controllers.purge_events.purge_room(room_id)
+ @trace
async def get_messages(
self,
requester: Requester,
@@ -462,7 +464,7 @@ class PaginationHandler:
membership,
member_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
if pagin_config.direction == "b":
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 895ea63ed3..741504ba9f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,6 @@ from typing import (
Callable,
Collection,
Dict,
- FrozenSet,
Generator,
Iterable,
List,
@@ -42,7 +41,6 @@ from typing import (
Set,
Tuple,
Type,
- Union,
)
from prometheus_client import Counter
@@ -68,7 +66,6 @@ from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -1656,15 +1653,18 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# doesn't return. C.f. #5503.
return [], max_token
- # Figure out which other users this user should receive updates for
- users_interested_in = await self._get_interested_in(user, explicit_room_id)
+ # Figure out which other users this user should explicitly receive
+ # updates for
+ additional_users_interested_in = (
+ await self.get_presence_router().get_interested_users(user.to_string())
+ )
# We have a set of users that we're interested in the presence of. We want to
# cross-reference that with the users that have actually changed their presence.
# Check whether this user should see all user updates
- if users_interested_in == PresenceRouter.ALL_USERS:
+ if additional_users_interested_in == PresenceRouter.ALL_USERS:
# Provide presence state for all users
presence_updates = await self._filter_all_presence_updates_for_user(
user_id, include_offline, from_key
@@ -1673,34 +1673,47 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
return presence_updates, max_token
# Make mypy happy. users_interested_in should now be a set
- assert not isinstance(users_interested_in, str)
+ assert not isinstance(additional_users_interested_in, str)
+
+ # We always care about our own presence.
+ additional_users_interested_in.add(user_id)
+
+ if explicit_room_id:
+ user_ids = await self.store.get_users_in_room(explicit_room_id)
+ additional_users_interested_in.update(user_ids)
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
- interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
+ interested_and_updated_users: Collection[str]
if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
# Cross-reference users we're interested in with those that have had updates.
- # Use a slightly-optimised method for processing smaller sets of updates.
- if updated_users is not None and len(updated_users) < 500:
- # For small deltas, it's quicker to get all changes and then
- # cross-reference with the users we're interested in
+ if updated_users is not None:
+ # If we have the full list of changes for presence we can
+ # simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
- for other_user_id in updated_users:
- if other_user_id in users_interested_in:
- # mypy thinks this variable could be a FrozenSet as it's possibly set
- # to one in the `get_entities_changed` call below, and `add()` is not
- # method on a FrozenSet. That doesn't affect us here though, as
- # `interested_and_updated_users` is clearly a set() above.
- interested_and_updated_users.add(other_user_id) # type: ignore
+
+ sharing_users = await self.store.do_users_share_a_room(
+ user_id, updated_users
+ )
+
+ interested_and_updated_users = (
+ sharing_users.union(additional_users_interested_in)
+ ).intersection(updated_users)
+
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
get_updates_counter.labels("full").inc()
+ users_interested_in = (
+ await self.store.get_users_who_share_room_with_user(user_id)
+ )
+ users_interested_in.update(additional_users_interested_in)
+
interested_and_updated_users = (
stream_change_cache.get_entities_changed(
users_interested_in, from_key
@@ -1709,7 +1722,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
else:
# No from_key has been specified. Return the presence for all users
# this user is interested in
- interested_and_updated_users = users_interested_in
+ interested_and_updated_users = (
+ await self.store.get_users_who_share_room_with_user(user_id)
+ )
+ interested_and_updated_users.update(additional_users_interested_in)
# Retrieve the current presence state for each user
users_to_state = await self.get_presence_handler().current_state_for_users(
@@ -1804,62 +1820,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
def get_current_key(self) -> int:
return self.store.get_current_presence_token()
- @cached(num_args=2, cache_context=True)
- async def _get_interested_in(
- self,
- user: UserID,
- explicit_room_id: Optional[str] = None,
- cache_context: Optional[_CacheContext] = None,
- ) -> Union[Set[str], str]:
- """Returns the set of users that the given user should see presence
- updates for.
-
- Args:
- user: The user to retrieve presence updates for.
- explicit_room_id: The users that are in the room will be returned.
-
- Returns:
- A set of user IDs to return presence updates for, or "ALL" to return all
- known updates.
- """
- user_id = user.to_string()
- users_interested_in = set()
- users_interested_in.add(user_id) # So that we receive our own presence
-
- # cache_context isn't likely to ever be None due to the @cached decorator,
- # but we can't have a non-optional argument after the optional argument
- # explicit_room_id either. Assert cache_context is not None so we can use it
- # without mypy complaining.
- assert cache_context
-
- # Check with the presence router whether we should poll additional users for
- # their presence information
- additional_users = await self.get_presence_router().get_interested_users(
- user.to_string()
- )
- if additional_users == PresenceRouter.ALL_USERS:
- # If the module requested that this user see the presence updates of *all*
- # users, then simply return that instead of calculating what rooms this
- # user shares
- return PresenceRouter.ALL_USERS
-
- # Add the additional users from the router
- users_interested_in.update(additional_users)
-
- # Find the users who share a room with this user
- users_who_share_room = await self.store.get_users_who_share_room_with_user(
- user_id, on_invalidate=cache_context.invalidate
- )
- users_interested_in.update(users_who_share_room)
-
- if explicit_room_id:
- user_ids = await self.store.get_users_in_room(
- explicit_room_id, on_invalidate=cache_context.invalidate
- )
- users_interested_in.update(user_ids)
-
- return users_interested_in
-
def handle_timeouts(
user_states: List[UserPresenceState],
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d2882b0a..d4a866b346 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -163,7 +163,10 @@ class ReceiptsHandler:
if not is_new:
return
- if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE:
+ if self.federation_sender and receipt_type not in (
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ):
await self.federation_sender.send_read_receipt(receipt)
@@ -203,24 +206,38 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
for event_id, orig_event_content in room.get("content", {}).items():
event_content = orig_event_content
# If there are private read receipts, additional logic is necessary.
- if ReceiptTypes.READ_PRIVATE in event_content:
+ if (
+ ReceiptTypes.READ_PRIVATE in event_content
+ or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content
+ ):
# Make a copy without private read receipts to avoid leaking
# other user's private read receipts..
event_content = {
receipt_type: receipt_value
for receipt_type, receipt_value in event_content.items()
- if receipt_type != ReceiptTypes.READ_PRIVATE
+ if receipt_type
+ not in (
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ )
}
# Copy the current user's private read receipt from the
# original content, if it exists.
- user_private_read_receipt = orig_event_content[
- ReceiptTypes.READ_PRIVATE
- ].get(user_id, None)
+ user_private_read_receipt = orig_event_content.get(
+ ReceiptTypes.READ_PRIVATE, {}
+ ).get(user_id, None)
if user_private_read_receipt:
event_content[ReceiptTypes.READ_PRIVATE] = {
user_id: user_private_read_receipt
}
+ user_unstable_private_read_receipt = orig_event_content.get(
+ ReceiptTypes.UNSTABLE_READ_PRIVATE, {}
+ ).get(user_id, None)
+ if user_unstable_private_read_receipt:
+ event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = {
+ user_id: user_unstable_private_read_receipt
+ }
# Include the event if there is at least one non-private read
# receipt or the current user has a private read receipt.
@@ -256,10 +273,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
room_ids, from_key=from_key, to_key=to_key
)
- if self.config.experimental.msc2285_enabled:
- events = ReceiptEventSource.filter_out_private_receipts(
- events, user.to_string()
- )
+ events = ReceiptEventSource.filter_out_private_receipts(
+ events, user.to_string()
+ )
return events, to_key
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c77d181722..20ec22105a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -29,7 +29,13 @@ from synapse.api.constants import (
JoinRules,
LoginType,
)
-from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ ConsentNotGivenError,
+ InvalidClientTokenError,
+ SynapseError,
+)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
@@ -180,10 +186,7 @@ class RegistrationHandler:
)
if guest_access_token:
user_data = await self.auth.get_user_by_access_token(guest_access_token)
- if (
- not user_data.is_guest
- or UserID.from_string(user_data.user_id).localpart != localpart
- ):
+ if not user_data.is_guest or user_data.user.localpart != localpart:
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
@@ -618,7 +621,7 @@ class RegistrationHandler:
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
if not service:
- raise AuthError(403, "Invalid application service token.")
+ raise InvalidClientTokenError()
if not service.is_interested_in_user(user_id):
raise SynapseError(
400,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0b63cd2186..28d7093f08 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -19,6 +19,7 @@ import attr
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
+from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
@@ -73,7 +74,6 @@ class RelationsHandler:
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
@@ -89,7 +89,6 @@ class RelationsHandler:
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
@@ -104,7 +103,7 @@ class RelationsHandler:
# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
@@ -122,7 +121,6 @@ class RelationsHandler:
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
- aggregation_key=aggregation_key,
limit=limit,
direction=direction,
from_token=from_token,
@@ -364,6 +362,7 @@ class RelationsHandler:
return results
+ @trace
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 978d3ee39f..2bf0ebd025 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -721,7 +721,7 @@ class RoomCreationHandler:
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
- is_requester_admin = await self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester)
# Let the third party rules modify the room creation config if needed, or abort
# the room creation entirely with an exception.
@@ -1279,7 +1279,7 @@ class RoomContextHandler:
"""
user = requester.user
if use_admin_priviledge:
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
@@ -1384,6 +1384,7 @@ class TimestampLookupHandler:
self.store = hs.get_datastores().main
self.state_handler = hs.get_state_handler()
self.federation_client = hs.get_federation_client()
+ self.federation_event_handler = hs.get_federation_event_handler()
self._storage_controllers = hs.get_storage_controllers()
async def get_event_for_timestamp(
@@ -1479,38 +1480,68 @@ class TimestampLookupHandler:
remote_response,
)
- # TODO: Do we want to persist this as an extremity?
- # TODO: I think ideally, we would try to backfill from
- # this event and run this whole
- # `get_event_for_timestamp` function again to make sure
- # they didn't give us an event from their gappy history.
remote_event_id = remote_response.event_id
- origin_server_ts = remote_response.origin_server_ts
+ remote_origin_server_ts = remote_response.origin_server_ts
+
+ # Backfill this event so we can get a pagination token for
+ # it with `/context` and paginate `/messages` from this
+ # point.
+ #
+ # TODO: The requested timestamp may lie in a part of the
+ # event graph that the remote server *also* didn't have,
+ # in which case they will have returned another event
+ # which may be nowhere near the requested timestamp. In
+ # the future, we may need to reconcile that gap and ask
+ # other homeservers, and/or extend `/timestamp_to_event`
+ # to return events on *both* sides of the timestamp to
+ # help reconcile the gap faster.
+ remote_event = (
+ await self.federation_event_handler.backfill_event_id(
+ domain, room_id, remote_event_id
+ )
+ )
+
+ # XXX: When we see that the remote server is not trustworthy,
+ # maybe we should not ask them first in the future.
+ if remote_origin_server_ts != remote_event.origin_server_ts:
+ logger.info(
+ "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
+ domain,
+ remote_event_id,
+ remote_origin_server_ts,
+ remote_event.origin_server_ts,
+ )
# Only return the remote event if it's closer than the local event
if not local_event or (
- abs(origin_server_ts - timestamp)
+ abs(remote_event.origin_server_ts - timestamp)
< abs(local_event.origin_server_ts - timestamp)
):
- return remote_event_id, origin_server_ts
+ logger.info(
+ "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
+ remote_event_id,
+ remote_event.origin_server_ts,
+ timestamp,
+ local_event.event_id if local_event else None,
+ local_event.origin_server_ts if local_event else None,
+ )
+ return remote_event_id, remote_origin_server_ts
except (HttpResponseException, InvalidResponseError) as ex:
# Let's not put a high priority on some other homeserver
# failing to respond or giving a random response
logger.debug(
- "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+ "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
domain,
type(ex).__name__,
ex,
ex.args,
)
- except Exception as ex:
+ except Exception:
# But we do want to see some exceptions in our code
logger.warning(
- "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+ "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
domain,
- type(ex).__name__,
- ex,
- ex.args,
+ exc_info=True,
)
# To appease mypy, we have to add both of these conditions to check for
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 29868eb743..bb0bdb8e6f 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -182,7 +182,7 @@ class RoomListHandler:
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
"join_rule": room["join_rules"],
- "org.matrix.msc3827.room_type": room["room_type"],
+ "room_type": room["room_type"],
}
# Filter out Nones – rather omit the field altogether
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ef2fa6bb6f..9c0fdeca15 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -32,6 +32,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.logging import opentracing
from synapse.module_api import NOT_SPAM
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -178,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""Try and join a room that this server is not in
Args:
- requester
+ requester: The user making the request, according to the access token.
remote_room_hosts: List of servers that can be used to join via.
room_id: Room that we are trying to join
user: User who is trying to join
@@ -430,14 +431,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
await self._join_rate_per_room_limiter.ratelimit(
requester, key=room_id, update=False
)
-
- result_event = await self.event_creation_handler.handle_new_client_event(
- requester,
- event,
- context,
- extra_users=[target],
- ratelimit=ratelimit,
- )
+ with opentracing.start_active_span("handle_new_client_event"):
+ result_event = await self.event_creation_handler.handle_new_client_event(
+ requester,
+ event,
+ context,
+ extra_users=[target],
+ ratelimit=ratelimit,
+ )
if event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -566,25 +567,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# by application services), and then by room ID.
async with self.member_as_limiter.queue(as_id):
async with self.member_linearizer.queue(key):
- result = await self.update_membership_locked(
- requester,
- target,
- room_id,
- action,
- txn_id=txn_id,
- remote_room_hosts=remote_room_hosts,
- third_party_signed=third_party_signed,
- ratelimit=ratelimit,
- content=content,
- new_room=new_room,
- require_consent=require_consent,
- outlier=outlier,
- historical=historical,
- allow_no_prev_events=allow_no_prev_events,
- prev_event_ids=prev_event_ids,
- state_event_ids=state_event_ids,
- depth=depth,
- )
+ with opentracing.start_active_span("update_membership_locked"):
+ result = await self.update_membership_locked(
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=txn_id,
+ remote_room_hosts=remote_room_hosts,
+ third_party_signed=third_party_signed,
+ ratelimit=ratelimit,
+ content=content,
+ new_room=new_room,
+ require_consent=require_consent,
+ outlier=outlier,
+ historical=historical,
+ allow_no_prev_events=allow_no_prev_events,
+ prev_event_ids=prev_event_ids,
+ state_event_ids=state_event_ids,
+ depth=depth,
+ )
return result
@@ -651,6 +653,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Returns:
A tuple of the new event ID and stream ID.
"""
+
content_specified = bool(content)
if content is None:
content = {}
@@ -688,7 +691,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
errcode=Codes.BAD_JSON,
)
- if "avatar_url" in content:
+ if "avatar_url" in content and content.get("avatar_url") is not None:
if not await self.profile_handler.check_avatar_size_and_mime_type(
content["avatar_url"],
):
@@ -743,7 +746,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_requester_admin = True
else:
- is_requester_admin = await self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester)
if not is_requester_admin:
if self.config.server.block_non_admin_invites:
@@ -878,7 +881,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
bypass_spam_checker = True
else:
- bypass_spam_checker = await self.auth.is_server_admin(requester.user)
+ bypass_spam_checker = await self.auth.is_server_admin(requester)
inviter = await self._get_inviter(target.to_string(), room_id)
if (
@@ -1438,7 +1441,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ShadowBanError if the requester has been shadow-banned.
"""
if self.config.server.block_non_admin_invites:
- is_requester_admin = await self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN
@@ -1710,14 +1713,18 @@ class RoomMemberMasterHandler(RoomMemberHandler):
]
if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
+ raise SynapseError(
+ 404,
+ "Can't join remote room because no servers "
+ "that are in the room have been provided.",
+ )
check_complexity = self.hs.config.server.limit_remote_rooms.enabled
if (
check_complexity
and self.hs.config.server.limit_remote_rooms.admins_can_join
):
- check_complexity = not await self.auth.is_server_admin(user)
+ check_complexity = not await self.store.is_server_admin(user)
if check_complexity:
# Fetch the room complexity
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 13098f56ed..732b0310bc 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,11 +28,11 @@ from synapse.api.constants import (
RoomTypes,
)
from synapse.api.errors import (
- AuthError,
Codes,
NotFoundError,
StoreError,
SynapseError,
+ UnstableSpecAuthError,
UnsupportedRoomVersionError,
)
from synapse.api.ratelimiting import Ratelimiter
@@ -175,10 +175,11 @@ class RoomSummaryHandler:
# First of all, check that the room is accessible.
if not await self._is_local_room_accessible(requested_room_id, requester):
- raise AuthError(
+ raise UnstableSpecAuthError(
403,
"User %s not in room %s, and room previews are disabled"
% (requester, requested_room_id),
+ errcode=Codes.NOT_JOINED,
)
# If this is continuing a previous session, pull the persisted data.
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index a305a66860..e2844799e8 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -23,10 +23,12 @@ from pkg_resources import parse_version
import twisted
from twisted.internet.defer import Deferred
-from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
+from twisted.internet.interfaces import IOpenSSLContextFactory
+from twisted.internet.ssl import optionsForClientTLS
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
from synapse.logging.context import make_deferred_yieldable
+from synapse.types import ISynapseReactor
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -48,7 +50,7 @@ class _NoTLSESMTPSender(ESMTPSender):
async def _sendmail(
- reactor: IReactorTCP,
+ reactor: ISynapseReactor,
smtphost: str,
smtpport: int,
from_addr: str,
@@ -59,6 +61,7 @@ async def _sendmail(
require_auth: bool = False,
require_tls: bool = False,
enable_tls: bool = True,
+ force_tls: bool = False,
) -> None:
"""A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
@@ -73,8 +76,9 @@ async def _sendmail(
password: password to give when authenticating
require_auth: if auth is not offered, fail the request
require_tls: if TLS is not offered, fail the reqest
- enable_tls: True to enable TLS. If this is False and require_tls is True,
+ enable_tls: True to enable STARTTLS. If this is False and require_tls is True,
the request will fail.
+ force_tls: True to enable Implicit TLS.
"""
msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred()
@@ -105,13 +109,23 @@ async def _sendmail(
# set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None)
- reactor.connectTCP(
- smtphost,
- smtpport,
- factory,
- timeout=30,
- bindAddress=None,
- )
+ if force_tls:
+ reactor.connectSSL(
+ smtphost,
+ smtpport,
+ factory,
+ optionsForClientTLS(smtphost),
+ timeout=30,
+ bindAddress=None,
+ )
+ else:
+ reactor.connectTCP(
+ smtphost,
+ smtpport,
+ factory,
+ timeout=30,
+ bindAddress=None,
+ )
await make_deferred_yieldable(d)
@@ -132,6 +146,7 @@ class SendEmailHandler:
self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None
self._require_transport_security = hs.config.email.require_transport_security
self._enable_tls = hs.config.email.enable_smtp_tls
+ self._force_tls = hs.config.email.force_tls
self._sendmail = _sendmail
@@ -189,4 +204,5 @@ class SendEmailHandler:
require_auth=self._smtp_user is not None,
require_tls=self._require_transport_security,
enable_tls=self._enable_tls,
+ force_tls=self._force_tls,
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d42a414c90..2d95b1fa24 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,19 @@
# limitations under the License.
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ FrozenSet,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
import attr
from prometheus_client import Counter
@@ -89,7 +101,7 @@ class SyncConfig:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TimelineBatch:
prev_batch: StreamToken
- events: List[EventBase]
+ events: Sequence[EventBase]
limited: bool
# A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true.
@@ -507,10 +519,17 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents):
+ # FIXME(faster_joins): We use the partial state here as
+ # we don't want to block `/sync` on finishing a lazy join.
+ # Which should be fine once
+ # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+ # since we shouldn't reach here anymore?
+ # Note that we use the current state as a whitelist for filtering
+ # `recents`, so partial state is only a problem when a membership
+ # event turns up in `recents` but has not made it into the current
+ # state.
current_state_ids_map = (
- await self._state_storage_controller.get_current_state_ids(
- room_id
- )
+ await self.store.get_partial_current_state_ids(room_id)
)
current_state_ids = frozenset(current_state_ids_map.values())
@@ -579,7 +598,13 @@ class SyncHandler:
if any(e.is_state() for e in loaded_recents):
# FIXME(faster_joins): We use the partial state here as
# we don't want to block `/sync` on finishing a lazy join.
- # Is this the correct way of doing it?
+ # Which should be fine once
+ # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+ # since we shouldn't reach here anymore?
+ # Note that we use the current state as a whitelist for filtering
+ # `loaded_recents`, so partial state is only a problem when a
+ # membership event turns up in `loaded_recents` but has not made it
+ # into the current state.
current_state_ids_map = (
await self.store.get_partial_current_state_ids(room_id)
)
@@ -627,7 +652,10 @@ class SyncHandler:
)
async def get_state_after_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the room state after the given event
@@ -635,9 +663,14 @@ class SyncHandler:
Args:
event_id: event of interest
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
"""
state_ids = await self._state_storage_controller.get_state_ids_for_event(
- event_id, state_filter=state_filter or StateFilter.all()
+ event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
# using get_metadata_for_events here (instead of get_event) sidesteps an issue
@@ -660,6 +693,7 @@ class SyncHandler:
room_id: str,
stream_position: StreamToken,
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""Get the room state at a particular stream position
@@ -667,6 +701,9 @@ class SyncHandler:
room_id: room for which to get state
stream_position: point at which to get state
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the last event in the room before `stream_position` and
+ `state_filter` is not satisfied by partial state. Defaults to `True`.
"""
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
@@ -678,7 +715,9 @@ class SyncHandler:
if last_event_id:
state = await self.get_state_after_event(
- last_event_id, state_filter=state_filter or StateFilter.all()
+ last_event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
else:
@@ -852,16 +891,26 @@ class SyncHandler:
now_token: StreamToken,
full_state: bool,
) -> MutableStateMap[EventBase]:
- """Works out the difference in state between the start of the timeline
- and the previous sync.
+ """Works out the difference in state between the end of the previous sync and
+ the start of the timeline.
Args:
room_id:
batch: The timeline batch for the room that will be sent to the user.
sync_config:
- since_token: Token of the end of the previous batch. May be None.
+ since_token: Token of the end of the previous batch. May be `None`.
now_token: Token of the end of the current batch.
full_state: Whether to force returning the full state.
+ `lazy_load_members` still applies when `full_state` is `True`.
+
+ Returns:
+ The state to return in the sync response for the room.
+
+ Clients will overlay this onto the state at the end of the previous sync to
+ arrive at the state at the start of the timeline.
+
+ Clients will then overlay state events in the timeline to arrive at the
+ state at the end of the timeline, in preparation for the next sync.
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
@@ -869,8 +918,17 @@ class SyncHandler:
# TODO(mjark) Check for new redactions in the state events.
with Measure(self.clock, "compute_state_delta"):
+ # The memberships needed for events in the timeline.
+ # Only calculated when `lazy_load_members` is on.
+ members_to_fetch: Optional[Set[str]] = None
+
+ # A dictionary mapping user IDs to the first event in the timeline sent by
+ # them. Only calculated when `lazy_load_members` is on.
+ first_event_by_sender_map: Optional[Dict[str, EventBase]] = None
- members_to_fetch = None
+ # The contribution to the room state from state events in the timeline.
+ # Only contains the last event for any given state key.
+ timeline_state: StateMap[str]
lazy_load_members = sync_config.filter_collection.lazy_load_members()
include_redundant_members = (
@@ -881,10 +939,23 @@ class SyncHandler:
# We only request state for the members needed to display the
# timeline:
- members_to_fetch = {
- event.sender # FIXME: we also care about invite targets etc.
- for event in batch.events
- }
+ timeline_state = {}
+
+ members_to_fetch = set()
+ first_event_by_sender_map = {}
+ for event in batch.events:
+ # Build the map from user IDs to the first timeline event they sent.
+ if event.sender not in first_event_by_sender_map:
+ first_event_by_sender_map[event.sender] = event
+
+ # We need the event's sender, unless their membership was in a
+ # previous timeline event.
+ if (EventTypes.Member, event.sender) not in timeline_state:
+ members_to_fetch.add(event.sender)
+ # FIXME: we also care about invite targets etc.
+
+ if event.is_state():
+ timeline_state[(event.type, event.state_key)] = event.event_id
if full_state:
# always make sure we LL ourselves so we know we're in the room
@@ -894,55 +965,80 @@ class SyncHandler:
members_to_fetch.add(sync_config.user.to_string())
state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+
+ # We are happy to use partial state to compute the `/sync` response.
+ # Since partial state may not include the lazy-loaded memberships we
+ # require, we fix up the state response afterwards with memberships from
+ # auth events.
+ await_full_state = False
else:
+ timeline_state = {
+ (event.type, event.state_key): event.event_id
+ for event in batch.events
+ if event.is_state()
+ }
+
state_filter = StateFilter.all()
+ await_full_state = True
- timeline_state = {
- (event.type, event.state_key): event.event_id
- for event in batch.events
- if event.is_state()
- }
+ # Now calculate the state to return in the sync response for the room.
+ # This is more or less the change in state between the end of the previous
+ # sync's timeline and the start of the current sync's timeline.
+ # See the docstring above for details.
+ state_ids: StateMap[str]
if full_state:
if batch:
- current_state_ids = (
+ state_at_timeline_end = (
await self._state_storage_controller.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter
+ batch.events[-1].event_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
)
- state_ids = (
+ state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ batch.events[0].event_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
)
else:
- current_state_ids = await self.get_state_at(
- room_id, stream_position=now_token, state_filter=state_filter
+ state_at_timeline_end = await self.get_state_at(
+ room_id,
+ stream_position=now_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
- state_ids = current_state_ids
+ state_at_timeline_start = state_at_timeline_end
state_ids = _calculate_state(
timeline_contains=timeline_state,
- timeline_start=state_ids,
- previous={},
- current=current_state_ids,
+ timeline_start=state_at_timeline_start,
+ timeline_end=state_at_timeline_end,
+ previous_timeline_end={},
lazy_load_members=lazy_load_members,
)
elif batch.limited:
if batch:
state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ batch.events[0].event_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
)
else:
# We can get here if the user has ignored the senders of all
# the recent events.
state_at_timeline_start = await self.get_state_at(
- room_id, stream_position=now_token, state_filter=state_filter
+ room_id,
+ stream_position=now_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
# for now, we disable LL for gappy syncs - see
@@ -964,28 +1060,35 @@ class SyncHandler:
# is indeed the case.
assert since_token is not None
state_at_previous_sync = await self.get_state_at(
- room_id, stream_position=since_token, state_filter=state_filter
+ room_id,
+ stream_position=since_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
if batch:
- current_state_ids = (
+ state_at_timeline_end = (
await self._state_storage_controller.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter
+ batch.events[-1].event_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
)
else:
- # Its not clear how we get here, but empirically we do
- # (#5407). Logging has been added elsewhere to try and
- # figure out where this state comes from.
- current_state_ids = await self.get_state_at(
- room_id, stream_position=now_token, state_filter=state_filter
+ # We can get here if the user has ignored the senders of all
+ # the recent events.
+ state_at_timeline_end = await self.get_state_at(
+ room_id,
+ stream_position=now_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
state_ids = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
- previous=state_at_previous_sync,
- current=current_state_ids,
+ timeline_end=state_at_timeline_end,
+ previous_timeline_end=state_at_previous_sync,
# we have to include LL members in case LL initial sync missed them
lazy_load_members=lazy_load_members,
)
@@ -1008,8 +1111,30 @@ class SyncHandler:
(EventTypes.Member, member)
for member in members_to_fetch
),
+ await_full_state=False,
)
+ # If we only have partial state for the room, `state_ids` may be missing the
+ # memberships we wanted. We attempt to find some by digging through the auth
+ # events of timeline events.
+ if lazy_load_members and await self.store.is_partial_state_room(room_id):
+ assert members_to_fetch is not None
+ assert first_event_by_sender_map is not None
+
+ additional_state_ids = (
+ await self._find_missing_partial_state_memberships(
+ room_id, members_to_fetch, first_event_by_sender_map, state_ids
+ )
+ )
+ state_ids = {**state_ids, **additional_state_ids}
+
+ # At this point, if `lazy_load_members` is enabled, `state_ids` includes
+ # the memberships of all event senders in the timeline. This is because we
+ # may not have sent the memberships in a previous sync.
+
+ # When `include_redundant_members` is on, we send all the lazy-loaded
+ # memberships of event senders. Otherwise we make an effort to limit the set
+ # of memberships we send to those that we have not already sent to this client.
if lazy_load_members and not include_redundant_members:
cache_key = (sync_config.user.to_string(), sync_config.device_id)
cache = self.get_lazy_loaded_members_cache(cache_key)
@@ -1051,6 +1176,99 @@ class SyncHandler:
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
}
+ async def _find_missing_partial_state_memberships(
+ self,
+ room_id: str,
+ members_to_fetch: Collection[str],
+ events_with_membership_auth: Mapping[str, EventBase],
+ found_state_ids: StateMap[str],
+ ) -> StateMap[str]:
+ """Finds missing memberships from a set of auth events and returns them as a
+ state map.
+
+ Args:
+ room_id: The partial state room to find the remaining memberships for.
+ members_to_fetch: The memberships to find.
+ events_with_membership_auth: A mapping from user IDs to events whose auth
+ events are known to contain their membership.
+ found_state_ids: A dict from (type, state_key) -> state_event_id, containing
+ memberships that have been previously found. Entries in
+ `members_to_fetch` that have a membership in `found_state_ids` are
+ ignored.
+
+ Returns:
+ A dict from ("m.room.member", state_key) -> state_event_id, containing the
+ memberships missing from `found_state_ids`.
+
+ Raises:
+ KeyError: if `events_with_membership_auth` does not have an entry for a
+ missing membership. Memberships in `found_state_ids` do not need an
+ entry in `events_with_membership_auth`.
+ """
+ additional_state_ids: MutableStateMap[str] = {}
+
+ # Tracks the missing members for logging purposes.
+ missing_members = set()
+
+ # Identify memberships missing from `found_state_ids` and pick out the auth
+ # events in which to look for them.
+ auth_event_ids: Set[str] = set()
+ for member in members_to_fetch:
+ if (EventTypes.Member, member) in found_state_ids:
+ continue
+
+ missing_members.add(member)
+ event_with_membership_auth = events_with_membership_auth[member]
+ auth_event_ids.update(event_with_membership_auth.auth_event_ids())
+
+ auth_events = await self.store.get_events(auth_event_ids)
+
+ # Run through the missing memberships once more, picking out the memberships
+ # from the pile of auth events we have just fetched.
+ for member in members_to_fetch:
+ if (EventTypes.Member, member) in found_state_ids:
+ continue
+
+ event_with_membership_auth = events_with_membership_auth[member]
+
+ # Dig through the auth events to find the desired membership.
+ for auth_event_id in event_with_membership_auth.auth_event_ids():
+ # We only store events once we have all their auth events,
+ # so the auth event must be in the pile we have just
+ # fetched.
+ auth_event = auth_events[auth_event_id]
+
+ if (
+ auth_event.type == EventTypes.Member
+ and auth_event.state_key == member
+ ):
+ missing_members.remove(member)
+ additional_state_ids[
+ (EventTypes.Member, member)
+ ] = auth_event.event_id
+ break
+
+ if missing_members:
+ # There really shouldn't be any missing memberships now. Either:
+ # * we couldn't find an auth event, which shouldn't happen because we do
+ # not persist events with persisting their auth events first, or
+ # * the set of auth events did not contain a membership we wanted, which
+ # means our caller didn't compute the events in `members_to_fetch`
+ # correctly, or we somehow accepted an event whose auth events were
+ # dodgy.
+ logger.error(
+ "Failed to find memberships for %s in partial state room "
+ "%s in the auth events of %s.",
+ missing_members,
+ room_id,
+ [
+ events_with_membership_auth[member].event_id
+ for member in missing_members
+ ],
+ )
+
+ return additional_state_ids
+
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> NotifCounts:
@@ -1536,15 +1754,13 @@ class SyncHandler:
ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
- sync_result_builder, ignored_users, self.rooms_to_exclude
+ sync_result_builder, ignored_users
)
tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key
)
else:
- room_changes = await self._get_all_rooms(
- sync_result_builder, ignored_users, self.rooms_to_exclude
- )
+ room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1623,13 +1839,14 @@ class SyncHandler:
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
- excluded_rooms: List[str],
) -> _RoomChanges:
"""Determine the changes in rooms to report to the user.
This function is a first pass at generating the rooms part of the sync response.
It determines which rooms have changed during the sync period, and categorises
- them into four buckets: "knock", "invite", "join" and "leave".
+ them into four buckets: "knock", "invite", "join" and "leave". It also excludes
+ from that list any room that appears in the list of rooms to exclude from sync
+ results in the server configuration.
1. Finds all membership changes for the user in the sync period (from
`since_token` up to `now_token`).
@@ -1655,7 +1872,7 @@ class SyncHandler:
# _have_rooms_changed. We could keep the results in memory to avoid a
# second query, at the cost of more complicated source code.
membership_change_events = await self.store.get_membership_changes_for_user(
- user_id, since_token.room_key, now_token.room_key, excluded_rooms
+ user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
)
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@@ -1696,7 +1913,11 @@ class SyncHandler:
continue
if room_id in sync_result_builder.joined_room_ids or has_join:
- old_state_ids = await self.get_state_at(room_id, since_token)
+ old_state_ids = await self.get_state_at(
+ room_id,
+ since_token,
+ state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
+ )
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
@@ -1722,7 +1943,13 @@ class SyncHandler:
newly_left_rooms.append(room_id)
else:
if not old_state_ids:
- old_state_ids = await self.get_state_at(room_id, since_token)
+ old_state_ids = await self.get_state_at(
+ room_id,
+ since_token,
+ state_filter=StateFilter.from_types(
+ [(EventTypes.Member, user_id)]
+ ),
+ )
old_mem_ev_id = old_state_ids.get(
(EventTypes.Member, user_id), None
)
@@ -1862,7 +2089,6 @@ class SyncHandler:
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
- ignored_rooms: List[str],
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
@@ -1884,7 +2110,7 @@ class SyncHandler:
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id,
membership_list=Membership.LIST,
- excluded_rooms=ignored_rooms,
+ excluded_rooms=self.rooms_to_exclude,
)
room_entries = []
@@ -2150,7 +2376,9 @@ class SyncHandler:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
async def get_rooms_for_user_at(
- self, user_id: str, room_key: RoomStreamToken
+ self,
+ user_id: str,
+ room_key: RoomStreamToken,
) -> FrozenSet[str]:
"""Get set of joined rooms for a user at the given stream ordering.
@@ -2176,7 +2404,12 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
+ # We also need to check whether the room should be excluded from sync
+ # responses as per the homeserver config.
for joined_room in joined_rooms:
+ if joined_room.room_id in self.rooms_to_exclude:
+ continue
+
if not joined_room.event_pos.persisted_after(room_key):
joined_room_ids.add(joined_room.room_id)
continue
@@ -2188,10 +2421,10 @@ class SyncHandler:
joined_room.room_id, joined_room.event_pos.stream
)
)
- users_in_room = await self.state.get_current_users_in_room(
+ user_ids_in_room = await self.state.get_current_user_ids_in_room(
joined_room.room_id, extrems
)
- if user_id in users_in_room:
+ if user_id in user_ids_in_room:
joined_room_ids.add(joined_room.room_id)
return frozenset(joined_room_ids)
@@ -2211,8 +2444,8 @@ def _action_has_highlight(actions: List[JsonDict]) -> bool:
def _calculate_state(
timeline_contains: StateMap[str],
timeline_start: StateMap[str],
- previous: StateMap[str],
- current: StateMap[str],
+ timeline_end: StateMap[str],
+ previous_timeline_end: StateMap[str],
lazy_load_members: bool,
) -> StateMap[str]:
"""Works out what state to include in a sync response.
@@ -2220,45 +2453,50 @@ def _calculate_state(
Args:
timeline_contains: state in the timeline
timeline_start: state at the start of the timeline
- previous: state at the end of the previous sync (or empty dict
+ timeline_end: state at the end of the timeline
+ previous_timeline_end: state at the end of the previous sync (or empty dict
if this is an initial sync)
- current: state at the end of the timeline
lazy_load_members: whether to return members from timeline_start
or not. assumes that timeline_start has already been filtered to
include only the members the client needs to know about.
"""
- event_id_to_key = {
- e: key
- for key, e in itertools.chain(
+ event_id_to_state_key = {
+ event_id: state_key
+ for state_key, event_id in itertools.chain(
timeline_contains.items(),
- previous.items(),
timeline_start.items(),
- current.items(),
+ timeline_end.items(),
+ previous_timeline_end.items(),
)
}
- c_ids = set(current.values())
- ts_ids = set(timeline_start.values())
- p_ids = set(previous.values())
- tc_ids = set(timeline_contains.values())
+ timeline_end_ids = set(timeline_end.values())
+ timeline_start_ids = set(timeline_start.values())
+ previous_timeline_end_ids = set(previous_timeline_end.values())
+ timeline_contains_ids = set(timeline_contains.values())
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
# as we may not have sent them to the client before. We find these membership
# events by filtering them out of timeline_start, which has already been filtered
# to only include membership events for the senders in the timeline.
- # In practice, we can do this by removing them from the p_ids list,
- # which is the list of relevant state we know we have already sent to the client.
+ # In practice, we can do this by removing them from the previous_timeline_end_ids
+ # list, which is the list of relevant state we know we have already sent to the
+ # client.
# see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
if lazy_load_members:
- p_ids.difference_update(
+ previous_timeline_end_ids.difference_update(
e for t, e in timeline_start.items() if t[0] == EventTypes.Member
)
- state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
+ state_ids = (
+ (timeline_end_ids | timeline_start_ids)
+ - previous_timeline_end_ids
+ - timeline_contains_ids
+ )
- return {event_id_to_key[e]: e for e in state_ids}
+ return {event_id_to_state_key[e]: e for e in state_ids}
@attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index d104ea07fe..bcac3372a2 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler):
self, target_user: UserID, requester: Requester, room_id: str, timeout: int
) -> None:
target_user_id = target_user.to_string()
- auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver")
- if target_user_id != auth_user_id:
+ if target_user != requester.user:
raise AuthError(400, "Cannot set another user's typing state")
if requester.shadow_banned:
@@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler):
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
- await self.auth.check_user_in_room(room_id, target_user_id)
+ await self.auth.check_user_in_room(room_id, requester)
logger.debug("%s has started typing in %s", target_user_id, room_id)
@@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler):
self, target_user: UserID, requester: Requester, room_id: str
) -> None:
target_user_id = target_user.to_string()
- auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver")
- if target_user_id != auth_user_id:
+ if target_user != requester.user:
raise AuthError(400, "Cannot set another user's typing state")
if requester.shadow_banned:
@@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler):
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
- await self.auth.check_user_in_room(room_id, target_user_id)
+ await self.auth.check_user_in_room(room_id, requester)
logger.debug("%s has stopped typing in %s", target_user_id, room_id)
@@ -489,8 +487,15 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
handler = self.get_typing_handler()
events = []
- for room_id in handler._room_serials.keys():
- if handler._room_serials[room_id] <= from_key:
+
+ # Work on a copy of things here as these may change in the handler while
+ # waiting for the AS `is_interested_in_room` call to complete.
+ # Shallow copy is safe as no nested data is present.
+ latest_room_serial = handler._latest_room_serial
+ room_serials = handler._room_serials.copy()
+
+ for room_id, serial in room_serials.items():
+ if serial <= from_key:
continue
if not await service.is_interested_in_room(room_id, self._main_store):
@@ -498,7 +503,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
events.append(self._make_event_for(room_id))
- return events, handler._latest_room_serial
+ return events, latest_room_serial
async def get_new_events(
self,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index cf2d6f904b..19f42159b8 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -58,6 +58,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
+from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
@@ -155,15 +156,16 @@ def is_method_cancellable(method: Callable[..., Any]) -> bool:
return getattr(method, "cancellable", False)
-def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
+def return_json_error(
+ f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
+) -> None:
"""Sends a JSON error response to clients."""
if f.check(SynapseError):
# mypy doesn't understand that f.check asserts the type.
exc: SynapseError = f.value # type: ignore
error_code = exc.code
- error_dict = exc.error_dict()
-
+ error_dict = exc.error_dict(config)
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
elif f.check(CancelledError):
error_code = HTTP_STATUS_REQUEST_CANCELLED
@@ -450,7 +452,7 @@ class DirectServeJsonResource(_AsyncResource):
request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response"""
- return_json_error(f, request)
+ return_json_error(f, request, None)
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -575,6 +577,14 @@ class JsonResource(DirectServeJsonResource):
return callback_return
+ def _send_error_response(
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
+ ) -> None:
+ """Implements _AsyncResource._send_error_response"""
+ return_json_error(f, request, self.hs.config)
+
class DirectServeHtmlResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 4ff840ca0e..26aaabfb34 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -23,9 +23,12 @@ from typing import (
Optional,
Sequence,
Tuple,
+ Type,
+ TypeVar,
overload,
)
+from pydantic import BaseModel, ValidationError
from typing_extensions import Literal
from twisted.web.server import Request
@@ -694,6 +697,28 @@ def parse_json_object_from_request(
return content
+Model = TypeVar("Model", bound=BaseModel)
+
+
+def parse_and_validate_json_object_from_request(
+ request: Request, model_type: Type[Model]
+) -> Model:
+ """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
+ validate using the given pydantic model.
+
+ Raises:
+ SynapseError if the request body couldn't be decoded as JSON or
+ if it wasn't a JSON object.
+ """
+ content = parse_json_object_from_request(request, allow_empty_body=False)
+ try:
+ instance = model_type.parse_obj(content)
+ except ValidationError as e:
+ raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=Codes.BAD_JSON)
+
+ return instance
+
+
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
absent = []
for k in required:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index eeec74b78a..1155f3f610 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -226,7 +226,7 @@ class SynapseRequest(Request):
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we return both.
- if self._requester.user.to_string() != authenticated_entity:
+ if requester != authenticated_entity:
return requester, authenticated_entity
return requester, None
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 17e729f0c7..482316a1ff 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ from typing import (
Any,
Callable,
Collection,
+ ContextManager,
Dict,
Generator,
Iterable,
@@ -182,6 +183,8 @@ from typing import (
Type,
TypeVar,
Union,
+ cast,
+ overload,
)
import attr
@@ -307,6 +310,19 @@ class SynapseTags:
# The name of the external cache
CACHE_NAME = "cache.name"
+ # Used to tag function arguments
+ #
+ # Tag a named arg. The name of the argument should be appended to this prefix.
+ FUNC_ARG_PREFIX = "ARG."
+ # Tag extra variadic number of positional arguments (`def foo(first, second, *extras)`)
+ FUNC_ARGS = "args"
+ # Tag keyword args
+ FUNC_KWARGS = "kwargs"
+
+ # Some intermediate result that's interesting to the function. The label for
+ # the result should be appended to this prefix.
+ RESULT_PREFIX = "RESULT."
+
class SynapseBaggage:
FORCE_TRACING = "synapse-force-tracing"
@@ -328,6 +344,7 @@ class _Sentinel(enum.Enum):
P = ParamSpec("P")
R = TypeVar("R")
+T = TypeVar("T")
def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@@ -343,22 +360,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
return _only_if_tracing_inner
-def ensure_active_span(message: str, ret=None):
+@overload
+def ensure_active_span(
+ message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+ ...
+
+
+@overload
+def ensure_active_span(
+ message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+ ...
+
+
+def ensure_active_span(
+ message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
"""Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level.
Args:
message: Message which fills in "There was no active span when trying to %s"
in the error log if there is no active span and opentracing is enabled.
- ret (object): return value if opentracing is None or there is no active span.
+ ret: return value if opentracing is None or there is no active span.
- Returns (object): The result of the func or ret if opentracing is disabled or there
+ Returns:
+ The result of the func, falling back to ret if opentracing is disabled or there
was no active span.
"""
- def ensure_active_span_inner_1(func):
+ def ensure_active_span_inner_1(
+ func: Callable[P, R]
+ ) -> Callable[P, Union[Optional[T], R]]:
@wraps(func)
- def ensure_active_span_inner_2(*args, **kwargs):
+ def ensure_active_span_inner_2(
+ *args: P.args, **kwargs: P.kwargs
+ ) -> Union[Optional[T], R]:
if not opentracing:
return ret
@@ -464,7 +502,7 @@ def start_active_span(
finish_on_close: bool = True,
*,
tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
"""Starts an active opentracing span.
Records the start time for the span, and sets it as the "active span" in the
@@ -502,7 +540,7 @@ def start_active_span_follows_from(
*,
inherit_force_tracing: bool = False,
tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
"""Starts an active opentracing span, with additional references to previous spans
Args:
@@ -717,7 +755,9 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
-@ensure_active_span("get the active span context as a dict", ret={})
+@ensure_active_span(
+ "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
@@ -797,75 +837,117 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators
-def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+def _custom_sync_async_decorator(
+ func: Callable[P, R],
+ wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
"""
- Decorator to trace a function with a custom opname.
-
- See the module's doc string for usage examples.
+ Decorates a function that is sync or async (coroutines), or that returns a Twisted
+ `Deferred`. The custom business logic of the decorator goes in `wrapping_logic`.
+
+ Example usage:
+ ```py
+ # Decorator to time the function and log it out
+ def duration(func: Callable[P, R]) -> Callable[P, R]:
+ @contextlib.contextmanager
+ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Generator[None, None, None]:
+ start_ts = time.time()
+ try:
+ yield
+ finally:
+ end_ts = time.time()
+ duration = end_ts - start_ts
+ logger.info("%s took %s seconds", func.__name__, duration)
+ return _custom_sync_async_decorator(func, _wrapping_logic)
+ ```
+ Args:
+ func: The function to be decorated
+ wrapping_logic: The business logic of your custom decorator.
+ This should be a ContextManager so you are able to run your logic
+ before/after the function as desired.
"""
- def decorator(func: Callable[P, R]) -> Callable[P, R]:
- if opentracing is None:
- return func # type: ignore[unreachable]
+ if inspect.iscoroutinefunction(func):
- if inspect.iscoroutinefunction(func):
+ @wraps(func)
+ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ with wrapping_logic(func, *args, **kwargs):
+ return await func(*args, **kwargs) # type: ignore[misc]
- @wraps(func)
- async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- with start_active_span(opname):
- return await func(*args, **kwargs) # type: ignore[misc]
+ else:
+ # The other case here handles both sync functions and those
+ # decorated with inlineDeferred.
+ @wraps(func)
+ def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ scope = wrapping_logic(func, *args, **kwargs)
+ scope.__enter__()
- else:
- # The other case here handles both sync functions and those
- # decorated with inlineDeferred.
- @wraps(func)
- def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- scope = start_active_span(opname)
- scope.__enter__()
-
- try:
- result = func(*args, **kwargs)
- if isinstance(result, defer.Deferred):
-
- def call_back(result: R) -> R:
- scope.__exit__(None, None, None)
- return result
-
- def err_back(result: R) -> R:
- scope.__exit__(None, None, None)
- return result
-
- result.addCallbacks(call_back, err_back)
-
- else:
- if inspect.isawaitable(result):
- logger.error(
- "@trace may not have wrapped %s correctly! "
- "The function is not async but returned a %s.",
- func.__qualname__,
- type(result).__name__,
- )
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result: R) -> R:
+ scope.__exit__(None, None, None)
+ return result
+ def err_back(result: R) -> R:
scope.__exit__(None, None, None)
+ return result
+
+ result.addCallbacks(call_back, err_back)
+
+ else:
+ if inspect.isawaitable(result):
+ logger.error(
+ "@trace may not have wrapped %s correctly! "
+ "The function is not async but returned a %s.",
+ func.__qualname__,
+ type(result).__name__,
+ )
+
+ scope.__exit__(None, None, None)
+
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _wrapper # type: ignore[return-value]
+
+
+def trace_with_opname(
+ opname: str,
+ *,
+ tracer: Optional["opentracing.Tracer"] = None,
+) -> Callable[[Callable[P, R]], Callable[P, R]]:
+ """
+ Decorator to trace a function with a custom opname.
+ See the module's doc string for usage examples.
+ """
- return result
+ # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+ @contextlib.contextmanager # type: ignore[arg-type]
+ def _wrapping_logic(
+ func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+ ) -> Generator[None, None, None]:
+ with start_active_span(opname, tracer=tracer):
+ yield
- except Exception as e:
- scope.__exit__(type(e), None, e.__traceback__)
- raise
+ def _decorator(func: Callable[P, R]) -> Callable[P, R]:
+ if not opentracing:
+ return func
- return _trace_inner # type: ignore[return-value]
+ return _custom_sync_async_decorator(func, _wrapping_logic)
- return decorator
+ return _decorator
def trace(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to trace a function.
-
Sets the operation name to that of the function's name.
-
See the module's doc string for usage examples.
"""
@@ -874,22 +956,36 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
"""
- Tags all of the args to the active span.
+ Decorator to tag all of the args to the active span.
+
+ Args:
+ func: `func` is assumed to be a method taking a `self` parameter, or a
+ `classmethod` taking a `cls` parameter. In either case, a tag is not
+ created for this parameter.
"""
if not opentracing:
return func
- @wraps(func)
- def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+ # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+ @contextlib.contextmanager # type: ignore[arg-type]
+ def _wrapping_logic(
+ func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+ ) -> Generator[None, None, None]:
argspec = inspect.getfullargspec(func)
- for i, arg in enumerate(argspec.args[1:]):
- set_tag("ARG_" + arg, args[i]) # type: ignore[index]
- set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
- set_tag("kwargs", kwargs)
- return func(*args, **kwargs)
-
- return _tag_args_inner
+ # We use `[1:]` to skip the `self` object reference and `start=1` to
+ # make the index line up with `argspec.args`.
+ #
+ # FIXME: We could update this to handle any type of function by ignoring the
+ # first argument only if it's named `self` or `cls`. This isn't fool-proof
+ # but handles the idiomatic cases.
+ for i, arg in enumerate(args[1:], start=1): # type: ignore[index]
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg))
+ set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index]
+ set_tag(SynapseTags.FUNC_KWARGS, str(kwargs))
+ yield
+
+ return _custom_sync_async_decorator(func, _wrapping_logic)
@contextlib.contextmanager
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index eef3462e10..7a1516d3a8 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -235,7 +235,7 @@ def run_as_background_process(
f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
)
else:
- ctx = nullcontext()
+ ctx = nullcontext() # type: ignore[assignment]
with ctx:
return await func(*args, **kwargs)
except Exception:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 6d8bf54083..87ba154cb7 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -929,10 +929,12 @@ class ModuleApi:
room_id: str,
new_membership: str,
content: Optional[JsonDict] = None,
+ remote_room_hosts: Optional[List[str]] = None,
) -> EventBase:
"""Updates the membership of a user to the given value.
Added in Synapse v1.46.0.
+ Changed in Synapse v1.65.0: Added the 'remote_room_hosts' parameter.
Args:
sender: The user performing the membership change. Must be a user local to
@@ -946,6 +948,7 @@ class ModuleApi:
https://spec.matrix.org/unstable/client-server-api/#mroommember for the
list of allowed values.
content: Additional values to include in the resulting event's content.
+ remote_room_hosts: Remote servers to use for remote joins/knocks/etc.
Returns:
The newly created membership event.
@@ -1005,15 +1008,12 @@ class ModuleApi:
room_id=room_id,
action=new_membership,
content=content,
+ remote_room_hosts=remote_room_hosts,
)
# Try to retrieve the resulting event.
event = await self._hs.get_datastores().main.get_event(event_id)
- # update_membership is supposed to always return after the event has been
- # successfully persisted.
- assert event is not None
-
return event
async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
@@ -1452,6 +1452,81 @@ class ModuleApi:
start_timestamp, end_timestamp
)
+ async def lookup_room_alias(self, room_alias: str) -> Tuple[str, List[str]]:
+ """
+ Get the room ID associated with a room alias.
+
+ Added in Synapse v1.65.0.
+
+ Args:
+ room_alias: The alias to look up.
+
+ Returns:
+ A tuple of:
+ The room ID (str).
+ Hosts likely to be participating in the room ([str]).
+
+ Raises:
+ SynapseError if room alias is invalid or could not be found.
+ """
+ alias = RoomAlias.from_string(room_alias)
+ (room_id, hosts) = await self._hs.get_room_member_handler().lookup_room_alias(
+ alias
+ )
+
+ return room_id.to_string(), hosts
+
+ async def create_room(
+ self,
+ user_id: str,
+ config: JsonDict,
+ ratelimit: bool = True,
+ creator_join_profile: Optional[JsonDict] = None,
+ ) -> Tuple[str, Optional[str]]:
+ """Creates a new room.
+
+ Added in Synapse v1.65.0.
+
+ Args:
+ user_id:
+ The user who requested the room creation.
+ config : A dict of configuration options. See "Request body" of:
+ https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
+ ratelimit: set to False to disable the rate limiter for this specific operation.
+
+ creator_join_profile:
+ Set to override the displayname and avatar for the creating
+ user in this room. If unset, displayname and avatar will be
+ derived from the user's profile. If set, should contain the
+ values to go in the body of the 'join' event (typically
+ `avatar_url` and/or `displayname`.
+
+ Returns:
+ A tuple containing: 1) the room ID (str), 2) if an alias was requested,
+ the room alias (str), otherwise None if no alias was requested.
+
+ Raises:
+ ResourceLimitError if server is blocked to some resource being
+ exceeded.
+ RuntimeError if the user_id does not refer to a local user.
+ SynapseError if the user_id is invalid, room ID couldn't be stored, or
+ something went horribly wrong.
+ """
+ if not self.is_mine(user_id):
+ raise RuntimeError(
+ "Tried to create a room as a user that isn't local to this homeserver",
+ )
+
+ requester = create_requester(user_id)
+ room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room(
+ requester=requester,
+ config=config,
+ ratelimit=ratelimit,
+ creator_join_profile=creator_join_profile,
+ )
+
+ return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None)
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 6c0cc5a6ce..440205e80c 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -14,128 +14,235 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
-from typing import Any, Dict, List
-
-from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+"""
+Push rules is the system used to determine which events trigger a push (and a
+bump in notification counts).
+
+This consists of a list of "push rules" for each user, where a push rule is a
+pair of "conditions" and "actions". When a user receives an event Synapse
+iterates over the list of push rules until it finds one where all the conditions
+match the event, at which point "actions" describe the outcome (e.g. notify,
+highlight, etc).
+
+Push rules are split up into 5 different "kinds" (aka "priority classes"), which
+are run in order:
+ 1. Override — highest priority rules, e.g. always ignore notices
+ 2. Content — content specific rules, e.g. @ notifications
+ 3. Room — per room rules, e.g. enable/disable notifications for all messages
+ in a room
+ 4. Sender — per sender rules, e.g. never notify for messages from a given
+ user
+ 5. Underride — the lowest priority "default" rules, e.g. notify for every
+ message.
+
+The set of "base rules" are the list of rules that every user has by default. A
+user can modify their copy of the push rules in one of three ways:
+
+ 1. Adding a new push rule of a certain kind
+ 2. Changing the actions of a base rule
+ 3. Enabling/disabling a base rule.
+
+The base rules are split into whether they come before or after a particular
+kind, so the order of push rule evaluation would be: base rules for before
+"override" kind, user defined "override" rules, base rules after "override"
+kind, etc, etc.
+"""
+
+import itertools
+import logging
+from typing import Dict, Iterator, List, Mapping, Sequence, Tuple, Union
+
+import attr
+
+from synapse.config.experimental import ExperimentalConfig
+from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class PushRule:
+ """A push rule
+
+ Attributes:
+ rule_id: a unique ID for this rule
+ priority_class: what "kind" of push rule this is (see
+ `PRIORITY_CLASS_MAP` for mapping between int and kind)
+ conditions: the sequence of conditions that all need to match
+ actions: the actions to apply if all conditions are met
+ default: is this a base rule?
+ default_enabled: is this enabled by default?
+ """
+ rule_id: str
+ priority_class: int
+ conditions: Sequence[Mapping[str, str]]
+ actions: Sequence[Union[str, Mapping]]
+ default: bool = False
+ default_enabled: bool = True
-def list_with_base_rules(rawrules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Combine the list of rules set by the user with the default push rules
- Args:
- rawrules: The rules the user has modified or set.
+@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False)
+class PushRules:
+ """A collection of push rules for an account.
- Returns:
- A new list with the rules set by the user combined with the defaults.
+ Can be iterated over, producing push rules in priority order.
"""
- ruleslist = []
- # Grab the base rules that the user has modified.
- # The modified base rules have a priority_class of -1.
- modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0}
+ # A mapping from rule ID to push rule that overrides a base rule. These will
+ # be returned instead of the base rule.
+ overriden_base_rules: Dict[str, PushRule] = attr.Factory(dict)
+
+ # The following stores the custom push rules at each priority class.
+ #
+ # We keep these separate (rather than combining into one big list) to avoid
+ # copying the base rules around all the time.
+ override: List[PushRule] = attr.Factory(list)
+ content: List[PushRule] = attr.Factory(list)
+ room: List[PushRule] = attr.Factory(list)
+ sender: List[PushRule] = attr.Factory(list)
+ underride: List[PushRule] = attr.Factory(list)
+
+ def __iter__(self) -> Iterator[PushRule]:
+ # When iterating over the push rules we need to return the base rules
+ # interspersed at the correct spots.
+ for rule in itertools.chain(
+ BASE_PREPEND_OVERRIDE_RULES,
+ self.override,
+ BASE_APPEND_OVERRIDE_RULES,
+ self.content,
+ BASE_APPEND_CONTENT_RULES,
+ self.room,
+ self.sender,
+ self.underride,
+ BASE_APPEND_UNDERRIDE_RULES,
+ ):
+ # Check if a base rule has been overriden by a custom rule. If so
+ # return that instead.
+ override_rule = self.overriden_base_rules.get(rule.rule_id)
+ if override_rule:
+ yield override_rule
+ else:
+ yield rule
+
+ def __len__(self) -> int:
+ # The length is mostly used by caches to get a sense of "size" / amount
+ # of memory this object is using, so we only count the number of custom
+ # rules.
+ return (
+ len(self.overriden_base_rules)
+ + len(self.override)
+ + len(self.content)
+ + len(self.room)
+ + len(self.sender)
+ + len(self.underride)
+ )
- # Remove the modified base rules from the list, They'll be added back
- # in the default positions in the list.
- rawrules = [r for r in rawrules if r["priority_class"] >= 0]
- # shove the server default rules for each kind onto the end of each
- current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
+@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False)
+class FilteredPushRules:
+ """A wrapper around `PushRules` that filters out disabled experimental push
+ rules, and includes the "enabled" state for each rule when iterated over.
+ """
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- )
- )
+ push_rules: PushRules
+ enabled_map: Dict[str, bool]
+ experimental_config: ExperimentalConfig
- for r in rawrules:
- if r["priority_class"] < current_prio_class:
- while r["priority_class"] < current_prio_class:
- ruleslist.extend(
- make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- )
- )
- current_prio_class -= 1
- if current_prio_class > 0:
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- )
- )
-
- ruleslist.append(r)
-
- while current_prio_class > 0:
- ruleslist.extend(
- make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- )
- )
- current_prio_class -= 1
- if current_prio_class > 0:
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- )
- )
+ def __iter__(self) -> Iterator[Tuple[PushRule, bool]]:
+ for rule in self.push_rules:
+ if not _is_experimental_rule_enabled(
+ rule.rule_id, self.experimental_config
+ ):
+ continue
- return ruleslist
+ enabled = self.enabled_map.get(rule.rule_id, rule.default_enabled)
+ yield rule, enabled
-def make_base_append_rules(
- kind: str, modified_base_rules: Dict[str, Dict[str, Any]]
-) -> List[Dict[str, Any]]:
- rules = []
+ def __len__(self) -> int:
+ return len(self.push_rules)
- if kind == "override":
- rules = BASE_APPEND_OVERRIDE_RULES
- elif kind == "underride":
- rules = BASE_APPEND_UNDERRIDE_RULES
- elif kind == "content":
- rules = BASE_APPEND_CONTENT_RULES
- # Copy the rules before modifying them
- rules = copy.deepcopy(rules)
- for r in rules:
- # Only modify the actions, keep the conditions the same.
- assert isinstance(r["rule_id"], str)
- modified = modified_base_rules.get(r["rule_id"])
- if modified:
- r["actions"] = modified["actions"]
+DEFAULT_EMPTY_PUSH_RULES = PushRules()
- return rules
+def compile_push_rules(rawrules: List[PushRule]) -> PushRules:
+ """Given a set of custom push rules return a `PushRules` instance (which
+ includes the base rules).
+ """
+
+ if not rawrules:
+ # Fast path to avoid allocating empty lists when there are no custom
+ # rules for the user.
+ return DEFAULT_EMPTY_PUSH_RULES
+
+ rules = PushRules()
-def make_base_prepend_rules(
- kind: str,
- modified_base_rules: Dict[str, Dict[str, Any]],
-) -> List[Dict[str, Any]]:
- rules = []
+ for rule in rawrules:
+ # We need to decide which bucket each custom push rule goes into.
- if kind == "override":
- rules = BASE_PREPEND_OVERRIDE_RULES
+ # If it has the same ID as a base rule then it overrides that...
+ overriden_base_rule = BASE_RULES_BY_ID.get(rule.rule_id)
+ if overriden_base_rule:
+ rules.overriden_base_rules[rule.rule_id] = attr.evolve(
+ overriden_base_rule, actions=rule.actions
+ )
+ continue
+
+ # ... otherwise it gets added to the appropriate priority class bucket
+ collection: List[PushRule]
+ if rule.priority_class == 5:
+ collection = rules.override
+ elif rule.priority_class == 4:
+ collection = rules.content
+ elif rule.priority_class == 3:
+ collection = rules.room
+ elif rule.priority_class == 2:
+ collection = rules.sender
+ elif rule.priority_class == 1:
+ collection = rules.underride
+ elif rule.priority_class <= 0:
+ logger.info(
+ "Got rule with priority class less than zero, but doesn't override a base rule: %s",
+ rule,
+ )
+ continue
+ else:
+ # We log and continue here so as not to break event sending
+ logger.error("Unknown priority class: %", rule.priority_class)
+ continue
- # Copy the rules before modifying them
- rules = copy.deepcopy(rules)
- for r in rules:
- # Only modify the actions, keep the conditions the same.
- assert isinstance(r["rule_id"], str)
- modified = modified_base_rules.get(r["rule_id"])
- if modified:
- r["actions"] = modified["actions"]
+ collection.append(rule)
return rules
-# We have to annotate these types, otherwise mypy infers them as
-# `List[Dict[str, Sequence[Collection[str]]]]`.
-BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/content/.m.rule.contains_user_name",
- "conditions": [
+def _is_experimental_rule_enabled(
+ rule_id: str, experimental_config: ExperimentalConfig
+) -> bool:
+ """Used by `FilteredPushRules` to filter out experimental rules when they
+ have not been enabled.
+ """
+ if (
+ rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
+ and not experimental_config.msc3786_enabled
+ ):
+ return False
+ if (
+ rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+ and not experimental_config.msc3772_enabled
+ ):
+ return False
+ return True
+
+
+BASE_APPEND_CONTENT_RULES = [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["content"],
+ rule_id="global/content/.m.rule.contains_user_name",
+ conditions=[
{
"kind": "event_match",
"key": "content.body",
@@ -143,29 +250,33 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
"pattern_type": "user_localpart",
}
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
],
- }
+ )
]
-BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/override/.m.rule.master",
- "enabled": False,
- "conditions": [],
- "actions": ["dont_notify"],
- }
+BASE_PREPEND_OVERRIDE_RULES = [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.master",
+ default_enabled=False,
+ conditions=[],
+ actions=["dont_notify"],
+ )
]
-BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/override/.m.rule.suppress_notices",
- "conditions": [
+BASE_APPEND_OVERRIDE_RULES = [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.suppress_notices",
+ conditions=[
{
"kind": "event_match",
"key": "content.msgtype",
@@ -173,13 +284,15 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_suppress_notices",
}
],
- "actions": ["dont_notify"],
- },
+ actions=["dont_notify"],
+ ),
# NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
# otherwise invites will be matched by .m.rule.member_event
- {
- "rule_id": "global/override/.m.rule.invite_for_me",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.invite_for_me",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -195,21 +308,23 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
# Match the requester's MXID.
{"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight", "value": False},
],
- },
+ ),
# Will we sometimes want to know about people joining and leaving?
# Perhaps: if so, this could be expanded upon. Seems the most usual case
# is that we don't though. We add this override rule so that even if
# the room rule is set to notify, we don't get notifications about
# join/leave/avatar/displayname events.
# See also: https://matrix.org/jira/browse/SYN-607
- {
- "rule_id": "global/override/.m.rule.member_event",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.member_event",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -217,24 +332,28 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_member",
}
],
- "actions": ["dont_notify"],
- },
+ actions=["dont_notify"],
+ ),
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
- {
- "rule_id": "global/override/.m.rule.contains_display_name",
- "conditions": [{"kind": "contains_display_name"}],
- "actions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.contains_display_name",
+ conditions=[{"kind": "contains_display_name"}],
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
],
- },
- {
- "rule_id": "global/override/.m.rule.roomnotif",
- "conditions": [
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.roomnotif",
+ conditions=[
{
"kind": "event_match",
"key": "content.body",
@@ -247,11 +366,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_roomnotif_pl",
},
],
- "actions": ["notify", {"set_tweak": "highlight", "value": True}],
- },
- {
- "rule_id": "global/override/.m.rule.tombstone",
- "conditions": [
+ actions=["notify", {"set_tweak": "highlight", "value": True}],
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.tombstone",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -265,11 +386,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_tombstone_statekey",
},
],
- "actions": ["notify", {"set_tweak": "highlight", "value": True}],
- },
- {
- "rule_id": "global/override/.m.rule.reaction",
- "conditions": [
+ actions=["notify", {"set_tweak": "highlight", "value": True}],
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.reaction",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -277,14 +400,16 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_reaction",
}
],
- "actions": ["dont_notify"],
- },
+ actions=["dont_notify"],
+ ),
# XXX: This is an experimental rule that is only enabled if msc3786_enabled
# is enabled, if it is not the rule gets filtered out in _load_rules() in
# PushRulesWorkerStore
- {
- "rule_id": "global/override/.org.matrix.msc3786.rule.room.server_acl",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.org.matrix.msc3786.rule.room.server_acl",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -298,15 +423,17 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_room_server_acl_state_key",
},
],
- "actions": [],
- },
+ actions=[],
+ ),
]
-BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/underride/.m.rule.call",
- "conditions": [
+BASE_APPEND_UNDERRIDE_RULES = [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.call",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -314,17 +441,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_call",
}
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "ring"},
{"set_tweak": "highlight", "value": False},
],
- },
+ ),
# XXX: once m.direct is standardised everywhere, we should use it to detect
# a DM from the user's perspective rather than this heuristic.
- {
- "rule_id": "global/underride/.m.rule.room_one_to_one",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.room_one_to_one",
+ conditions=[
{"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
{
"kind": "event_match",
@@ -333,17 +462,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_message",
},
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight", "value": False},
],
- },
+ ),
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
- {
- "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.encrypted_room_one_to_one",
+ conditions=[
{"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
{
"kind": "event_match",
@@ -352,15 +483,17 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_encrypted",
},
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight", "value": False},
],
- },
- {
- "rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
- "conditions": [
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.org.matrix.msc3772.thread_reply",
+ conditions=[
{
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.thread",
@@ -368,11 +501,13 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"sender_type": "user_id",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
- {
- "rule_id": "global/underride/.m.rule.message",
- "conditions": [
+ actions=["notify", {"set_tweak": "highlight", "value": False}],
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.message",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -380,13 +515,15 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_message",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
+ actions=["notify", {"set_tweak": "highlight", "value": False}],
+ ),
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
- {
- "rule_id": "global/underride/.m.rule.encrypted",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.encrypted",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -394,11 +531,13 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_encrypted",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
- {
- "rule_id": "global/underride/.im.vector.jitsi",
- "conditions": [
+ actions=["notify", {"set_tweak": "highlight", "value": False}],
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.im.vector.jitsi",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -418,29 +557,27 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_is_state_event",
},
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
+ actions=["notify", {"set_tweak": "highlight", "value": False}],
+ ),
]
BASE_RULE_IDS = set()
+BASE_RULES_BY_ID: Dict[str, PushRule] = {}
+
for r in BASE_APPEND_CONTENT_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["content"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
+ BASE_RULE_IDS.add(r.rule_id)
+ BASE_RULES_BY_ID[r.rule_id] = r
for r in BASE_PREPEND_OVERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["override"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
+ BASE_RULE_IDS.add(r.rule_id)
+ BASE_RULES_BY_ID[r.rule_id] = r
for r in BASE_APPEND_OVERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["override"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
+ BASE_RULE_IDS.add(r.rule_id)
+ BASE_RULES_BY_ID[r.rule_id] = r
for r in BASE_APPEND_UNDERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
+ BASE_RULE_IDS.add(r.rule_id)
+ BASE_RULES_BY_ID[r.rule_id] = r
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 713dcf6950..ccd512be54 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,7 +15,18 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
from prometheus_client import Counter
@@ -30,6 +41,7 @@ from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
+from .baserules import FilteredPushRules, PushRule
from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
@@ -112,7 +124,7 @@ class BulkPushRuleEvaluator:
async def _get_rules_for_event(
self,
event: EventBase,
- ) -> Dict[str, List[Dict[str, Any]]]:
+ ) -> Dict[str, FilteredPushRules]:
"""Get the push rules for all users who may need to be notified about
the event.
@@ -186,7 +198,7 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
async def _get_mutual_relations(
- self, event: EventBase, rules: Iterable[Dict[str, Any]]
+ self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]]
) -> Dict[str, Set[Tuple[str, str]]]:
"""
Fetch event metadata for events which related to the same event as the given event.
@@ -216,12 +228,11 @@ class BulkPushRuleEvaluator:
# Pre-filter to figure out which relation types are interesting.
rel_types = set()
- for rule in rules:
- # Skip disabled rules.
- if "enabled" in rule and not rule["enabled"]:
+ for rule, enabled in rules:
+ if not enabled:
continue
- for condition in rule["conditions"]:
+ for condition in rule.conditions:
if condition["kind"] != "org.matrix.msc3772.relation_match":
continue
@@ -254,7 +265,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event)
- actions_by_user: Dict[str, List[Union[dict, str]]] = {}
+ actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
room_member_count = await self.store.get_number_joined_users_in_room(
event.room_id
@@ -317,15 +328,13 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later.
actions_by_user[uid] = []
- for rule in rules:
- if "enabled" in rule and not rule["enabled"]:
+ for rule, enabled in rules:
+ if not enabled:
continue
- matches = evaluator.check_conditions(
- rule["conditions"], uid, display_name
- )
+ matches = evaluator.check_conditions(rule.conditions, uid, display_name)
if matches:
- actions = [x for x in rule["actions"] if x != "dont_notify"]
+ actions = [x for x in rule.actions if x != "dont_notify"]
if actions and "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 5117ef6854..73618d9234 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -18,16 +18,15 @@ from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
from synapse.types import UserID
+from .baserules import FilteredPushRules, PushRule
+
def format_push_rules_for_user(
- user: UserID, ruleslist: List
+ user: UserID, ruleslist: FilteredPushRules
) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
- # We're going to be mutating this a lot, so do a deep copy
- ruleslist = copy.deepcopy(ruleslist)
-
rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
"global": {},
"device": {},
@@ -35,11 +34,30 @@ def format_push_rules_for_user(
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
- for r in ruleslist:
- template_name = _priority_class_to_template_name(r["priority_class"])
+ for r, enabled in ruleslist:
+ template_name = _priority_class_to_template_name(r.priority_class)
+
+ rulearray = rules["global"][template_name]
+
+ template_rule = _rule_to_template(r)
+ if not template_rule:
+ continue
+
+ rulearray.append(template_rule)
+
+ template_rule["enabled"] = enabled
+
+ if "conditions" not in template_rule:
+ # Not all formatted rules have explicit conditions, e.g. "room"
+ # rules omit them as they can be derived from the kind and rule ID.
+ #
+ # If the formatted rule has no conditions then we can skip the
+ # formatting of conditions.
+ continue
# Remove internal stuff.
- for c in r["conditions"]:
+ template_rule["conditions"] = copy.deepcopy(template_rule["conditions"])
+ for c in template_rule["conditions"]:
c.pop("_cache_key", None)
pattern_type = c.pop("pattern_type", None)
@@ -52,16 +70,6 @@ def format_push_rules_for_user(
if sender_type == "user_id":
c["sender"] = user.to_string()
- rulearray = rules["global"][template_name]
-
- template_rule = _rule_to_template(r)
- if template_rule:
- if "enabled" in r:
- template_rule["enabled"] = r["enabled"]
- else:
- template_rule["enabled"] = True
- rulearray.append(template_rule)
-
return rules
@@ -71,24 +79,24 @@ def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
return d
-def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- unscoped_rule_id = None
- if "rule_id" in rule:
- unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
+def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
+ templaterule: Dict[str, Any]
+
+ unscoped_rule_id = _rule_id_from_namespaced(rule.rule_id)
- template_name = _priority_class_to_template_name(rule["priority_class"])
+ template_name = _priority_class_to_template_name(rule.priority_class)
if template_name in ["override", "underride"]:
- templaterule = {k: rule[k] for k in ["conditions", "actions"]}
+ templaterule = {"conditions": rule.conditions, "actions": rule.actions}
elif template_name in ["sender", "room"]:
- templaterule = {"actions": rule["actions"]}
- unscoped_rule_id = rule["conditions"][0]["pattern"]
+ templaterule = {"actions": rule.actions}
+ unscoped_rule_id = rule.conditions[0]["pattern"]
elif template_name == "content":
- if len(rule["conditions"]) != 1:
+ if len(rule.conditions) != 1:
return None
- thecond = rule["conditions"][0]
+ thecond = rule.conditions[0]
if "pattern" not in thecond:
return None
- templaterule = {"actions": rule["actions"]}
+ templaterule = {"actions": rule.actions}
templaterule["pattern"] = thecond["pattern"]
else:
# This should not be reached unless this function is not kept in sync
@@ -97,8 +105,8 @@ def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id
- if "default" in rule:
- templaterule["default"] = rule["default"]
+ if rule.default:
+ templaterule["default"] = True
return templaterule
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2e8a017add..3c5632cd91 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,7 +15,18 @@
import logging
import re
-from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
+from typing import (
+ Any,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Pattern,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
from matrix_common.regex import glob_to_regex, to_word_pattern
@@ -32,14 +43,14 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(
- ev: EventBase, condition: Dict[str, Any], room_member_count: int
+ ev: EventBase, condition: Mapping[str, Any], room_member_count: int
) -> bool:
return _test_ineq_condition(condition, room_member_count)
def _sender_notification_permission(
ev: EventBase,
- condition: Dict[str, Any],
+ condition: Mapping[str, Any],
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
) -> bool:
@@ -54,7 +65,7 @@ def _sender_notification_permission(
return sender_power_level >= room_notif_level
-def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
+def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool:
if "is" not in condition:
return False
m = INEQUALITY_EXPR.match(condition["is"])
@@ -137,7 +148,7 @@ class PushRuleEvaluatorForEvent:
self._condition_cache: Dict[str, bool] = {}
def check_conditions(
- self, conditions: List[dict], uid: str, display_name: Optional[str]
+ self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str]
) -> bool:
"""
Returns true if a user's conditions/user ID/display name match the event.
@@ -169,7 +180,7 @@ class PushRuleEvaluatorForEvent:
return True
def matches(
- self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
+ self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
"""
Returns true if a user's condition/user ID/display name match the event.
@@ -204,7 +215,7 @@ class PushRuleEvaluatorForEvent:
# endpoint with an unknown kind, see _rule_tuple_from_request_object.
return True
- def _event_match(self, condition: dict, user_id: str) -> bool:
+ def _event_match(self, condition: Mapping, user_id: str) -> bool:
"""
Check an "event_match" push rule condition.
@@ -269,7 +280,7 @@ class PushRuleEvaluatorForEvent:
return bool(r.search(body))
- def _relation_match(self, condition: dict, user_id: str) -> bool:
+ def _relation_match(self, condition: Mapping, user_id: str) -> bool:
"""
Check an "relation_match" push rule condition.
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
deleted file mode 100644
index 7644146dba..0000000000
--- a/synapse/replication/slave/storage/_base.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Optional
-
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class BaseSlavedStore(CacheInvalidationWorkerStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
- if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen: Optional[
- MultiWriterIdGenerator
- ] = MultiWriterIdGenerator(
- db_conn,
- database,
- stream_name="caches",
- instance_name=hs.get_instance_name(),
- tables=[
- (
- "cache_invalidation_stream_by_instance",
- "instance_name",
- "stream_id",
- )
- ],
- sequence_name="cache_invalidation_stream_seq",
- writers=[],
- )
- else:
- self._cache_id_gen = None
-
- self.hs = hs
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
deleted file mode 100644
index ee74ee7d85..0000000000
--- a/synapse/replication/slave/storage/account_data.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.account_data import AccountDataWorkerStore
-from synapse.storage.databases.main.tags import TagsWorkerStore
-
-
-class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
- pass
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
deleted file mode 100644
index 29f50c0add..0000000000
--- a/synapse/replication/slave/storage/appservice.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.appservice import (
- ApplicationServiceTransactionWorkerStore,
- ApplicationServiceWorkerStore,
-)
-
-
-class SlavedApplicationServiceStore(
- ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore
-):
- pass
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
deleted file mode 100644
index e940751084..0000000000
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-
-
-class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
- pass
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index a48cc02069..6fcade510a 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -14,7 +14,6 @@
from typing import TYPE_CHECKING, Any, Iterable
-from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
@@ -24,7 +23,7 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
+class SlavedDeviceStore(DeviceWorkerStore):
def __init__(
self,
database: DatabasePool,
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
deleted file mode 100644
index 71fde0c96c..0000000000
--- a/synapse/replication/slave/storage/directory.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.directory import DirectoryWorkerStore
-
-from ._base import BaseSlavedStore
-
-
-class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
- pass
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index a72dad7464..fe47778cb1 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -29,8 +29,6 @@ from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import BaseSlavedStore
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -56,7 +54,6 @@ class SlavedEventStore(
EventsWorkerStore,
UserErasureWorkerStore,
RelationsWorkerStore,
- BaseSlavedStore,
):
def __init__(
self,
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 4d185e2b56..c52679cd60 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,16 +14,15 @@
from typing import TYPE_CHECKING
+from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.filtering import FilteringStore
-from ._base import BaseSlavedStore
-
if TYPE_CHECKING:
from synapse.server import HomeServer
-class SlavedFilteringStore(BaseSlavedStore):
+class SlavedFilteringStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
deleted file mode 100644
index 99f4a22642..0000000000
--- a/synapse/replication/slave/storage/profile.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.profile import ProfileWorkerStore
-
-
-class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
- pass
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 52ee3f7e58..5e65eaf1e0 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -31,6 +31,5 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
- self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index de642bba71..44ed20e424 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -18,14 +18,13 @@ from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.pusher import PusherWorkerStore
-from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
if TYPE_CHECKING:
from synapse.server import HomeServer
-class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
+class SlavedPusherStore(PusherWorkerStore):
def __init__(
self,
database: DatabasePool,
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
deleted file mode 100644
index 3826b87dec..0000000000
--- a/synapse/replication/slave/storage/receipts.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
-
-from ._base import BaseSlavedStore
-
-
-class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
- pass
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
deleted file mode 100644
index 5dae35a960..0000000000
--- a/synapse/replication/slave/storage/registration.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.registration import RegistrationWorkerStore
-
-from ._base import BaseSlavedStore
-
-
-class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
- pass
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e4f2201c92..1ed7230e32 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -416,7 +416,10 @@ class FederationSenderHandler:
if not self._is_mine_id(receipt.user_id):
continue
# Private read receipts never get sent over federation.
- if receipt.receipt_type == ReceiptTypes.READ_PRIVATE:
+ if receipt.receipt_type in (
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ):
continue
receipt_info = ReadReceipt(
receipt.room_id,
diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html
index b751359bdf..bd4f7cea97 100644
--- a/synapse/res/templates/account_previously_renewed.html
+++ b/synapse/res/templates/account_previously_renewed.html
@@ -1 +1,12 @@
-<html><body>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <title>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</title>
+</head>
+<body>
+ Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+</body>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
index e8c0f52f05..57b319f375 100644
--- a/synapse/res/templates/account_renewed.html
+++ b/synapse/res/templates/account_renewed.html
@@ -1 +1,12 @@
-<html><body>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <title>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</title>
+</head>
+<body>
+ Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+</body>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html
index cc4ab07e09..71f2215b7a 100644
--- a/synapse/res/templates/add_threepid.html
+++ b/synapse/res/templates/add_threepid.html
@@ -1,9 +1,14 @@
-<html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <title>Request to add an email address to your Matrix account</title>
+</head>
<body>
<p>A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:</p>
-
<a href="{{ link }}">{{ link }}</a>
-
<p>If this was not you, you can safely ignore this email. Thank you.</p>
</body>
</html>
diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html
index 441d11c846..bd627ee9ce 100644
--- a/synapse/res/templates/add_threepid_failure.html
+++ b/synapse/res/templates/add_threepid_failure.html
@@ -1,8 +1,13 @@
-<html>
-<head></head>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <title>Request failed</title>
+</head>
<body>
-<p>The request failed for the following reason: {{ failure_reason }}.</p>
-
-<p>No changes have been made to your account.</p>
+ <p>The request failed for the following reason: {{ failure_reason }}.</p>
+ <p>No changes have been made to your account.</p>
</body>
</html>
diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html
index fbd6e4018f..49170c138e 100644
--- a/synapse/res/templates/add_threepid_success.html
+++ b/synapse/res/templates/add_threepid_success.html
@@ -1,6 +1,12 @@
-<html>
-<head></head>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <title>Your email has now been validated</title>
+</head>
<body>
-<p>Your email has now been validated, please return to your client. You may now close this window.</p>
+ <p>Your email has now been validated, please return to your client. You may now close this window.</p>
</body>
-</html>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html
index baf4633142..2d6ac44a0e 100644
--- a/synapse/res/templates/auth_success.html
+++ b/synapse/res/templates/auth_success.html
@@ -1,8 +1,8 @@
<html>
<head>
<title>Success!</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
if (window.onAuthDone) {
diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html
index 6bd2b98364..2c7c384fe3 100644
--- a/synapse/res/templates/invalid_token.html
+++ b/synapse/res/templates/invalid_token.html
@@ -1 +1,12 @@
-<html><body>Invalid renewal token.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <title>Invalid renewal token.</title>
+</head>
+<body>
+ Invalid renewal token.
+</body>
+</html>
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index d87311f659..865f9f7ada 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -1,6 +1,8 @@
<!doctype html>
<html lang="en">
<head>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<style type="text/css">
{% include 'mail.css' without context %}
{% include "mail-%s.css" % app_name ignore missing without context %}
diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index 27d4182790..9dba0c0253 100644
--- a/synapse/res/templates/notif_mail.html
+++ b/synapse/res/templates/notif_mail.html
@@ -1,6 +1,8 @@
<!doctype html>
<html lang="en">
<head>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<style type="text/css">
{%- include 'mail.css' without context %}
{%- include "mail-%s.css" % app_name ignore missing without context %}
diff --git a/synapse/res/templates/password_reset.html b/synapse/res/templates/password_reset.html
index a197bf872c..a8bdce357b 100644
--- a/synapse/res/templates/password_reset.html
+++ b/synapse/res/templates/password_reset.html
@@ -1,4 +1,9 @@
-<html>
+<html lang="en">
+ <head>
+ <title>Password reset</title>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ </head>
<body>
<p>A password reset request has been received for your Matrix account. If this was you, please click the link below to confirm resetting your password:</p>
diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html
index def4b5162b..2e3fd2ec1e 100644
--- a/synapse/res/templates/password_reset_confirmation.html
+++ b/synapse/res/templates/password_reset_confirmation.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+ <title>Password reset confirmation</title>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
<body>
<!--Use a hidden form to resubmit the information necessary to reset the password-->
<form method="post">
diff --git a/synapse/res/templates/password_reset_failure.html b/synapse/res/templates/password_reset_failure.html
index 9e3c4446e3..2d59c463f0 100644
--- a/synapse/res/templates/password_reset_failure.html
+++ b/synapse/res/templates/password_reset_failure.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+ <title>Password reset failure</title>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
<body>
<p>The request failed for the following reason: {{ failure_reason }}.</p>
diff --git a/synapse/res/templates/password_reset_success.html b/synapse/res/templates/password_reset_success.html
index 7324d66d1e..5165bd1fa2 100644
--- a/synapse/res/templates/password_reset_success.html
+++ b/synapse/res/templates/password_reset_success.html
@@ -1,5 +1,8 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
<body>
<p>Your email has now been validated, please return to your client to reset your password. You may now close this window.</p>
</body>
diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html
index b3db06ef97..615d3239c6 100644
--- a/synapse/res/templates/recaptcha.html
+++ b/synapse/res/templates/recaptcha.html
@@ -1,8 +1,8 @@
<html>
<head>
<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://www.recaptcha.net/recaptcha/api.js"
async defer></script>
<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
diff --git a/synapse/res/templates/registration.html b/synapse/res/templates/registration.html
index 16730a527f..20e831ff4a 100644
--- a/synapse/res/templates/registration.html
+++ b/synapse/res/templates/registration.html
@@ -1,4 +1,9 @@
-<html>
+<html lang="en">
+<head>
+ <title>Registration</title>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
<body>
<p>You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:</p>
diff --git a/synapse/res/templates/registration_failure.html b/synapse/res/templates/registration_failure.html
index 2833d79c37..a6ed22bc90 100644
--- a/synapse/res/templates/registration_failure.html
+++ b/synapse/res/templates/registration_failure.html
@@ -1,5 +1,8 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
<body>
<p>Validation failed for the following reason: {{ failure_reason }}.</p>
</body>
diff --git a/synapse/res/templates/registration_success.html b/synapse/res/templates/registration_success.html
index fbd6e4018f..d51d5549d8 100644
--- a/synapse/res/templates/registration_success.html
+++ b/synapse/res/templates/registration_success.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+ <title>Your email has now been validated</title>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
<body>
<p>Your email has now been validated, please return to your client. You may now close this window.</p>
</body>
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
index 4577ce1702..59a98f564c 100644
--- a/synapse/res/templates/registration_token.html
+++ b/synapse/res/templates/registration_token.html
@@ -1,8 +1,8 @@
-<html>
+<html lang="en">
<head>
<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
</head>
<body>
diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html
index c3e4deed93..075f801cec 100644
--- a/synapse/res/templates/sso_account_deactivated.html
+++ b/synapse/res/templates/sso_account_deactivated.html
@@ -3,8 +3,8 @@
<head>
<meta charset="UTF-8">
<title>SSO account deactivated</title>
- <meta name="viewport" content="width=device-width, user-scalable=no">
- <style type="text/css">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0"> <style type="text/css">
{% include "sso.css" without context %}
</style>
</head>
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
index 1ba850369a..2d1db386e1 100644
--- a/synapse/res/templates/sso_auth_account_details.html
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -3,7 +3,8 @@
<head>
<title>Create your account</title>
<meta charset="utf-8">
- <meta name="viewport" content="width=device-width, user-scalable=no">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<script type="text/javascript">
let wasKeyboard = false;
document.addEventListener("mousedown", function() { wasKeyboard = false; });
@@ -138,7 +139,7 @@
<div class="username_input" id="username_input">
<label for="field-username">Username (required)</label>
<div class="prefix">@</div>
- <input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
+ <input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus autocorrect="off" autocapitalize="none">
<div class="postfix">:{{ server_name }}</div>
</div>
<output for="username_input" id="field-username-output"></output>
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index da579ffe69..94403fc3ce 100644
--- a/synapse/res/templates/sso_auth_bad_user.html
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -3,7 +3,8 @@
<head>
<meta charset="UTF-8">
<title>Authentication failed</title>
- <meta name="viewport" content="width=device-width, user-scalable=no">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<style type="text/css">
{% include "sso.css" without context %}
</style>
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index f9d0456f0a..aa1c974a6b 100644
--- a/synapse/res/templates/sso_auth_confirm.html
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -3,7 +3,8 @@
<head>
<meta charset="UTF-8">
<title>Confirm it's you</title>
- <meta name="viewport" content="width=device-width, user-scalable=no">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<style type="text/css">
{% include "sso.css" without context %}
</style>
diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html
index 1ed3967e87..4898af6011 100644
--- a/synapse/res/templates/sso_auth_success.html
+++ b/synapse/res/templates/sso_auth_success.html
@@ -3,7 +3,8 @@
<head>
<meta charset="UTF-8">
<title>Authentication successful</title>
- <meta name="viewport" content="width=device-width, user-scalable=no">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<style type="text/css">
{% include "sso.css" without context %}
</style>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 472309c350..19992ff2ad 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -3,7 +3,8 @@
<head>
<meta charset="UTF-8">
<title>Authentication failed</title>
- <meta name="viewport" content="width=device-width, user-scalable=no">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<style type="text/css">
{% include "sso.css" without context %}
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 53b82db84e..56fabfa3d2 100644
--- a/synapse/res/templates/sso_login_idp_picker.html
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -1,6 +1,8 @@
<!DOCTYPE html>
<html lang="en">
<head>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta charset="UTF-8">
<title>Choose identity provider</title>
<style type="text/css">
diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html
index 68c8b9f33a..523f64c4fc 100644
--- a/synapse/res/templates/sso_new_user_consent.html
+++ b/synapse/res/templates/sso_new_user_consent.html
@@ -3,7 +3,8 @@
<head>
<meta charset="UTF-8">
<title>Agree to terms and conditions</title>
- <meta name="viewport" content="width=device-width, user-scalable=no">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<style type="text/css">
{% include "sso.css" without context %}
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index 1b01471ac8..1049a9bd92 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -3,7 +3,8 @@
<head>
<meta charset="UTF-8">
<title>Continue to your account</title>
- <meta name="viewport" content="width=device-width, user-scalable=no">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<style type="text/css">
{% include "sso.css" without context %}
diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html
index 369ff446d2..2081d990ab 100644
--- a/synapse/res/templates/terms.html
+++ b/synapse/res/templates/terms.html
@@ -1,8 +1,8 @@
<html>
<head>
<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<meta http-equiv="X-UA-Compatible" content="IE=edge">
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
</head>
<body>
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index 399b205aaf..b467a61dfb 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -19,7 +19,7 @@ from typing import Iterable, Pattern
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.http.site import SynapseRequest
-from synapse.types import UserID
+from synapse.types import Requester
def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]:
@@ -48,19 +48,19 @@ async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None
AuthError if the requester is not a server admin
"""
requester = await auth.get_user_by_req(request)
- await assert_user_is_admin(auth, requester.user)
+ await assert_user_is_admin(auth, requester)
-async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, requester: Requester) -> None:
"""Verify that the given user is an admin user
Args:
auth: Auth singleton
- user_id: user to check
+ requester: The user making the request, according to the access token.
Raises:
AuthError if the user is not a server admin
"""
- is_admin = await auth.is_server_admin(user_id)
+ is_admin = await auth.is_server_admin(requester)
if not is_admin:
raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 19d4a008e8..73470f09ae 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -54,7 +54,7 @@ class QuarantineMediaInRoom(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining room: %s", room_id)
@@ -81,7 +81,7 @@ class QuarantineMediaByUser(RestServlet):
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining media by user: %s", user_id)
@@ -110,7 +110,7 @@ class QuarantineMediaByID(RestServlet):
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining media by ID: %s/%s", server_name, media_id)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 9d953d58de..3d870629c4 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -75,7 +75,7 @@ class RoomRestV2Servlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_user_is_admin(self._auth, requester)
content = parse_json_object_from_request(request)
@@ -303,6 +303,7 @@ class RoomRestServlet(RestServlet):
members = await self.store.get_users_in_room(room_id)
ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+ ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
return HTTPStatus.OK, ret
@@ -326,7 +327,7 @@ class RoomRestServlet(RestServlet):
pagination_handler: "PaginationHandler",
) -> Tuple[int, JsonDict]:
requester = await auth.get_user_by_req(request)
- await assert_user_is_admin(auth, requester.user)
+ await assert_user_is_admin(auth, requester)
content = parse_json_object_from_request(request)
@@ -460,7 +461,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
assert request.args is not None
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
content = parse_json_object_from_request(request)
@@ -550,7 +551,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
content = parse_json_object_from_request(request, allow_empty_body=True)
room_id, _ = await self.resolve_room_id(room_identifier)
@@ -741,7 +742,7 @@ class RoomEventContextServlet(RestServlet):
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
limit = parse_integer(request, "limit", default=10)
@@ -833,7 +834,7 @@ class BlockRoomRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_user_is_admin(self._auth, requester)
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index ba2f7fa6d8..78ee9b6532 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -183,7 +183,7 @@ class UserRestServletV2(RestServlet):
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
target_user = UserID.from_string(user_id)
body = parse_json_object_from_request(request)
@@ -575,10 +575,9 @@ class WhoisRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
- auth_user = requester.user
- if target_user != auth_user:
- await assert_user_is_admin(self.auth, auth_user)
+ if target_user != requester.user:
+ await assert_user_is_admin(self.auth, requester)
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
@@ -601,7 +600,7 @@ class DeactivateAccountRestServlet(RestServlet):
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
if not self.is_mine(UserID.from_string(target_user_id)):
raise SynapseError(
@@ -693,7 +692,7 @@ class ResetPasswordRestServlet(RestServlet):
This needs user to have administrator access in Synapse.
"""
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
UserID.from_string(target_user_id)
@@ -807,7 +806,7 @@ class UserAdminServlet(RestServlet):
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
auth_user = requester.user
target_user = UserID.from_string(user_id)
@@ -921,7 +920,7 @@ class UserTokenRestServlet(RestServlet):
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
auth_user = requester.user
if not self.is_mine_id(user_id):
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 0cc87a4001..1f9a8ccc23 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -15,10 +15,11 @@
# limitations under the License.
import logging
import random
-from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
from urllib.parse import urlparse
+from pydantic import StrictBool, StrictStr, constr
+
from twisted.web.server import Request
from synapse.api.constants import LoginType
@@ -33,12 +34,15 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_and_validate_json_object_from_request,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
+from synapse.rest.client.models import AuthenticationData, EmailRequestTokenBody
+from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -80,32 +84,16 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
400, "Email-based password resets have been disabled on this server"
)
- body = parse_json_object_from_request(request)
-
- assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
-
- # Extract params from body
- client_secret = body["client_secret"]
- assert_valid_client_secret(client_secret)
-
- # Canonicalise the email address. The addresses are all stored canonicalised
- # in the database. This allows the user to reset his password without having to
- # know the exact spelling (eg. upper and lower case) of address in the database.
- # Stored in the database "foo@bar.com"
- # User requests with "FOO@bar.com" would raise a Not Found error
- try:
- email = validate_email(body["email"])
- except ValueError as e:
- raise SynapseError(400, str(e))
- send_attempt = body["send_attempt"]
- next_link = body.get("next_link") # Optional param
+ body = parse_and_validate_json_object_from_request(
+ request, EmailRequestTokenBody
+ )
- if next_link:
+ if body.next_link:
# Raise if the provided next_link value isn't valid
- assert_valid_next_link(self.hs, next_link)
+ assert_valid_next_link(self.hs, body.next_link)
await self.identity_handler.ratelimit_request_token_requests(
- request, "email", email
+ request, "email", body.email
)
# The email will be sent to the stored address.
@@ -113,7 +101,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# an email address which is controlled by the attacker but which, after
# canonicalisation, matches the one in our database.
existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
- "email", email
+ "email", body.email
)
if existing_user_id is None:
@@ -129,15 +117,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Send password reset emails from Synapse
sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
+ body.email,
+ body.client_secret,
+ body.send_attempt,
self.mailer.send_password_reset_mail,
- next_link,
+ body.next_link,
)
-
threepid_send_requests.labels(type="email", reason="password_reset").observe(
- send_attempt
+ body.send_attempt
)
# Wrap the session id in a JSON object
@@ -156,16 +143,23 @@ class PasswordRestServlet(RestServlet):
self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
+ class PostBody(RequestBodyModel):
+ auth: Optional[AuthenticationData] = None
+ logout_devices: StrictBool = True
+ if TYPE_CHECKING:
+ # workaround for https://github.com/samuelcolvin/pydantic/issues/156
+ new_password: Optional[StrictStr] = None
+ else:
+ new_password: Optional[constr(max_length=512, strict=True)] = None
+
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- body = parse_json_object_from_request(request)
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the new password provided to us.
- new_password = body.pop("new_password", None)
+ new_password = body.new_password
if new_password is not None:
- if not isinstance(new_password, str) or len(new_password) > 512:
- raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(new_password)
# there are two possibilities here. Either the user does not have an
@@ -185,7 +179,7 @@ class PasswordRestServlet(RestServlet):
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@@ -208,7 +202,7 @@ class PasswordRestServlet(RestServlet):
result, params, session_id = await self.auth_handler.check_ui_auth(
[[LoginType.EMAIL_IDENTITY]],
request,
- body,
+ body.dict(exclude_unset=True),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@@ -283,37 +277,33 @@ class DeactivateAccountRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ class PostBody(RequestBodyModel):
+ auth: Optional[AuthenticationData] = None
+ id_server: Optional[StrictStr] = None
+ # Not specced, see https://github.com/matrix-org/matrix-spec/issues/297
+ erase: StrictBool = False
+
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- body = parse_json_object_from_request(request)
- erase = body.get("erase", False)
- if not isinstance(erase, bool):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Param 'erase' must be a boolean, if given",
- Codes.BAD_JSON,
- )
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
requester = await self.auth.get_user_by_req(request)
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase, requester
+ requester.user.to_string(), body.erase, requester
)
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(),
- erase,
- requester,
- id_server=body.get("id_server"),
+ requester.user.to_string(), body.erase, requester, id_server=body.id_server
)
if result:
id_server_unbind_result = "success"
@@ -347,28 +337,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
"Adding emails have been disabled due to lack of an email config"
)
raise SynapseError(
- 400, "Adding an email to your account is disabled on this server"
+ 400,
+ "Adding an email to your account is disabled on this server",
)
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
- client_secret = body["client_secret"]
- assert_valid_client_secret(client_secret)
-
- # Canonicalise the email address. The addresses are all stored canonicalised
- # in the database.
- # This ensures that the validation email is sent to the canonicalised address
- # as it will later be entered into the database.
- # Otherwise the email will be sent to "FOO@bar.com" and stored as
- # "foo@bar.com" in database.
- try:
- email = validate_email(body["email"])
- except ValueError as e:
- raise SynapseError(400, str(e))
- send_attempt = body["send_attempt"]
- next_link = body.get("next_link") # Optional param
+ body = parse_and_validate_json_object_from_request(
+ request, EmailRequestTokenBody
+ )
- if not await check_3pid_allowed(self.hs, "email", email):
+ if not await check_3pid_allowed(self.hs, "email", body.email):
raise SynapseError(
403,
"Your email domain is not authorized on this server",
@@ -376,14 +353,14 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
await self.identity_handler.ratelimit_request_token_requests(
- request, "email", email
+ request, "email", body.email
)
- if next_link:
+ if body.next_link:
# Raise if the provided next_link value isn't valid
- assert_valid_next_link(self.hs, next_link)
+ assert_valid_next_link(self.hs, body.next_link)
- existing_user_id = await self.store.get_user_id_by_threepid("email", email)
+ existing_user_id = await self.store.get_user_id_by_threepid("email", body.email)
if existing_user_id is not None:
if self.config.server.request_token_inhibit_3pid_errors:
@@ -396,16 +373,17 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
+ # Send threepid validation emails from Synapse
sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
+ body.email,
+ body.client_secret,
+ body.send_attempt,
self.mailer.send_add_threepid_mail,
- next_link,
+ body.next_link,
)
threepid_send_requests.labels(type="email", reason="add_threepid").observe(
- send_attempt
+ body.send_attempt
)
# Wrap the session id in a JSON object
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 6fab102437..ed6ce78d47 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -42,12 +42,26 @@ class DevicesRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
+ self._msc3852_enabled = hs.config.experimental.msc3852_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
)
+
+ # If MSC3852 is disabled, then the "last_seen_user_agent" field will be
+ # removed from each device. If it is enabled, then the field name will
+ # be replaced by the unstable identifier.
+ #
+ # When MSC3852 is accepted, this block of code can just be removed to
+ # expose "last_seen_user_agent" to clients.
+ for device in devices:
+ last_seen_user_agent = device["last_seen_user_agent"]
+ del device["last_seen_user_agent"]
+ if self._msc3852_enabled:
+ device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
+
return 200, {"devices": devices}
@@ -108,6 +122,7 @@ class DeviceRestServlet(RestServlet):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
+ self._msc3852_enabled = hs.config.experimental.msc3852_enabled
async def on_GET(
self, request: SynapseRequest, device_id: str
@@ -118,6 +133,18 @@ class DeviceRestServlet(RestServlet):
)
if device is None:
raise NotFoundError("No device found")
+
+ # If MSC3852 is disabled, then the "last_seen_user_agent" field will be
+ # removed from each device. If it is enabled, then the field name will
+ # be replaced by the unstable identifier.
+ #
+ # When MSC3852 is accepted, this block of code can just be removed to
+ # expose "last_seen_user_agent" to clients.
+ last_seen_user_agent = device["last_seen_user_agent"]
+ del device["last_seen_user_agent"]
+ if self._msc3852_enabled:
+ device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
+
return 200, device
@interactive_auth_handler
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index eb1b85721f..a395694fa5 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
+from synapse.logging.opentracing import log_kv, set_tag
from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
@@ -71,7 +71,6 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
- @trace_with_opname("upload_keys")
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
@@ -208,7 +207,9 @@ class KeyChangesServlet(RestServlet):
# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
- set_tag("to", parse_string(request, "to"))
+ #
+ # XXX This does not enforce that "to" is passed.
+ set_tag("to", str(parse_string(request, "to")))
from_token = await StreamToken.from_string(self.store, from_token_string)
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
new file mode 100644
index 0000000000..3150602997
--- /dev/null
+++ b/synapse/rest/client/models.py
@@ -0,0 +1,69 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Dict, Optional
+
+from pydantic import Extra, StrictInt, StrictStr, constr, validator
+
+from synapse.rest.models import RequestBodyModel
+from synapse.util.threepids import validate_email
+
+
+class AuthenticationData(RequestBodyModel):
+ """
+ Data used during user-interactive authentication.
+
+ (The name "Authentication Data" is taken directly from the spec.)
+
+ Additional keys will be present, depending on the `type` field. Use `.dict()` to
+ access them.
+ """
+
+ class Config:
+ extra = Extra.allow
+
+ session: Optional[StrictStr] = None
+ type: Optional[StrictStr] = None
+
+
+class EmailRequestTokenBody(RequestBodyModel):
+ if TYPE_CHECKING:
+ client_secret: StrictStr
+ else:
+ # See also assert_valid_client_secret()
+ client_secret: constr(
+ regex="[0-9a-zA-Z.=_-]", # noqa: F722
+ min_length=0,
+ max_length=255,
+ strict=True,
+ )
+ email: StrictStr
+ id_server: Optional[StrictStr]
+ id_access_token: Optional[StrictStr]
+ next_link: Optional[StrictStr]
+ send_attempt: StrictInt
+
+ @validator("id_access_token", always=True)
+ def token_required_for_identity_server(
+ cls, token: Optional[str], values: Dict[str, object]
+ ) -> Optional[str]:
+ if values.get("id_server") is not None and token is None:
+ raise ValueError("id_access_token is required if an id_server is supplied.")
+ return token
+
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database. This allows the user to reset his password without having to
+ # know the exact spelling (eg. upper and lower case) of address in the database.
+ # Without this, an email stored in the database as "foo@bar.com" would cause
+ # user requests for "FOO@bar.com" to raise a Not Found error.
+ _email_validator = validator("email", allow_reuse=True)(validate_email)
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 24bc7c9095..a73322a6a4 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -58,7 +58,12 @@ class NotificationsServlet(RestServlet):
)
receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
- user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+ user_id,
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ],
)
notif_event_ids = [pa.event_id for pa in push_actions]
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index c16d707909..e69fa0829d 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -66,7 +66,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
- is_admin = await self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester)
content = parse_json_object_from_request(request)
@@ -123,7 +123,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
- is_admin = await self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester)
content = parse_json_object_from_request(request)
try:
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 8896f2df50..aaad8b233f 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -40,9 +40,13 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
+ self._known_receipt_types = {
+ ReceiptTypes.READ,
+ ReceiptTypes.FULLY_READ,
+ ReceiptTypes.READ_PRIVATE,
+ }
if hs.config.experimental.msc2285_enabled:
- self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE)
+ self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE)
async def on_POST(
self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 409bfd43c1..c6108fc5eb 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -44,11 +44,13 @@ class ReceiptRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- self._known_receipt_types = {ReceiptTypes.READ}
+ self._known_receipt_types = {
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.FULLY_READ,
+ }
if hs.config.experimental.msc2285_enabled:
- self._known_receipt_types.update(
- (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
- )
+ self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE)
async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index a8402cdb3a..20bab20c8f 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -32,7 +32,7 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
-from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
@@ -306,7 +306,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
hs.get_clock(),
- FederationRateLimitConfig(
+ FederationRatelimitSettings(
# Time window of 2s
window_size=2000,
# Artificially delay requests if rate > sleep_limit/window_size
@@ -465,9 +465,6 @@ class RegisterRestServlet(RestServlet):
"Appservice token must be provided when using a type of m.login.application_service",
)
- # Verify the AS
- self.auth.get_appservice_by_req(request)
-
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 2f513164cb..3259de4802 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,12 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
+from enum import Enum
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from urllib import parse as urlparse
+from prometheus_client.core import Histogram
+
from twisted.web.server import Request
from synapse import event_auth
@@ -46,6 +49,7 @@ from synapse.http.servlet import (
parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
@@ -61,6 +65,70 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class _RoomSize(Enum):
+ """
+ Enum to differentiate sizes of rooms. This is a pretty good approximation
+ about how hard it will be to get events in the room. We could also look at
+ room "complexity".
+ """
+
+ # This doesn't necessarily mean the room is a DM, just that there is a DM
+ # amount of people there.
+ DM_SIZE = "direct_message_size"
+ SMALL = "small"
+ SUBSTANTIAL = "substantial"
+ LARGE = "large"
+
+ @staticmethod
+ def from_member_count(member_count: int) -> "_RoomSize":
+ if member_count <= 2:
+ return _RoomSize.DM_SIZE
+ elif member_count < 100:
+ return _RoomSize.SMALL
+ elif member_count < 1000:
+ return _RoomSize.SUBSTANTIAL
+ else:
+ return _RoomSize.LARGE
+
+
+# This is an extra metric on top of `synapse_http_server_response_time_seconds`
+# which times the same sort of thing but this one allows us to see values
+# greater than 10s. We use a separate dedicated histogram with its own buckets
+# so that we don't increase the cardinality of the general one because it's
+# multiplied across hundreds of servlets.
+messsages_response_timer = Histogram(
+ "synapse_room_message_list_rest_servlet_response_time_seconds",
+ "sec",
+ # We have a label for room size so we can try to see a more realistic
+ # picture of /messages response time for bigger rooms. We don't want the
+ # tiny rooms that can always respond fast skewing our results when we're trying
+ # to optimize the bigger cases.
+ ["room_size"],
+ buckets=(
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.1,
+ 0.25,
+ 0.5,
+ 1.0,
+ 2.5,
+ 5.0,
+ 10.0,
+ 20.0,
+ 30.0,
+ 60.0,
+ 80.0,
+ 100.0,
+ 120.0,
+ 150.0,
+ 180.0,
+ "+Inf",
+ ),
+)
+
+
class TransactionRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -165,7 +233,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
msg_handler = self.message_handler
data = await msg_handler.get_room_data(
- user_id=requester.user.to_string(),
+ requester=requester,
room_id=room_id,
event_type=event_type,
state_key=state_key,
@@ -510,7 +578,7 @@ class RoomMemberListRestServlet(RestServlet):
events = await handler.get_state_events(
room_id=room_id,
- user_id=requester.user.to_string(),
+ requester=requester,
at_token=at_token,
state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
)
@@ -556,6 +624,7 @@ class RoomMessageListRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
+ self.clock = hs.get_clock()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -563,6 +632,18 @@ class RoomMessageListRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
+ processing_start_time = self.clock.time_msec()
+ # Fire off and hope that we get a result by the end.
+ #
+ # We're using the mypy type ignore comment because the `@cached`
+ # decorator on `get_number_joined_users_in_room` doesn't play well with
+ # the type system. Maybe in the future, it can use some ParamSpec
+ # wizardry to fix it up.
+ room_member_count_deferred = run_in_background( # type: ignore[call-arg]
+ self.store.get_number_joined_users_in_room,
+ room_id, # type: ignore[arg-type]
+ )
+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
@@ -593,6 +674,12 @@ class RoomMessageListRestServlet(RestServlet):
event_filter=event_filter,
)
+ processing_end_time = self.clock.time_msec()
+ room_member_count = await make_deferred_yieldable(room_member_count_deferred)
+ messsages_response_timer.labels(
+ room_size=_RoomSize.from_member_count(room_member_count)
+ ).observe((processing_end_time - processing_start_time) / 1000)
+
return 200, msgs
@@ -613,8 +700,7 @@ class RoomStateRestServlet(RestServlet):
# Get all the current state for this room
events = await self.message_handler.get_state_events(
room_id=room_id,
- user_id=requester.user.to_string(),
- is_guest=requester.is_guest,
+ requester=requester,
)
return 200, events
@@ -672,7 +758,7 @@ class RoomEventServlet(RestServlet):
== "true"
)
if include_unredacted_content and not await self.auth.is_server_admin(
- requester.user
+ requester
):
power_level_event = (
await self._storage_controllers.state.get_current_state_event(
@@ -1177,9 +1263,7 @@ class TimestampLookupRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
- await self._auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string()
- )
+ await self._auth.check_user_in_room_or_world_readable(room_id, requester)
timestamp = parse_integer(request, "ts", required=True)
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 1a8e9a96d4..46a8b03829 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -19,7 +19,7 @@ from synapse.http import servlet
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag, trace_with_opname
+from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict
@@ -43,7 +43,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
- @trace_with_opname("sendToDevice")
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index f4f06563dd..c9a830cbac 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -94,9 +94,10 @@ class VersionsRestServlet(RestServlet):
# Supports the busy presence state described in MSC3026.
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving private read receipts as per MSC2285
+ "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec
"org.matrix.msc2285": self.config.experimental.msc2285_enabled,
- # Supports filtering of /publicRooms by room type MSC3827
- "org.matrix.msc3827": self.config.experimental.msc3827_enabled,
+ # Supports filtering of /publicRooms by room type as per MSC3827
+ "org.matrix.msc3827.stable": True,
# Adds support for importing historical messages as per MSC2716
"org.matrix.msc2716": self.config.experimental.msc2716_enabled,
# Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
diff --git a/synapse/rest/models.py b/synapse/rest/models.py
new file mode 100644
index 0000000000..ac39cda8e5
--- /dev/null
+++ b/synapse/rest/models.py
@@ -0,0 +1,23 @@
+from pydantic import BaseModel, Extra
+
+
+class RequestBodyModel(BaseModel):
+ """A custom version of Pydantic's BaseModel which
+
+ - ignores unknown fields and
+ - does not allow fields to be overwritten after construction,
+
+ but otherwise uses Pydantic's default behaviour.
+
+ Ignoring unknown fields is a useful default. It means that clients can provide
+ unstable field not known to the server without the request being refused outright.
+
+ Subclassing in this way is recommended by
+ https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+ """
+
+ class Config:
+ # By default, ignore fields that we don't recognise.
+ extra = Extra.ignore
+ # By default, don't allow fields to be reassigned after parsing.
+ allow_mutation = False
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 8ecab86ec7..70d054a8f4 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -244,7 +244,7 @@ class ServerNoticesManager:
assert self.server_notices_mxid is not None
notice_user_data_in_room = await self._message_handler.get_room_data(
- self.server_notices_mxid,
+ create_requester(self.server_notices_mxid),
room_id,
EventTypes.Member,
self.server_notices_mxid,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index e3faa52cd6..3047e1b1ad 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
import logging
-from collections import defaultdict
+from collections import ChainMap, defaultdict
from typing import (
TYPE_CHECKING,
Any,
@@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
@@ -92,8 +91,11 @@ class _StateCacheEntry:
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
- if state is None and state_group is None:
- raise Exception("Either state or state group must be not None")
+ if state is None and state_group is None and prev_group is None:
+ raise Exception("One of state, state_group or prev_group must be not None")
+
+ if prev_group is not None and delta_ids is None:
+ raise Exception("If prev_group is set so must delta_ids")
# A map from (type, state_key) to event_id.
#
@@ -120,18 +122,48 @@ class _StateCacheEntry:
if self._state is not None:
return self._state
- assert self.state_group is not None
+ if self.state_group is not None:
+ return await state_storage.get_state_ids_for_group(
+ self.state_group, state_filter
+ )
+
+ assert self.prev_group is not None and self.delta_ids is not None
- return await state_storage.get_state_ids_for_group(
- self.state_group, state_filter
+ prev_state = await state_storage.get_state_ids_for_group(
+ self.prev_group, state_filter
)
+ # ChainMap expects MutableMapping, but since we're using it immutably
+ # its safe to give it immutable maps.
+ return ChainMap(self.delta_ids, prev_state) # type: ignore[arg-type]
+
+ def set_state_group(self, state_group: int) -> None:
+ """Update the state group assigned to this state (e.g. after we've
+ persisted it).
+
+ Note: this will cause the cache entry to drop any stored state.
+ """
+
+ self.state_group = state_group
+
+ # We clear out the state as we know longer need to explicitly keep it in
+ # the `state_cache` (as the store state group cache will do that).
+ self._state = None
+
def __len__(self) -> int:
- # The len should is used to estimate how large this cache entry is, for
- # cache eviction purposes. This is why if `self.state` is None it's fine
- # to return 1.
+ # The len should be used to estimate how large this cache entry is, for
+ # cache eviction purposes. This is why it's fine to return 1 if we're
+ # not storing any state.
+
+ length = 0
+
+ if self._state:
+ length += len(self._state)
- return len(self._state) if self._state else 1
+ if self.delta_ids:
+ length += len(self.delta_ids)
+
+ return length or 1 # Make sure its not 0.
class StateHandler:
@@ -177,11 +209,11 @@ class StateHandler:
ret = await self.resolve_state_groups_for_events(room_id, event_ids)
return await ret.get_state(self._state_storage_controller, state_filter)
- async def get_current_users_in_room(
+ async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
"""
- Get the users who are currently in a room.
+ Get the users IDs who are currently in a room.
Note: This is much slower than using the equivalent method
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -192,15 +224,15 @@ class StateHandler:
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
- Dictionary of user IDs to their profileinfo.
+ Set of user IDs in the room.
"""
assert latest_event_ids is not None
- logger.debug("calling resolve_state_groups from get_current_users_in_room")
+ logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
- return await self.store.get_joined_users_from_state(room_id, state, entry)
+ return await self.store.get_joined_user_ids_from_state(room_id, state, entry)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
@@ -222,7 +254,7 @@ class StateHandler:
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
- partial_state: bool = False,
+ partial_state: Optional[bool] = None,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@@ -237,10 +269,18 @@ class StateHandler:
it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling.
- partial_state: True if `state_ids_before_event` is partial and omits
- non-critical membership events
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
Returns:
The event context.
+
+ Raises:
+ RuntimeError if `state_ids_before_event` is not provided and one or more
+ prev events are missing or outliers.
"""
assert not event.internal_metadata.is_outlier()
@@ -265,12 +305,14 @@ class StateHandler:
)
)
+ # the partial_state flag must be provided
+ assert partial_state is not None
else:
# otherwise, we'll need to resolve the state across the prev_events.
# partial_state should not be set explicitly in this case:
# we work it out dynamically
- assert not partial_state
+ assert partial_state is None
# if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use
@@ -280,13 +322,13 @@ class StateHandler:
incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids
)
- if any(incomplete_prev_events.values()):
+ partial_state = any(incomplete_prev_events.values())
+ if partial_state:
logger.debug(
"New/incoming event %s refers to prev_events %s with partial state",
event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v],
)
- partial_state = True
logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
@@ -320,7 +362,7 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
)
- entry.state_group = state_group_before_event
+ entry.set_state_group(state_group_before_event)
else:
state_group_before_event = entry.state_group
@@ -393,6 +435,10 @@ class StateHandler:
Returns:
The resolved state
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie. they are outliers or unknown)
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -747,7 +793,7 @@ def _make_state_cache_entry(
old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids:
# got an exact match.
- return _StateCacheEntry(state=new_state, state_group=sg)
+ return _StateCacheEntry(state=None, state_group=sg)
# TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't
@@ -769,9 +815,14 @@ def _make_state_cache_entry(
prev_group = old_group
delta_ids = n_delta_ids
- return _StateCacheEntry(
- state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
- )
+ if prev_group is not None:
+ # If we have a prev group and deltas then we can drop the new state from
+ # the cache (to reduce memory usage).
+ return _StateCacheEntry(
+ state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
+ )
+ else:
+ return _StateCacheEntry(state=new_state, state_group=None)
@attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 7db032203b..cf3045f82e 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -434,7 +434,7 @@ async def _add_event_and_auth_chain_to_graph(
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore,
- auth_diff: Set[str],
+ full_conflicted_set: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
@@ -445,7 +445,7 @@ async def _add_event_and_auth_chain_to_graph(
event_id: Event to add to the graph
event_map
state_res_store
- auth_diff: Set of event IDs that are in the auth difference.
+ full_conflicted_set: Set of event IDs that are in the full conflicted set.
"""
state = [event_id]
@@ -455,7 +455,7 @@ async def _add_event_and_auth_chain_to_graph(
event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
- if aid in auth_diff:
+ if aid in full_conflicted_set:
if aid not in graph:
state.append(aid)
@@ -468,7 +468,7 @@ async def _reverse_topological_power_sort(
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore,
- auth_diff: Set[str],
+ full_conflicted_set: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
@@ -479,7 +479,7 @@ async def _reverse_topological_power_sort(
event_ids: The events to sort
event_map
state_res_store
- auth_diff: Set of event IDs that are in the auth difference.
+ full_conflicted_set: Set of event IDs that are in the full conflicted set.
Returns:
The sorted list
@@ -488,7 +488,7 @@ async def _reverse_topological_power_sort(
graph: Dict[str, Set[str]] = {}
for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph(
- graph, room_id, event_id, event_map, state_res_store, auth_diff
+ graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
)
# We await occasionally when we're working with large data sets to
diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html
index 9e6daf38ac..40510889ac 100644
--- a/synapse/static/client/login/index.html
+++ b/synapse/static/client/login/index.html
@@ -3,7 +3,8 @@
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<title> Login </title>
- <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="style.css">
<script src="js/jquery-3.4.1.min.js"></script>
<script src="js/login.js"></script>
diff --git a/synapse/static/client/register/index.html b/synapse/static/client/register/index.html
index 140653574d..27bbd76f51 100644
--- a/synapse/static/client/register/index.html
+++ b/synapse/static/client/register/index.html
@@ -2,7 +2,8 @@
<html>
<head>
<title> Registration </title>
-<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<meta http-equiv="X-UA-Compatible" content="IE=edge">
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="style.css">
<script src="js/jquery-3.4.1.min.js"></script>
<script src="https://www.recaptcha.net/recaptcha/api/js/recaptcha_ajax.js"></script>
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a2f8310388..e30f9c76d4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -80,6 +80,10 @@ class SQLBaseStore(metaclass=ABCMeta):
)
self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
+ # There's no easy way of invalidating this cache for just the users
+ # that have changed, so we just clear the entire thing.
+ self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
+
for user_id in members_changed:
self._attempt_to_invalidate_cache(
"get_user_in_room_with_profile", (room_id, user_id)
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index cf98b0ab48..dad3731b9b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,8 +45,14 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+ SynapseTags,
+ active_span,
+ set_tag,
+ start_active_span_follows_from,
+ trace,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
@@ -223,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue.append(end_item)
# also add our active opentracing span to the item so that we get a link back
- span = opentracing.active_span()
+ span = active_span()
if span:
end_item.parent_opentracing_span_contexts.append(span.context)
@@ -234,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
res = await make_deferred_yieldable(end_item.deferred.observe())
# add another opentracing span which links to the persist trace.
- with opentracing.start_active_span_follows_from(
+ with start_active_span_follows_from(
f"{task.name}_complete", (end_item.opentracing_span_context,)
):
pass
@@ -266,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
- with opentracing.start_active_span_follows_from(
+ with start_active_span_follows_from(
item.task.name,
item.parent_opentracing_span_contexts,
inherit_force_tracing=True,
@@ -355,7 +361,7 @@ class EventsPersistenceStorageController:
f"Found an unexpected task type in event persistence queue: {task}"
)
- @opentracing.trace
+ @trace
async def persist_events(
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -380,9 +386,21 @@ class EventsPersistenceStorageController:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
+ event_ids: List[str] = []
partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
+ event_ids.append(event.event_id)
+
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+ str(event_ids),
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+ str(len(event_ids)),
+ )
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
@@ -418,7 +436,7 @@ class EventsPersistenceStorageController:
self.main_store.get_room_max_token(),
)
- @opentracing.trace
+ @trace
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index e08f956e6e..f9ffd0e29e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -29,6 +29,8 @@ from typing import (
from synapse.api.constants import EventTypes
from synapse.events import EventBase
+from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
@@ -82,13 +84,15 @@ class StateStorageController:
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
- self, _room_id: str, event_ids: Collection[str]
+ self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
+ await_full_state: if `True`, will block if we do not yet have complete
+ state at these events.
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -100,7 +104,9 @@ class StateStorageController:
if not event_ids:
return {}
- event_to_groups = await self.get_state_group_for_events(event_ids)
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -175,6 +181,7 @@ class StateStorageController:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
+ @trace
async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
@@ -221,10 +228,13 @@ class StateStorageController:
return {event: event_to_state[event] for event in event_ids}
+ @trace
+ @tag_args
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -233,6 +243,9 @@ class StateStorageController:
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at these events and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from event_id -> (type, state_key) -> event_id
@@ -241,8 +254,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ if (
+ await_full_state
+ and state_filter
+ and not state_filter.must_await_full_state(self._is_mine_id)
+ ):
+ # Full state is not required if the state filter is restrictive enough.
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
@@ -283,8 +300,12 @@ class StateStorageController:
)
return state_map[event_id]
+ @trace
async def get_state_ids_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
@@ -292,6 +313,9 @@ class StateStorageController:
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from (type, state_key) -> state_event_id
@@ -301,7 +325,9 @@ class StateStorageController:
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
- [event_id], state_filter or StateFilter.all()
+ [event_id],
+ state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
return state_map[event_id]
@@ -323,6 +349,8 @@ class StateStorageController:
groups, state_filter or StateFilter.all()
)
+ @trace
+ @tag_args
async def get_state_group_for_events(
self,
event_ids: Collection[str],
@@ -334,6 +362,10 @@ class StateStorageController:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie. they are outliers or unknown)
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
@@ -460,6 +492,7 @@ class StateStorageController:
prev_stream_id, max_stream_id
)
+ @trace
async def get_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
@@ -493,3 +526,15 @@ class StateStorageController:
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
+
+ async def get_users_in_room_with_profiles(
+ self, room_id: str
+ ) -> Dict[str, ProfileInfo]:
+ """
+ Get the current users in the room with their profiles.
+ If the room is currently partial-stated, this will block until the room has
+ full state.
+ """
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_users_in_room_with_profiles(room_id)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ea672ff89e..b394a6658b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -39,7 +39,7 @@ from typing import (
)
import attr
-from prometheus_client import Histogram
+from prometheus_client import Counter, Histogram
from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.enterprise import adbapi
@@ -76,7 +76,8 @@ perf_logger = logging.getLogger("synapse.storage.TIME")
sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+sql_txn_count = Counter("synapse_storage_transaction_time_count", "sec", ["desc"])
+sql_txn_duration = Counter("synapse_storage_transaction_time_sum", "sec", ["desc"])
# Unique indexes which have been added in background updates. Maps from table name
@@ -795,7 +796,8 @@ class DatabasePool:
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, duration)
- sql_txn_timer.labels(desc).observe(duration)
+ sql_txn_count.labels(desc).inc(1)
+ sql_txn_duration.labels(desc).inc(duration)
async def runInteraction(
self,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a3d31d3737..4dccbb732a 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -24,9 +24,9 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -149,31 +149,6 @@ class DataStore(
],
)
- self._cache_id_gen: Optional[MultiWriterIdGenerator]
- if isinstance(self.database_engine, PostgresEngine):
- # We set the `writers` to an empty list here as we don't care about
- # missing updates over restarts, as we'll not have anything in our
- # caches to invalidate. (This reduces the amount of writes to the DB
- # that happen).
- self._cache_id_gen = MultiWriterIdGenerator(
- db_conn,
- database,
- stream_name="caches",
- instance_name=hs.get_instance_name(),
- tables=[
- (
- "cache_invalidation_stream_by_instance",
- "instance_name",
- "stream_id",
- )
- ],
- sequence_name="cache_invalidation_stream_seq",
- writers=[],
- )
-
- else:
- self._cache_id_gen = None
-
super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn, self.get_account_data_for_room, (user_id,)
)
self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
- self._invalidate_cache_and_stream(
- txn, self.get_push_rules_enabled_for_user, (user_id,)
- )
# This user might be contained in the ignored_by cache for other users,
# so we have to invalidate it all.
self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2367ddeea3..12e9a42382 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -32,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
@@ -65,6 +66,31 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
psql_only=True, # The table is only on postgres DBs.
)
+ self._cache_id_gen: Optional[MultiWriterIdGenerator]
+ if isinstance(self.database_engine, PostgresEngine):
+ # We set the `writers` to an empty list here as we don't care about
+ # missing updates over restarts, as we'll not have anything in our
+ # caches to invalidate. (This reduces the amount of writes to the DB
+ # that happen).
+ self._cache_id_gen = MultiWriterIdGenerator(
+ db_conn,
+ database,
+ stream_name="caches",
+ instance_name=hs.get_instance_name(),
+ tables=[
+ (
+ "cache_invalidation_stream_by_instance",
+ "instance_name",
+ "stream_id",
+ )
+ ],
+ sequence_name="cache_invalidation_stream_seq",
+ writers=[],
+ )
+
+ else:
+ self._cache_id_gen = None
+
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 422e0e65ca..73c95ffb6f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
(user_id, device_id), None
)
- set_tag("last_deleted_stream_id", last_deleted_stream_id)
+ set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 7a6ed332aa..ca0fe8c4be 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -706,8 +706,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)
- set_tag("in_cache", results)
- set_tag("not_in_cache", user_ids_not_in_cache)
+ set_tag("in_cache", str(results))
+ set_tag("not_in_cache", str(user_ids_not_in_cache))
return user_ids_not_in_cache, results
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 60f622ad71..46c0d06157 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
- set_tag("query_list", query_list)
+ set_tag("query_list", str(query_list))
if not query_list:
return {}
@@ -418,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
- set_tag("new_keys", new_keys)
+ set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -1161,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
- set_tag("device_keys", device_keys)
+ set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..c836078da6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.opentracing import tag_args, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(event_ids)
+ @trace
+ @tag_args
async def get_auth_chain_ids(
self,
room_id: str,
@@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
+ @trace
+ @tag_args
async def get_oldest_event_ids_with_depth_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
+ @trace
async def get_insertion_event_backward_extremities_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_results.reverse()
return event_results
+ @trace
+ @tag_args
async def get_successor_events(self, event_id: str) -> List[str]:
"""Fetch all events that have the given event as a prev event
@@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_delete_old_forward_extrem_cache_txn,
)
+ @trace
async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
await self.db_pool.simple_upsert(
table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index dd2627037c..eabf9c9739 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -12,14 +12,85 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Responsible for storing and fetching push actions / notifications.
+
+There are two main uses for push actions:
+ 1. Sending out push to a user's device; and
+ 2. Tracking per-room per-user notification counts (used in sync requests).
+
+For the former we simply use the `event_push_actions` table, which contains all
+the calculated actions for a given user (which were calculated by the
+`BulkPushRuleEvaluator`).
+
+For the latter we could simply count the number of rows in `event_push_actions`
+table for a given room/user, but in practice this is *very* heavyweight when
+there were a large number of notifications (due to e.g. the user never reading a
+room). Plus, keeping all push actions indefinitely uses a lot of disk space.
+
+To fix these issues, we add a new table `event_push_summary` that tracks
+per-user per-room counts of all notifications that happened before a stream
+ordering S. Thus, to get the notification count for a user / room we can simply
+query a single row in `event_push_summary` and count the number of rows in
+`event_push_actions` with a stream ordering larger than S (and as long as S is
+"recent", the number of rows needing to be scanned will be small).
+
+The `event_push_summary` table is updated via a background job that periodically
+chooses a new stream ordering S' (usually the latest stream ordering), counts
+all notifications in `event_push_actions` between the existing S and S', and
+adds them to the existing counts in `event_push_summary`.
+
+This allows us to delete old rows from `event_push_actions` once those rows have
+been counted and added to `event_push_summary` (we call this process
+"rotation").
+
+
+We need to handle when a user sends a read receipt to the room. Again this is
+done as a background process. For each receipt we clear the row in
+`event_push_summary` and count the number of notifications in
+`event_push_actions` that happened after the receipt but before S, and insert
+that count into `event_push_summary` (If the receipt happened *after* S then we
+simply clear the `event_push_summary`.)
+
+Note that its possible that if the read receipt is for an old event the relevant
+`event_push_actions` rows will have been rotated and we get the wrong count
+(it'll be too low). We accept this as a rare edge case that is unlikely to
+impact the user much (since the vast majority of read receipts will be for the
+latest event).
+
+The last complication is to handle the race where we request the notifications
+counts after a user sends a read receipt into the room, but *before* the
+background update handles the receipt (without any special handling the counts
+would be outdated). We fix this by including in `event_push_summary` the read
+receipt we used when updating `event_push_summary`, and every time we query the
+table we check if that matches the most recent read receipt in the room. If yes,
+continue as above, if not we simply query the `event_push_actions` table
+directly.
+
+Since read receipts are almost always for recent events, scanning the
+`event_push_actions` table in this case is unlikely to be a problem. Even if it
+is a problem, it is temporary until the background job handles the new read
+receipt.
+"""
+
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
from synapse.api.constants import ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -93,7 +164,9 @@ class NotifCounts:
highlight_count: int = 0
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+ actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
@@ -166,7 +239,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count
- for a given user in a given room after the given read receipt.
+ for a given user in a given room after their latest read receipt.
Note that this function assumes the user to be a current member of the room,
since it's either called by the sync handler to handle joined room entries, or by
@@ -177,9 +250,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for.
Returns
- A dict containing the counts mentioned earlier in this docstring,
- respectively under the keys "notify_count", "highlight_count" and
- "unread_count".
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -194,20 +266,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
room_id: str,
user_id: str,
) -> NotifCounts:
+ # Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_receipt_for_user_txn(
txn,
user_id,
room_id,
- receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+ receipt_types=(
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
)
- stream_ordering = None
if result:
_, stream_ordering = result
- if stream_ordering is None:
- # Either last_read_event_id is None, or it's an event we don't have (e.g.
- # because it's been purged), in which case retrieve the stream ordering for
+ else:
+ # If the user has no receipts in the room, retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -224,10 +299,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
def _get_unread_counts_by_pos_txn(
- self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ receipt_stream_ordering: int,
) -> NotifCounts:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
+
+ Args:
+ txn: The database transaction.
+ room_id: The room ID to get unread counts for.
+ user_id: The user ID to get unread counts for.
+ receipt_stream_ordering: The stream ordering of the user's latest
+ receipt in the room. If there are no receipts, the stream ordering
+ of the user's join event.
+
+ Returns
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
counts = NotifCounts()
@@ -255,7 +346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
OR last_receipt_stream_ordering = ?
)
""",
- (room_id, user_id, stream_ordering, stream_ordering),
+ (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
)
row = txn.fetchone()
@@ -265,7 +356,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
counts.notify_count += row[1]
counts.unread_count += row[2]
- # Next we need to count highlights, which aren't summarized
+ # Next we need to count highlights, which aren't summarised
sql = """
SELECT COUNT(*) FROM event_push_actions
WHERE user_id = ?
@@ -273,17 +364,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND stream_ordering > ?
AND highlight = 1
"""
- txn.execute(sql, (user_id, room_id, stream_ordering))
+ txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
row = txn.fetchone()
if row:
counts.highlight_count += row[0]
# Finally we need to count push actions that aren't included in the
- # summary returned above, e.g. recent events that haven't been
- # summarized yet, or the summary is empty due to a recent read receipt.
- stream_ordering = max(stream_ordering, summary_stream_ordering)
+ # summary returned above. This might be due to recent events that haven't
+ # been summarised yet or the summary is out of date due to a recent read
+ # receipt.
+ start_unread_stream_ordering = max(
+ receipt_stream_ordering, summary_stream_ordering
+ )
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
- txn, room_id, user_id, stream_ordering
+ txn, room_id, user_id, start_unread_stream_ordering
)
counts.notify_count += notify_count
@@ -304,6 +398,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
Does not consult `event_push_summary` table, which may include push
actions that have been deleted from `event_push_actions` table.
+
+ Args:
+ txn: The database transaction.
+ room_id: The room ID to get unread counts for.
+ user_id: The user ID to get unread counts for.
+ stream_ordering: The (exclusive) minimum stream ordering to consider.
+ max_stream_ordering: The (inclusive) maximum stream ordering to consider.
+ If this is not given, then no maximum is applied.
+
+ Return:
+ A tuple of the notif count and unread count in the given range.
"""
# If there have been no events in the room since the stream ordering,
@@ -376,6 +481,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries.
"""
+
# find rooms that have a read receipt in them and return the next
# push actions
def get_after_receipt(
@@ -383,28 +489,41 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) -> List[Tuple[str, str, int, str, bool]]:
# find rooms that have a read receipt in them and return the next
# push actions
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight "
- " FROM ("
- " SELECT room_id,"
- " MAX(stream_ordering) as stream_ordering"
- " FROM events"
- " INNER JOIN receipts_linearized USING (room_id, event_id)"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- ") AS rl,"
- " event_push_actions AS ep"
- " WHERE"
- " ep.room_id = rl.room_id"
- " AND ep.stream_ordering > rl.stream_ordering"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering ASC LIMIT ?"
+
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight
+ FROM (
+ SELECT room_id,
+ MAX(stream_ordering) as stream_ordering
+ FROM events
+ INNER JOIN receipts_linearized USING (room_id, event_id)
+ WHERE {receipt_types_clause} AND user_id = ?
+ GROUP BY room_id
+ ) AS rl,
+ event_push_actions AS ep
+ WHERE
+ ep.room_id = rl.room_id
+ AND ep.stream_ordering > rl.stream_ordering
+ AND ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering ASC LIMIT ?
+ """
+ args.extend(
+ (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
@@ -418,24 +537,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_no_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight "
- " FROM event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id NOT IN ("
- " SELECT room_id FROM receipts_linearized"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- " )"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering ASC LIMIT ?"
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight
+ FROM event_push_actions AS ep
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ ep.room_id NOT IN (
+ SELECT room_id FROM receipts_linearized
+ WHERE {receipt_types_clause} AND user_id = ?
+ GROUP BY room_id
+ )
+ AND ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering ASC LIMIT ?
+ """
+ args.extend(
+ (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
@@ -485,34 +616,47 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will be ordered by descending received_ts.
The list will have between 0~limit entries.
"""
+
# find rooms that have a read receipt in them and return the most recent
# push actions
def get_after_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight, e.received_ts"
- " FROM ("
- " SELECT room_id,"
- " MAX(stream_ordering) as stream_ordering"
- " FROM events"
- " INNER JOIN receipts_linearized USING (room_id, event_id)"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- ") AS rl,"
- " event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id = rl.room_id"
- " AND ep.stream_ordering > rl.stream_ordering"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering DESC LIMIT ?"
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight, e.received_ts
+ FROM (
+ SELECT room_id,
+ MAX(stream_ordering) as stream_ordering
+ FROM events
+ INNER JOIN receipts_linearized USING (room_id, event_id)
+ WHERE {receipt_types_clause} AND user_id = ?
+ GROUP BY room_id
+ ) AS rl,
+ event_push_actions AS ep
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ ep.room_id = rl.room_id
+ AND ep.stream_ordering > rl.stream_ordering
+ AND ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering DESC LIMIT ?
+ """
+ args.extend(
+ (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
@@ -526,24 +670,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_no_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight, e.received_ts"
- " FROM event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id NOT IN ("
- " SELECT room_id FROM receipts_linearized"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- " )"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering DESC LIMIT ?"
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight, e.received_ts
+ FROM event_push_actions AS ep
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ ep.room_id NOT IN (
+ SELECT room_id FROM receipts_linearized
+ WHERE {receipt_types_clause} AND user_id = ?
+ GROUP BY room_id
+ )
+ AND ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering DESC LIMIT ?
+ """
+ args.extend(
+ (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
@@ -606,7 +762,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
async def add_push_actions_to_staging(
self,
event_id: str,
- user_id_actions: Dict[str, List[Union[dict, str]]],
+ user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
count_as_unread: bool,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -623,7 +779,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
- user_id: str, actions: List[Union[dict, str]]
+ user_id: str, actions: Collection[Union[Mapping, str]]
) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
@@ -769,12 +925,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# [10, <none>, 20], we should treat this as being equivalent to
# [10, 10, 20].
#
- sql = (
- "SELECT received_ts FROM events"
- " WHERE stream_ordering <= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT 1"
- )
+ sql = """
+ SELECT received_ts FROM events
+ WHERE stream_ordering <= ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
while range_end - range_start > 0:
middle = (range_end + range_start) // 2
@@ -802,14 +958,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self, stream_ordering: int
) -> Optional[int]:
def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
- sql = (
- "SELECT e.received_ts"
- " FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.stream_ordering > ? AND notif = 1"
- " ORDER BY ep.stream_ordering ASC"
- " LIMIT 1"
- )
+ sql = """
+ SELECT e.received_ts
+ FROM event_push_actions AS ep
+ JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id
+ WHERE ep.stream_ordering > ? AND notif = 1
+ ORDER BY ep.stream_ordering ASC
+ LIMIT 1
+ """
txn.execute(sql, (stream_ordering,))
return cast(Optional[Tuple[int]], txn.fetchone())
@@ -858,10 +1014,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
Any push actions which predate the user's most recent read receipt are
now redundant, so we can remove them from `event_push_actions` and
update `event_push_summary`.
+
+ Returns true if all new receipts have been processed.
"""
limit = 100
+ # The (inclusive) receipt stream ID that was previously processed..
min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_last_receipt_stream_id",
@@ -871,6 +1030,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
max_receipts_stream_id = self._receipts_id_gen.get_current_token()
+ # The (inclusive) event stream ordering that was previously summarised.
+ old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
+ )
+
sql = """
SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
FROM receipts_linearized AS r
@@ -895,13 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
rows = txn.fetchall()
- old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_push_summary_stream_ordering",
- keyvalues={},
- retcol="stream_ordering",
- )
-
# For each new read receipt we delete push actions from before it and
# recalculate the summary.
for _, room_id, user_id, stream_ordering in rows:
@@ -920,10 +1080,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(room_id, user_id, stream_ordering),
)
+ # Fetch the notification counts between the stream ordering of the
+ # latest receipt and what was previously summarised.
notif_count, unread_count = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
)
+ # Replace the previous summary with the new counts.
self.db_pool.simple_upsert_txn(
txn,
table="event_push_summary",
@@ -956,10 +1119,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return len(rows) < limit
def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
- """Archives older notifications into event_push_summary. Returns whether
- the archiving process has caught up or not.
+ """Archives older notifications (from event_push_actions) into event_push_summary.
+
+ Returns whether the archiving process has caught up or not.
"""
+ # The (inclusive) event stream ordering that was previously summarised.
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
@@ -974,7 +1139,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
- """,
+ """,
(old_rotate_stream_ordering, self._rotate_count),
)
stream_row = txn.fetchone()
@@ -993,19 +1158,29 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
- self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+ self._rotate_notifs_before_txn(
+ txn, old_rotate_stream_ordering, rotate_to_stream_ordering
+ )
return caught_up
def _rotate_notifs_before_txn(
- self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+ self,
+ txn: LoggingTransaction,
+ old_rotate_stream_ordering: int,
+ rotate_to_stream_ordering: int,
) -> None:
- old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_push_summary_stream_ordering",
- keyvalues={},
- retcol="stream_ordering",
- )
+ """Archives older notifications (from event_push_actions) into event_push_summary.
+
+ Any event_push_actions between old_rotate_stream_ordering (exclusive) and
+ rotate_to_stream_ordering (inclusive) will be added to the event_push_summary
+ table.
+
+ Args:
+ txn: The database transaction.
+ old_rotate_stream_ordering: The previous maximum event stream ordering.
+ rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
+ """
# Calculate the new counts that should be upserted into event_push_summary
sql = """
@@ -1090,12 +1265,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(rotate_to_stream_ordering,),
)
- async def _remove_old_push_actions_that_have_rotated(
- self,
- ) -> None:
- """Clear out old push actions that have been summarized."""
+ async def _remove_old_push_actions_that_have_rotated(self) -> None:
+ """Clear out old push actions that have been summarised."""
- # We want to clear out anything that older than a day that *has* already
+ # We want to clear out anything that is older than a day that *has* already
# been rotated.
rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
table="event_push_summary_stream_ordering",
@@ -1119,7 +1292,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering <= ? AND highlight = 0
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
- """,
+ """,
(
max_stream_ordering_to_delete,
batch_size,
@@ -1215,16 +1388,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# NB. This assumes event_ids are globally unique since
# it makes the query easier to index
- sql = (
- "SELECT epa.event_id, epa.room_id,"
- " epa.stream_ordering, epa.topological_ordering,"
- " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
- " FROM event_push_actions epa, events e"
- " WHERE epa.event_id = e.event_id"
- " AND epa.user_id = ? %s"
- " AND epa.notif = 1"
- " ORDER BY epa.stream_ordering DESC"
- " LIMIT ?" % (before_clause,)
+ sql = """
+ SELECT epa.event_id, epa.room_id,
+ epa.stream_ordering, epa.topological_ordering,
+ epa.actions, epa.highlight, epa.profile_tag, e.received_ts
+ FROM event_push_actions epa, events e
+ WHERE epa.event_id = e.event_id
+ AND epa.user_id = ? %s
+ AND epa.notif = 1
+ ORDER BY epa.stream_ordering DESC
+ LIMIT ?
+ """ % (
+ before_clause,
)
txn.execute(sql, args)
return cast(
@@ -1247,7 +1422,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
]
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
for action in actions:
if not isinstance(action, dict):
continue
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f600f1190..a4010ee28d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+ @trace
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -1490,7 +1492,7 @@ class PersistEventsStore:
event.sender,
"url" in event.content and isinstance(event.content["url"], str),
event.get_state_key(),
- context.rejected or None,
+ context.rejected,
)
for event, context in events_and_contexts
),
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5914a35420..8a7cdb024d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
current_context,
make_deferred_yieldable,
)
+from synapse.logging.opentracing import start_active_span, tag_args, trace
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
+ @trace
+ @tag_args
async def get_events_as_list(
self,
event_ids: Collection[str],
@@ -600,7 +603,11 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
map from event id to result
"""
- event_entry_map = await self._get_events_from_cache(
+ # Shortcut: check if we have any events in the *in memory* cache - this function
+ # may be called repeatedly for the same event so at this point we cannot reach
+ # out to any external cache for performance reasons. The external cache is
+ # checked later on in the `get_missing_events_from_cache_or_db` function below.
+ event_entry_map = self._get_events_from_local_cache(
event_ids,
)
@@ -632,7 +639,9 @@ class EventsWorkerStore(SQLBaseStore):
if missing_events_ids:
- async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+ async def get_missing_events_from_cache_or_db() -> Dict[
+ str, EventCacheEntry
+ ]:
"""Fetches the events in `missing_event_ids` from the database.
Also creates entries in `self._current_event_fetches` to allow
@@ -657,10 +666,18 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event
# out of the database to check it.
#
+ missing_events = {}
try:
- missing_events = await self._get_events_from_db(
+ # Try to fetch from any external cache. We already checked the
+ # in-memory cache above.
+ missing_events = await self._get_events_from_external_cache(
missing_events_ids,
)
+ # Now actually fetch any remaining events from the DB
+ db_missing_events = await self._get_events_from_db(
+ missing_events_ids - missing_events.keys(),
+ )
+ missing_events.update(db_missing_events)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
@@ -679,7 +696,7 @@ class EventsWorkerStore(SQLBaseStore):
# cancellations, since multiple `_get_events_from_cache_or_db` calls can
# reuse the same fetch.
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
- get_missing_events_from_db()
+ get_missing_events_from_cache_or_db()
)
event_entry_map.update(missing_events)
@@ -754,7 +771,54 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]:
- """Fetch events from the caches.
+ """Fetch events from the caches, both in memory and any external.
+
+ May return rejected events.
+
+ Args:
+ events: list of event_ids to fetch
+ update_metrics: Whether to update the cache hit ratio metrics
+ """
+ event_map = self._get_events_from_local_cache(
+ events, update_metrics=update_metrics
+ )
+
+ missing_event_ids = (e for e in events if e not in event_map)
+ event_map.update(
+ await self._get_events_from_external_cache(
+ events=missing_event_ids,
+ update_metrics=update_metrics,
+ )
+ )
+
+ return event_map
+
+ async def _get_events_from_external_cache(
+ self, events: Iterable[str], update_metrics: bool = True
+ ) -> Dict[str, EventCacheEntry]:
+ """Fetch events from any configured external cache.
+
+ May return rejected events.
+
+ Args:
+ events: list of event_ids to fetch
+ update_metrics: Whether to update the cache hit ratio metrics
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = await self._get_event_cache.get_external(
+ (event_id,), None, update_metrics=update_metrics
+ )
+ if ret:
+ event_map[event_id] = ret
+
+ return event_map
+
+ def _get_events_from_local_cache(
+ self, events: Iterable[str], update_metrics: bool = True
+ ) -> Dict[str, EventCacheEntry]:
+ """Fetch events from the local, in memory, caches.
May return rejected events.
@@ -766,7 +830,7 @@ class EventsWorkerStore(SQLBaseStore):
for event_id in events:
# First check if it's in the event cache
- ret = await self._get_event_cache.get(
+ ret = self._get_event_cache.get_local(
(event_id,), None, update_metrics=update_metrics
)
if ret:
@@ -788,7 +852,7 @@ class EventsWorkerStore(SQLBaseStore):
# We add the entry back into the cache as we want to keep
# recently queried events in the cache.
- await self._get_event_cache.set((event_id,), cache_entry)
+ self._get_event_cache.set_local((event_id,), cache_entry)
return event_map
@@ -1029,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
"""
fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}
- events_to_fetch = event_ids
- while events_to_fetch:
- row_map = await self._enqueue_events(events_to_fetch)
+ async def _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch: Collection[str],
+ ) -> Collection[str]:
+ """
+ Fetch all of the given event_ids and return any associated redaction event_ids
+ that we still need to fetch in the next iteration.
+ """
+ row_map = await self._enqueue_events(event_ids_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids: Set[str] = set()
- for event_id in events_to_fetch:
+ for event_id in event_ids_to_fetch:
row = row_map.get(event_id)
fetched_event_ids.add(event_id)
if row:
fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_event_ids)
- if events_to_fetch:
- logger.debug("Also fetching redaction events %s", events_to_fetch)
+ event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+ return event_ids_to_fetch
+
+ # Grab the initial list of events requested
+ event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids
+ )
+ # Then go and recursively find all of the associated redactions
+ with start_active_span("recursively fetching redactions"):
+ while event_ids_to_fetch:
+ logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+ event_ids_to_fetch = (
+ await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch
+ )
+ )
# build a map from event_id to EventBase
event_map: Dict[str, EventBase] = {}
@@ -1363,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
+ @trace
+ @tag_args
async def have_seen_events(
self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
@@ -2110,14 +2195,92 @@ class EventsWorkerStore(SQLBaseStore):
def _get_partial_state_events_batch_txn(
txn: LoggingTransaction, room_id: str
) -> List[str]:
+ # we want to work through the events from oldest to newest, so
+ # we only want events whose prev_events do *not* have partial state - hence
+ # the 'NOT EXISTS' clause in the below.
+ #
+ # This is necessary because ordering by stream ordering isn't quite enough
+ # to ensure that we work from oldest to newest event (in particular,
+ # if an event is initially persisted as an outlier and later de-outliered,
+ # it can end up with a lower stream_ordering than its prev_events).
+ #
+ # Typically this means we'll only return one event per batch, but that's
+ # hard to do much about.
+ #
+ # See also: https://github.com/matrix-org/synapse/issues/13001
txn.execute(
"""
SELECT event_id FROM partial_state_events AS pse
JOIN events USING (event_id)
- WHERE pse.room_id = ?
+ WHERE pse.room_id = ? AND
+ NOT EXISTS(
+ SELECT 1 FROM event_edges AS ee
+ JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
+ WHERE ee.event_id=pse.event_id
+ )
ORDER BY events.stream_ordering
LIMIT 100
""",
(room_id,),
)
return [row[0] for row in txn]
+
+ def mark_event_rejected_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ rejection_reason: Optional[str],
+ ) -> None:
+ """Mark an event that was previously accepted as rejected, or vice versa
+
+ This can happen, for example, when resyncing state during a faster join.
+
+ Args:
+ txn:
+ event_id: ID of event to update
+ rejection_reason: reason it has been rejected, or None if it is now accepted
+ """
+ if rejection_reason is None:
+ logger.info(
+ "Marking previously-processed event %s as accepted",
+ event_id,
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ "rejections",
+ keyvalues={"event_id": event_id},
+ )
+ else:
+ logger.info(
+ "Marking previously-processed event %s as rejected(%s)",
+ event_id,
+ rejection_reason,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="rejections",
+ keyvalues={"event_id": event_id},
+ values={
+ "reason": rejection_reason,
+ "last_check": self._clock.time_msec(),
+ },
+ )
+ self.db_pool.simple_update_txn(
+ txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ updatevalues={"rejection_reason": rejection_reason},
+ )
+
+ self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+ # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+ # call '_send_invalidation_to_replication', but we actually need the other
+ # end to call _invalidate_local_get_event_cache() rather than (just)
+ # _get_event_cache.invalidate().
+ #
+ # One solution might be to (somehow) get the workers to call
+ # _invalidate_caches_for_event() (though that will invalidate more than
+ # strictly necessary).
+ #
+ # https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 768f95d16c..5079edd1e0 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
# limitations under the License.
import abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+)
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -50,60 +62,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _is_experimental_rule_enabled(
- rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
- """Used by `_load_rules` to filter out experimental rules when they
- have not been enabled.
- """
- if (
- rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
- and not experimental_config.msc3786_enabled
- ):
- return False
- if (
- rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
- and not experimental_config.msc3772_enabled
- ):
- return False
- return True
-
-
def _load_rules(
rawrules: List[JsonDict],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
- ruleslist = []
- for rawrule in rawrules:
- rule = dict(rawrule)
- rule["conditions"] = db_to_json(rawrule["conditions"])
- rule["actions"] = db_to_json(rawrule["actions"])
- rule["default"] = False
- ruleslist.append(rule)
-
- # We're going to be mutating this a lot, so copy it. We also filter out
- # any experimental default push rules that aren't enabled.
- rules = [
- rule
- for rule in list_with_base_rules(ruleslist)
- if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
- ]
+) -> FilteredPushRules:
+ """Take the DB rows returned from the DB and convert them into a full
+ `FilteredPushRules` object.
+ """
- for i, rule in enumerate(rules):
- rule_id = rule["rule_id"]
+ ruleslist = [
+ PushRule(
+ rule_id=rawrule["rule_id"],
+ priority_class=rawrule["priority_class"],
+ conditions=db_to_json(rawrule["conditions"]),
+ actions=db_to_json(rawrule["actions"]),
+ )
+ for rawrule in rawrules
+ ]
- if rule_id not in enabled_map:
- continue
- if rule.get("enabled", True) == bool(enabled_map[rule_id]):
- continue
+ push_rules = compile_push_rules(ruleslist)
- # Rules are cached across users.
- rule = dict(rule)
- rule["enabled"] = bool(enabled_map[rule_id])
- rules[i] = rule
+ filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
- return rules
+ return filtered_rules
# The ABCMeta metaclass ensures that it cannot be instantiated without
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
raise NotImplementedError()
@cached(max_entries=5000)
- async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+ async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -183,7 +165,6 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, self.hs.config.experimental)
- @cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
@@ -216,11 +197,11 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules(
self, user_ids: Collection[str]
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, FilteredPushRules]:
if not user_ids:
return {}
- results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -234,20 +215,19 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
- results.setdefault(row["user_name"], []).append(row)
+ raw_rules.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
- for user_id, rules in results.items():
+ results: Dict[str, FilteredPushRules] = {}
+
+ for user_id, rules in raw_rules.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
)
return results
- @cachedList(
- cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
- )
async def bulk_get_push_rules_enabled(
self, user_ids: Collection[str]
) -> Dict[str, Dict[str, bool]]:
@@ -262,6 +242,7 @@ class PushRulesWorkerStore(
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
+ batch_size=1000,
)
for row in rows:
enabled = bool(row["enabled"])
@@ -345,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: str,
rule_id: str,
priority_class: int,
- conditions: List[Dict[str, str]],
- actions: List[Union[JsonDict, str]],
+ conditions: Sequence[Mapping[str, str]],
+ actions: Sequence[Union[Mapping[str, Any], str]],
before: Optional[str] = None,
after: Optional[str] = None,
) -> None:
@@ -808,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
- txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
@@ -817,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room(
- self, new_room_id: str, user_id: str, rule: dict
+ self, new_room_id: str, user_id: str, rule: PushRule
) -> None:
"""Copy a single push rule from one room to another for a specific user.
@@ -827,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore):
rule: A push rule.
"""
# Create new rule id
- rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
+ new_conditions = []
+
# Change room id in each condition
- for condition in rule.get("conditions", []):
+ for condition in rule.conditions:
+ new_condition = condition
if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
+ new_condition = dict(condition)
+ new_condition["pattern"] = new_room_id
+
+ new_conditions.append(new_condition)
# Add the rule for the new room
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
+ priority_class=rule.priority_class,
+ conditions=new_conditions,
+ actions=rule.actions,
)
async def copy_push_rules_from_room_to_room_for_user(
@@ -859,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore):
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
+ for rule, enabled in user_push_rules:
+ if not enabled:
+ continue
+
+ conditions = rule.conditions
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0090c9f225..124c70ad37 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -161,7 +161,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: The receipt types to fetch.
Returns:
- The latest receipt, if one exists.
+ The event ID and stream ordering of the latest receipt, if one exists.
"""
clause, args = make_in_list_sql_clause(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cb63cd9b7d..7fb9c801da 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
"""
user_id: str
+ token_id: int
is_guest: bool = False
shadow_banned: bool = False
- token_id: Optional[int] = None
device_id: Optional[str] = None
valid_until_ms: Optional[int] = None
token_owner: str = attr.ib()
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
@@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
@@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?")
where_args.append(event_type)
- if aggregation_key:
- where_clause.append("aggregation_key = ?")
- where_args.append(aggregation_key)
-
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d6d485507b..b7d4baa6bb 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -207,7 +207,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None]
) -> Tuple[Union[str, None], List[str]]:
- if not room_types or not self.config.experimental.msc3827_enabled:
+ if not room_types:
return None, []
else:
# We use None when we want get rooms without a type
@@ -2001,9 +2001,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+ # We join on room_stats_state despite not using any columns from it
+ # because the join can influence the number of rows returned;
+ # e.g. a room that doesn't have state, maybe because it was deleted.
+ # The query returning the total count should be consistent with
+ # the query returning the results.
sql = """
SELECT COUNT(*) as total_event_reports
FROM event_reports AS er
+ JOIN room_stats_state ON room_stats_state.room_id = er.room_id
{}
""".format(
where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index df6b82660e..046ad3a11c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -55,6 +56,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -183,7 +185,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn,
)
- @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
@@ -281,6 +283,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
A mapping from user ID to ProfileInfo.
+
+ Preconditions:
+ - There is full state available for the room (it is not partial-stated).
"""
def _get_users_in_room_with_profiles(
@@ -561,7 +566,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results_dict.get("membership"), results_dict.get("event_id")
- @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -732,25 +737,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return frozenset(r.room_id for r in rooms)
- @cached(
- max_entries=500000,
- cache_context=True,
- iterable=True,
- prune_unread_entries=False,
+ @cached(max_entries=10000)
+ async def does_pair_of_users_share_a_room(
+ self, user_id: str, other_user_id: str
+ ) -> bool:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
)
- async def get_users_who_share_room_with_user(
- self, user_id: str, cache_context: _CacheContext
+ async def _do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
+ ) -> Mapping[str, Optional[bool]]:
+ """Return mapping from user ID to whether they share a room with the
+ given user.
+
+ Note: `None` and `False` are equivalent and mean they don't share a
+ room.
+ """
+
+ def do_users_share_a_room_txn(
+ txn: LoggingTransaction, user_ids: Collection[str]
+ ) -> Dict[str, bool]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
+ )
+
+ # This query works by fetching both the list of rooms for the target
+ # user and the set of other users, and then checking if there is any
+ # overlap.
+ sql = f"""
+ SELECT b.state_key
+ FROM (
+ SELECT room_id FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+ ) AS a
+ INNER JOIN (
+ SELECT room_id, state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+ ) AS b using (room_id)
+ LIMIT 1
+ """
+
+ txn.execute(sql, (user_id, *args))
+ return {u: True for u, in txn}
+
+ to_return = {}
+ for batch_user_ids in batch_iter(other_user_ids, 1000):
+ res = await self.db_pool.runInteraction(
+ "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+ )
+ to_return.update(res)
+
+ return to_return
+
+ async def do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
) -> Set[str]:
+ """Return the set of users who share a room with the first users"""
+
+ user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+ return {u for u, share_room in user_dict.items() if share_room}
+
+ async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
"""Returns the set of users who share a room with `user_id`"""
- room_ids = await self.get_rooms_for_user(
- user_id, on_invalidate=cache_context.invalidate
- )
+ room_ids = await self.get_rooms_for_user(user_id)
user_who_share_room = set()
for room_id in room_ids:
- user_ids = await self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
+ user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
return user_who_share_room
@@ -779,9 +835,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset()
- async def get_joined_users_from_state(
+ async def get_joined_user_ids_from_state(
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -792,25 +848,25 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
- return await self._get_joined_users_from_context(
+ return await self._get_joined_user_ids_from_context(
room_id, state_group, state, context=state_entry
)
@cached(num_args=2, iterable=True, max_entries=100000)
- async def _get_joined_users_from_context(
+ async def _get_joined_user_ids_from_context(
self,
room_id: str,
state_group: Union[object, int],
current_state_ids: StateMap[str],
event: Optional[EventBase] = None,
context: Optional["_StateCacheEntry"] = None,
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
assert state_group is not None
- users_in_room = {}
+ users_in_room = set()
member_event_ids = [
e_id
for key, e_id in current_state_ids.items()
@@ -823,11 +879,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids`
if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get_immediate(
+ prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
(room_id, context.prev_group), None
)
- if prev_res and isinstance(prev_res, dict):
- users_in_room = dict(prev_res)
+ if prev_res and isinstance(prev_res, set):
+ users_in_room = prev_res
member_event_ids = [
e_id
for key, e_id in context.delta_ids.items()
@@ -835,7 +891,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
]
for etype, state_key in context.delta_ids:
if etype == EventTypes.Member:
- users_in_room.pop(state_key, None)
+ users_in_room.discard(state_key)
# We check if we have any of the member event ids in the event cache
# before we ask the DB
@@ -843,7 +899,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
- event_map = await self._get_events_from_cache(
+ event_map = self._get_events_from_local_cache(
member_event_ids, update_metrics=False
)
@@ -852,71 +908,64 @@ class RoomMemberWorkerStore(EventsWorkerStore):
ev_entry = event_map.get(event_id)
if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN:
- users_in_room[ev_entry.event.state_key] = ProfileInfo(
- display_name=ev_entry.event.content.get("displayname", None),
- avatar_url=ev_entry.event.content.get("avatar_url", None),
- )
+ users_in_room.add(ev_entry.event.state_key)
else:
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
- event_to_memberships = await self._get_joined_profiles_from_event_ids(
+ event_to_memberships = await self._get_user_ids_from_membership_event_ids(
missing_member_event_ids
)
- users_in_room.update(row for row in event_to_memberships.values() if row)
+ users_in_room.update(
+ user_id for user_id in event_to_memberships.values() if user_id
+ )
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
- users_in_room[event.state_key] = ProfileInfo(
- display_name=event.content.get("displayname", None),
- avatar_url=event.content.get("avatar_url", None),
- )
+ users_in_room.add(event.state_key)
return users_in_room
- @cached(max_entries=10000)
- def _get_joined_profile_from_event_id(
+ @cached(
+ max_entries=10000,
+ # This name matches the old function that has been replaced - the cache name
+ # is kept here to maintain backwards compatibility.
+ name="_get_joined_profile_from_event_id",
+ )
+ def _get_user_id_from_membership_event_id(
self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
+ cached_method_name="_get_user_id_from_membership_event_id",
list_name="event_ids",
)
- async def _get_joined_profiles_from_event_ids(
+ async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str]
- ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+ ) -> Dict[str, Optional[str]]:
"""For given set of member event_ids check if they point to a join
- event and if so return the associated user and profile info.
+ event.
Args:
event_ids: The member event IDs to lookup
Returns:
- Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+ Map from event ID to `user_id`, or None if event is not a join.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
- retcols=("user_id", "display_name", "avatar_url", "event_id"),
+ retcols=("user_id", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=1000,
- desc="_get_joined_profiles_from_event_ids",
+ desc="_get_user_ids_from_membership_event_ids",
)
- return {
- row["event_id"]: (
- row["user_id"],
- ProfileInfo(
- avatar_url=row["avatar_url"], display_name=row["display_name"]
- ),
- )
- for row in rows
- }
+ return {row["event_id"]: row["user_id"] for row in rows}
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1075,12 +1124,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
- joined_users = await self.get_joined_users_from_state(
+ joined_user_ids = await self.get_joined_user_ids_from_state(
room_id, state, state_entry
)
cache.hosts_to_joined_users = {}
- for user_id in joined_users:
+ for user_id in joined_user_ids:
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
@@ -1159,6 +1208,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
+ async def is_locally_forgotten_room(self, room_id: str) -> bool:
+ """Returns whether all local users have forgotten this room_id.
+
+ Args:
+ room_id: The room ID to query.
+
+ Returns:
+ Whether the room is forgotten.
+ """
+
+ sql = """
+ SELECT count(*) > 0 FROM local_current_membership
+ INNER JOIN room_memberships USING (room_id, event_id)
+ WHERE
+ room_id = ?
+ AND forgotten = 0;
+ """
+
+ rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+ # `count(*)` returns always an integer
+ # If any rows still exist it means someone has not forgotten this room yet
+ return not rows[0][0]
+
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 9674c4a757..0b10af0e58 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -419,15 +419,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# anything that was rejected should have the same state as its
# predecessor.
if context.rejected:
- assert context.state_group == context.state_group_before_event
+ state_group = context.state_group_before_event
+ else:
+ state_group = context.state_group
self.db_pool.simple_update_txn(
txn,
table="event_to_state_groups",
keyvalues={"event_id": event.event_id},
- updatevalues={"state_group": context.state_group},
+ updatevalues={"state_group": state_group},
)
+ # the event may now be rejected where it was not before, or vice versa,
+ # in which case we need to update the rejected flags.
+ if bool(context.rejected) != (event.rejected_reason is not None):
+ self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
self.db_pool.simple_delete_one_txn(
txn,
table="partial_state_events",
@@ -440,7 +447,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.call_after(
self._get_state_group_for_event.prefill,
(event.event_id,),
- context.state_group,
+ state_group,
)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2590b52f73..a347430aa7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -58,6 +58,7 @@ from twisted.internet import defer
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -1346,6 +1347,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, next_token
+ @trace
async def paginate_room_events(
self,
room_id: str,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index afbc85ad0c..bb64543c1f 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -202,7 +202,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
- cache_entry = cache.get(group)
+ # If we are asked explicitly for a subset of keys, we only ask for those
+ # from the cache. This ensures that the `DictionaryCache` can make
+ # better decisions about what to cache and what to expire.
+ dict_keys = None
+ if not state_filter.has_wildcards():
+ dict_keys = state_filter.concrete_types()
+
+ cache_entry = cache.get(group, dict_keys=dict_keys)
state_dict_ids = cache_entry.value
if cache_entry.full or state_filter.is_full():
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index af3bab2c15..0004d955b4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -539,15 +539,6 @@ class StateFilter:
is_mine_id: a callable which confirms if a given state_key matches a mxid
of a local user
"""
-
- # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
- # there may be circumstances in which we return a piece of state that, once we
- # resync the state, we discover is invalid. For example: if it turns out that
- # the sender of a piece of state wasn't actually in the room, then clearly that
- # state shouldn't have been returned.
- # We should at least add some tests around this to see what happens.
- # https://github.com/matrix-org/synapse/issues/13006
-
# if we haven't requested membership events, then it depends on the value of
# 'include_others'
if EventTypes.Member not in self.types:
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 466e5137f2..b4bf49dace 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.util import unwrapFirstError
@@ -58,6 +59,7 @@ class PartialStateEventsTracker:
for o in observers:
o.callback(None)
+ @trace_with_opname("PartialStateEventsTracker.await_full_state")
async def await_full_state(self, event_ids: Collection[str]) -> None:
"""Wait for all the given events to have full state.
@@ -151,6 +153,7 @@ class PartialCurrentStateTracker:
for o in observers:
o.callback(None)
+ @trace_with_opname("PartialCurrentStateTracker.await_full_state")
async def await_full_state(self, room_id: str) -> None:
# We add the deferred immediately so that the DB call to check for
# partial state doesn't race when we unpartial the room.
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 54e0b1a23b..bcd840bd88 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
+from synapse.logging.opentracing import trace
from synapse.streams import EventSource
from synapse.types import StreamToken
@@ -69,6 +70,7 @@ class EventSources:
)
return token
+ @trace
async def get_current_token_for_pagination(self, room_id: str) -> StreamToken:
"""Get the current token for a given room to be used to paginate
events.
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..6425f851ea 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import enum
import threading
from typing import (
Callable,
+ Collection,
+ Dict,
Generic,
- Iterable,
MutableMapping,
Optional,
+ Set,
Sized,
+ Tuple,
TypeVar,
Union,
cast,
@@ -31,7 +35,6 @@ from typing import (
from prometheus_client import Gauge
from twisted.internet import defer
-from twisted.python import failure
from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache: Union[
- TreeCache, "MutableMapping[KT, CacheEntry]"
+ TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
] = cache_type()
def metrics_cb() -> None:
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
Raises:
KeyError if the key is not found in the cache
"""
- callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
- val.callbacks.update(callbacks)
+ val.add_invalidation_callback(key, callback)
if update_metrics:
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
- return val.deferred.observe()
+ return val.deferred(key)
+
+ callbacks = (callback,) if callback else ()
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
else:
return defer.succeed(val2)
+ def get_bulk(
+ self,
+ keys: Collection[KT],
+ callback: Optional[Callable[[], None]] = None,
+ ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+ """Bulk lookup of items in the cache.
+
+ Returns:
+ A 3-tuple of:
+ 1. a dict of key/value of items already cached;
+ 2. a deferred that resolves to a dict of key/value of items
+ we're already fetching; and
+ 3. a collection of keys that don't appear in the previous two.
+ """
+
+ # The cached results
+ cached = {}
+
+ # List of pending deferreds
+ pending = []
+
+ # Dict that gets filled out when the pending deferreds complete
+ pending_results = {}
+
+ # List of keys that aren't in either cache
+ missing = []
+
+ callbacks = (callback,) if callback else ()
+
+ for key in keys:
+ # Check if its in the main cache.
+ immediate_value = self.cache.get(
+ key,
+ _Sentinel.sentinel,
+ callbacks=callbacks,
+ )
+ if immediate_value is not _Sentinel.sentinel:
+ cached[key] = immediate_value
+ continue
+
+ # Check if its in the pending cache
+ pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+ if pending_value is not _Sentinel.sentinel:
+ pending_value.add_invalidation_callback(key, callback)
+
+ def completed_cb(value: VT, key: KT) -> VT:
+ pending_results[key] = value
+ return value
+
+ # Add a callback to fill out `pending_results` when that completes
+ d = pending_value.deferred(key).addCallback(completed_cb, key)
+ pending.append(d)
+ continue
+
+ # Not in either cache
+ missing.append(key)
+
+ # If we've got pending deferreds, squash them into a single one that
+ # returns `pending_results`.
+ pending_deferred = None
+ if pending:
+ pending_deferred = defer.gatherResults(
+ pending, consumeErrors=True
+ ).addCallback(lambda _: pending_results)
+
+ return (cached, pending_deferred, missing)
+
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
- if not isinstance(value, defer.Deferred):
- raise TypeError("not a Deferred")
-
- callbacks = [callback] if callback else []
self.check_thread()
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry:
- existing_entry.invalidate()
+ self._pending_deferred_cache.pop(key, None)
# XXX: why don't we invalidate the entry in `self.cache` yet?
- # we can save a whole load of effort if the deferred is ready.
- if value.called:
- result = value.result
- if not isinstance(result, failure.Failure):
- self.cache.set(key, cast(VT, result), callbacks)
- return value
-
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
+ entry = CacheEntrySingle[KT, VT](value)
+ entry.add_invalidation_callback(key, callback)
+ self._pending_deferred_cache[key] = entry
+ deferred = entry.deferred(key).addCallbacks(
+ self._completed_callback,
+ self._error_callback,
+ callbackArgs=(entry, key),
+ errbackArgs=(entry, key),
+ )
- observable = ObservableDeferred(value, consumeErrors=True)
- observer = observable.observe()
- entry = CacheEntry(deferred=observable, callbacks=callbacks)
+ # we return a new Deferred which will be called before any subsequent observers.
+ return deferred
- self._pending_deferred_cache[key] = entry
+ def start_bulk_input(
+ self,
+ keys: Collection[KT],
+ callback: Optional[Callable[[], None]] = None,
+ ) -> "CacheMultipleEntries[KT, VT]":
+ """Bulk set API for use when fetching multiple keys at once from the DB.
- def compare_and_pop() -> bool:
- """Check if our entry is still the one in _pending_deferred_cache, and
- if so, pop it.
-
- Returns true if the entries matched.
- """
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- return True
-
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
- return False
-
- def cb(result: VT) -> None:
- if compare_and_pop():
- self.cache.set(key, result, entry.callbacks)
- else:
- # we're not going to put this entry into the cache, so need
- # to make sure that the invalidation callbacks are called.
- # That was probably done when _pending_deferred_cache was
- # updated, but it's possible that `set` was called without
- # `invalidate` being previously called, in which case it may
- # not have been. Either way, let's double-check now.
- entry.invalidate()
-
- def eb(_fail: Failure) -> None:
- compare_and_pop()
- entry.invalidate()
-
- # once the deferred completes, we can move the entry from the
- # _pending_deferred_cache to the real cache.
- #
- observer.addCallbacks(cb, eb)
+ Called *before* starting the fetch from the DB, and the caller *must*
+ call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+ """
- # we return a new Deferred which will be called before any subsequent observers.
- return observable.observe()
+ entry = CacheMultipleEntries[KT, VT]()
+ entry.add_global_invalidation_callback(callback)
+
+ for key in keys:
+ self._pending_deferred_cache[key] = entry
+
+ return entry
+
+ def _completed_callback(
+ self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+ ) -> VT:
+ """Called when a deferred is completed."""
+ # We check if the current entry matches the entry associated with the
+ # deferred. If they don't match then it got invalidated.
+ current_entry = self._pending_deferred_cache.pop(key, None)
+ if current_entry is not entry:
+ if current_entry:
+ self._pending_deferred_cache[key] = current_entry
+ return value
+
+ self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+ return value
+
+ def _error_callback(
+ self,
+ failure: Failure,
+ entry: "CacheEntry[KT, VT]",
+ key: KT,
+ ) -> Failure:
+ """Called when a deferred errors."""
+
+ # We check if the current entry matches the entry associated with the
+ # deferred. If they don't match then it got invalidated.
+ current_entry = self._pending_deferred_cache.pop(key, None)
+ if current_entry is not entry:
+ if current_entry:
+ self._pending_deferred_cache[key] = current_entry
+ return failure
+
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
+
+ return failure
def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None:
- callbacks = [callback] if callback else []
+ callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks)
+ self._pending_deferred_cache.pop(key, None)
def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, which will (a) stop it being returned
- # for future queries and (b) stop it being persisted as a proper entry
+ # _pending_deferred_cache, which will (a) stop it being returned for
+ # future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
-
- # run the invalidation callbacks now, rather than waiting for the
- # deferred to resolve.
if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
for entry in iterate_tree_cache_entry(entry):
- entry.invalidate()
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
- for entry in self._pending_deferred_cache.values():
- entry.invalidate()
+ for key, entry in self._pending_deferred_cache.items():
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
+
self._pending_deferred_cache.clear()
-class CacheEntry:
- __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+ """Abstract class for entries in `DeferredCache[KT, VT]`"""
- def __init__(
- self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
- ):
- self.deferred = deferred
- self.callbacks = set(callbacks)
- self.invalidated = False
-
- def invalidate(self) -> None:
- if not self.invalidated:
- self.invalidated = True
- for callback in self.callbacks:
- callback()
- self.callbacks.clear()
+ @abc.abstractmethod
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ """Get a deferred that a caller can wait on to get the value at the
+ given key"""
+ ...
+
+ @abc.abstractmethod
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ """Add an invalidation callback"""
+ ...
+
+ @abc.abstractmethod
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ """Get all invalidation callbacks"""
+ ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+ """An implementation of `CacheEntry` wrapping a deferred that results in a
+ single cache entry.
+ """
+
+ __slots__ = ["_deferred", "_callbacks"]
+
+ def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+ self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+ self._callbacks: Set[Callable[[], None]] = set()
+
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ return self._deferred.observe()
+
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ if callback is None:
+ return
+
+ self._callbacks.add(callback)
+
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+ """Cache entry that is used for bulk lookups and insertions."""
+
+ __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+ def __init__(self) -> None:
+ self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+ self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+ self._global_callbacks: Set[Callable[[], None]] = set()
+
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ if not self._deferred:
+ self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+ return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ if callback is None:
+ return
+
+ self._callbacks.setdefault(key, set()).add(callback)
+
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ return self._callbacks.get(key, set()) | self._global_callbacks
+
+ def add_global_invalidation_callback(
+ self, callback: Optional[Callable[[], None]]
+ ) -> None:
+ """Add a callback for when any keys get invalidated."""
+ if callback is None:
+ return
+
+ self._global_callbacks.add(callback)
+
+ def complete_bulk(
+ self,
+ cache: DeferredCache[KT, VT],
+ result: Dict[KT, VT],
+ ) -> None:
+ """Called when there is a result"""
+ for key, value in result.items():
+ cache._completed_callback(value, self, key)
+
+ if self._deferred:
+ self._deferred.callback(result)
+
+ def error_bulk(
+ self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+ ) -> None:
+ """Called when bulk lookup failed."""
+ for key in keys:
+ cache._error_callback(failure, self, key)
+
+ if self._deferred:
+ self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
Generic,
Hashable,
Iterable,
+ List,
Mapping,
Optional,
Sequence,
@@ -73,8 +74,10 @@ class _CacheDescriptorBase:
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
+ name: Optional[str] = None,
):
self.orig = orig
+ self.name = name or orig.__name__
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
@@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
- cache_name=self.orig.__name__,
+ cache_name=self.name,
max_size=self.max_entries,
)
@@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
wrapped = cast(_CachedFunction, _wrapped)
wrapped.cache = cache
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
):
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
+ name=name,
)
if tree and self.num_args < 2:
@@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
- name=self.orig.__name__,
+ name=self.name,
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
@@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped.cache = cache
wrapped.num_args = self.num_args
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
+ name: Optional[str] = None,
):
"""
Args:
@@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
- super().__init__(orig, num_args=num_args, uncached_args=None)
+ super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
self.list_name = list_name
@@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- results = {}
-
- def update_results_dict(res: Any, arg: Hashable) -> None:
- results[arg] = res
-
- # list of deferreds to wait for
- cached_defers = []
-
- missing = set()
-
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
@@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key
+
else:
keylist = list(keyargs)
@@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keylist[self.list_pos] = arg
return tuple(keylist)
- for arg in list_args:
- try:
- res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not res.called:
- res.addCallback(update_results_dict, arg)
- cached_defers.append(res)
- else:
- results[arg] = res.result
- except KeyError:
- missing.add(arg)
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key[self.list_pos]
+
+ cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+ immediate_results, pending_deferred, missing = cache.get_bulk(
+ cache_keys, callback=invalidate_callback
+ )
+
+ results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+ cached_defers: List["defer.Deferred[Any]"] = []
+ if pending_deferred:
+
+ def update_results(r: Dict) -> None:
+ for k, v in r.items():
+ results[cache_key_to_arg(k)] = v
+
+ pending_deferred.addCallback(update_results)
+ cached_defers.append(pending_deferred)
if missing:
- # we need a deferred for each entry in the list,
- # which we put in the cache. Each deferred resolves with the
- # relevant result for that key.
- deferreds_map = {}
- for arg in missing:
- deferred: "defer.Deferred[Any]" = defer.Deferred()
- deferreds_map[arg] = deferred
- key = arg_to_cache_key(arg)
- cached_defers.append(
- cache.set(key, deferred, callback=invalidate_callback)
- )
+ cache_entry = cache.start_bulk_input(missing, invalidate_callback)
def complete_all(res: Dict[Hashable, Any]) -> None:
- # the wrapped function has completed. It returns a dict.
- # We can now update our own result map, and then resolve the
- # observable deferreds in the cache.
- for e, d1 in deferreds_map.items():
- val = res.get(e, None)
- # make sure we update the results map before running the
- # deferreds, because as soon as we run the last deferred, the
- # gatherResults() below will complete and return the result
- # dict to our caller.
- results[e] = val
- d1.callback(val)
+ missing_results = {}
+ for key in missing:
+ arg = cache_key_to_arg(key)
+ val = res.get(arg, None)
+
+ results[arg] = val
+ missing_results[key] = val
+
+ cache_entry.complete_bulk(cache, missing_results)
def errback_all(f: Failure) -> None:
- # the wrapped function has failed. Propagate the failure into
- # the cache, which will invalidate the entry, and cause the
- # relevant cached_deferreds to fail, which will propagate the
- # failure to our caller.
- for d1 in deferreds_map.values():
- d1.errback(f)
+ cache_entry.error_bulk(cache, missing, f)
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = {
+ cache_key_to_arg(key) for key in missing
+ }
# dispatch the call, and attach the two handlers
- defer.maybeDeferred(
+ missing_d = defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback_all)
+ cached_defers.append(missing_d)
if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
else:
return defer.succeed(results)
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -577,6 +571,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
@@ -587,13 +582,18 @@ def cached(
cache_context=cache_context,
iterable=iterable,
prune_unread_entries=prune_unread_entries,
+ name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
def cachedList(
- *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+ *,
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
@@ -628,6 +628,7 @@ def cachedList(
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
+ name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d267703df0..fa91479c97 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,11 +14,13 @@
import enum
import logging
import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
import attr
+from typing_extensions import Literal
from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
logger = logging.getLogger(__name__)
@@ -33,10 +35,12 @@ DV = TypeVar("DV")
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class DictionaryEntry: # should be: Generic[DKT, DV].
"""Returned when getting an entry from the cache
+ If `full` is true then `known_absent` will be the empty set.
+
Attributes:
full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
@@ -53,20 +57,90 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
return len(self.value)
+class _FullCacheKey(enum.Enum):
+ """The key we use to cache the full dict."""
+
+ KEY = object()
+
+
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
+class _PerKeyValue(Generic[DV]):
+ """The cached value of a dictionary key. If `value` is the sentinel,
+ indicates that the requested key is known to *not* be in the full dict.
+ """
+
+ __slots__ = ["value"]
+
+ def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None:
+ self.value = value
+
+ def __len__(self) -> int:
+ # We add a `__len__` implementation as we use this class in a cache
+ # where the values are variable length.
+ return 1
+
+
class DictionaryCache(Generic[KT, DKT, DV]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
+
+ This cache has two levels of key. First there is the "cache key" (of type
+ `KT`), which maps to a dict. The keys to that dict are the "dict key" (of
+ type `DKT`). The overall structure is therefore `KT->DKT->DV`. For
+ example, it might look like:
+
+ {
+ 1: { 1: "a", 2: "b" },
+ 2: { 1: "c" },
+ }
+
+ It is possible to look up either individual dict keys, or the *complete*
+ dict for a given cache key.
+
+ Each dict item, and the complete dict is treated as a separate LRU
+ entry for the purpose of cache expiry. For example, given:
+ dict_cache.get(1, None) -> DictionaryEntry({1: "a", 2: "b"})
+ dict_cache.get(1, [1]) -> DictionaryEntry({1: "a"})
+ dict_cache.get(1, [2]) -> DictionaryEntry({2: "b"})
+
+ ... then the cache entry for the complete dict will expire first,
+ followed by the cache entry for the '1' dict key, and finally that
+ for the '2' dict key.
"""
def __init__(self, name: str, max_entries: int = 1000):
- self.cache: LruCache[KT, DictionaryEntry] = LruCache(
- max_size=max_entries, cache_name=name, size_callback=len
+ # We use a single LruCache to store two different types of entries:
+ # 1. Map from (key, dict_key) -> dict value (or sentinel, indicating
+ # the key doesn't exist in the dict); and
+ # 2. Map from (key, _FullCacheKey.KEY) -> full dict.
+ #
+ # The former is used when explicit keys of the dictionary are looked up,
+ # and the latter when the full dictionary is requested.
+ #
+ # If when explicit keys are requested and not in the cache, we then look
+ # to see if we have the full dict and use that if we do. If found in the
+ # full dict each key is added into the cache.
+ #
+ # This set up allows the `LruCache` to prune the full dict entries if
+ # they haven't been used in a while, even when there have been recent
+ # queries for subsets of the dict.
+ #
+ # Typing:
+ # * A key of `(KT, DKT)` has a value of `_PerKeyValue`
+ # * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]`
+ self.cache: LruCache[
+ Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]],
+ Union[_PerKeyValue, Dict[DKT, DV]],
+ ] = LruCache(
+ max_size=max_entries,
+ cache_name=name,
+ cache_type=TreeCache,
+ size_callback=len,
)
self.name = name
@@ -91,23 +165,83 @@ class DictionaryCache(Generic[KT, DKT, DV]):
Args:
key
dict_keys: If given a set of keys then return only those keys
- that exist in the cache.
+ that exist in the cache. If None then returns the full dict
+ if it is in the cache.
Returns:
- DictionaryEntry
+ DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
+ will contain include the keys that are in the cache. If None then
+ will either return the full dict if in the cache, or the empty
+ dict (with `full` set to False) if it isn't.
"""
- entry = self.cache.get(key, _Sentinel.sentinel)
- if entry is not _Sentinel.sentinel:
- if dict_keys is None:
- return DictionaryEntry(
- entry.full, entry.known_absent, dict(entry.value)
- )
+ if dict_keys is None:
+ # The caller wants the full set of dictionary keys for this cache key
+ return self._get_full_dict(key)
+
+ # We are being asked for a subset of keys.
+
+ # First go and check for each requested dict key in the cache, tracking
+ # which we couldn't find.
+ values = {}
+ known_absent = set()
+ missing = []
+ for dict_key in dict_keys:
+ entry = self.cache.get((key, dict_key), _Sentinel.sentinel)
+ if entry is _Sentinel.sentinel:
+ missing.append(dict_key)
+ continue
+
+ assert isinstance(entry, _PerKeyValue)
+
+ if entry.value is _Sentinel.sentinel:
+ known_absent.add(dict_key)
else:
- return DictionaryEntry(
- entry.full,
- entry.known_absent,
- {k: entry.value[k] for k in dict_keys if k in entry.value},
- )
+ values[dict_key] = entry.value
+
+ # If we found everything we can return immediately.
+ if not missing:
+ return DictionaryEntry(False, known_absent, values)
+
+ # We are missing some keys, so check if we happen to have the full dict in
+ # the cache.
+ #
+ # We don't update the last access time for this cache fetch, as we
+ # aren't explicitly interested in the full dict and so we don't want
+ # requests for explicit dict keys to keep the full dict in the cache.
+ entry = self.cache.get(
+ (key, _FullCacheKey.KEY),
+ _Sentinel.sentinel,
+ update_last_access=False,
+ )
+ if entry is _Sentinel.sentinel:
+ # Not in the cache, return the subset of keys we found.
+ return DictionaryEntry(False, known_absent, values)
+
+ # We have the full dict!
+ assert isinstance(entry, dict)
+
+ for dict_key in missing:
+ # We explicitly add each dict key to the cache, so that cache hit
+ # rates and LRU times for each key can be tracked separately.
+ value = entry.get(dict_key, _Sentinel.sentinel) # type: ignore[arg-type]
+ self.cache[(key, dict_key)] = _PerKeyValue(value)
+
+ if value is not _Sentinel.sentinel:
+ values[dict_key] = value
+
+ return DictionaryEntry(True, set(), values)
+
+ def _get_full_dict(
+ self,
+ key: KT,
+ ) -> DictionaryEntry:
+ """Fetch the full dict for the given key."""
+
+ # First we check if we have cached the full dict.
+ entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel)
+ if entry is not _Sentinel.sentinel:
+ assert isinstance(entry, dict)
+ return DictionaryEntry(True, set(), entry)
return DictionaryEntry(False, set(), {})
@@ -117,7 +251,13 @@ class DictionaryCache(Generic[KT, DKT, DV]):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
- self.cache.pop(key, None)
+
+ # We want to drop all information about the dict for the given key, so
+ # we use `del_multi` to delete it all in one go.
+ #
+ # We ignore the type error here: `del_multi` accepts a truncated key
+ # (when the key type is a tuple).
+ self.cache.del_multi((key,)) # type: ignore[arg-type]
def invalidate_all(self) -> None:
self.check_thread()
@@ -131,7 +271,16 @@ class DictionaryCache(Generic[KT, DKT, DV]):
value: Dict[DKT, DV],
fetched_keys: Optional[Iterable[DKT]] = None,
) -> None:
- """Updates the entry in the cache
+ """Updates the entry in the cache.
+
+ Note: This does *not* invalidate any existing entries for the `key`.
+ In particular, if we add an entry for the cached "full dict" with
+ `fetched_keys=None`, existing entries for individual dict keys are
+ not invalidated. Likewise, adding entries for individual keys does
+ not invalidate any cached value for the full dict.
+
+ In other words: if the underlying data is *changed*, the cache must
+ be explicitly invalidated via `.invalidate()`.
Args:
sequence
@@ -149,20 +298,27 @@ class DictionaryCache(Generic[KT, DKT, DV]):
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
if fetched_keys is None:
- self._insert(key, value, set())
+ self.cache[(key, _FullCacheKey.KEY)] = value
else:
- self._update_or_insert(key, value, fetched_keys)
+ self._update_subset(key, value, fetched_keys)
- def _update_or_insert(
- self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
+ def _update_subset(
+ self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT]
) -> None:
- # We pop and reinsert as we need to tell the cache the size may have
- # changed
+ """Add the given dictionary values as explicit keys in the cache.
+
+ Args:
+ key: top-level cache key
+ value: The dictionary with all the values that we should cache
+ fetched_keys: The full set of dict keys that were looked up. Any keys
+ here not in `value` should be marked as "known absent".
+ """
+
+ for dict_key, dict_value in value.items():
+ self.cache[(key, dict_key)] = _PerKeyValue(dict_value)
- entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
- entry.value.update(value)
- entry.known_absent.update(known_absent)
- self.cache[key] = entry
+ for dict_key in fetched_keys:
+ if dict_key in value:
+ continue
- def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
- self.cache[key] = DictionaryEntry(True, known_absent, value)
+ self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 31f41fec82..aa93109d13 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -25,8 +25,10 @@ from typing import (
Collection,
Dict,
Generic,
+ Iterable,
List,
Optional,
+ Tuple,
Type,
TypeVar,
Union,
@@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.caches.treecache import (
+ TreeCache,
+ iterate_tree_cache_entry,
+ iterate_tree_cache_items,
+)
from synapse.util.linked_list import ListNode
if TYPE_CHECKING:
@@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]):
default: Literal[None] = None,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
+ update_last_access: bool = ...,
) -> Optional[VT]:
...
@@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]):
default: T,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
+ update_last_access: bool = ...,
) -> Union[T, VT]:
...
@@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]):
default: Optional[T] = None,
callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True,
+ update_last_access: bool = True,
) -> Union[None, T, VT]:
+ """Look up a key in the cache
+
+ Args:
+ key
+ default
+ callbacks: A collection of callbacks that will fire when the
+ node is removed from the cache (either due to invalidation
+ or expiry).
+ update_metrics: Whether to update the hit rate metrics
+ update_last_access: Whether to update the last access metrics
+ on a node if successfully fetched. These metrics are used
+ to determine when to remove the node from the cache. Set
+ to False if this fetch should *not* prevent a node from
+ being expired.
+ """
node = cache.get(key, None)
if node is not None:
- move_node_to_front(node)
+ if update_last_access:
+ move_node_to_front(node)
node.add_callbacks(callbacks)
if update_metrics and metrics:
metrics.inc_hits()
@@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]):
metrics.inc_misses()
return default
+ @overload
+ def cache_get_multi(
+ key: tuple,
+ default: Literal[None] = None,
+ update_metrics: bool = True,
+ ) -> Union[None, Iterable[Tuple[KT, VT]]]:
+ ...
+
+ @overload
+ def cache_get_multi(
+ key: tuple,
+ default: T,
+ update_metrics: bool = True,
+ ) -> Union[T, Iterable[Tuple[KT, VT]]]:
+ ...
+
+ @synchronized
+ def cache_get_multi(
+ key: tuple,
+ default: Optional[T] = None,
+ update_metrics: bool = True,
+ ) -> Union[None, T, Iterable[Tuple[KT, VT]]]:
+ """Returns a generator yielding all entries under the given key.
+
+ Can only be used if backed by a tree cache.
+
+ Example:
+
+ cache = LruCache(10, cache_type=TreeCache)
+ cache[(1, 1)] = "a"
+ cache[(1, 2)] = "b"
+ cache[(2, 1)] = "c"
+
+ items = cache.get_multi((1,))
+ assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+ Returns:
+ Either default if the key doesn't exist, or a generator of the
+ key/value pairs.
+ """
+
+ assert isinstance(cache, TreeCache)
+
+ node = cache.get(key, None)
+ if node is not None:
+ if update_metrics and metrics:
+ metrics.inc_hits()
+
+ # We store entries in the `TreeCache` with values of type `_Node`,
+ # which we need to unwrap.
+ return (
+ (full_key, lru_node.value)
+ for full_key, lru_node in iterate_tree_cache_items(key, node)
+ )
+ else:
+ if update_metrics and metrics:
+ metrics.inc_misses()
+ return default
+
@synchronized
def cache_set(
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]):
self.setdefault = cache_set_default
self.pop = cache_pop
self.del_multi = cache_del_multi
+ if cache_type is TreeCache:
+ self.get_multi = cache_get_multi
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
# invalidated by the cache invalidation replication stream.
self.invalidate = cache_del_multi
@@ -748,9 +834,26 @@ class AsyncLruCache(Generic[KT, VT]):
) -> Optional[VT]:
return self._lru_cache.get(key, update_metrics=update_metrics)
+ async def get_external(
+ self,
+ key: KT,
+ default: Optional[T] = None,
+ update_metrics: bool = True,
+ ) -> Optional[VT]:
+ # This method should fetch from any configured external cache, in this case noop.
+ return None
+
+ def get_local(
+ self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+ ) -> Optional[VT]:
+ return self._lru_cache.get(key, update_metrics=update_metrics)
+
async def set(self, key: KT, value: VT) -> None:
self._lru_cache.set(key, value)
+ def set_local(self, key: KT, value: VT) -> None:
+ self._lru_cache.set(key, value)
+
async def invalidate(self, key: KT) -> None:
# This method should invalidate any external cache and then invalidate the LruCache.
return self._lru_cache.invalidate(key)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index e78305f787..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,15 @@ class TreeCache:
self.size += 1
def get(self, key, default=None):
+ """When `key` is a full key, fetches the value for the given key (if
+ any).
+
+ If `key` is only a partial key (i.e. a truncated tuple) then returns a
+ `TreeCacheNode`, which can be passed to the `iterate_tree_cache_*`
+ functions to iterate over all entries in the cache with keys that start
+ with the given partial key.
+ """
+
node = self.root
for k in key[:-1]:
node = node.get(k, None)
@@ -126,6 +135,9 @@ class TreeCache:
def values(self):
return iterate_tree_cache_entry(self.root)
+ def items(self):
+ return iterate_tree_cache_items((), self.root)
+
def __len__(self) -> int:
return self.size
@@ -139,3 +151,32 @@ def iterate_tree_cache_entry(d):
yield from iterate_tree_cache_entry(value_d)
else:
yield d
+
+
+def iterate_tree_cache_items(key, value):
+ """Helper function to iterate over the leaves of a tree, i.e. a dict of that
+ can contain dicts.
+
+ The provided key is a tuple that will get prepended to the returned keys.
+
+ Example:
+
+ cache = TreeCache()
+ cache[(1, 1)] = "a"
+ cache[(1, 2)] = "b"
+ cache[(2, 1)] = "c"
+
+ tree_node = cache.get((1,))
+
+ items = iterate_tree_cache_items((1,), tree_node)
+ assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+ Returns:
+ A generator yielding key/value pairs.
+ """
+ if isinstance(value, TreeCacheNode):
+ for sub_key, sub_value in value.items():
+ yield from iterate_tree_cache_items((*key, sub_key), sub_value)
+ else:
+ # we've reached a leaf of the tree.
+ yield key, value
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index dfe628c97e..f678b52cb4 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,15 +18,19 @@ import logging
import typing
from typing import Any, DefaultDict, Iterator, List, Set
+from prometheus_client.core import Counter
+
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
-from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics import Histogram, LaterGauge
from synapse.util import Clock
if typing.TYPE_CHECKING:
@@ -35,8 +39,34 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Track how much the ratelimiter is affecting requests
+rate_limit_sleep_counter = Counter("synapse_rate_limit_sleep", "")
+rate_limit_reject_counter = Counter("synapse_rate_limit_reject", "")
+queue_wait_timer = Histogram(
+ "synapse_rate_limit_queue_wait_time_seconds",
+ "sec",
+ [],
+ buckets=(
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.1,
+ 0.25,
+ 0.5,
+ 0.75,
+ 1.0,
+ 2.5,
+ 5.0,
+ 10.0,
+ 20.0,
+ "+Inf",
+ ),
+)
+
+
class FederationRateLimiter:
- def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+ def __init__(self, clock: Clock, config: FederationRatelimitSettings):
def new_limiter() -> "_PerHostRatelimiter":
return _PerHostRatelimiter(clock=clock, config=config)
@@ -44,6 +74,27 @@ class FederationRateLimiter:
str, "_PerHostRatelimiter"
] = collections.defaultdict(new_limiter)
+ # We track the number of affected hosts per time-period so we can
+ # differentiate one really noisy homeserver from a general
+ # ratelimit tuning problem across the federation.
+ LaterGauge(
+ "synapse_rate_limit_sleep_affected_hosts",
+ "Number of hosts that had requests put to sleep",
+ [],
+ lambda: sum(
+ ratelimiter.should_sleep() for ratelimiter in self.ratelimiters.values()
+ ),
+ )
+ LaterGauge(
+ "synapse_rate_limit_reject_affected_hosts",
+ "Number of hosts that had requests rejected",
+ [],
+ lambda: sum(
+ ratelimiter.should_reject()
+ for ratelimiter in self.ratelimiters.values()
+ ),
+ )
+
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
"""Used to ratelimit an incoming request from a given host
@@ -59,11 +110,11 @@ class FederationRateLimiter:
Returns:
context manager which returns a deferred.
"""
- return self.ratelimiters[host].ratelimit()
+ return self.ratelimiters[host].ratelimit(host)
class _PerHostRatelimiter:
- def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+ def __init__(self, clock: Clock, config: FederationRatelimitSettings):
"""
Args:
clock
@@ -94,19 +145,42 @@ class _PerHostRatelimiter:
self.request_times: List[int] = []
@contextlib.contextmanager
- def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
+ def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
# Exceptions will be reraised at the yield.
+ self.host = host
+
request_id = object()
- ret = self._on_enter(request_id)
+ # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+ # type-checking, but we'd need Twisted >= 21.2.
+ ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
try:
yield ret
finally:
self._on_exit(request_id)
+ def should_reject(self) -> bool:
+ """
+ Whether to reject the request if we already have too many queued up
+ (either sleeping or in the ready queue).
+ """
+ queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
+ return queue_size > self.reject_limit
+
+ def should_sleep(self) -> bool:
+ """
+ Whether to sleep the request if we already have too many requests coming
+ through within the window.
+ """
+ return len(self.request_times) > self.sleep_limit
+
+ async def _on_enter_with_tracing(self, request_id: object) -> None:
+ with start_active_span("ratelimit wait"), queue_wait_timer.time():
+ await self._on_enter(request_id)
+
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
time_now = self.clock.time_msec()
@@ -117,8 +191,9 @@ class _PerHostRatelimiter:
# reject the request if we already have too many queued up (either
# sleeping or in the ready queue).
- queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
- if queue_size > self.reject_limit:
+ if self.should_reject():
+ logger.debug("Ratelimiter(%s): rejecting request", self.host)
+ rate_limit_reject_counter.inc()
raise LimitExceededError(
retry_after_ms=int(self.window_size / self.sleep_limit)
)
@@ -130,7 +205,8 @@ class _PerHostRatelimiter:
queue_defer: defer.Deferred[None] = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
- "Ratelimiter: queueing request (queue now %i items)",
+ "Ratelimiter(%s): queueing request (queue now %i items)",
+ self.host,
len(self.ready_request_queue),
)
@@ -139,19 +215,28 @@ class _PerHostRatelimiter:
return defer.succeed(None)
logger.debug(
- "Ratelimit [%s]: len(self.request_times)=%d",
+ "Ratelimit(%s) [%s]: len(self.request_times)=%d",
+ self.host,
id(request_id),
len(self.request_times),
)
- if len(self.request_times) > self.sleep_limit:
- logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
+ if self.should_sleep():
+ logger.debug(
+ "Ratelimiter(%s) [%s]: sleeping request for %f sec",
+ self.host,
+ id(request_id),
+ self.sleep_sec,
+ )
+ rate_limit_sleep_counter.inc()
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
- logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
+ logger.debug(
+ "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
+ )
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
@@ -161,7 +246,9 @@ class _PerHostRatelimiter:
ret_defer = queue_request()
def on_start(r: object) -> object:
- logger.debug("Ratelimit [%s]: Processing req", id(request_id))
+ logger.debug(
+ "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
+ )
self.current_processing.add(request_id)
return r
@@ -183,7 +270,7 @@ class _PerHostRatelimiter:
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id: object) -> None:
- logger.debug("Ratelimit [%s]: Processed req", id(request_id))
+ logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 9abbaa5a64..c810a05907 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
+from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter
@@ -51,6 +52,7 @@ MEMBERSHIP_PRIORITY = (
_HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "")
+@trace
async def filter_events_for_client(
storage: StorageControllers,
user_id: str,
@@ -71,8 +73,8 @@ async def filter_events_for_client(
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
- always_include_ids: set of event ids to specifically
- include (unless sender is ignored)
+ always_include_ids: set of event ids to specifically include, if present
+ in events (unless sender is ignored)
filter_send_to_client: Whether we're checking an event that's going to be
sent to a client. This might not always be the case since this function can
also be called to check whether a user can see the state at a given point.
|