diff options
Diffstat (limited to 'synapse')
48 files changed, 1612 insertions, 685 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a33ac34161..f7d29b4319 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,8 @@ """Contains constants from the specification.""" +from typing_extensions import Final + # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 @@ -39,125 +41,125 @@ class Membership: """Represents the membership states of a user in a room.""" - INVITE = "invite" - JOIN = "join" - KNOCK = "knock" - LEAVE = "leave" - BAN = "ban" - LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) + INVITE: Final = "invite" + JOIN: Final = "join" + KNOCK: Final = "knock" + LEAVE: Final = "leave" + BAN: Final = "ban" + LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN) class PresenceState: """Represents the presence state of a user.""" - OFFLINE = "offline" - UNAVAILABLE = "unavailable" - ONLINE = "online" - BUSY = "org.matrix.msc3026.busy" + OFFLINE: Final = "offline" + UNAVAILABLE: Final = "unavailable" + ONLINE: Final = "online" + BUSY: Final = "org.matrix.msc3026.busy" class JoinRules: - PUBLIC = "public" - KNOCK = "knock" - INVITE = "invite" - PRIVATE = "private" + PUBLIC: Final = "public" + KNOCK: Final = "knock" + INVITE: Final = "invite" + PRIVATE: Final = "private" # As defined for MSC3083. - RESTRICTED = "restricted" + RESTRICTED: Final = "restricted" class RestrictedJoinRuleTypes: """Understood types for the allow rules in restricted join rules.""" - ROOM_MEMBERSHIP = "m.room_membership" + ROOM_MEMBERSHIP: Final = "m.room_membership" class LoginType: - PASSWORD = "m.login.password" - EMAIL_IDENTITY = "m.login.email.identity" - MSISDN = "m.login.msisdn" - RECAPTCHA = "m.login.recaptcha" - TERMS = "m.login.terms" - SSO = "m.login.sso" - DUMMY = "m.login.dummy" - REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" + PASSWORD: Final = "m.login.password" + EMAIL_IDENTITY: Final = "m.login.email.identity" + MSISDN: Final = "m.login.msisdn" + RECAPTCHA: Final = "m.login.recaptcha" + TERMS: Final = "m.login.terms" + SSO: Final = "m.login.sso" + DUMMY: Final = "m.login.dummy" + REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token" # This is used in the `type` parameter for /register when called by # an appservice to register a new user. -APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service" +APP_SERVICE_REGISTRATION_TYPE: Final = "m.login.application_service" class EventTypes: - Member = "m.room.member" - Create = "m.room.create" - Tombstone = "m.room.tombstone" - JoinRules = "m.room.join_rules" - PowerLevels = "m.room.power_levels" - Aliases = "m.room.aliases" - Redaction = "m.room.redaction" - ThirdPartyInvite = "m.room.third_party_invite" - RelatedGroups = "m.room.related_groups" - - RoomHistoryVisibility = "m.room.history_visibility" - CanonicalAlias = "m.room.canonical_alias" - Encrypted = "m.room.encrypted" - RoomAvatar = "m.room.avatar" - RoomEncryption = "m.room.encryption" - GuestAccess = "m.room.guest_access" + Member: Final = "m.room.member" + Create: Final = "m.room.create" + Tombstone: Final = "m.room.tombstone" + JoinRules: Final = "m.room.join_rules" + PowerLevels: Final = "m.room.power_levels" + Aliases: Final = "m.room.aliases" + Redaction: Final = "m.room.redaction" + ThirdPartyInvite: Final = "m.room.third_party_invite" + RelatedGroups: Final = "m.room.related_groups" + + RoomHistoryVisibility: Final = "m.room.history_visibility" + CanonicalAlias: Final = "m.room.canonical_alias" + Encrypted: Final = "m.room.encrypted" + RoomAvatar: Final = "m.room.avatar" + RoomEncryption: Final = "m.room.encryption" + GuestAccess: Final = "m.room.guest_access" # These are used for validation - Message = "m.room.message" - Topic = "m.room.topic" - Name = "m.room.name" + Message: Final = "m.room.message" + Topic: Final = "m.room.topic" + Name: Final = "m.room.name" - ServerACL = "m.room.server_acl" - Pinned = "m.room.pinned_events" + ServerACL: Final = "m.room.server_acl" + Pinned: Final = "m.room.pinned_events" - Retention = "m.room.retention" + Retention: Final = "m.room.retention" - Dummy = "org.matrix.dummy_event" + Dummy: Final = "org.matrix.dummy_event" - SpaceChild = "m.space.child" - SpaceParent = "m.space.parent" + SpaceChild: Final = "m.space.child" + SpaceParent: Final = "m.space.parent" - MSC2716_INSERTION = "org.matrix.msc2716.insertion" - MSC2716_BATCH = "org.matrix.msc2716.batch" - MSC2716_MARKER = "org.matrix.msc2716.marker" + MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion" + MSC2716_BATCH: Final = "org.matrix.msc2716.batch" + MSC2716_MARKER: Final = "org.matrix.msc2716.marker" class ToDeviceEventTypes: - RoomKeyRequest = "m.room_key_request" + RoomKeyRequest: Final = "m.room_key_request" class DeviceKeyAlgorithms: """Spec'd algorithms for the generation of per-device keys""" - ED25519 = "ed25519" - CURVE25519 = "curve25519" - SIGNED_CURVE25519 = "signed_curve25519" + ED25519: Final = "ed25519" + CURVE25519: Final = "curve25519" + SIGNED_CURVE25519: Final = "signed_curve25519" class EduTypes: - Presence = "m.presence" + Presence: Final = "m.presence" class RejectedReason: - AUTH_ERROR = "auth_error" + AUTH_ERROR: Final = "auth_error" class RoomCreationPreset: - PRIVATE_CHAT = "private_chat" - PUBLIC_CHAT = "public_chat" - TRUSTED_PRIVATE_CHAT = "trusted_private_chat" + PRIVATE_CHAT: Final = "private_chat" + PUBLIC_CHAT: Final = "public_chat" + TRUSTED_PRIVATE_CHAT: Final = "trusted_private_chat" class ThirdPartyEntityKind: - USER = "user" - LOCATION = "location" + USER: Final = "user" + LOCATION: Final = "location" -ServerNoticeMsgType = "m.server_notice" -ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" +ServerNoticeMsgType: Final = "m.server_notice" +ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached" class UserTypes: @@ -165,91 +167,91 @@ class UserTypes: 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ - SUPPORT = "support" - BOT = "bot" - ALL_USER_TYPES = (SUPPORT, BOT) + SUPPORT: Final = "support" + BOT: Final = "bot" + ALL_USER_TYPES: Final = (SUPPORT, BOT) class RelationTypes: """The types of relations known to this server.""" - ANNOTATION = "m.annotation" - REPLACE = "m.replace" - REFERENCE = "m.reference" - THREAD = "io.element.thread" + ANNOTATION: Final = "m.annotation" + REPLACE: Final = "m.replace" + REFERENCE: Final = "m.reference" + THREAD: Final = "io.element.thread" class LimitBlockingTypes: """Reasons that a server may be blocked""" - MONTHLY_ACTIVE_USER = "monthly_active_user" - HS_DISABLED = "hs_disabled" + MONTHLY_ACTIVE_USER: Final = "monthly_active_user" + HS_DISABLED: Final = "hs_disabled" class EventContentFields: """Fields found in events' content, regardless of type.""" # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 - LABELS = "org.matrix.labels" + LABELS: Final = "org.matrix.labels" # Timestamp to delete the event after # cf https://github.com/matrix-org/matrix-doc/pull/2228 - SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" + SELF_DESTRUCT_AFTER: Final = "org.matrix.self_destruct_after" # cf https://github.com/matrix-org/matrix-doc/pull/1772 - ROOM_TYPE = "type" + ROOM_TYPE: Final = "type" # Whether a room can federate. - FEDERATE = "m.federate" + FEDERATE: Final = "m.federate" # The creator of the room, as used in `m.room.create` events. - ROOM_CREATOR = "creator" + ROOM_CREATOR: Final = "creator" # Used in m.room.guest_access events. - GUEST_ACCESS = "guest_access" + GUEST_ACCESS: Final = "guest_access" # Used on normal messages to indicate they were historically imported after the fact - MSC2716_HISTORICAL = "org.matrix.msc2716.historical" + MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical" # For "insertion" events to indicate what the next batch ID should be in # order to connect to it - MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id" + MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id" # Used on "batch" events to indicate which insertion event it connects to - MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id" + MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id" # For "marker" events - MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" + MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion" # The authorising user for joining a restricted room. - AUTHORISING_USER = "join_authorised_via_users_server" + AUTHORISING_USER: Final = "join_authorised_via_users_server" class RoomTypes: """Understood values of the room_type field of m.room.create events.""" - SPACE = "m.space" + SPACE: Final = "m.space" class RoomEncryptionAlgorithms: - MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" - DEFAULT = MEGOLM_V1_AES_SHA2 + MEGOLM_V1_AES_SHA2: Final = "m.megolm.v1.aes-sha2" + DEFAULT: Final = MEGOLM_V1_AES_SHA2 class AccountDataTypes: - DIRECT = "m.direct" - IGNORED_USER_LIST = "m.ignored_user_list" + DIRECT: Final = "m.direct" + IGNORED_USER_LIST: Final = "m.ignored_user_list" class HistoryVisibility: - INVITED = "invited" - JOINED = "joined" - SHARED = "shared" - WORLD_READABLE = "world_readable" + INVITED: Final = "invited" + JOINED: Final = "joined" + SHARED: Final = "shared" + WORLD_READABLE: Final = "world_readable" class GuestAccess: - CAN_JOIN = "can_join" + CAN_JOIN: Final = "can_join" # anything that is not "can_join" is considered "forbidden", but for completeness: - FORBIDDEN = "forbidden" + FORBIDDEN: Final = "forbidden" class ReadReceiptEventFields: - MSC2285_HIDDEN = "org.matrix.msc2285.hidden" + MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 502cc8e8d1..b4bed5bf40 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -113,6 +113,7 @@ from synapse.storage.databases.main.monthly_active_users import ( ) from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.room import RoomWorkerStore +from synapse.storage.databases.main.room_batch import RoomBatchStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.session import SessionStore from synapse.storage.databases.main.stats import StatsStore @@ -240,6 +241,7 @@ class GenericWorkerSlavedStore( SlavedEventStore, SlavedKeyStore, RoomWorkerStore, + RoomBatchStore, DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7e09530ad2..52541faab2 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -194,6 +194,7 @@ class SynapseHomeServer(HomeServer): { "/_matrix/client/api/v1": client_resource, "/_matrix/client/r0": client_resource, + "/_matrix/client/v1": client_resource, "/_matrix/client/v3": client_resource, "/_matrix/client/unstable": client_resource, "/_matrix/client/v2_alpha": client_resource, diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 61e569d412..1ddad7cb70 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError @@ -113,32 +114,24 @@ class RegistrationConfig(Config): self.session_lifetime = session_lifetime # The `refreshable_access_token_lifetime` applies for tokens that can be renewed - # using a refresh token, as per MSC2918. If it is `None`, the refresh - # token mechanism is disabled. - # - # Since it is incompatible with the `session_lifetime` mechanism, it is set to - # `None` by default if a `session_lifetime` is set. + # using a refresh token, as per MSC2918. + # If it is `None`, the refresh token mechanism is disabled. refreshable_access_token_lifetime = config.get( "refreshable_access_token_lifetime", - "5m" if session_lifetime is None else None, + "5m", ) if refreshable_access_token_lifetime is not None: refreshable_access_token_lifetime = self.parse_duration( refreshable_access_token_lifetime ) - self.refreshable_access_token_lifetime = refreshable_access_token_lifetime - - if ( - session_lifetime is not None - and refreshable_access_token_lifetime is not None - ): - raise ConfigError( - "The refresh token mechanism is incompatible with the " - "`session_lifetime` option. Consider disabling the " - "`session_lifetime` option or disabling the refresh token " - "mechanism by removing the `refreshable_access_token_lifetime` " - "option." - ) + self.refreshable_access_token_lifetime: Optional[ + int + ] = refreshable_access_token_lifetime + + refresh_token_lifetime = config.get("refresh_token_lifetime") + if refresh_token_lifetime is not None: + refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) + self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime # The fallback template used for authenticating using a registration token self.registration_token_template = self.read_template("registration_token.html") diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 4cda439ad9..993b04099e 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -667,21 +667,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): perspective_name, ) + request: JsonDict = {} + for queue_value in keys_to_fetch: + # there may be multiple requests for each server, so we have to merge + # them intelligently. + request_for_server = { + key_id: { + "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, + } + for key_id in queue_value.key_ids + } + request.setdefault(queue_value.server_name, {}).update(request_for_server) + + logger.debug("Request to notary server %s: %s", perspective_name, request) + try: query_response = await self.client.post_json( destination=perspective_name, path="/_matrix/key/v2/query", - data={ - "server_keys": { - queue_value.server_name: { - key_id: { - "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, - } - for key_id in queue_value.key_ids - } - for queue_value in keys_to_fetch - } - }, + data={"server_keys": request}, ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve upon @@ -689,6 +693,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) + logger.debug( + "Response from notary server %s: %s", perspective_name, query_response + ) + keys: Dict[str, Dict[str, FetchKeyResult]] = {} added_keys: List[Tuple[str, str, FetchKeyResult]] = [] diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index d7527008c4..f251402ed8 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -322,6 +322,11 @@ class _AsyncEventContextImpl(EventContext): attributes by loading from the database. """ if self.state_group is None: + # No state group means the event is an outlier. Usually the state_ids dicts are also + # pre-set to empty dicts, but they get reset when the context is serialized, so set + # them to empty dicts again here. + self._current_state_ids = {} + self._prev_state_ids = {} return current_state_ids = await self._storage.state.get_state_ids_for_group( diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 3b85b135e0..bc3f96c1fc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1395,11 +1395,28 @@ class FederationClient(FederationBase): async def send_request( destination: str, ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: - res = await self.transport_layer.get_room_hierarchy( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - ) + try: + res = await self.transport_layer.get_room_hierarchy( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) + except HttpResponseException as e: + # If an error is received that is due to an unrecognised endpoint, + # fallback to the unstable endpoint. Otherwise consider it a + # legitmate error and raise. + if not self._is_unknown_endpoint(e): + raise + + logger.debug( + "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API" + ) + + res = await self.transport_layer.get_room_hierarchy_unstable( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) room = res.get("room") if not isinstance(room, dict): @@ -1449,6 +1466,10 @@ class FederationClient(FederationBase): if e.code != 502: raise + logger.debug( + "Couldn't fetch room hierarchy, falling back to the spaces API" + ) + # Fallback to the old federation API and translate the results if # no servers implement the new API. # diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9a8758e9a6..8fbc75aa65 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -613,8 +613,11 @@ class FederationServer(FederationBase): state = await self.store.get_events(state_ids) time_now = self._clock.time_msec() + event_json = event.get_pdu_json() return { - "org.matrix.msc3083.v2.event": event.get_pdu_json(), + # TODO Remove the unstable prefix when servers have updated. + "org.matrix.msc3083.v2.event": event_json, + "event": event_json, "state": [p.get_pdu_json(time_now) for p in state.values()], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], } diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 10b5aa5af8..fe29bcfd4b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1192,10 +1192,24 @@ class TransportLayerClient: ) async def get_room_hierarchy( - self, - destination: str, - room_id: str, - suggested_only: bool, + self, destination: str, room_id: str, suggested_only: bool + ) -> JsonDict: + """ + Args: + destination: The remote server + room_id: The room ID to ask about. + suggested_only: if True, only suggested rooms will be returned + """ + path = _create_v1_path("/hierarchy/%s", room_id) + + return await self.client.get_json( + destination=destination, + path=path, + args={"suggested_only": "true" if suggested_only else "false"}, + ) + + async def get_room_hierarchy_unstable( + self, destination: str, room_id: str, suggested_only: bool ) -> JsonDict: """ Args: @@ -1317,15 +1331,26 @@ class SendJoinParser(ByteParser[SendJoinResponse]): prefix + "auth_chain.item", use_float=True, ) - self._coro_event = ijson.kvitems_coro( + # TODO Remove the unstable prefix when servers have updated. + # + # By re-using the same event dictionary this will cause the parsing of + # org.matrix.msc3083.v2.event and event to stomp over each other. + # Generally this should be fine. + self._coro_unstable_event = ijson.kvitems_coro( _event_parser(self._response.event_dict), prefix + "org.matrix.msc3083.v2.event", use_float=True, ) + self._coro_event = ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "event", + use_float=True, + ) def write(self, data: bytes) -> int: self._coro_state.send(data) self._coro_auth.send(data) + self._coro_unstable_event.send(data) self._coro_event.send(data) return len(data) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 2fdf6cc99e..66e915228c 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -611,7 +611,6 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): class FederationRoomHierarchyServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" PATH = "/hierarchy/(?P<room_id>[^/]*)" def __init__( @@ -637,6 +636,10 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): ) +class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + + class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -701,6 +704,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( RoomComplexityServlet, FederationSpaceSummaryServlet, FederationRoomHierarchyServlet, + FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4b66a9862f..4d9c4e5834 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,6 +18,7 @@ import time import unicodedata import urllib.parse from binascii import crc32 +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -756,53 +757,109 @@ class AuthHandler: async def refresh_token( self, refresh_token: str, - valid_until_ms: Optional[int], - ) -> Tuple[str, str]: + access_token_valid_until_ms: Optional[int], + refresh_token_valid_until_ms: Optional[int], + ) -> Tuple[str, str, Optional[int]]: """ Consumes a refresh token and generate both a new access token and a new refresh token from it. The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token. + The lifetime of both the access token and refresh token will be capped so that they + do not exceed the session's ultimate expiry time, if applicable. + Args: refresh_token: The token to consume. - valid_until_ms: The expiration timestamp of the new access token. - + access_token_valid_until_ms: The expiration timestamp of the new access token. + None if the access token does not expire. + refresh_token_valid_until_ms: The expiration timestamp of the new refresh token. + None if the refresh token does not expire. Returns: - A tuple containing the new access token and refresh token + A tuple containing: + - the new access token + - the new refresh token + - the actual expiry time of the access token, which may be earlier than + `access_token_valid_until_ms`. """ # Verify the token signature first before looking up the token if not self._verify_refresh_token(refresh_token): - raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN) + raise SynapseError( + HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN + ) existing_token = await self.store.lookup_refresh_token(refresh_token) if existing_token is None: - raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN) + raise SynapseError( + HTTPStatus.UNAUTHORIZED, + "refresh token does not exist", + Codes.UNKNOWN_TOKEN, + ) if ( existing_token.has_next_access_token_been_used or existing_token.has_next_refresh_token_been_refreshed ): raise SynapseError( - 403, "refresh token isn't valid anymore", Codes.FORBIDDEN + HTTPStatus.FORBIDDEN, + "refresh token isn't valid anymore", + Codes.FORBIDDEN, + ) + + now_ms = self._clock.time_msec() + + if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms: + + raise SynapseError( + HTTPStatus.FORBIDDEN, + "The supplied refresh token has expired", + Codes.FORBIDDEN, ) + if existing_token.ultimate_session_expiry_ts is not None: + # This session has a bounded lifetime, even across refreshes. + + if access_token_valid_until_ms is not None: + access_token_valid_until_ms = min( + access_token_valid_until_ms, + existing_token.ultimate_session_expiry_ts, + ) + else: + access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts + + if refresh_token_valid_until_ms is not None: + refresh_token_valid_until_ms = min( + refresh_token_valid_until_ms, + existing_token.ultimate_session_expiry_ts, + ) + else: + refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts + if existing_token.ultimate_session_expiry_ts < now_ms: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "The session has expired and can no longer be refreshed", + Codes.FORBIDDEN, + ) + ( new_refresh_token, new_refresh_token_id, ) = await self.create_refresh_token_for_user_id( - user_id=existing_token.user_id, device_id=existing_token.device_id + user_id=existing_token.user_id, + device_id=existing_token.device_id, + expiry_ts=refresh_token_valid_until_ms, + ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts, ) access_token = await self.create_access_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id, - valid_until_ms=valid_until_ms, + valid_until_ms=access_token_valid_until_ms, refresh_token_id=new_refresh_token_id, ) await self.store.replace_refresh_token( existing_token.token_id, new_refresh_token_id ) - return access_token, new_refresh_token + return access_token, new_refresh_token, access_token_valid_until_ms def _verify_refresh_token(self, token: str) -> bool: """ @@ -836,6 +893,8 @@ class AuthHandler: self, user_id: str, device_id: str, + expiry_ts: Optional[int], + ultimate_session_expiry_ts: Optional[int], ) -> Tuple[str, int]: """ Creates a new refresh token for the user with the given user ID. @@ -843,6 +902,13 @@ class AuthHandler: Args: user_id: canonical user ID device_id: the device ID to associate with the token. + expiry_ts (milliseconds since the epoch): Time after which the + refresh token cannot be used. + If None, the refresh token never expires until it has been used. + ultimate_session_expiry_ts (milliseconds since the epoch): + Time at which the session will end and can not be extended any + further. + If None, the session can be refreshed indefinitely. Returns: The newly created refresh token and its ID in the database @@ -852,6 +918,8 @@ class AuthHandler: user_id=user_id, token=refresh_token, device_id=device_id, + expiry_ts=expiry_ts, + ultimate_session_expiry_ts=ultimate_session_expiry_ts, ) return refresh_token, refresh_token_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 448a36108e..24ca11b924 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -119,6 +119,7 @@ class RegistrationHandler: self.refreshable_access_token_lifetime = ( hs.config.registration.refreshable_access_token_lifetime ) + self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime init_counters_for_auth_provider("") @@ -793,13 +794,13 @@ class RegistrationHandler: class and RegisterDeviceReplicationServlet. """ assert not self.hs.config.worker.worker_app - valid_until_ms = None + access_token_expiry = None if self.session_lifetime is not None: if is_guest: raise Exception( "session_lifetime is not currently implemented for guest access" ) - valid_until_ms = self.clock.time_msec() + self.session_lifetime + access_token_expiry = self.clock.time_msec() + self.session_lifetime refresh_token = None refresh_token_id = None @@ -808,25 +809,57 @@ class RegistrationHandler: user_id, device_id, initial_display_name ) if is_guest: - assert valid_until_ms is None + assert access_token_expiry is None access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: if should_issue_refresh_token: + # A refreshable access token lifetime must be configured + # since we're told to issue a refresh token (the caller checks + # that this value is set before setting this flag). + assert self.refreshable_access_token_lifetime is not None + + now_ms = self.clock.time_msec() + + # Set the expiry time of the refreshable access token + access_token_expiry = now_ms + self.refreshable_access_token_lifetime + + # Set the refresh token expiry time (if configured) + refresh_token_expiry = None + if self.refresh_token_lifetime is not None: + refresh_token_expiry = now_ms + self.refresh_token_lifetime + + # Set an ultimate session expiry time (if configured) + ultimate_session_expiry_ts = None + if self.session_lifetime is not None: + ultimate_session_expiry_ts = now_ms + self.session_lifetime + + # Also ensure that the issued tokens don't outlive the + # session. + # (It would be weird to configure a homeserver with a shorter + # session lifetime than token lifetime, but may as well handle + # it.) + access_token_expiry = min( + access_token_expiry, ultimate_session_expiry_ts + ) + if refresh_token_expiry is not None: + refresh_token_expiry = min( + refresh_token_expiry, ultimate_session_expiry_ts + ) + ( refresh_token, refresh_token_id, ) = await self._auth_handler.create_refresh_token_for_user_id( user_id, device_id=registered_device_id, - ) - valid_until_ms = ( - self.clock.time_msec() + self.refreshable_access_token_lifetime + expiry_ts=refresh_token_expiry, + ultimate_session_expiry_ts=ultimate_session_expiry_ts, ) access_token = await self._auth_handler.create_access_token_for_user_id( user_id, device_id=registered_device_id, - valid_until_ms=valid_until_ms, + valid_until_ms=access_token_expiry, is_appservice_ghost=is_appservice_ghost, refresh_token_id=refresh_token_id, ) @@ -834,7 +867,7 @@ class RegistrationHandler: return { "device_id": registered_device_id, "access_token": access_token, - "valid_until_ms": valid_until_ms, + "valid_until_ms": access_token_expiry, "refresh_token": refresh_token, } diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 8181cc0b52..b2cfe537df 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,8 +36,9 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict +from synapse.types import JsonDict, Requester from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -93,6 +94,9 @@ class RoomSummaryHandler: self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() + self._ratelimiter = Ratelimiter( + store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 + ) # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. @@ -249,7 +253,7 @@ class RoomSummaryHandler: async def get_room_hierarchy( self, - requester: str, + requester: Requester, requested_room_id: str, suggested_only: bool = False, max_depth: Optional[int] = None, @@ -276,6 +280,8 @@ class RoomSummaryHandler: Returns: The JSON hierarchy dictionary. """ + await self._ratelimiter.ratelimit(requester) + # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. # @@ -283,7 +289,7 @@ class RoomSummaryHandler: # to process multiple requests for the same page will result in errors. return await self._pagination_response_cache.wrap( ( - requester, + requester.user.to_string(), requested_room_id, suggested_only, max_depth, @@ -291,7 +297,7 @@ class RoomSummaryHandler: from_token, ), self._get_room_hierarchy, - requester, + requester.user.to_string(), requested_room_id, suggested_only, max_depth, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 96d7a8f2a9..a8154168be 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -24,6 +24,7 @@ from typing import ( List, Optional, Tuple, + TypeVar, Union, ) @@ -81,10 +82,19 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import ( + defer_to_thread, + make_deferred_yieldable, + run_in_background, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client.login import LoginResponse from synapse.storage import DataStore +from synapse.storage.background_updates import ( + DEFAULT_BATCH_SIZE_CALLBACK, + MIN_BATCH_SIZE_CALLBACK, + ON_UPDATE_CALLBACK, +) from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -104,6 +114,9 @@ if TYPE_CHECKING: from synapse.app.generic_worker import GenericWorkerSlavedStore from synapse.server import HomeServer + +T = TypeVar("T") + """ This package defines the 'stable' API which can be used by extension modules which are loaded into Synapse. @@ -307,7 +320,25 @@ class ModuleApi: auth_checkers=auth_checkers, ) - def register_web_resource(self, path: str, resource: Resource): + def register_background_update_controller_callbacks( + self, + on_update: ON_UPDATE_CALLBACK, + default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, + min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None, + ) -> None: + """Registers background update controller callbacks. + + Added in Synapse v1.49.0. + """ + + for db in self._hs.get_datastores().databases: + db.updates.register_update_controller_callbacks( + on_update=on_update, + default_batch_size=default_batch_size, + min_batch_size=min_batch_size, + ) + + def register_web_resource(self, path: str, resource: Resource) -> None: """Registers a web resource to be served at the given path. This function should be called during initialisation of the module. @@ -432,7 +463,7 @@ class ModuleApi: username: provided user id Returns: - str: qualified @user:id + qualified @user:id """ if username.startswith("@"): return username @@ -468,7 +499,7 @@ class ModuleApi: """ return await self._store.user_get_threepids(user_id) - def check_user_exists(self, user_id: str): + def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]": """Check if user exists. Added in Synapse v0.25.0. @@ -477,13 +508,18 @@ class ModuleApi: user_id: Complete @user:id Returns: - Deferred[str|None]: Canonical (case-corrected) user_id, or None + Canonical (case-corrected) user_id, or None if the user is not registered. """ return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks - def register(self, localpart, displayname=None, emails: Optional[List[str]] = None): + def register( + self, + localpart: str, + displayname: Optional[str] = None, + emails: Optional[List[str]] = None, + ) -> Generator["defer.Deferred[Any]", Any, Tuple[str, str]]: """Registers a new user with given localpart and optional displayname, emails. Also returns an access token for the new user. @@ -495,12 +531,12 @@ class ModuleApi: Added in Synapse v0.25.0. Args: - localpart (str): The localpart of the new user. - displayname (str|None): The displayname of the new user. - emails (List[str]): Emails to bind to the new user. + localpart: The localpart of the new user. + displayname: The displayname of the new user. + emails: Emails to bind to the new user. Returns: - Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token) + a 2-tuple of (user_id, access_token) """ logger.warning( "Using deprecated ModuleApi.register which creates a dummy user device." @@ -510,23 +546,26 @@ class ModuleApi: return user_id, access_token def register_user( - self, localpart, displayname=None, emails: Optional[List[str]] = None - ): + self, + localpart: str, + displayname: Optional[str] = None, + emails: Optional[List[str]] = None, + ) -> "defer.Deferred[str]": """Registers a new user with given localpart and optional displayname, emails. Added in Synapse v1.2.0. Args: - localpart (str): The localpart of the new user. - displayname (str|None): The displayname of the new user. - emails (List[str]): Emails to bind to the new user. + localpart: The localpart of the new user. + displayname: The displayname of the new user. + emails: Emails to bind to the new user. Raises: SynapseError if there is an error performing the registration. Check the 'errcode' property for more information on the reason for failure Returns: - defer.Deferred[str]: user_id + user_id """ return defer.ensureDeferred( self._hs.get_registration_handler().register_user( @@ -536,20 +575,25 @@ class ModuleApi: ) ) - def register_device(self, user_id, device_id=None, initial_display_name=None): + def register_device( + self, + user_id: str, + device_id: Optional[str] = None, + initial_display_name: Optional[str] = None, + ) -> "defer.Deferred[Tuple[str, str, Optional[int], Optional[str]]]": """Register a device for a user and generate an access token. Added in Synapse v1.2.0. Args: - user_id (str): full canonical @user:id - device_id (str|None): The device ID to check, or None to generate + user_id: full canonical @user:id + device_id: The device ID to check, or None to generate a new one. - initial_display_name (str|None): An optional display name for the + initial_display_name: An optional display name for the device. Returns: - defer.Deferred[tuple[str, str]]: Tuple of device ID and access token + Tuple of device ID, access token, access token expiration time and refresh token """ return defer.ensureDeferred( self._hs.get_registration_handler().register_device( @@ -603,7 +647,9 @@ class ModuleApi: ) @defer.inlineCallbacks - def invalidate_access_token(self, access_token): + def invalidate_access_token( + self, access_token: str + ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user Added in Synapse v0.25.0. @@ -635,14 +681,20 @@ class ModuleApi: self._auth_handler.delete_access_token(access_token) ) - def run_db_interaction(self, desc, func, *args, **kwargs): + def run_db_interaction( + self, + desc: str, + func: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> "defer.Deferred[T]": """Run a function with a database connection Added in Synapse v0.25.0. Args: - desc (str): description for the transaction, for metrics etc - func (func): function to be run. Passed a database cursor object + desc: description for the transaction, for metrics etc + func: function to be run. Passed a database cursor object as well as *args and **kwargs *args: positional args to be passed to func **kwargs: named args to be passed to func @@ -656,7 +708,7 @@ class ModuleApi: def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str - ): + ) -> None: """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. @@ -686,7 +738,7 @@ class ModuleApi: client_redirect_url: str, new_user: bool = False, auth_provider_id: str = "<unknown>", - ): + ) -> None: """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. @@ -925,11 +977,11 @@ class ModuleApi: self, f: Callable, msec: float, - *args, + *args: object, desc: Optional[str] = None, run_on_all_instances: bool = False, - **kwargs, - ): + **kwargs: object, + ) -> None: """Wraps a function as a background process and calls it repeatedly. NOTE: Will only run on the instance that is configured to run @@ -970,13 +1022,18 @@ class ModuleApi: f, ) + async def sleep(self, seconds: float) -> None: + """Sleeps for the given number of seconds.""" + + await self._clock.sleep(seconds) + async def send_mail( self, recipient: str, subject: str, html: str, text: str, - ): + ) -> None: """Send an email on behalf of the homeserver. Added in Synapse v1.39.0. @@ -1124,6 +1181,26 @@ class ModuleApi: return {key: state_events[event_id] for key, event_id in state_ids.items()} + async def defer_to_thread( + self, + f: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> T: + """Runs the given function in a separate thread from Synapse's thread pool. + + Added in Synapse v1.49.0. + + Args: + f: The function to run. + args: The function's arguments. + kwargs: The function's keyword arguments. + + Returns: + The return value of the function once ran in a thread. + """ + return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index cf5abdfbda..4f13c0418a 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -21,6 +21,8 @@ from twisted.internet.interfaces import IDelayedCall from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams from synapse.push.mailer import Mailer +from synapse.push.push_types import EmailReason +from synapse.storage.databases.main.event_push_actions import EmailPushAction from synapse.util.threepids import validate_email if TYPE_CHECKING: @@ -190,7 +192,7 @@ class EmailPusher(Pusher): # we then consider all previously outstanding notifications # to be delivered. - reason = { + reason: EmailReason = { "room_id": push_action["room_id"], "now": self.clock.time_msec(), "received_at": received_at, @@ -275,7 +277,7 @@ class EmailPusher(Pusher): return may_send_at async def sent_notif_update_throttle( - self, room_id: str, notified_push_action: dict + self, room_id: str, notified_push_action: EmailPushAction ) -> None: # We have sent a notification, so update the throttle accordingly. # If the event that triggered the notif happened more than @@ -315,7 +317,9 @@ class EmailPusher(Pusher): self.pusher_id, room_id, self.throttle_params[room_id] ) - async def send_notification(self, push_actions: List[dict], reason: dict) -> None: + async def send_notification( + self, push_actions: List[EmailPushAction], reason: EmailReason + ) -> None: logger.info("Sending notif email for user %r", self.user_id) await self.mailer.send_notification_mail( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index dbf4ad7f97..3fa603ccb7 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -26,6 +26,7 @@ from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException +from synapse.storage.databases.main.event_push_actions import HttpPushAction from . import push_rule_evaluator, push_tools @@ -273,7 +274,7 @@ class HttpPusher(Pusher): ) break - async def _process_one(self, push_action: dict) -> bool: + async def _process_one(self, push_action: HttpPushAction) -> bool: if "notify" not in push_action["actions"]: return True diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index ce299ba3da..ba4f866487 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -14,7 +14,7 @@ import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar import bleach import jinja2 @@ -28,6 +28,14 @@ from synapse.push.presentable_names import ( descriptor_from_member_events, name_from_member_event, ) +from synapse.push.push_types import ( + EmailReason, + MessageVars, + NotifVars, + RoomVars, + TemplateVars, +) +from synapse.storage.databases.main.event_push_actions import EmailPushAction from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID from synapse.util.async_helpers import concurrently_execute @@ -135,7 +143,7 @@ class Mailer: % urllib.parse.urlencode(params) ) - template_vars = {"link": link} + template_vars: TemplateVars = {"link": link} await self.send_email( email_address, @@ -165,7 +173,7 @@ class Mailer: % urllib.parse.urlencode(params) ) - template_vars = {"link": link} + template_vars: TemplateVars = {"link": link} await self.send_email( email_address, @@ -196,7 +204,7 @@ class Mailer: % urllib.parse.urlencode(params) ) - template_vars = {"link": link} + template_vars: TemplateVars = {"link": link} await self.send_email( email_address, @@ -210,8 +218,8 @@ class Mailer: app_id: str, user_id: str, email_address: str, - push_actions: Iterable[Dict[str, Any]], - reason: Dict[str, Any], + push_actions: Iterable[EmailPushAction], + reason: EmailReason, ) -> None: """ Send email regarding a user's room notifications @@ -230,7 +238,7 @@ class Mailer: [pa["event_id"] for pa in push_actions] ) - notifs_by_room: Dict[str, List[Dict[str, Any]]] = {} + notifs_by_room: Dict[str, List[EmailPushAction]] = {} for pa in push_actions: notifs_by_room.setdefault(pa["room_id"], []).append(pa) @@ -258,7 +266,7 @@ class Mailer: # actually sort our so-called rooms_in_order list, most recent room first rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) - rooms: List[Dict[str, Any]] = [] + rooms: List[RoomVars] = [] for r in rooms_in_order: roomvars = await self._get_room_vars( @@ -289,7 +297,7 @@ class Mailer: notifs_by_room, state_by_room, notif_events, reason ) - template_vars = { + template_vars: TemplateVars = { "user_display_name": user_display_name, "unsubscribe_link": self._make_unsubscribe_link( user_id, app_id, email_address @@ -302,10 +310,10 @@ class Mailer: await self.send_email(email_address, summary_text, template_vars) async def send_email( - self, email_address: str, subject: str, extra_template_vars: Dict[str, Any] + self, email_address: str, subject: str, extra_template_vars: TemplateVars ) -> None: """Send an email with the given information and template text""" - template_vars = { + template_vars: TemplateVars = { "app_name": self.app_name, "server_name": self.hs.config.server.server_name, } @@ -327,10 +335,10 @@ class Mailer: self, room_id: str, user_id: str, - notifs: Iterable[Dict[str, Any]], + notifs: Iterable[EmailPushAction], notif_events: Dict[str, EventBase], room_state_ids: StateMap[str], - ) -> Dict[str, Any]: + ) -> RoomVars: """ Generate the variables for notifications on a per-room basis. @@ -356,7 +364,7 @@ class Mailer: room_name = await calculate_room_name(self.store, room_state_ids, user_id) - room_vars: Dict[str, Any] = { + room_vars: RoomVars = { "title": room_name, "hash": string_ordinal_total(room_id), # See sender avatar hash "notifs": [], @@ -417,11 +425,11 @@ class Mailer: async def _get_notif_vars( self, - notif: Dict[str, Any], + notif: EmailPushAction, user_id: str, notif_event: EventBase, room_state_ids: StateMap[str], - ) -> Dict[str, Any]: + ) -> NotifVars: """ Generate the variables for a single notification. @@ -442,7 +450,7 @@ class Mailer: after_limit=CONTEXT_AFTER, ) - ret = { + ret: NotifVars = { "link": self._make_notif_link(notif), "ts": notif["received_ts"], "messages": [], @@ -461,8 +469,8 @@ class Mailer: return ret async def _get_message_vars( - self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str] - ) -> Optional[Dict[str, Any]]: + self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str] + ) -> Optional[MessageVars]: """ Generate the variables for a single event, if possible. @@ -494,7 +502,9 @@ class Mailer: if sender_state_event: sender_name = name_from_member_event(sender_state_event) - sender_avatar_url = sender_state_event.content.get("avatar_url") + sender_avatar_url: Optional[str] = sender_state_event.content.get( + "avatar_url" + ) else: # No state could be found, fallback to the MXID. sender_name = event.sender @@ -504,7 +514,7 @@ class Mailer: # sender_hash % the number of default images to choose from sender_hash = string_ordinal_total(event.sender) - ret = { + ret: MessageVars = { "event_type": event.type, "is_historical": event.event_id != notif["event_id"], "id": event.event_id, @@ -519,6 +529,8 @@ class Mailer: return ret msgtype = event.content.get("msgtype") + if not isinstance(msgtype, str): + msgtype = None ret["msgtype"] = msgtype @@ -533,7 +545,7 @@ class Mailer: return ret def _add_text_message_vars( - self, messagevars: Dict[str, Any], event: EventBase + self, messagevars: MessageVars, event: EventBase ) -> None: """ Potentially add a sanitised message body to the message variables. @@ -543,8 +555,8 @@ class Mailer: event: The event under consideration. """ msgformat = event.content.get("format") - - messagevars["format"] = msgformat + if not isinstance(msgformat, str): + msgformat = None formatted_body = event.content.get("formatted_body") body = event.content.get("body") @@ -555,7 +567,7 @@ class Mailer: messagevars["body_text_html"] = safe_text(body) def _add_image_message_vars( - self, messagevars: Dict[str, Any], event: EventBase + self, messagevars: MessageVars, event: EventBase ) -> None: """ Potentially add an image URL to the message variables. @@ -570,7 +582,7 @@ class Mailer: async def _make_summary_text_single_room( self, room_id: str, - notifs: List[Dict[str, Any]], + notifs: List[EmailPushAction], room_state_ids: StateMap[str], notif_events: Dict[str, EventBase], user_id: str, @@ -685,10 +697,10 @@ class Mailer: async def _make_summary_text( self, - notifs_by_room: Dict[str, List[Dict[str, Any]]], + notifs_by_room: Dict[str, List[EmailPushAction]], room_state_ids: Dict[str, StateMap[str]], notif_events: Dict[str, EventBase], - reason: Dict[str, Any], + reason: EmailReason, ) -> str: """ Make a summary text for the email when multiple rooms have notifications. @@ -718,7 +730,7 @@ class Mailer: async def _make_summary_text_from_member_events( self, room_id: str, - notifs: List[Dict[str, Any]], + notifs: List[EmailPushAction], room_state_ids: StateMap[str], notif_events: Dict[str, EventBase], ) -> str: @@ -805,7 +817,7 @@ class Mailer: base_url = "https://matrix.to/#" return "%s/%s" % (base_url, room_id) - def _make_notif_link(self, notif: Dict[str, str]) -> str: + def _make_notif_link(self, notif: EmailPushAction) -> str: """ Generate a link to open an event in the web client. diff --git a/synapse/push/push_types.py b/synapse/push/push_types.py new file mode 100644 index 0000000000..8d16ab62ce --- /dev/null +++ b/synapse/push/push_types.py @@ -0,0 +1,136 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional + +from typing_extensions import TypedDict + + +class EmailReason(TypedDict, total=False): + """ + Information on the event that triggered the email to be sent + + room_id: the ID of the room the event was sent in + now: timestamp in ms when the email is being sent out + room_name: a human-readable name for the room the event was sent in + received_at: the time in milliseconds at which the event was received + delay_before_mail_ms: the amount of time in milliseconds Synapse always waits + before ever emailing about a notification (to give the user a chance to respond + to other push or notice the window) + last_sent_ts: the time in milliseconds at which a notification was last sent + for an event in this room + throttle_ms: the minimum amount of time in milliseconds between two + notifications can be sent for this room + """ + + room_id: str + now: int + room_name: Optional[str] + received_at: int + delay_before_mail_ms: int + last_sent_ts: int + throttle_ms: int + + +class MessageVars(TypedDict, total=False): + """ + Details about a specific message to include in a notification + + event_type: the type of the event + is_historical: a boolean, which is `False` if the message is the one + that triggered the notification, `True` otherwise + id: the ID of the event + ts: the time in milliseconds at which the event was sent + sender_name: the display name for the event's sender + sender_avatar_url: the avatar URL (as a `mxc://` URL) for the event's + sender + sender_hash: a hash of the user ID of the sender + msgtype: the type of the message + body_text_html: html representation of the message + body_text_plain: plaintext representation of the message + image_url: mxc url of an image, when "msgtype" is "m.image" + """ + + event_type: str + is_historical: bool + id: str + ts: int + sender_name: str + sender_avatar_url: Optional[str] + sender_hash: int + msgtype: Optional[str] + body_text_html: str + body_text_plain: str + image_url: str + + +class NotifVars(TypedDict): + """ + Details about an event we are about to include in a notification + + link: a `matrix.to` link to the event + ts: the time in milliseconds at which the event was received + messages: a list of messages containing one message before the event, the + message in the event, and one message after the event. + """ + + link: str + ts: Optional[int] + messages: List[MessageVars] + + +class RoomVars(TypedDict): + """ + Represents a room containing events to include in the email. + + title: a human-readable name for the room + hash: a hash of the ID of the room + invite: a boolean, which is `True` if the room is an invite the user hasn't + accepted yet, `False` otherwise + notifs: a list of events, or an empty list if `invite` is `True`. + link: a `matrix.to` link to the room + avator_url: url to the room's avator + """ + + title: Optional[str] + hash: int + invite: bool + notifs: List[NotifVars] + link: str + avatar_url: Optional[str] + + +class TemplateVars(TypedDict, total=False): + """ + Generic structure for passing to the email sender, can hold all the fields used in email templates. + + app_name: name of the app/service this homeserver is associated with + server_name: name of our own homeserver + link: a link to include into the email to be sent + user_display_name: the display name for the user receiving the notification + unsubscribe_link: the link users can click to unsubscribe from email notifications + summary_text: a summary of the notification(s). The text used can be customised + by configuring the various settings in the `email.subjects` section of the + configuration file. + rooms: a list of rooms containing events to include in the email + reason: information on the event that triggered the email to be sent + """ + + app_name: str + server_name: str + link: str + user_display_name: str + unsubscribe_link: str + summary_text: str + rooms: List[RoomVars] + reason: EmailReason diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 154e5b7028..7d26954244 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -86,7 +86,7 @@ REQUIREMENTS = [ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", - "ijson>=3.0", + "ijson>=3.1", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 8c1bf9227a..fa132d10b4 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -14,10 +14,18 @@ from typing import List, Optional, Tuple from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import _load_current_id +from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id -class SlavedIdTracker: +class SlavedIdTracker(AbstractStreamIdTracker): + """Tracks the "current" stream ID of a stream with a single writer. + + See `AbstractStreamIdTracker` for more details. + + Note that this class does not work correctly when there are multiple + writers. + """ + def __init__( self, db_conn: LoggingDatabaseConnection, @@ -36,17 +44,7 @@ class SlavedIdTracker: self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self) -> int: - """ - - Returns: - int - """ return self._current def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - - For streams with single writers this is equivalent to - `get_current_token`. - """ return self.get_current_token() diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 4d5f862862..7541e21de9 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - # We assert this for the benefit of mypy - assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) - if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index a030e9299e..a390cfcb74 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq from collections.abc import Iterable -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Tuple, Type import attr @@ -157,7 +157,7 @@ class EventsStream(Stream): # now we fetch up to that many rows from the events table - event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows( + event_rows = await self._store.get_all_new_forward_event_rows( instance_name, from_token, current_token, target_row_count ) @@ -191,7 +191,7 @@ class EventsStream(Stream): # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. - ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows( + ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( instance_name, from_token, upper_limit ) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ee4a5e481b..c51a029bf3 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -17,6 +17,7 @@ import logging import platform +from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple import synapse @@ -98,7 +99,7 @@ class VersionServlet(RestServlet): } def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - return 200, self.res + return HTTPStatus.OK, self.res class PurgeHistoryRestServlet(RestServlet): @@ -130,7 +131,7 @@ class PurgeHistoryRestServlet(RestServlet): event = await self.store.get_event(event_id) if event.room_id != room_id: - raise SynapseError(400, "Event is for wrong room.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.") # RoomStreamToken expects [int] not Optional[int] assert event.internal_metadata.stream_ordering is not None @@ -144,7 +145,9 @@ class PurgeHistoryRestServlet(RestServlet): ts = body["purge_up_to_ts"] if not isinstance(ts, int): raise SynapseError( - 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "purge_up_to_ts must be an int", + errcode=Codes.BAD_JSON, ) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) @@ -160,7 +163,9 @@ class PurgeHistoryRestServlet(RestServlet): stream_ordering, ) raise SynapseError( - 404, "there is no event to be purged", errcode=Codes.NOT_FOUND + HTTPStatus.NOT_FOUND, + "there is no event to be purged", + errcode=Codes.NOT_FOUND, ) (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) @@ -173,7 +178,7 @@ class PurgeHistoryRestServlet(RestServlet): ) else: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "must specify purge_up_to_event_id or purge_up_to_ts", errcode=Codes.BAD_JSON, ) @@ -182,7 +187,7 @@ class PurgeHistoryRestServlet(RestServlet): room_id, token, delete_local_events=delete_local_events ) - return 200, {"purge_id": purge_id} + return HTTPStatus.OK, {"purge_id": purge_id} class PurgeHistoryStatusRestServlet(RestServlet): @@ -201,7 +206,7 @@ class PurgeHistoryStatusRestServlet(RestServlet): if purge_status is None: raise NotFoundError("purge id '%s' not found" % purge_id) - return 200, purge_status.asdict() + return HTTPStatus.OK, purge_status.asdict() ######################################################################################## diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index d9a2f6ca15..399b205aaf 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from http import HTTPStatus from typing import Iterable, Pattern from synapse.api.auth import Auth @@ -62,4 +63,4 @@ async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: """ is_admin = await auth.is_server_admin(user_id) if not is_admin: - raise AuthError(403, "You are not a server admin") + raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 80fbf32f17..2e5a6600d3 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError @@ -53,7 +54,7 @@ class DeviceRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -62,7 +63,7 @@ class DeviceRestServlet(RestServlet): device = await self.device_handler.get_device( target_user.to_string(), device_id ) - return 200, device + return HTTPStatus.OK, device async def on_DELETE( self, request: SynapseRequest, user_id: str, device_id: str @@ -71,14 +72,14 @@ class DeviceRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") await self.device_handler.delete_device(target_user.to_string(), device_id) - return 200, {} + return HTTPStatus.OK, {} async def on_PUT( self, request: SynapseRequest, user_id: str, device_id: str @@ -87,7 +88,7 @@ class DeviceRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -97,7 +98,7 @@ class DeviceRestServlet(RestServlet): await self.device_handler.update_device( target_user.to_string(), device_id, body ) - return 200, {} + return HTTPStatus.OK, {} class DevicesRestServlet(RestServlet): @@ -124,14 +125,14 @@ class DevicesRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") devices = await self.device_handler.get_devices_by_user(target_user.to_string()) - return 200, {"devices": devices, "total": len(devices)} + return HTTPStatus.OK, {"devices": devices, "total": len(devices)} class DeleteDevicesRestServlet(RestServlet): @@ -155,7 +156,7 @@ class DeleteDevicesRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -167,4 +168,4 @@ class DeleteDevicesRestServlet(RestServlet): await self.device_handler.delete_devices( target_user.to_string(), body["devices"] ) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index bbfcaf723b..5ee8b11110 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -66,21 +67,23 @@ class EventReportsRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "The start parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "The limit parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) event_reports, total = await self.store.get_event_reports_paginate( @@ -90,7 +93,7 @@ class EventReportsRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(event_reports) - return 200, ret + return HTTPStatus.OK, ret class EventReportDetailRestServlet(RestServlet): @@ -127,13 +130,17 @@ class EventReportDetailRestServlet(RestServlet): try: resolved_report_id = int(report_id) except ValueError: - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) if resolved_report_id < 0: - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) ret = await self.store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") - return 200, ret + return HTTPStatus.OK, ret diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index 68a3ba3cb7..a27110388f 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError @@ -43,7 +44,7 @@ class DeleteGroupAdminRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) if not self.is_mine_id(group_id): - raise SynapseError(400, "Can only delete local groups") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups") await self.group_server.delete_group(group_id, requester.user.to_string()) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 30a687d234..9e23e2d8fc 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError @@ -62,7 +63,7 @@ class QuarantineMediaInRoom(RestServlet): room_id, requester.user.to_string() ) - return 200, {"num_quarantined": num_quarantined} + return HTTPStatus.OK, {"num_quarantined": num_quarantined} class QuarantineMediaByUser(RestServlet): @@ -89,7 +90,7 @@ class QuarantineMediaByUser(RestServlet): user_id, requester.user.to_string() ) - return 200, {"num_quarantined": num_quarantined} + return HTTPStatus.OK, {"num_quarantined": num_quarantined} class QuarantineMediaByID(RestServlet): @@ -118,7 +119,7 @@ class QuarantineMediaByID(RestServlet): server_name, media_id, requester.user.to_string() ) - return 200, {} + return HTTPStatus.OK, {} class UnquarantineMediaByID(RestServlet): @@ -147,7 +148,7 @@ class UnquarantineMediaByID(RestServlet): # Remove from quarantine this media id await self.store.quarantine_media_by_id(server_name, media_id, None) - return 200, {} + return HTTPStatus.OK, {} class ProtectMediaByID(RestServlet): @@ -170,7 +171,7 @@ class ProtectMediaByID(RestServlet): # Protect this media id await self.store.mark_local_media_as_safe(media_id, safe=True) - return 200, {} + return HTTPStatus.OK, {} class UnprotectMediaByID(RestServlet): @@ -193,7 +194,7 @@ class UnprotectMediaByID(RestServlet): # Unprotect this media id await self.store.mark_local_media_as_safe(media_id, safe=False) - return 200, {} + return HTTPStatus.OK, {} class ListMediaInRoom(RestServlet): @@ -211,11 +212,11 @@ class ListMediaInRoom(RestServlet): requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: - raise AuthError(403, "You are not a server admin") + raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) - return 200, {"local": local_mxcs, "remote": remote_mxcs} + return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs} class PurgeMediaCacheRestServlet(RestServlet): @@ -233,13 +234,13 @@ class PurgeMediaCacheRestServlet(RestServlet): if before_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, @@ -247,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet): ret = await self.media_repository.delete_old_remote_media(before_ts) - return 200, ret + return HTTPStatus.OK, ret class DeleteMediaByID(RestServlet): @@ -267,7 +268,7 @@ class DeleteMediaByID(RestServlet): await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: - raise SynapseError(400, "Can only delete local media") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") if await self.store.get_local_media(media_id) is None: raise NotFoundError("Unknown media") @@ -277,7 +278,7 @@ class DeleteMediaByID(RestServlet): deleted_media, total = await self.media_repository.delete_local_media_ids( [media_id] ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} class DeleteMediaByDateSize(RestServlet): @@ -304,26 +305,26 @@ class DeleteMediaByDateSize(RestServlet): if before_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, ) if size_gt < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter size_gt must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if self.server_name != server_name: - raise SynapseError(400, "Can only delete local media") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") logging.info( "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s" @@ -333,7 +334,7 @@ class DeleteMediaByDateSize(RestServlet): deleted_media, total = await self.media_repository.delete_old_local_media( before_ts, size_gt, keep_profiles ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} class UserMediaRestServlet(RestServlet): @@ -369,7 +370,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -380,14 +381,14 @@ class UserMediaRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -425,7 +426,7 @@ class UserMediaRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(media) - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -436,7 +437,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -447,14 +448,14 @@ class UserMediaRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -492,7 +493,7 @@ class UserMediaRestServlet(RestServlet): ([row["media_id"] for row in media]) ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index aba48f6e7b..891b98c088 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -14,6 +14,7 @@ import logging import string +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -77,7 +78,7 @@ class ListRegistrationTokensRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) valid = parse_boolean(request, "valid") token_list = await self.store.get_registration_tokens(valid) - return 200, {"registration_tokens": token_list} + return HTTPStatus.OK, {"registration_tokens": token_list} class NewRegistrationTokenRestServlet(RestServlet): @@ -123,16 +124,20 @@ class NewRegistrationTokenRestServlet(RestServlet): if "token" in body: token = body["token"] if not isinstance(token, str): - raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "token must be a string", + Codes.INVALID_PARAM, + ) if not (0 < len(token) <= 64): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "token must not be empty and must not be longer than 64 characters", Codes.INVALID_PARAM, ) if not set(token).issubset(self.allowed_chars_set): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "token must consist only of characters matched by the regex [A-Za-z0-9-_]", Codes.INVALID_PARAM, ) @@ -142,11 +147,13 @@ class NewRegistrationTokenRestServlet(RestServlet): length = body.get("length", 16) if not isinstance(length, int): raise SynapseError( - 400, "length must be an integer", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "length must be an integer", + Codes.INVALID_PARAM, ) if not (0 < length <= 64): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "length must be greater than zero and not greater than 64", Codes.INVALID_PARAM, ) @@ -162,7 +169,7 @@ class NewRegistrationTokenRestServlet(RestServlet): or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -170,11 +177,15 @@ class NewRegistrationTokenRestServlet(RestServlet): expiry_time = body.get("expiry_time", None) if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must be an integer or null", + Codes.INVALID_PARAM, ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must not be in the past", + Codes.INVALID_PARAM, ) created = await self.store.create_registration_token( @@ -182,7 +193,9 @@ class NewRegistrationTokenRestServlet(RestServlet): ) if not created: raise SynapseError( - 400, f"Token already exists: {token}", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + f"Token already exists: {token}", + Codes.INVALID_PARAM, ) resp = { @@ -192,7 +205,7 @@ class NewRegistrationTokenRestServlet(RestServlet): "completed": 0, "expiry_time": expiry_time, } - return 200, resp + return HTTPStatus.OK, resp class RegistrationTokenRestServlet(RestServlet): @@ -261,7 +274,7 @@ class RegistrationTokenRestServlet(RestServlet): if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return 200, token_info + return HTTPStatus.OK, token_info async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Update a registration token.""" @@ -277,7 +290,7 @@ class RegistrationTokenRestServlet(RestServlet): or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -287,11 +300,15 @@ class RegistrationTokenRestServlet(RestServlet): expiry_time = body["expiry_time"] if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must be an integer or null", + Codes.INVALID_PARAM, ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must not be in the past", + Codes.INVALID_PARAM, ) new_attributes["expiry_time"] = expiry_time @@ -307,7 +324,7 @@ class RegistrationTokenRestServlet(RestServlet): if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return 200, token_info + return HTTPStatus.OK, token_info async def on_DELETE( self, request: SynapseRequest, token: str @@ -316,6 +333,6 @@ class RegistrationTokenRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if await self.store.delete_registration_token(token): - return 200, {} + return HTTPStatus.OK, {} raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index a89dda1ba5..6bbc5510f0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -102,7 +102,9 @@ class RoomRestV2Servlet(RestServlet): ) if not RoomID.is_valid(room_id): - raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) if not await self._store.get_room(room_id): raise NotFoundError("Unknown room id %s" % (room_id,)) @@ -118,7 +120,7 @@ class RoomRestV2Servlet(RestServlet): force_purge=force_purge, ) - return 200, {"delete_id": delete_id} + return HTTPStatus.OK, {"delete_id": delete_id} class DeleteRoomStatusByRoomIdRestServlet(RestServlet): @@ -137,7 +139,9 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): - raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) if delete_ids is None: @@ -153,7 +157,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): **delete.asdict(), } ] - return 200, {"results": cast(JsonDict, response)} + return HTTPStatus.OK, {"results": cast(JsonDict, response)} class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): @@ -175,7 +179,7 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): if delete_status is None: raise NotFoundError("delete id '%s' not found" % delete_id) - return 200, cast(JsonDict, delete_status.asdict()) + return HTTPStatus.OK, cast(JsonDict, delete_status.asdict()) class ListRoomRestServlet(RestServlet): @@ -217,7 +221,7 @@ class ListRoomRestServlet(RestServlet): RoomSortOrder.STATE_EVENTS.value, ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Unknown value for order_by: %s" % (order_by,), errcode=Codes.INVALID_PARAM, ) @@ -225,7 +229,7 @@ class ListRoomRestServlet(RestServlet): search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "search_term cannot be an empty string", errcode=Codes.INVALID_PARAM, ) @@ -233,7 +237,9 @@ class ListRoomRestServlet(RestServlet): direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) reverse_order = True if direction == "b" else False @@ -265,7 +271,7 @@ class ListRoomRestServlet(RestServlet): else: response["prev_batch"] = 0 - return 200, response + return HTTPStatus.OK, response class RoomRestServlet(RestServlet): @@ -310,7 +316,7 @@ class RoomRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -386,7 +392,7 @@ class RoomRestServlet(RestServlet): # See https://github.com/python/mypy/issues/4976#issuecomment-579883622 # for some discussion on why this is necessary. Either way, # `ret` is an opaque dictionary blob as far as the rest of the app cares. - return 200, cast(JsonDict, ret) + return HTTPStatus.OK, cast(JsonDict, ret) class RoomMembersRestServlet(RestServlet): @@ -413,7 +419,7 @@ class RoomMembersRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret = {"members": members, "total": len(members)} - return 200, ret + return HTTPStatus.OK, ret class RoomStateRestServlet(RestServlet): @@ -452,7 +458,7 @@ class RoomStateRestServlet(RestServlet): ) ret = {"state": room_state} - return 200, ret + return HTTPStatus.OK, ret class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): @@ -481,7 +487,10 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): target_user = UserID.from_string(content["user_id"]) if not self.hs.is_mine(target_user): - raise SynapseError(400, "This endpoint can only be used with local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "This endpoint can only be used with local users", + ) if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -527,7 +536,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): ratelimit=False, ) - return 200, {"room_id": room_id} + return HTTPStatus.OK, {"room_id": room_id} class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): @@ -568,7 +577,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): # Figure out which local users currently have power in the room, if any. room_state = await self.state_handler.get_current_state(room_id) if not room_state: - raise SynapseError(400, "Server not in room") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") create_event = room_state[(EventTypes.Create, "")] power_levels = room_state.get((EventTypes.PowerLevels, "")) @@ -582,7 +591,9 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): admin_users.sort(key=lambda user: user_power[user]) if not admin_users: - raise SynapseError(400, "No local admin user in room") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "No local admin user in room" + ) admin_user_id = None @@ -599,7 +610,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): if not admin_user_id: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No local admin user in room", ) @@ -610,7 +621,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): admin_user_id = create_event.sender if not self.is_mine_id(admin_user_id): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No local admin user in room", ) @@ -639,7 +650,8 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): except AuthError: # The admin user we found turned out not to have enough power. raise SynapseError( - 400, "No local admin user in room with power to update power levels." + HTTPStatus.BAD_REQUEST, + "No local admin user in room with power to update power levels.", ) # Now we check if the user we're granting admin rights to is already in @@ -653,7 +665,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): ) if is_joined: - return 200, {} + return HTTPStatus.OK, {} join_rules = room_state.get((EventTypes.JoinRules, "")) is_public = False @@ -661,7 +673,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC if is_public: - return 200, {} + return HTTPStatus.OK, {} await self.room_member_handler.update_membership( fake_requester, @@ -670,7 +682,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): action=Membership.INVITE, ) - return 200, {} + return HTTPStatus.OK, {} class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): @@ -702,7 +714,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): room_id, _ = await self.resolve_room_id(room_identifier) deleted_count = await self.store.delete_forward_extremities_for_room(room_id) - return 200, {"deleted": deleted_count} + return HTTPStatus.OK, {"deleted": deleted_count} async def on_GET( self, request: SynapseRequest, room_identifier: str @@ -713,7 +725,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): room_id, _ = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return 200, {"count": len(extremities), "results": extremities} + return HTTPStatus.OK, {"count": len(extremities), "results": extremities} class RoomEventContextServlet(RestServlet): @@ -762,7 +774,9 @@ class RoomEventContextServlet(RestServlet): ) if not results: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError( + HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND + ) time_now = self.clock.time_msec() results["events_before"] = await self._event_serializer.serialize_events( @@ -781,7 +795,7 @@ class RoomEventContextServlet(RestServlet): bundle_relations=False, ) - return 200, results + return HTTPStatus.OK, results class BlockRoomRestServlet(RestServlet): diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 19f84f33f2..b295fb078b 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes @@ -82,11 +83,15 @@ class SendServerNoticeServlet(RestServlet): # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the # admin api). if not self.server_notices_manager.is_enabled(): - raise SynapseError(400, "Server notices are not enabled on this server") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server" + ) target_user = UserID.from_string(body["user_id"]) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Server notices can only be sent to local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" + ) if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -99,7 +104,7 @@ class SendServerNoticeServlet(RestServlet): txn_id=txn_id, ) - return 200, {"event_id": event.event_id} + return HTTPStatus.OK, {"event_id": event.event_id} def on_PUT( self, request: SynapseRequest, txn_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 948de94ccd..ca41fd45f2 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError @@ -53,7 +54,7 @@ class UserMediaStatisticsRestServlet(RestServlet): UserSortOrder.DISPLAYNAME.value, ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Unknown value for order_by: %s" % (order_by,), errcode=Codes.INVALID_PARAM, ) @@ -61,7 +62,7 @@ class UserMediaStatisticsRestServlet(RestServlet): start = parse_integer(request, "from", default=0) if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -69,7 +70,7 @@ class UserMediaStatisticsRestServlet(RestServlet): limit = parse_integer(request, "limit", default=100) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -77,7 +78,7 @@ class UserMediaStatisticsRestServlet(RestServlet): from_ts = parse_integer(request, "from_ts", default=0) if from_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -86,13 +87,13 @@ class UserMediaStatisticsRestServlet(RestServlet): if until_ts is not None: if until_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter until_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if until_ts <= from_ts: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter until_ts must be greater than from_ts.", errcode=Codes.INVALID_PARAM, ) @@ -100,7 +101,7 @@ class UserMediaStatisticsRestServlet(RestServlet): search_term = parse_string(request, "search_term") if search_term == "": raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter search_term cannot be an empty string.", errcode=Codes.INVALID_PARAM, ) @@ -108,7 +109,9 @@ class UserMediaStatisticsRestServlet(RestServlet): direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) users_media, total = await self.store.get_users_media_usage_paginate( @@ -118,4 +121,4 @@ class UserMediaStatisticsRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(users_media) - return 200, ret + return HTTPStatus.OK, ret diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index ccd9a2a175..2a60b602b1 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -79,14 +79,14 @@ class UsersRestServletV2(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -122,7 +122,7 @@ class UsersRestServletV2(RestServlet): if (start + limit) < total: ret["next_token"] = str(start + len(users)) - return 200, ret + return HTTPStatus.OK, ret class UserRestServletV2(RestServlet): @@ -172,14 +172,14 @@ class UserRestServletV2(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") ret = await self.admin_handler.get_user(target_user) if not ret: raise NotFoundError("User not found") - return 200, ret + return HTTPStatus.OK, ret async def on_PUT( self, request: SynapseRequest, user_id: str @@ -191,7 +191,10 @@ class UserRestServletV2(RestServlet): body = parse_json_object_from_request(request) if not self.hs.is_mine(target_user): - raise SynapseError(400, "This endpoint can only be used with local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "This endpoint can only be used with local users", + ) user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() @@ -210,7 +213,7 @@ class UserRestServletV2(RestServlet): user_type = body.get("user_type", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") set_admin_to = body.get("admin", False) if not isinstance(set_admin_to, bool): @@ -223,11 +226,13 @@ class UserRestServletV2(RestServlet): password = body.get("password", None) if password is not None: if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") deactivate = body.get("deactivated", False) if not isinstance(deactivate, bool): - raise SynapseError(400, "'deactivated' parameter is not of type boolean") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" + ) # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: @@ -282,7 +287,9 @@ class UserRestServletV2(RestServlet): user_id, ) except ExternalIDReuseException: - raise SynapseError(409, "External id is already in use.") + raise SynapseError( + HTTPStatus.CONFLICT, "External id is already in use." + ) if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -293,7 +300,9 @@ class UserRestServletV2(RestServlet): if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: - raise SynapseError(400, "You may not demote yourself.") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "You may not demote yourself." + ) await self.store.set_server_admin(target_user, set_admin_to) @@ -319,7 +328,8 @@ class UserRestServletV2(RestServlet): and self.auth_handler.can_change_password() ): raise SynapseError( - 400, "Must provide a password to re-activate an account." + HTTPStatus.BAD_REQUEST, + "Must provide a password to re-activate an account.", ) await self.deactivate_account_handler.activate_account( @@ -332,7 +342,7 @@ class UserRestServletV2(RestServlet): user = await self.admin_handler.get_user(target_user) assert user is not None - return 200, user + return HTTPStatus.OK, user else: # create user displayname = body.get("displayname", None) @@ -381,7 +391,9 @@ class UserRestServletV2(RestServlet): user_id, ) except ExternalIDReuseException: - raise SynapseError(409, "External id is already in use.") + raise SynapseError( + HTTPStatus.CONFLICT, "External id is already in use." + ) if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -429,51 +441,61 @@ class UserRegisterServlet(RestServlet): nonce = secrets.token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return 200, {"nonce": nonce} + return HTTPStatus.OK, {"nonce": nonce} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() if not self.hs.config.registration.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled" + ) body = parse_json_object_from_request(request) if "nonce" not in body: - raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "nonce must be specified", + errcode=Codes.BAD_JSON, + ) nonce = body["nonce"] if nonce not in self.nonces: - raise SynapseError(400, "unrecognised nonce") + raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce") # Delete the nonce, so it can't be reused, even if it's invalid del self.nonces[nonce] if "username" not in body: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "username must be specified", + errcode=Codes.BAD_JSON, ) else: if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") username = body["username"].encode("utf-8") if b"\x00" in username: - raise SynapseError(400, "Invalid username") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") if "password" not in body: raise SynapseError( - 400, "password must be specified", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "password must be specified", + errcode=Codes.BAD_JSON, ) else: password = body["password"] if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") password_bytes = password.encode("utf-8") if b"\x00" in password_bytes: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") password_hash = await self.auth_handler.hash(password) @@ -482,10 +504,12 @@ class UserRegisterServlet(RestServlet): displayname = body.get("displayname", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") if "mac" not in body: - raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON + ) got_mac = body["mac"] @@ -507,7 +531,7 @@ class UserRegisterServlet(RestServlet): want_mac = want_mac_builder.hexdigest() if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): - raise SynapseError(403, "HMAC incorrect") + raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication from synapse.rest.client.register import RegisterRestServlet @@ -524,7 +548,7 @@ class UserRegisterServlet(RestServlet): ) result = await register._create_registration_details(user_id, body) - return 200, result + return HTTPStatus.OK, result class WhoisRestServlet(RestServlet): @@ -552,11 +576,11 @@ class WhoisRestServlet(RestServlet): await assert_user_is_admin(self.auth, auth_user) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only whois a local user") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) - return 200, ret + return HTTPStatus.OK, ret class DeactivateAccountRestServlet(RestServlet): @@ -575,7 +599,9 @@ class DeactivateAccountRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) if not self.is_mine(UserID.from_string(target_user_id)): - raise SynapseError(400, "Can only deactivate local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Can only deactivate local users" + ) if not await self.store.get_user_by_id(target_user_id): raise NotFoundError("User not found") @@ -597,7 +623,7 @@ class DeactivateAccountRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - return 200, {"id_server_unbind_result": id_server_unbind_result} + return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result} class AccountValidityRenewServlet(RestServlet): @@ -620,7 +646,7 @@ class AccountValidityRenewServlet(RestServlet): if "user_id" not in body: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Missing property 'user_id' in the request body", ) @@ -631,7 +657,7 @@ class AccountValidityRenewServlet(RestServlet): ) res = {"expiration_ts": expiration_ts} - return 200, res + return HTTPStatus.OK, res class ResetPasswordRestServlet(RestServlet): @@ -678,7 +704,7 @@ class ResetPasswordRestServlet(RestServlet): await self._set_password_handler.set_password( target_user_id, new_password_hash, logout_devices, requester ) - return 200, {} + return HTTPStatus.OK, {} class SearchUsersRestServlet(RestServlet): @@ -712,16 +738,16 @@ class SearchUsersRestServlet(RestServlet): # To allow all users to get the users list # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") + # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) logger.info("term: %s ", term) ret = await self.store.search_users(term) - return 200, ret + return HTTPStatus.OK, ret class UserAdminServlet(RestServlet): @@ -765,11 +791,14 @@ class UserAdminServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Only local users can be admins of this homeserver") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Only local users can be admins of this homeserver", + ) is_admin = await self.store.is_server_admin(target_user) - return 200, {"admin": is_admin} + return HTTPStatus.OK, {"admin": is_admin} async def on_PUT( self, request: SynapseRequest, user_id: str @@ -785,16 +814,19 @@ class UserAdminServlet(RestServlet): assert_params_in_dict(body, ["admin"]) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Only local users can be admins of this homeserver") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Only local users can be admins of this homeserver", + ) set_admin_to = bool(body["admin"]) if target_user == auth_user and not set_admin_to: - raise SynapseError(400, "You may not demote yourself.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.") await self.store.set_server_admin(target_user, set_admin_to) - return 200, {} + return HTTPStatus.OK, {} class UserMembershipRestServlet(RestServlet): @@ -816,7 +848,7 @@ class UserMembershipRestServlet(RestServlet): room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} - return 200, ret + return HTTPStatus.OK, ret class PushersRestServlet(RestServlet): @@ -845,7 +877,7 @@ class PushersRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -854,7 +886,10 @@ class PushersRestServlet(RestServlet): filtered_pushers = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} + return HTTPStatus.OK, { + "pushers": filtered_pushers, + "total": len(filtered_pushers), + } class UserTokenRestServlet(RestServlet): @@ -887,16 +922,22 @@ class UserTokenRestServlet(RestServlet): auth_user = requester.user if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be logged in as") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" + ) body = parse_json_object_from_request(request, allow_empty_body=True) valid_until_ms = body.get("valid_until_ms") if valid_until_ms and not isinstance(valid_until_ms, int): - raise SynapseError(400, "'valid_until_ms' parameter must be an int") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" + ) if auth_user.to_string() == user_id: - raise SynapseError(400, "Cannot use admin API to login as self") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self" + ) token = await self.auth_handler.create_access_token_for_user_id( user_id=auth_user.to_string(), @@ -905,7 +946,7 @@ class UserTokenRestServlet(RestServlet): puppets_user_id=user_id, ) - return 200, {"access_token": token} + return HTTPStatus.OK, {"access_token": token} class ShadowBanRestServlet(RestServlet): @@ -947,11 +988,13 @@ class ShadowBanRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be shadow-banned") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" + ) await self.store.set_shadow_banned(UserID.from_string(user_id), True) - return 200, {} + return HTTPStatus.OK, {} async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -959,11 +1002,13 @@ class ShadowBanRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be shadow-banned") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" + ) await self.store.set_shadow_banned(UserID.from_string(user_id), False) - return 200, {} + return HTTPStatus.OK, {} class RateLimitRestServlet(RestServlet): @@ -995,7 +1040,7 @@ class RateLimitRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -1016,7 +1061,7 @@ class RateLimitRestServlet(RestServlet): else: ret = {} - return 200, ret + return HTTPStatus.OK, ret async def on_POST( self, request: SynapseRequest, user_id: str @@ -1024,7 +1069,9 @@ class RateLimitRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be ratelimited") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" + ) if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -1036,14 +1083,14 @@ class RateLimitRestServlet(RestServlet): if not isinstance(messages_per_second, int) or messages_per_second < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) if not isinstance(burst_count, int) or burst_count < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), errcode=Codes.INVALID_PARAM, ) @@ -1059,7 +1106,7 @@ class RateLimitRestServlet(RestServlet): "burst_count": ratelimit.burst_count, } - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -1067,11 +1114,13 @@ class RateLimitRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be ratelimited") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" + ) if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") await self.store.delete_ratelimit_for_user(user_id) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 67e03dca04..09f378f919 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -14,7 +14,17 @@ import logging import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Union, +) from typing_extensions import TypedDict @@ -28,7 +38,6 @@ from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, assert_params_in_dict, - parse_boolean, parse_bytes_from_args, parse_json_object_from_request, parse_string, @@ -155,11 +164,14 @@ class LoginRestServlet(RestServlet): login_submission = parse_json_object_from_request(request) if self._msc2918_enabled: - # Check if this login should also issue a refresh token, as per - # MSC2918 - should_issue_refresh_token = parse_boolean( - request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False + # Check if this login should also issue a refresh token, as per MSC2918 + should_issue_refresh_token = login_submission.get( + "org.matrix.msc2918.refresh_token", False ) + if not isinstance(should_issue_refresh_token, bool): + raise SynapseError( + 400, "`org.matrix.msc2918.refresh_token` should be true or false." + ) else: should_issue_refresh_token = False @@ -458,6 +470,7 @@ class RefreshTokenServlet(RestServlet): self.refreshable_access_token_lifetime = ( hs.config.registration.refreshable_access_token_lifetime ) + self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -467,22 +480,33 @@ class RefreshTokenServlet(RestServlet): if not isinstance(token, str): raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - valid_until_ms = ( - self._clock.time_msec() + self.refreshable_access_token_lifetime - ) - access_token, refresh_token = await self._auth_handler.refresh_token( - token, valid_until_ms - ) - expires_in_ms = valid_until_ms - self._clock.time_msec() - return ( - 200, - { - "access_token": access_token, - "refresh_token": refresh_token, - "expires_in_ms": expires_in_ms, - }, + now = self._clock.time_msec() + access_valid_until_ms = None + if self.refreshable_access_token_lifetime is not None: + access_valid_until_ms = now + self.refreshable_access_token_lifetime + refresh_valid_until_ms = None + if self.refresh_token_lifetime is not None: + refresh_valid_until_ms = now + self.refresh_token_lifetime + + ( + access_token, + refresh_token, + actual_access_token_expiry, + ) = await self._auth_handler.refresh_token( + token, access_valid_until_ms, refresh_valid_until_ms ) + response: Dict[str, Union[str, int]] = { + "access_token": access_token, + "refresh_token": refresh_token, + } + + # expires_in_ms is only present if the token expires + if actual_access_token_expiry is not None: + response["expires_in_ms"] = actual_access_token_expiry - now + + return 200, response + class SsoRedirectServlet(RestServlet): PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index d2b11e39d9..11fd6cd24d 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -41,7 +41,6 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, - parse_boolean, parse_json_object_from_request, parse_string, ) @@ -449,9 +448,13 @@ class RegisterRestServlet(RestServlet): if self._msc2918_enabled: # Check if this registration should also issue a refresh token, as # per MSC2918 - should_issue_refresh_token = parse_boolean( - request, name="org.matrix.msc2918.refresh_token", default=False + should_issue_refresh_token = body.get( + "org.matrix.msc2918.refresh_token", False ) + if not isinstance(should_issue_refresh_token, bool): + raise SynapseError( + 400, "`org.matrix.msc2918.refresh_token` should be true or false." + ) else: should_issue_refresh_token = False diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 955d4e8641..73d0f7c950 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1138,12 +1138,12 @@ class RoomSpaceSummaryRestServlet(RestServlet): class RoomHierarchyRestServlet(RestServlet): - PATTERNS = ( + PATTERNS = [ re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" + "^/_matrix/client/(v1|unstable/org.matrix.msc2946)" "/rooms/(?P<room_id>[^/]*)/hierarchy$" ), - ) + ] def __init__(self, hs: "HomeServer"): super().__init__() @@ -1168,7 +1168,7 @@ class RoomHierarchyRestServlet(RestServlet): ) return 200, await self._room_summary_handler.get_room_hierarchy( - requester.user.to_string(), + requester, room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), max_depth=max_depth, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1605411b00..446204dbe5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -764,7 +764,7 @@ class StateResolutionStore: store: "DataStore" def get_events( - self, event_ids: Iterable[str], allow_rejected: bool = False + self, event_ids: Collection[str], allow_rejected: bool = False ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 6edadea550..499a328201 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -17,6 +17,7 @@ import logging from typing import ( Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -44,7 +45,7 @@ async def resolve_events_with_store( room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], + state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], ) -> StateMap[str]: """ Args: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0623da9aa1..3056e64ff5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.storage.types import Connection -from synapse.types import StreamToken, get_domain_from_id +from synapse.types import get_domain_from_id from synapse.util import json_decoder if TYPE_CHECKING: @@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta): self, stream_name: str, instance_name: str, - token: StreamToken, + token: int, rows: Iterable[Any], ) -> None: pass diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index bc8364400d..d64910aded 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,12 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + Awaitable, + Callable, + Dict, + Iterable, + Optional, +) + +import attr from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.types import Connection from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.util import Clock, json_encoder from . import engines @@ -28,6 +38,45 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]] +DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] +MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _BackgroundUpdateHandler: + """A handler for a given background update. + + Attributes: + callback: The function to call to make progress on the background + update. + oneshot: Wether the update is likely to happen all in one go, ignoring + the supplied target duration, e.g. index creation. This is used by + the update controller to help correctly schedule the update. + """ + + callback: Callable[[JsonDict, int], Awaitable[int]] + oneshot: bool = False + + +class _BackgroundUpdateContextManager: + BACKGROUND_UPDATE_INTERVAL_MS = 1000 + BACKGROUND_UPDATE_DURATION_MS = 100 + + def __init__(self, sleep: bool, clock: Clock): + self._sleep = sleep + self._clock = clock + + async def __aenter__(self) -> int: + if self._sleep: + await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000) + + return self.BACKGROUND_UPDATE_DURATION_MS + + async def __aexit__(self, *exc) -> None: + pass + + class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" @@ -84,20 +133,22 @@ class BackgroundUpdater: MINIMUM_BACKGROUND_BATCH_SIZE = 1 DEFAULT_BACKGROUND_BATCH_SIZE = 100 - BACKGROUND_UPDATE_INTERVAL_MS = 1000 - BACKGROUND_UPDATE_DURATION_MS = 100 def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database + self._database_name = database.name() + # if a background update is currently running, its name. self._current_background_update: Optional[str] = None + self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None + self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None + self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None + self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {} - self._background_update_handlers: Dict[ - str, Callable[[JsonDict, int], Awaitable[int]] - ] = {} + self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {} self._all_done = False # Whether we're currently running updates @@ -107,6 +158,83 @@ class BackgroundUpdater: # enable/disable background updates via the admin API. self.enabled = True + def register_update_controller_callbacks( + self, + on_update: ON_UPDATE_CALLBACK, + default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, + min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, + ) -> None: + """Register callbacks from a module for each hook.""" + if self._on_update_callback is not None: + logger.warning( + "More than one module tried to register callbacks for controlling" + " background updates. Only the callbacks registered by the first module" + " (in order of appearance in Synapse's configuration file) that tried to" + " do so will be called." + ) + + return + + self._on_update_callback = on_update + + if default_batch_size is not None: + self._default_batch_size_callback = default_batch_size + + if min_batch_size is not None: + self._min_batch_size_callback = min_batch_size + + def _get_context_manager_for_update( + self, + sleep: bool, + update_name: str, + database_name: str, + oneshot: bool, + ) -> AsyncContextManager[int]: + """Get a context manager to run a background update with. + + If a module has registered a `update_handler` callback, use the context manager + it returns. + + Otherwise, returns a context manager that will return a default value, optionally + sleeping if needed. + + Args: + sleep: Whether we can sleep between updates. + update_name: The name of the update. + database_name: The name of the database the update is being run on. + oneshot: Whether the update will complete all in one go, e.g. index creation. + In such cases the returned target duration is ignored. + + Returns: + The target duration in milliseconds that the background update should run for. + + Note: this is a *target*, and an iteration may take substantially longer or + shorter. + """ + if self._on_update_callback is not None: + return self._on_update_callback(update_name, database_name, oneshot) + + return _BackgroundUpdateContextManager(sleep, self._clock) + + async def _default_batch_size(self, update_name: str, database_name: str) -> int: + """The batch size to use for the first iteration of a new background + update. + """ + if self._default_batch_size_callback is not None: + return await self._default_batch_size_callback(update_name, database_name) + + return self.DEFAULT_BACKGROUND_BATCH_SIZE + + async def _min_batch_size(self, update_name: str, database_name: str) -> int: + """A lower bound on the batch size of a new background update. + + Used to ensure that progress is always made. Must be greater than 0. + """ + if self._min_batch_size_callback is not None: + return await self._min_batch_size_callback(update_name, database_name) + + return self.MINIMUM_BACKGROUND_BATCH_SIZE + def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: """Returns the current background update, if any.""" @@ -135,13 +263,8 @@ class BackgroundUpdater: try: logger.info("Starting background schema updates") while self.enabled: - if sleep: - await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) - try: - result = await self.do_next_background_update( - self.BACKGROUND_UPDATE_DURATION_MS - ) + result = await self.do_next_background_update(sleep) except Exception: logger.exception("Error doing update") else: @@ -203,13 +326,15 @@ class BackgroundUpdater: return not update_exists - async def do_next_background_update(self, desired_duration_ms: float) -> bool: + async def do_next_background_update(self, sleep: bool = True) -> bool: """Does some amount of work on the next queued background update Returns once some amount of work is done. Args: - desired_duration_ms: How long we want to spend updating. + sleep: Whether to limit how quickly we run background updates or + not. + Returns: True if we have finished running all the background updates, otherwise False """ @@ -252,7 +377,19 @@ class BackgroundUpdater: self._current_background_update = upd["update_name"] - await self._do_background_update(desired_duration_ms) + # We have a background update to run, otherwise we would have returned + # early. + assert self._current_background_update is not None + update_info = self._background_update_handlers[self._current_background_update] + + async with self._get_context_manager_for_update( + sleep=sleep, + update_name=self._current_background_update, + database_name=self._database_name, + oneshot=update_info.oneshot, + ) as desired_duration_ms: + await self._do_background_update(desired_duration_ms) + return False async def _do_background_update(self, desired_duration_ms: float) -> int: @@ -260,7 +397,7 @@ class BackgroundUpdater: update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) - update_handler = self._background_update_handlers[update_name] + update_handler = self._background_update_handlers[update_name].callback performance = self._background_update_performance.get(update_name) @@ -273,9 +410,14 @@ class BackgroundUpdater: if items_per_ms is not None: batch_size = int(desired_duration_ms * items_per_ms) # Clamp the batch size so that we always make progress - batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE) + batch_size = max( + batch_size, + await self._min_batch_size(update_name, self._database_name), + ) else: - batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE + batch_size = await self._default_batch_size( + update_name, self._database_name + ) progress_json = await self.db_pool.simple_select_one_onecol( "background_updates", @@ -294,6 +436,8 @@ class BackgroundUpdater: duration_ms = time_stop - time_start + performance.update(items_updated, duration_ms) + logger.info( "Running background update %r. Processed %r items in %rms." " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", @@ -306,8 +450,6 @@ class BackgroundUpdater: batch_size, ) - performance.update(items_updated, duration_ms) - return len(self._background_update_performance) def register_background_update_handler( @@ -331,7 +473,9 @@ class BackgroundUpdater: update_name: The name of the update that this code handles. update_handler: The function that does the update. """ - self._background_update_handlers[update_name] = update_handler + self._background_update_handlers[update_name] = _BackgroundUpdateHandler( + update_handler + ) def register_noop_background_update(self, update_name: str) -> None: """Register a noop handler for a background update. @@ -453,7 +597,9 @@ class BackgroundUpdater: await self._end_background_update(update_name) return 1 - self.register_background_update_handler(update_name, updater) + self._background_update_handlers[update_name] = _BackgroundUpdateHandler( + updater, oneshot=True + ) async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index d957e770dc..3efdd0c920 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr +from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json @@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [ ] +class BasePushAction(TypedDict): + event_id: str + actions: List[Union[dict, str]] + + +class HttpPushAction(BasePushAction): + room_id: str + stream_ordering: int + + +class EmailPushAction(HttpPushAction): + received_ts: Optional[int] + + def _serialize_action(actions, is_highlight): """Custom serializer for actions. This allows us to "compress" common actions. @@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[dict]: + ) -> List[HttpPushAction]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. @@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[dict]: + ) -> List[EmailPushAction]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 06832221ad..4171b904eb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -15,7 +15,7 @@ # limitations under the License. import itertools import logging -from collections import OrderedDict, namedtuple +from collections import OrderedDict from typing import ( TYPE_CHECKING, Any, @@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder @@ -64,9 +65,6 @@ event_counter = Counter( ) -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) - - @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -108,23 +106,30 @@ class PersistEventsStore: self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id - # Ideally we'd move these ID gens here, unfortunately some other ID - # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen - self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen - # This should only exist on instances that are configured to write assert ( hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" + # Since we have been configured to write, we ought to have id generators, + # rather than id trackers. + assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator) + assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator) + + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen + async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], + *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], - backfilled: bool = False, + use_negative_stream_ordering: bool = False, + inhibit_local_membership_updates: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -137,7 +142,14 @@ class PersistEventsStore: room state new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. - backfilled + use_negative_stream_ordering: Whether to start stream_ordering on + the negative side and decrement. This should be set as True + for backfilled events because backfilled events get a negative + stream ordering so they don't come down incremental `/sync`. + inhibit_local_membership_updates: Stop the local_current_membership + from being updated by these events. This should be set to True + for backfilled events because backfilled events in the past do + not affect the current local state. Returns: Resolves when the events have been persisted @@ -159,7 +171,7 @@ class PersistEventsStore: # # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. - if backfilled: + if use_negative_stream_ordering: stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) @@ -176,13 +188,13 @@ class PersistEventsStore: "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, - backfilled=backfilled, + inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc(len(events_and_contexts)) - if not backfilled: + if stream < 0: # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. synapse.metrics.event_persisted_position.set( @@ -316,8 +328,9 @@ class PersistEventsStore: def _persist_events_txn( self, txn: LoggingTransaction, + *, events_and_contexts: List[Tuple[EventBase, EventContext]], - backfilled: bool, + inhibit_local_membership_updates: bool = False, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): @@ -330,7 +343,10 @@ class PersistEventsStore: Args: txn events_and_contexts: events to persist - backfilled: True if the events were backfilled + inhibit_local_membership_updates: Stop the local_current_membership + from being updated by these events. This should be set to True + for backfilled events because backfilled events in the past do + not affect the current local state. delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. @@ -363,9 +379,7 @@ class PersistEventsStore: events_and_contexts ) - self._update_room_depths_txn( - txn, events_and_contexts=events_and_contexts, backfilled=backfilled - ) + self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -398,7 +412,7 @@ class PersistEventsStore: txn, events_and_contexts=events_and_contexts, all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, + inhibit_local_membership_updates=inhibit_local_membership_updates, ) # We call this last as it assumes we've inserted the events into @@ -1200,7 +1214,6 @@ class PersistEventsStore: self, txn, events_and_contexts: List[Tuple[EventBase, EventContext]], - backfilled: bool, ): """Update min_depth for each room @@ -1208,13 +1221,18 @@ class PersistEventsStore: txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting - backfilled (bool): True if the events were backfilled """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) - if not backfilled: + # Then update the `stream_ordering` position to mark the latest + # event as the front of the room. This should not be done for + # backfilled events because backfilled events have negative + # stream_ordering and happened in the past so we know that we don't + # need to update the stream_ordering tip/front for the room. + assert event.internal_metadata.stream_ordering is not None + if event.internal_metadata.stream_ordering >= 0: txn.call_after( self.store._events_stream_cache.entity_has_changed, event.room_id, @@ -1427,7 +1445,12 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _update_metadata_tables_txn( - self, txn, events_and_contexts, all_events_and_contexts, backfilled + self, + txn, + *, + events_and_contexts, + all_events_and_contexts, + inhibit_local_membership_updates: bool = False, ): """Update all the miscellaneous tables for new events @@ -1439,7 +1462,10 @@ class PersistEventsStore: events that we were going to persist. This includes events we've already persisted, etc, that wouldn't appear in events_and_context. - backfilled (bool): True if the events were backfilled + inhibit_local_membership_updates: Stop the local_current_membership + from being updated by these events. This should be set to True + for backfilled events because backfilled events in the past do + not affect the current local state. """ # Insert all the push actions into the event_push_actions table. @@ -1513,7 +1539,7 @@ class PersistEventsStore: for event, _ in events_and_contexts if event.type == EventTypes.Member ], - backfilled=backfilled, + inhibit_local_membership_updates=inhibit_local_membership_updates, ) # Insert event_reference_hashes table. @@ -1553,11 +1579,13 @@ class PersistEventsStore: for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(_EventCacheEntry(event=event, redacted_event=None)) + to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) + self.store._get_event_cache.set( + (cache_entry.event.event_id,), cache_entry + ) txn.call_after(prefill) @@ -1638,8 +1666,19 @@ class PersistEventsStore: txn, table="event_reference_hashes", values=vals ) - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database.""" + def _store_room_members_txn( + self, txn, events, *, inhibit_local_membership_updates: bool = False + ): + """ + Store a room member in the database. + Args: + txn: The transaction to use. + events: List of events to store. + inhibit_local_membership_updates: Stop the local_current_membership + from being updated by these events. This should be set to True + for backfilled events because backfilled events in the past do + not affect the current local state. + """ def non_null_str_or_none(val: Any) -> Optional[str]: return val if isinstance(val, str) and "\u0000" not in val else None @@ -1682,7 +1721,7 @@ class PersistEventsStore: # band membership", like a remote invite or a rejection of a remote invite. if ( self.is_mine_id(event.state_key) - and not backfilled + and not inhibit_local_membership_updates and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c6bf316d5b..4cefc0a07e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -15,14 +15,18 @@ import logging import threading from typing import ( + TYPE_CHECKING, + Any, Collection, Container, Dict, Iterable, List, + NoReturn, Optional, Set, Tuple, + cast, overload, ) @@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, + RoomVersion, RoomVersions, ) from synapse.events import EventBase, make_event_from_dict @@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError @@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -# These values are used in the `enqueus_event` and `_do_fetch` methods to +# These values are used in the `enqueue_event` and `_fetch_loop` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster # on jki.re @@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge( @attr.s(slots=True, auto_attribs=True) -class _EventCacheEntry: +class EventCacheEntry: event: EventBase redacted_event: Optional[EventBase] @@ -129,7 +145,7 @@ class _EventRow: json: str internal_metadata: str format_version: Optional[int] - room_version_id: Optional[int] + room_version_id: Optional[str] rejected_reason: Optional[str] redactions: List[str] outlier: bool @@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore): # options controlling this. USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) + self._stream_id_gen: AbstractStreamIdTracker + self._backfill_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache = LruCache( + self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore): # ID to cache entry. Note that the returned dict may not have the # requested event in it if the event isn't in the DB. self._current_event_fetches: Dict[ - str, ObservableDeferred[Dict[str, _EventCacheEntry]] + str, ObservableDeferred[Dict[str, EventCacheEntry]] ] = {} self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] + self._event_fetch_list: List[ + Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] + ] = [] self._event_fetch_ongoing = 0 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. - def get_chain_id_txn(txn): + def get_chain_id_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] self.event_chain_id_gen = build_sequence_generator( db_conn, @@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore): id_column="chain_id", ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: @@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - allow_none: Literal[False] = False, - check_room_id: Optional[str] = None, + get_prev_content: bool = ..., + allow_rejected: bool = ..., + allow_none: Literal[False] = ..., + check_room_id: Optional[str] = ..., ) -> EventBase: ... @@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - allow_none: Literal[True] = False, - check_room_id: Optional[str] = None, + get_prev_content: bool = ..., + allow_rejected: bool = ..., + allow_none: Literal[True] = ..., + check_room_id: Optional[str] = ..., ) -> Optional[EventBase]: ... @@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_events( self, - event_ids: Iterable[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, @@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore): async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False - ) -> Dict[str, _EventCacheEntry]: + ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore): # same dict into itself N times). already_fetching_ids: Set[str] = set() already_fetching_deferreds: Set[ - ObservableDeferred[Dict[str, _EventCacheEntry]] + ObservableDeferred[Dict[str, EventCacheEntry]] ] = set() for event_id in missing_events_ids: @@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore): # function returning more events than requested, but that can happen # already due to `_get_events_from_db`). fetching_deferred: ObservableDeferred[ - Dict[str, _EventCacheEntry] - ] = ObservableDeferred(defer.Deferred()) + Dict[str, EventCacheEntry] + ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - def _invalidate_get_event_cache(self, event_id): + def _invalidate_get_event_cache(self, event_id: str) -> None: self._get_event_cache.invalidate((event_id,)) def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True - ) -> Dict[str, _EventCacheEntry]: + ) -> Dict[str, EventCacheEntry]: """Fetch events from the caches. May return rejected events. @@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore): for e in state_to_include.values() ] - def _do_fetch(self, conn: Connection) -> None: + def _maybe_start_fetch_thread(self) -> None: + """Starts an event fetch thread if we are not yet at the maximum number.""" + with self._event_fetch_lock: + if ( + self._event_fetch_list + and self._event_fetch_ongoing < EVENT_QUEUE_THREADS + ): + self._event_fetch_ongoing += 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + # `_event_fetch_ongoing` is decremented in `_fetch_thread`. + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process("fetch_events", self._fetch_thread) + + async def _fetch_thread(self) -> None: + """Services requests for events from `_event_fetch_list`.""" + exc = None + try: + await self.db_pool.runWithConnection(self._fetch_loop) + except BaseException as e: + exc = e + raise + finally: + should_restart = False + event_fetches_to_fail = [] + with self._event_fetch_lock: + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + + # There may still be work remaining in `_event_fetch_list` if we + # failed, or it was added in between us deciding to exit and + # decrementing `_event_fetch_ongoing`. + if self._event_fetch_list: + if exc is None: + # We decided to exit, but then some more work was added + # before `_event_fetch_ongoing` was decremented. + # If a new event fetch thread was not started, we should + # restart ourselves since the remaining event fetch threads + # may take a while to get around to the new work. + # + # Unfortunately it is not possible to tell whether a new + # event fetch thread was started, so we restart + # unconditionally. If we are unlucky, we will end up with + # an idle fetch thread, but it will time out after + # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds + # in any case. + # + # Note that multiple fetch threads may run down this path at + # the same time. + should_restart = True + elif isinstance(exc, Exception): + if self._event_fetch_ongoing == 0: + # We were the last remaining fetcher and failed. + # Fail any outstanding fetches since no one else will + # handle them. + event_fetches_to_fail = self._event_fetch_list + self._event_fetch_list = [] + else: + # We weren't the last remaining fetcher, so another + # fetcher will pick up the work. This will either happen + # after their existing work, however long that takes, + # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if + # they are idle. + pass + else: + # The exception is a `SystemExit`, `KeyboardInterrupt` or + # `GeneratorExit`. Don't try to do anything clever here. + pass + + if should_restart: + # We exited cleanly but noticed more work. + self._maybe_start_fetch_thread() + + if event_fetches_to_fail: + # We were the last remaining fetcher and failed. + # Fail any outstanding fetches since no one else will handle them. + assert exc is not None + with PreserveLoggingContext(): + for _, deferred in event_fetches_to_fail: + deferred.errback(exc) + + def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ - try: - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if ( - not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING - or single_threaded - or i > EVENT_QUEUE_ITERATIONS - ): - break - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + # There are no requests waiting. If we haven't yet reached the + # maximum iteration limit, wait for some more requests to turn up. + # Otherwise, bail out. + single_threaded = self.database_engine.single_threaded + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): + return + + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 - self._fetch_event_list(conn, event_list) - finally: - self._event_fetch_ongoing -= 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + self._fetch_event_list(conn, event_list) def _fetch_event_list( - self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] + self, + conn: LoggingDatabaseConnection, + event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]], ) -> None: """Handle a load of requests from the _event_fetch_list queue @@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore): ) # We only want to resolve deferreds from the main thread - def fire(): + def fire() -> None: for _, d in event_list: d.callback(row_dict) @@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire(evs, exc): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(exc) + def fire_errback(exc: Exception) -> None: + for _, d in event_list: + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list, e) + self.hs.get_reactor().callFromThread(fire_errback, e) async def _get_events_from_db( - self, event_ids: Iterable[str] - ) -> Dict[str, _EventCacheEntry]: + self, event_ids: Collection[str] + ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the database. May return rejected events. @@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore): map from event id to result. May return extra events which weren't asked for. """ - fetched_events = {} + fetched_event_ids: Set[str] = set() + fetched_events: Dict[str, _EventRow] = {} events_to_fetch = event_ids while events_to_fetch: row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events - redaction_ids = set() + redaction_ids: Set[str] = set() for event_id in events_to_fetch: row = row_map.get(event_id) - fetched_events[event_id] = row + fetched_event_ids.add(event_id) if row: + fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_events.keys()) + events_to_fetch = redaction_ids.difference(fetched_event_ids) if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) # build a map from event_id to EventBase - event_map = {} + event_map: Dict[str, EventBase] = {} for event_id, row in fetched_events.items(): - if not row: - continue assert row.event_id == event_id rejected_reason = row.rejected_reason @@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row.room_version_id + room_version: Optional[RoomVersion] if not room_version_id: # this should only happen for out-of-band membership events which # arrived before #6983 landed. For all other events, we should have @@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore): # finally, we can decide whether each one needs redacting, and build # the cache entries. - result_map = {} + result_map: Dict[str, EventCacheEntry] = {} for event_id, original_ev in event_map.items(): redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( original_ev, redactions, event_map ) - cache_entry = _EventCacheEntry( + cache_entry = EventCacheEntry( event=original_ev, redacted_event=redacted_event ) @@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: + async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore): that weren't requested. """ - events_d = defer.Deferred() + events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred() with self._event_fetch_lock: self._event_fetch_list.append((events, events_d)) - self._event_fetch_lock.notify() - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process( - "fetch_events", self.db_pool.runWithConnection, self._do_fetch - ) + self._maybe_start_fetch_thread() logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): @@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - async def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: """Given a list of event ids, check if we have already processed and stored them as non outliers. """ @@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore): event_ids: events we are looking for Returns: - set[str]: The events we have already seen. + The set of events we have already seen. """ res = await self._have_seen_events_dict( (room_id, event_id) for event_id in event_ids @@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore): } results = {x: True for x in cache_results} - def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): + def have_seen_events_txn( + txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] + ) -> None: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore): return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str): + async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: # this only exists for the benefit of the @cachedList descriptor on # _have_seen_events_dict raise NotImplementedError() - def _get_current_state_event_counts_txn(self, txn, room_id): + def _get_current_state_event_counts_txn( + self, txn: LoggingTransaction, room_id: str + ) -> int: """ See get_current_state_event_counts. """ @@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - async def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id: str) -> Dict[str, float]: """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore): more resources. Args: - room_id (str) + room_id: The room ID to query. Returns: - dict[str:int] of complexity version to complexity. + dict[str:float] of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) @@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_events_token(self): + def get_current_events_token(self) -> int: """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: """Returns new events, for the Events replication stream Args: @@ -1295,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_all_new_forward_event_rows(txn): + def get_all_new_forward_event_rows( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" @@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (last_id, current_id, instance_name, limit)) - return txn.fetchall() + return cast( + List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + ) return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows @@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: """Returns de-outliered events, for the Events replication stream Args: @@ -1332,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_ex_outlier_stream_rows_txn(txn): + def get_ex_outlier_stream_rows_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" @@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_id, current_id, instance_name)) - return txn.fetchall() + return cast( + List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + ) return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn @@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_new_backfill_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. @@ -1386,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_new_backfill_event_rows(txn): + def get_all_new_backfill_event_rows( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id" @@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates = [(row[0], row[1:]) for row in txn] + new_event_updates: List[ + Tuple[int, Tuple[str, str, str, str, str, str]] + ] = [] + row: Tuple[int, str, str, str, str, str, str] + # Type safety: iterating over `txn` yields `Tuple`, i.e. + # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a + # variadic tuple to a fixed length tuple and flags it up as an error. + for row in txn: # type: ignore[assignment] + new_event_updates.append((row[0], row[1:])) limited = False if len(new_event_updates) == limit: @@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound, instance_name)) - new_event_updates.extend((row[0], row[1:]) for row in txn) + # Type safety: iterating over `txn` yields `Tuple`, i.e. + # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a + # variadic tuple to a fixed length tuple and flags it up as an error. + for row in txn: # type: ignore[assignment] + new_event_updates.append((row[0], row[1:])) if len(new_event_updates) >= limit: upper_bound = new_event_updates[-1][0] @@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_updated_current_state_deltas( self, instance_name: str, from_token: int, to_token: int, target_row_count: int - ) -> Tuple[List[Tuple], int, bool]: + ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]: """Fetch updates from current_state_delta_stream Args: @@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore): * `limited` is whether there are more updates to fetch. """ - def get_all_updated_current_state_deltas_txn(txn): + def get_all_updated_current_state_deltas_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str]]: sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream @@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore): ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) - def get_deltas_for_stream_id_txn(txn, stream_id): + def get_deltas_for_stream_id_txn( + txn: LoggingTransaction, stream_id: int + ) -> List[Tuple[int, str, str, str, str]]: sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE stream_id = ? """ txn.execute(sql, [stream_id]) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows: List[Tuple] = await self.db_pool.runInteraction( + rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - async def is_event_after(self, event_id1, event_id2): + async def is_event_after(self, event_id1: str, event_id2: str) -> bool: """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cached(max_entries=5000) - async def get_event_ordering(self, event_id): + async def get_event_ordering(self, event_id: str) -> Tuple[int, int]: res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], @@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore): None otherwise. """ - def get_next_event_to_expire_txn(txn): + def get_next_event_to_expire_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, int]]: txn.execute( """ SELECT event_id, expiry_ts FROM event_expiry @@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore): """ ) - return txn.fetchone() + return cast(Optional[Tuple[str, int]], txn.fetchone()) return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn @@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore): return mapping @wrap_as_background_process("_cleanup_old_transaction_ids") - async def _cleanup_old_transaction_ids(self): + async def _cleanup_old_transaction_ids(self) -> None: """Cleans out transaction id mappings older than 24hrs.""" - def _cleanup_old_transaction_ids_txn(txn): + def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: sql = """ DELETE FROM event_txn_id WHERE inserted_ts < ? diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index fa782023d4..3b63267395 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -82,9 +85,9 @@ class PushRulesWorkerStore( super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: Union[ - StreamIdGenerator, SlavedIdTracker - ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, "push_rules_stream", "stream_id" + ) else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 0e8c168667..e1ddf06916 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -106,6 +106,15 @@ class RefreshTokenLookupResult: has_next_access_token_been_used: bool """True if the next access token was already used at least once.""" + expiry_ts: Optional[int] + """The time at which the refresh token expires and can not be used. + If None, the refresh token doesn't expire.""" + + ultimate_session_expiry_ts: Optional[int] + """The time at which the session comes to an end and can no longer be + refreshed. + If None, the session can be refreshed indefinitely.""" + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( @@ -1626,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): rt.user_id, rt.device_id, rt.next_token_id, - (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, - at.used has_next_access_token_been_used + (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed, + at.used AS has_next_access_token_been_used, + rt.expiry_ts, + rt.ultimate_session_expiry_ts FROM refresh_tokens rt LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id @@ -1648,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): has_next_refresh_token_been_refreshed=row[4], # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), + expiry_ts=row[6], + ultimate_session_expiry_ts=row[7], ) return await self.db_pool.runInteraction( @@ -1915,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: str, token: str, device_id: Optional[str], + expiry_ts: Optional[int], + ultimate_session_expiry_ts: Optional[int], ) -> int: """Adds a refresh token for the given user. @@ -1922,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: The user ID. token: The new access token to add. device_id: ID of the device to associate with the refresh token. + expiry_ts (milliseconds since the epoch): Time after which the + refresh token cannot be used. + If None, the refresh token never expires until it has been used. + ultimate_session_expiry_ts (milliseconds since the epoch): + Time at which the session will end and can not be extended any + further. + If None, the session can be refreshed indefinitely. Raises: StoreError if there was a problem adding this. Returns: @@ -1937,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "device_id": device_id, "token": token, "next_token_id": None, + "expiry_ts": expiry_ts, + "ultimate_session_expiry_ts": ultimate_session_expiry_ts, }, desc="add_refresh_token_to_user", ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 402f134d89..428d66a617 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -583,7 +583,8 @@ class EventsPersistenceStorage: current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, - backfilled=backfilled, + use_negative_stream_ordering=backfilled, + inhibit_local_membership_updates=backfilled, ) await self._handle_potentially_left_users(potentially_left_users) diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql new file mode 100644 index 0000000000..bdc491c817 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql @@ -0,0 +1,28 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE refresh_tokens + -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens. + -- They may not be used after they have expired. + -- If null, then the refresh token's lifetime is unlimited. + ADD COLUMN expiry_ts BIGINT DEFAULT NULL; + +ALTER TABLE refresh_tokens + -- We also add an ultimate session expiry time (in milliseconds since the Epoch). + -- No matter how much the access and refresh tokens are refreshed, they cannot + -- be extended past this time. + -- If null, then the session length is unlimited. + ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL; diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index ac56bc9a05..4ff3013908 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -89,31 +89,77 @@ def _load_current_id( return (max if step > 0 else min)(current_id, step) -class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): - @abc.abstractmethod - def get_next(self) -> AsyncContextManager[int]: - raise NotImplementedError() +class AbstractStreamIdTracker(metaclass=abc.ABCMeta): + """Tracks the "current" stream ID of a stream that may have multiple writers. + + Stream IDs are monotonically increasing or decreasing integers representing write + transactions. The "current" stream ID is the stream ID such that all transactions + with equal or smaller stream IDs have completed. Since transactions may complete out + of order, this is not the same as the stream ID of the last completed transaction. + + Completed transactions include both committed transactions and transactions that + have been rolled back. + """ @abc.abstractmethod - def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: + def advance(self, instance_name: str, new_id: int) -> None: + """Advance the position of the named writer to the given ID, if greater + than existing entry. + """ raise NotImplementedError() @abc.abstractmethod def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + + Returns: + The maximum stream id. + """ raise NotImplementedError() @abc.abstractmethod def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to `get_current_token`. + """ + raise NotImplementedError() + + +class AbstractStreamIdGenerator(AbstractStreamIdTracker): + """Generates stream IDs for a stream that may have multiple writers. + + Each stream ID represents a write transaction, whose completion is tracked + so that the "current" stream ID of the stream can be determined. + + See `AbstractStreamIdTracker` for more details. + """ + + @abc.abstractmethod + def get_next(self) -> AsyncContextManager[int]: + """ + Usage: + async with stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: + """ + Usage: + async with stream_id_gen.get_next(n) as stream_ids: + # ... persist events ... + """ raise NotImplementedError() class StreamIdGenerator(AbstractStreamIdGenerator): - """Used to generate new stream ids when persisting events while keeping - track of which transactions have been completed. + """Generates and tracks stream IDs for a stream with a single writer. - This allows us to get the "current" stream id, i.e. the stream id such that - all ids less than or equal to it have completed. This handles the fact that - persistence of events can complete out of order. + This class must only be used when the current Synapse process is the sole + writer for a stream. Args: db_conn(connection): A database connection to use to fetch the @@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator): # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() + def advance(self, instance_name: str, new_id: int) -> None: + # `StreamIdGenerator` should only be used when there is a single writer, + # so replication should never happen. + raise Exception("Replication is not supported by StreamIdGenerator") + def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ with self._lock: self._current += self._step next_id = self._current @@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: - """ - Usage: - async with stream_id_gen.get_next(n) as stream_ids: - # ... persist events ... - """ with self._lock: next_ids = range( self._current + self._step, @@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - - Returns: - The maximum stream id. - """ with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step @@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return self._current def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - - For streams with single writers this is equivalent to - `get_current_token`. - """ return self.get_current_token() class MultiWriterIdGenerator(AbstractStreamIdGenerator): - """An ID generator that tracks a stream that can have multiple writers. + """Generates and tracks stream IDs for a stream with multiple writers. Uses a Postgres sequence to coordinate ID assignment, but positions of other writers will only get updated when `advance` is called (by replication). @@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): return stream_ids def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ - # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: - """ - Usage: - async with stream_id_gen.get_next_mult(5) as stream_ids: - # ... persist events ... - """ - # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self._add_persisted_position(next_id) def get_current_token(self) -> int: - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - """ - return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer.""" - # If we don't have an entry for the given instance name, we assume it's a # new writer. # @@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): } def advance(self, instance_name: str, new_id: int) -> None: - """Advance the position of the named writer to the given ID, if greater - than existing entry. - """ - new_id *= self._return_factor with self._lock: |