From d26094e92cace20525552e5a0c8b21ff9ce53f11 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 8 Jul 2021 20:25:59 -0500 Subject: Add base starting insertion event when no chunk ID is provided (MSC2716) (#10250) * Add base starting insertion point when no chunk ID is provided This is so we can have the marker event point to this initial insertion event and be able to traverse the events in the first chunk. --- synapse/rest/client/v1/room.py | 112 ++++++++++++++++++++++++++++++++--------- 1 file changed, 89 insertions(+), 23 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 92ebe838fd..9c58e3689e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -349,6 +349,35 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): return depth + def _create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ): + """Creates an event dict for an "insertion" event with the proper fields + and a random chunk ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + Tuple of event ID and stream ordering position + """ + + next_chunk_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=False) @@ -449,37 +478,68 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): events_to_create = body["events"] - # If provided, connect the chunk to the last insertion point - # The chunk ID passed in comes from the chunk_id in the - # "insertion" event from the previous chunk. + prev_event_ids = prev_events_from_query + inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query) + + # Figure out which chunk to connect to. If they passed in + # chunk_id_from_query let's use it. The chunk ID passed in comes + # from the chunk_id in the "insertion" event from the previous chunk. + last_event_in_chunk = events_to_create[-1] + chunk_id_to_connect_to = chunk_id_from_query + base_insertion_event = None if chunk_id_from_query: - last_event_in_chunk = events_to_create[-1] - last_event_in_chunk["content"][ - EventContentFields.MSC2716_CHUNK_ID - ] = chunk_id_from_query + # TODO: Verify the chunk_id_from_query corresponds to an insertion event + pass + # Otherwise, create an insertion event to act as a starting point. + # + # We don't always have an insertion event to start hanging more history + # off of (ideally there would be one in the main DAG, but that's not the + # case if we're wanting to add history to e.g. existing rooms without + # an insertion event), in which case we just create a new insertion event + # that can then get pointed to by a "marker" event later. + else: + base_insertion_event_dict = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_chunk["origin_server_ts"], + ) + base_insertion_event_dict["prev_events"] = prev_event_ids.copy() - # Add an "insertion" event to the start of each chunk (next to the oldest + ( + base_insertion_event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + base_insertion_event_dict, + prev_event_ids=base_insertion_event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + + chunk_id_to_connect_to = base_insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ] + + # Connect this current chunk to the insertion event from the previous chunk + last_event_in_chunk["content"][ + EventContentFields.MSC2716_CHUNK_ID + ] = chunk_id_to_connect_to + + # Add an "insertion" event to the start of each chunk (next to the oldest-in-time # event in the chunk) so the next chunk can be connected to this one. - next_chunk_id = random_string(64) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": requester.user.to_string(), - "content": { - EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, + insertion_event = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, # Since the insertion event is put at the start of the chunk, - # where the oldest event is, copy the origin_server_ts from + # where the oldest-in-time event is, copy the origin_server_ts from # the first event we're inserting - "origin_server_ts": events_to_create[0]["origin_server_ts"], - } + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) # Prepend the insertion event to the start of the chunk events_to_create = [insertion_event] + events_to_create - inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query) - event_ids = [] - prev_event_ids = prev_events_from_query events_to_persist = [] for ev in events_to_create: assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) @@ -533,10 +593,16 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): context=context, ) + # Add the base_insertion_event to the bottom of the list we return + if base_insertion_event is not None: + event_ids.append(base_insertion_event.event_id) + return 200, { "state_events": auth_event_ids, "events": event_ids, - "next_chunk_id": next_chunk_id, + "next_chunk_id": insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ], } def on_GET(self, request, room_id): -- cgit 1.4.1 From 0d5b08ac7ac88ae14cf81f0927084edc2c63a15f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 13 Jul 2021 14:12:33 -0500 Subject: Fix messages from multiple senders in historical chunk (MSC2716) (#10276) Fix messages from multiple senders in historical chunk. This also means that an app service does not need to define `?user_id` when using this endpoint. Follow-up to https://github.com/matrix-org/synapse/pull/9247 Part of MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 --- changelog.d/10276.bugfix | 1 + synapse/api/auth.py | 37 +++++++++++++++++++++++++++---- synapse/rest/client/v1/room.py | 49 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 76 insertions(+), 11 deletions(-) create mode 100644 changelog.d/10276.bugfix (limited to 'synapse/rest/client') diff --git a/changelog.d/10276.bugfix b/changelog.d/10276.bugfix new file mode 100644 index 0000000000..42adc57ad1 --- /dev/null +++ b/changelog.d/10276.bugfix @@ -0,0 +1 @@ +Fix historical batch send endpoint (MSC2716) rejecting batches with messages from multiple senders. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 307f5f9a94..42476a18e5 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -240,6 +240,37 @@ class Auth: except KeyError: raise MissingClientTokenError() + async def validate_appservice_can_control_user_id( + self, app_service: ApplicationService, user_id: str + ): + """Validates that the app service is allowed to control + the given user. + + Args: + app_service: The app service that controls the user + user_id: The author MXID that the app service is controlling + + Raises: + AuthError: If the application service is not allowed to control the user + (user namespace regex does not match, wrong homeserver, etc) + or if the user has not been registered yet. + """ + + # It's ok if the app service is trying to use the sender from their registration + if app_service.sender == user_id: + pass + # Check to make sure the app service is allowed to control the user + elif not app_service.is_interested_in_user(user_id): + raise AuthError( + 403, + "Application service cannot masquerade as this user (%s)." % user_id, + ) + # Check to make sure the user is already registered on the homeserver + elif not (await self.store.get_user_by_id(user_id)): + raise AuthError( + 403, "Application service has not registered this user (%s)" % user_id + ) + async def _get_appservice_user_id( self, request: Request ) -> Tuple[Optional[str], Optional[ApplicationService]]: @@ -261,13 +292,11 @@ class Auth: return app_service.sender, app_service user_id = request.args[b"user_id"][0].decode("utf8") + await self.validate_appservice_can_control_user_id(app_service, user_id) + if app_service.sender == user_id: return app_service.sender, app_service - if not app_service.is_interested_in_user(user_id): - raise AuthError(403, "Application service cannot masquerade as this user.") - if not (await self.store.get_user_by_id(user_id)): - raise AuthError(403, "Application service has not registered this user") return user_id, app_service async def get_user_by_access_token( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9c58e3689e..ebf4e32230 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -29,6 +29,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.filtering import Filter +from synapse.appservice import ApplicationService from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( RestServlet, @@ -47,11 +48,13 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, + Requester, RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID, + create_requester, ) from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -309,7 +312,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: ( most_recent_prev_event_id, most_recent_prev_event_depth, @@ -378,6 +381,25 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): return insertion_event + async def _create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=False) @@ -443,7 +465,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): if event_dict["type"] == EventTypes.Member: membership = event_dict["content"].get("membership", None) event_id, _ = await self.room_member_handler.update_membership( - requester, + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), target=UserID.from_string(event_dict["state_key"]), room_id=room_id, action=membership, @@ -463,7 +487,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): event, _, ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), event_dict, outlier=True, prev_event_ids=[fake_prev_event_id], @@ -479,7 +505,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): events_to_create = body["events"] prev_event_ids = prev_events_from_query - inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query) + inherited_depth = await self._inherit_depth_from_prev_ids( + prev_events_from_query + ) # Figure out which chunk to connect to. If they passed in # chunk_id_from_query let's use it. The chunk ID passed in comes @@ -509,7 +537,10 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): base_insertion_event, _, ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, + await self._create_requester_for_user_id_from_app_service( + base_insertion_event_dict["sender"], + requester.app_service, + ), base_insertion_event_dict, prev_event_ids=base_insertion_event_dict.get("prev_events"), auth_event_ids=auth_event_ids, @@ -558,7 +589,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): } event, context = await self.event_creation_handler.create_event( - requester, + await self._create_requester_for_user_id_from_app_service( + ev["sender"], requester.app_service + ), event_dict, prev_event_ids=event_dict.get("prev_events"), auth_event_ids=auth_event_ids, @@ -588,7 +621,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): # where topological_ordering is just depth. for (event, context) in reversed(events_to_persist): ev = await self.event_creation_handler.handle_new_client_event( - requester=requester, + await self._create_requester_for_user_id_from_app_service( + event["sender"], requester.app_service + ), event=event, context=context, ) -- cgit 1.4.1 From 36dc15412de9fc1bb2ba955c8b6f2da20d2ca20f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 16 Jul 2021 18:11:53 +0200 Subject: Add a module type for account validity (#9884) This adds an API for third-party plugin modules to implement account validity, so they can provide this feature instead of Synapse. The module implementing the current behaviour for this feature can be found at https://github.com/matrix-org/synapse-email-account-validity. To allow for a smooth transition between the current feature and the new module, hooks have been added to the existing account validity endpoints to allow their behaviours to be overridden by a module. --- changelog.d/9884.feature | 1 + docs/modules.md | 47 ++++- docs/sample_config.yaml | 85 --------- synapse/api/auth.py | 17 +- synapse/config/account_validity.py | 102 ++--------- synapse/handlers/account_validity.py | 128 ++++++++++++- synapse/handlers/register.py | 5 + synapse/module_api/__init__.py | 219 +++++++++++++++++++++-- synapse/module_api/errors.py | 6 +- synapse/push/pusherpool.py | 24 +-- synapse/rest/admin/users.py | 24 ++- synapse/rest/client/v2_alpha/account_validity.py | 7 +- tests/test_state.py | 1 + 13 files changed, 438 insertions(+), 228 deletions(-) create mode 100644 changelog.d/9884.feature (limited to 'synapse/rest/client') diff --git a/changelog.d/9884.feature b/changelog.d/9884.feature new file mode 100644 index 0000000000..525fd2f93c --- /dev/null +++ b/changelog.d/9884.feature @@ -0,0 +1 @@ +Add a module type for the account validity feature. diff --git a/docs/modules.md b/docs/modules.md index bec1c06d15..c4cb7018f7 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -63,7 +63,7 @@ Modules can register web resources onto Synapse's web server using the following API method: ```python -def ModuleApi.register_web_resource(path: str, resource: IResource) +def ModuleApi.register_web_resource(path: str, resource: IResource) -> None ``` The path is the full absolute path to register the resource at. For example, if you @@ -91,12 +91,17 @@ are split in categories. A single module may implement callbacks from multiple c and is under no obligation to implement all callbacks from the categories it registers callbacks for. +Modules can register callbacks using one of the module API's `register_[...]_callbacks` +methods. The callback functions are passed to these methods as keyword arguments, with +the callback name as the argument name and the function as its value. This is demonstrated +in the example below. A `register_[...]_callbacks` method exists for each module type +documented in this section. + #### Spam checker callbacks -To register one of the callbacks described in this section, a module needs to use the -module API's `register_spam_checker_callbacks` method. The callback functions are passed -to `register_spam_checker_callbacks` as keyword arguments, with the callback name as the -argument name and the function as its value. This is demonstrated in the example below. +Spam checker callbacks allow module developers to implement spam mitigation actions for +Synapse instances. Spam checker callbacks can be registered using the module API's +`register_spam_checker_callbacks` method. The available spam checker callbacks are: @@ -115,7 +120,7 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool Called when processing an invitation. The module must return a `bool` indicating whether the inviter can invite the invitee to the given room. Both inviter and invitee are -represented by their Matrix user ID (i.e. `@alice:example.com`). +represented by their Matrix user ID (e.g. `@alice:example.com`). ```python async def user_may_create_room(user: str) -> bool @@ -188,6 +193,36 @@ async def check_media_file_for_spam( Called when storing a local or remote file. The module must return a boolean indicating whether the given file can be stored in the homeserver's media store. +#### Account validity callbacks + +Account validity callbacks allow module developers to add extra steps to verify the +validity on an account, i.e. see if a user can be granted access to their account on the +Synapse instance. Account validity callbacks can be registered using the module API's +`register_account_validity_callbacks` method. + +The available account validity callbacks are: + +```python +async def is_user_expired(user: str) -> Optional[bool] +``` + +Called when processing any authenticated request (except for logout requests). The module +can return a `bool` to indicate whether the user has expired and should be locked out of +their account, or `None` if the module wasn't able to figure it out. The user is +represented by their Matrix user ID (e.g. `@alice:example.com`). + +If the module returns `True`, the current request will be denied with the error code +`ORG_MATRIX_EXPIRED_ACCOUNT` and the HTTP status code 403. Note that this doesn't +invalidate the user's access token. + +```python +async def on_user_registration(user: str) -> None +``` + +Called after successfully registering a user, in case the module needs to perform extra +operations to keep track of them. (e.g. add them to a database table). The user is +represented by their Matrix user ID. + ### Porting an existing module that uses the old interface In order to port a module that uses Synapse's old module interface, its author needs to: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a45732a246..f4845a5841 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1310,91 +1310,6 @@ account_threepid_delegates: #auto_join_rooms_for_guests: false -## Account Validity ## - -# Optional account validity configuration. This allows for accounts to be denied -# any request after a given period. -# -# Once this feature is enabled, Synapse will look for registered users without an -# expiration date at startup and will add one to every account it found using the -# current settings at that time. -# This means that, if a validity period is set, and Synapse is restarted (it will -# then derive an expiration date from the current validity period), and some time -# after that the validity period changes and Synapse is restarted, the users' -# expiration dates won't be updated unless their account is manually renewed. This -# date will be randomly selected within a range [now + period - d ; now + period], -# where d is equal to 10% of the validity period. -# -account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - # The currently available templates are: - # - # * account_renewed.html: Displayed to the user after they have successfully - # renewed their account. - # - # * account_previously_renewed.html: Displayed to the user if they attempt to - # renew their account with a token that is valid, but that has already - # been used. In this case the account is not renewed again. - # - # * invalid_token.html: Displayed to the user when they try to renew an account - # with an unknown or invalid renewal token. - # - # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for - # default template contents. - # - # The file name of some of these templates can be configured below for legacy - # reasons. - # - #template_dir: "res/templates" - - # A custom file name for the 'account_renewed.html' template. - # - # If not set, the file is assumed to be named "account_renewed.html". - # - #account_renewed_html_path: "account_renewed.html" - - # A custom file name for the 'invalid_token.html' template. - # - # If not set, the file is assumed to be named "invalid_token.html". - # - #invalid_token_html_path: "invalid_token.html" - - ## Metrics ### # Enable collection and rendering of performance metrics diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 8916e6fa2f..05699714ee 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -62,6 +62,7 @@ class Auth: self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() + self._account_validity_handler = hs.get_account_validity_handler() self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( 10000, "token_cache" @@ -69,9 +70,6 @@ class Auth: self._auth_blocking = AuthBlocking(self.hs) - self._account_validity_enabled = ( - hs.config.account_validity.account_validity_enabled - ) self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users @@ -187,12 +185,17 @@ class Auth: shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. - if self._account_validity_enabled and not allow_expired: - if await self.store.is_account_expired( - user_info.user_id, self.clock.time_msec() + if not allow_expired: + if await self._account_validity_handler.is_user_expired( + user_info.user_id ): + # Raise the error if either an account validity module has determined + # the account has expired, or the legacy account validity + # implementation is enabled and determined the account has expired raise AuthError( - 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT + 403, + "User account has expired", + errcode=Codes.EXPIRED_ACCOUNT, ) device_id = user_info.device_id diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 957de7f3a6..6be4eafe55 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -18,6 +18,21 @@ class AccountValidityConfig(Config): section = "account_validity" def read_config(self, config, **kwargs): + """Parses the old account validity config. The config format looks like this: + + account_validity: + enabled: true + period: 6w + renew_at: 1w + renew_email_subject: "Renew your %(app)s account" + template_dir: "res/templates" + account_renewed_html_path: "account_renewed.html" + invalid_token_html_path: "invalid_token.html" + + We expect admins to use modules for this feature (which is why it doesn't appear + in the sample config file), but we want to keep support for it around for a bit + for backwards compatibility. + """ account_validity_config = config.get("account_validity") or {} self.account_validity_enabled = account_validity_config.get("enabled", False) self.account_validity_renew_by_email_enabled = ( @@ -75,90 +90,3 @@ class AccountValidityConfig(Config): ], account_validity_template_dir, ) - - def generate_config_section(self, **kwargs): - return """\ - ## Account Validity ## - - # Optional account validity configuration. This allows for accounts to be denied - # any request after a given period. - # - # Once this feature is enabled, Synapse will look for registered users without an - # expiration date at startup and will add one to every account it found using the - # current settings at that time. - # This means that, if a validity period is set, and Synapse is restarted (it will - # then derive an expiration date from the current validity period), and some time - # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period - d ; now + period], - # where d is equal to 10% of the validity period. - # - account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - # The currently available templates are: - # - # * account_renewed.html: Displayed to the user after they have successfully - # renewed their account. - # - # * account_previously_renewed.html: Displayed to the user if they attempt to - # renew their account with a token that is valid, but that has already - # been used. In this case the account is not renewed again. - # - # * invalid_token.html: Displayed to the user when they try to renew an account - # with an unknown or invalid renewal token. - # - # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for - # default template contents. - # - # The file name of some of these templates can be configured below for legacy - # reasons. - # - #template_dir: "res/templates" - - # A custom file name for the 'account_renewed.html' template. - # - # If not set, the file is assumed to be named "account_renewed.html". - # - #account_renewed_html_path: "account_renewed.html" - - # A custom file name for the 'invalid_token.html' template. - # - # If not set, the file is assumed to be named "invalid_token.html". - # - #invalid_token_html_path: "invalid_token.html" - """ diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d752cf34f0..078accd634 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -15,9 +15,11 @@ import email.mime.multipart import email.utils import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple -from synapse.api.errors import StoreError, SynapseError +from twisted.web.http import Request + +from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils @@ -27,6 +29,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Types for callbacks to be registered via the module api +IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] +ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] +# Temporary hooks to allow for a transition from `/_matrix/client` endpoints +# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`. +ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] +ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]] +ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable] + class AccountValidityHandler: def __init__(self, hs: "HomeServer"): @@ -70,6 +81,99 @@ class AccountValidityHandler: if hs.config.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) + self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] + self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] + self._on_legacy_send_mail_callback: Optional[ + ON_LEGACY_SEND_MAIL_CALLBACK + ] = None + self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None + + # The legacy admin requests callback isn't a protected attribute because we need + # to access it from the admin servlet, which is outside of this handler. + self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None + + def register_account_validity_callbacks( + self, + is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, + on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, + on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, + on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + ): + """Register callbacks from module for each hook.""" + if is_user_expired is not None: + self._is_user_expired_callbacks.append(is_user_expired) + + if on_user_registration is not None: + self._on_user_registration_callbacks.append(on_user_registration) + + # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and + # an admin one). As part of moving the feature into a module, we need to change + # the path from /_matrix/client/unstable/account_validity/... to + # /_synapse/client/account_validity, because: + # + # * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix + # * the way we register servlets means that modules can't register resources + # under /_matrix/client + # + # We need to allow for a transition period between the old and new endpoints + # in order to allow for clients to update (and for emails to be processed). + # + # Once the email-account-validity module is loaded, it will take control of account + # validity by moving the rows from our `account_validity` table into its own table. + # + # Therefore, we need to allow modules (in practice just the one implementing the + # email-based account validity) to temporarily hook into the legacy endpoints so we + # can route the traffic coming into the old endpoints into the module, which is + # why we have the following three temporary hooks. + if on_legacy_send_mail is not None: + if self._on_legacy_send_mail_callback is not None: + raise RuntimeError("Tried to register on_legacy_send_mail twice") + + self._on_legacy_send_mail_callback = on_legacy_send_mail + + if on_legacy_renew is not None: + if self._on_legacy_renew_callback is not None: + raise RuntimeError("Tried to register on_legacy_renew twice") + + self._on_legacy_renew_callback = on_legacy_renew + + if on_legacy_admin_request is not None: + if self.on_legacy_admin_request_callback is not None: + raise RuntimeError("Tried to register on_legacy_admin_request twice") + + self.on_legacy_admin_request_callback = on_legacy_admin_request + + async def is_user_expired(self, user_id: str) -> bool: + """Checks if a user has expired against third-party modules. + + Args: + user_id: The user to check the expiry of. + + Returns: + Whether the user has expired. + """ + for callback in self._is_user_expired_callbacks: + expired = await callback(user_id) + if expired is not None: + return expired + + if self._account_validity_enabled: + # If no module could determine whether the user has expired and the legacy + # configuration is enabled, fall back to it. + return await self.store.is_account_expired(user_id, self.clock.time_msec()) + + return False + + async def on_user_registration(self, user_id: str): + """Tell third-party modules about a user's registration. + + Args: + user_id: The ID of the newly registered user. + """ + for callback in self._on_user_registration_callbacks: + await callback(user_id) + @wrap_as_background_process("send_renewals") async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time @@ -95,6 +199,17 @@ class AccountValidityHandler: Raises: SynapseError if the user is not set to renew. """ + # If a module supports sending a renewal email from here, do that, otherwise do + # the legacy dance. + if self._on_legacy_send_mail_callback is not None: + await self._on_legacy_send_mail_callback(user_id) + return + + if not self._account_validity_renew_by_email_enabled: + raise AuthError( + 403, "Account renewal via email is disabled on this server." + ) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) # If this user isn't set to be expired, raise an error. @@ -209,6 +324,10 @@ class AccountValidityHandler: token is considered stale. A token is stale if the 'token_used_ts_ms' db column is non-null. + This method exists to support handling the legacy account validity /renew + endpoint. If a module implements the on_legacy_renew callback, then this process + is delegated to the module instead. + Args: renewal_token: Token sent with the renewal request. Returns: @@ -218,6 +337,11 @@ class AccountValidityHandler: * An int representing the user's expiry timestamp as milliseconds since the epoch, or 0 if the token was invalid. """ + # If a module supports triggering a renew from here, do that, otherwise do the + # legacy dance. + if self._on_legacy_renew_callback is not None: + return await self._on_legacy_renew_callback(renewal_token) + try: ( user_id, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 26ef016179..056fe5e89f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -77,6 +77,7 @@ class RegistrationHandler(BaseHandler): self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() + self._account_validity_handler = hs.get_account_validity_handler() self._server_notices_mxid = hs.config.server_notices_mxid self._server_name = hs.hostname @@ -700,6 +701,10 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) + # Only call the account validity module(s) on the main process, to avoid + # repeating e.g. database writes on all of the workers. + await self._account_validity_handler.on_user_registration(user_id) + async def register_device( self, user_id: str, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 308f045700..f3c78089b7 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -12,18 +12,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import email.utils import logging -from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, +) + +import jinja2 from twisted.internet import defer from twisted.web.resource import IResource from synapse.events import EventBase from synapse.http.client import SimpleHttpClient +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) +from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.util import Clock +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -33,7 +57,20 @@ This package defines the 'stable' API which can be used by extension modules whi are loaded into Synapse. """ -__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"] +__all__ = [ + "errors", + "make_deferred_yieldable", + "parse_json_object_from_request", + "respond_with_html", + "run_in_background", + "cached", + "UserID", + "DatabasePool", + "LoggingTransaction", + "DirectServeHtmlResource", + "DirectServeJsonResource", + "ModuleApi", +] logger = logging.getLogger(__name__) @@ -52,12 +89,27 @@ class ModuleApi: self._server_name = hs.hostname self._presence_stream = hs.get_event_sources().sources["presence"] self._state = hs.get_state_handler() + self._clock = hs.get_clock() # type: Clock + self._send_email_handler = hs.get_send_email_handler() + + try: + app_name = self._hs.config.email_app_name + + self._from_string = self._hs.config.email_notif_from % {"app": app_name} + except (KeyError, TypeError): + # If substitution failed (which can happen if the string contains + # placeholders other than just "app", or if the type of the placeholder is + # not a string), fall back to the bare strings. + self._from_string = self._hs.config.email_notif_from + + self._raw_from = email.utils.parseaddr(self._from_string)[1] # We expose these as properties below in order to attach a helpful docstring. self._http_client: SimpleHttpClient = hs.get_simple_http_client() self._public_room_list_manager = PublicRoomListManager(hs) self._spam_checker = hs.get_spam_checker() + self._account_validity_handler = hs.get_account_validity_handler() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -67,6 +119,11 @@ class ModuleApi: """Registers callbacks for spam checking capabilities.""" return self._spam_checker.register_callbacks + @property + def register_account_validity_callbacks(self): + """Registers callbacks for account validity capabilities.""" + return self._account_validity_handler.register_account_validity_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. @@ -101,22 +158,56 @@ class ModuleApi: """ return self._public_room_list_manager - def get_user_by_req(self, req, allow_guest=False): + @property + def public_baseurl(self) -> str: + """The configured public base URL for this homeserver.""" + return self._hs.config.public_baseurl + + @property + def email_app_name(self) -> str: + """The application name configured in the homeserver's configuration.""" + return self._hs.config.email.email_app_name + + async def get_user_by_req( + self, + req: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + ) -> Requester: """Check the access_token provided for a request Args: - req (twisted.web.server.Request): Incoming HTTP request - allow_guest (bool): True if guest users should be allowed. If this + req: Incoming HTTP request + allow_guest: True if guest users should be allowed. If this is False, and the access token is for a guest user, an AuthError will be thrown + allow_expired: True if expired users should be allowed. If this + is False, and the access token is for an expired user, an + AuthError will be thrown + Returns: - twisted.internet.defer.Deferred[synapse.types.Requester]: - the requester for this request + The requester for this request + Raises: - synapse.api.errors.AuthError: if no user by that token exists, + InvalidClientCredentialsError: if no user by that token exists, or the token is invalid. """ - return self._auth.get_user_by_req(req, allow_guest) + return await self._auth.get_user_by_req( + req, + allow_guest, + allow_expired=allow_expired, + ) + + async def is_user_admin(self, user_id: str) -> bool: + """Checks if a user is a server admin. + + Args: + user_id: The Matrix ID of the user to check. + + Returns: + True if the user is a server admin, False otherwise. + """ + return await self._store.is_server_admin(UserID.from_string(user_id)) def get_qualified_user_id(self, username): """Qualify a user id, if necessary @@ -134,6 +225,32 @@ class ModuleApi: return username return UserID(username, self._hs.hostname).to_string() + async def get_profile_for_user(self, localpart: str) -> ProfileInfo: + """Look up the profile info for the user with the given localpart. + + Args: + localpart: The localpart to look up profile information for. + + Returns: + The profile information (i.e. display name and avatar URL). + """ + return await self._store.get_profileinfo(localpart) + + async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: + """Look up the threepids (email addresses and phone numbers) associated with the + given Matrix user ID. + + Args: + user_id: The Matrix user ID to look up threepids for. + + Returns: + A list of threepids, each threepid being represented by a dictionary + containing a "medium" key which value is "email" for email addresses and + "msisdn" for phone numbers, and an "address" key which value is the + threepid's address. + """ + return await self._store.user_get_threepids(user_id) + def check_user_exists(self, user_id): """Check if user exists. @@ -464,6 +581,88 @@ class ModuleApi: presence_events, destination ) + def looping_background_call( + self, + f: Callable, + msec: float, + *args, + desc: Optional[str] = None, + **kwargs, + ): + """Wraps a function as a background process and calls it repeatedly. + + Waits `msec` initially before calling `f` for the first time. + + Args: + f: The function to call repeatedly. f can be either synchronous or + asynchronous, and must follow Synapse's logcontext rules. + More info about logcontexts is available at + https://matrix-org.github.io/synapse/latest/log_contexts.html + msec: How long to wait between calls in milliseconds. + *args: Positional arguments to pass to function. + desc: The background task's description. Default to the function's name. + **kwargs: Key arguments to pass to function. + """ + if desc is None: + desc = f.__name__ + + if self._hs.config.run_background_tasks: + self._clock.looping_call( + run_as_background_process, + msec, + desc, + f, + *args, + **kwargs, + ) + else: + logger.warning( + "Not running looping call %s as the configuration forbids it", + f, + ) + + async def send_mail( + self, + recipient: str, + subject: str, + html: str, + text: str, + ): + """Send an email on behalf of the homeserver. + + Args: + recipient: The email address for the recipient. + subject: The email's subject. + html: The email's HTML content. + text: The email's text content. + """ + await self._send_email_handler.send_email( + email_address=recipient, + subject=subject, + app_name=self.email_app_name, + html=html, + text=text, + ) + + def read_templates( + self, + filenames: List[str], + custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Read and load the content of the template files at the given location. + By default, Synapse will look for these templates in its configured template + directory, but another directory to search in can be provided. + + Args: + filenames: The name of the template files to look for. + custom_template_directory: An additional directory to look for the files in. + + Returns: + A list containing the loaded templates, with the orders matching the one of + the filenames parameter. + """ + return self._hs.config.read_templates(filenames, custom_template_directory) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 02bbb0be39..98ea911a81 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -14,5 +14,9 @@ """Exception types which are exposed as part of the stable module API""" -from synapse.api.errors import RedirectException, SynapseError # noqa: F401 +from synapse.api.errors import ( # noqa: F401 + InvalidClientCredentialsError, + RedirectException, + SynapseError, +) from synapse.config._base import ConfigError # noqa: F401 diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2519ad76db..85621f33ef 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -62,10 +62,6 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity_enabled = ( - hs.config.account_validity.account_validity_enabled - ) - # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() @@ -89,6 +85,8 @@ class PusherPool: # map from user id to app_id:pushkey to pusher self.pushers: Dict[str, Dict[str, Pusher]] = {} + self._account_validity_handler = hs.get_account_validity_handler() + def start(self) -> None: """Starts the pushers off in a background process.""" if not self._should_start_pushers: @@ -238,12 +236,9 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity_enabled: - expired = await self.store.is_account_expired( - u, self.clock.time_msec() - ) - if expired: - continue + expired = await self._account_validity_handler.is_user_expired(u) + if expired: + continue if u in self.pushers: for p in self.pushers[u].values(): @@ -268,12 +263,9 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity_enabled: - expired = await self.store.is_account_expired( - u, self.clock.time_msec() - ) - if expired: - continue + expired = await self._account_validity_handler.is_user_expired(u) + if expired: + continue if u in self.pushers: for p in self.pushers[u].values(): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 7d75564758..06e6ccee42 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - body = parse_json_object_from_request(request) + if self.account_activity_handler.on_legacy_admin_request_callback: + expiration_ts = await ( + self.account_activity_handler.on_legacy_admin_request_callback(request) + ) + else: + body = parse_json_object_from_request(request) - if "user_id" not in body: - raise SynapseError(400, "Missing property 'user_id' in the request body") + if "user_id" not in body: + raise SynapseError( + 400, + "Missing property 'user_id' in the request body", + ) - expiration_ts = await self.account_activity_handler.renew_account_for_user( - body["user_id"], - body.get("expiration_ts"), - not body.get("enable_renewal_emails", True), - ) + expiration_ts = await self.account_activity_handler.renew_account_for_user( + body["user_id"], + body.get("expiration_ts"), + not body.get("enable_renewal_emails", True), + ) res = {"expiration_ts": expiration_ts} return 200, res diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 2d1ad3d3fb..3ebe401861 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -14,7 +14,7 @@ import logging -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import SynapseError from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet @@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet): ) async def on_POST(self, request): - if not self.account_validity_renew_by_email_enabled: - raise AuthError( - 403, "Account renewal via email is disabled on this server." - ) - requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() await self.account_activity_handler.send_renewal_email_to_user(user_id) diff --git a/tests/test_state.py b/tests/test_state.py index 780eba823c..e5488df1ac 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -168,6 +168,7 @@ class StateTestCase(unittest.TestCase): "get_state_handler", "get_clock", "get_state_resolution_handler", + "get_account_validity_handler", "hostname", ] ) -- cgit 1.4.1 From 98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Fri, 16 Jul 2021 19:22:36 +0200 Subject: Use inline type hints in `handlers/` and `rest/`. (#10382) --- changelog.d/10382.misc | 1 + synapse/handlers/_base.py | 8 +++--- synapse/handlers/admin.py | 4 +-- synapse/handlers/appservice.py | 6 ++-- synapse/handlers/auth.py | 16 +++++------ synapse/handlers/cas.py | 2 +- synapse/handlers/device.py | 14 +++++----- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/directory.py | 6 ++-- synapse/handlers/e2e_keys.py | 40 +++++++++++++-------------- synapse/handlers/events.py | 6 ++-- synapse/handlers/federation.py | 22 +++++++-------- synapse/handlers/groups_local.py | 4 +-- synapse/handlers/initial_sync.py | 14 ++++++++-- synapse/handlers/message.py | 18 +++++------- synapse/handlers/oidc.py | 18 ++++++------ synapse/handlers/pagination.py | 4 +-- synapse/handlers/presence.py | 28 +++++++++---------- synapse/handlers/profile.py | 4 +-- synapse/handlers/receipts.py | 4 +-- synapse/handlers/room.py | 16 +++++------ synapse/handlers/room_list.py | 18 ++++++------ synapse/handlers/saml.py | 6 ++-- synapse/handlers/search.py | 8 +++--- synapse/handlers/space_summary.py | 16 +++++------ synapse/handlers/sso.py | 12 ++++---- synapse/handlers/stats.py | 10 +++---- synapse/handlers/sync.py | 32 +++++++++++---------- synapse/handlers/typing.py | 14 +++++----- synapse/handlers/user_directory.py | 2 +- synapse/rest/admin/rooms.py | 8 ++---- synapse/rest/admin/users.py | 2 +- synapse/rest/client/v1/login.py | 8 +++--- synapse/rest/client/v1/room.py | 10 ++----- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- synapse/rest/consent/consent_resource.py | 4 +-- synapse/rest/key/v2/remote_key_resource.py | 4 +-- synapse/rest/media/v1/_base.py | 2 +- synapse/rest/media/v1/media_repository.py | 10 +++---- synapse/rest/media/v1/media_storage.py | 4 +-- synapse/rest/media/v1/preview_url_resource.py | 8 +++--- synapse/rest/media/v1/upload_resource.py | 6 ++-- synapse/rest/synapse/client/pick_username.py | 4 +-- 43 files changed, 212 insertions(+), 215 deletions(-) create mode 100644 changelog.d/10382.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/10382.misc b/changelog.d/10382.misc new file mode 100644 index 0000000000..eed2d8552a --- /dev/null +++ b/changelog.d/10382.misc @@ -0,0 +1 @@ +Convert internal type variable syntax to reflect wider ecosystem use. \ No newline at end of file diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d800e16912..525f3d39b1 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -38,10 +38,10 @@ class BaseHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() # type: synapse.storage.DataStore + self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler + self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() self.clock = hs.get_clock() self.hs = hs @@ -55,12 +55,12 @@ class BaseHandler: # Check whether ratelimiting room admin message redaction is enabled # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: - self.admin_redaction_ratelimiter = Ratelimiter( + self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, - ) # type: Optional[Ratelimiter] + ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d75a8b15c3..bfa7f2c545 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -139,7 +139,7 @@ class AdminHandler(BaseHandler): to_key = RoomStreamToken(None, stream_ordering) # Events that we've processed in this room - written_events = set() # type: Set[str] + written_events: Set[str] = set() # We need to track gaps in the events stream so that we can then # write out the state at those events. We do this by keeping track @@ -152,7 +152,7 @@ class AdminHandler(BaseHandler): # The reverse mapping to above, i.e. map from unseen event to events # that have the unseen event in their prev_events, i.e. the unseen # events "children". - unseen_to_child_events = {} # type: Dict[str, Set[str]] + unseen_to_child_events: Dict[str, Set[str]] = {} # We fetch events in the room the user could see by fetching *all* # events that we have and then filtering, this isn't the most diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 862638cc4f..21a17cd2e8 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -96,7 +96,7 @@ class ApplicationServicesHandler: self.current_max, limit ) - events_by_room = {} # type: Dict[str, List[EventBase]] + events_by_room: Dict[str, List[EventBase]] = {} for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -275,7 +275,7 @@ class ApplicationServicesHandler: async def _handle_presence( self, service: ApplicationService, users: Collection[Union[str, UserID]] ) -> List[JsonDict]: - events = [] # type: List[JsonDict] + events: List[JsonDict] = [] presence_source = self.event_sources.sources["presence"] from_key = await self.store.get_type_stream_id_for_appservice( service, "presence" @@ -375,7 +375,7 @@ class ApplicationServicesHandler: self, only_protocol: Optional[str] = None ) -> Dict[str, JsonDict]: services = self.store.get_app_services() - protocols = {} # type: Dict[str, List[JsonDict]] + protocols: Dict[str, List[JsonDict]] = {} # Collect up all the individual protocol responses out of the ASes for s in services: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e2ac595a62..22a8552241 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -191,7 +191,7 @@ class AuthHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] + self.checkers: Dict[str, UserInteractiveAuthChecker] = {} for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): @@ -296,7 +296,7 @@ class AuthHandler(BaseHandler): # A mapping of user ID to extra attributes to include in the login # response. - self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes] + self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {} async def validate_user_via_ui_auth( self, @@ -500,7 +500,7 @@ class AuthHandler(BaseHandler): all the stages in any of the permitted flows. """ - sid = None # type: Optional[str] + sid: Optional[str] = None authdict = clientdict.pop("auth", {}) if "session" in authdict: sid = authdict["session"] @@ -588,9 +588,9 @@ class AuthHandler(BaseHandler): ) # check auth type currently being presented - errordict = {} # type: Dict[str, Any] + errordict: Dict[str, Any] = {} if "type" in authdict: - login_type = authdict["type"] # type: str + login_type: str = authdict["type"] try: result = await self._check_auth_dict(authdict, clientip) if result: @@ -766,7 +766,7 @@ class AuthHandler(BaseHandler): LoginType.TERMS: self._get_params_terms, } - params = {} # type: Dict[str, Any] + params: Dict[str, Any] = {} for f in public_flows: for stage in f: @@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler): except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - user_id_to_verify = await self.get_session_data( + user_id_to_verify: str = await self.get_session_data( session_id, UIAuthSessionDataConstants.REQUEST_USER_ID - ) # type: str + ) idps = await self.hs.get_sso_handler().get_identity_providers_for_user( user_id_to_verify diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 7346ccfe93..b681d208bc 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -171,7 +171,7 @@ class CasHandler: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes = {} # type: Dict[str, List[Optional[str]]] + attributes: Dict[str, List[Optional[str]]] = {} for child in root[0]: if child.tag.endswith("user"): user = child.text diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 95bdc5902a..46ee834407 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler): user_id ) - hosts = set() # type: Set[str] + hosts: Set[str] = set() if self.hs.is_mine_id(user_id): hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.discard(self.server_name) @@ -613,20 +613,20 @@ class DeviceListUpdater: self._remote_edu_linearizer = Linearizer(name="remote_device_list") # user_id -> list of updates waiting to be handled. - self._pending_updates = ( - {} - ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]] + self._pending_updates: Dict[ + str, List[Tuple[str, str, Iterable[str], JsonDict]] + ] = {} # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious # resyncs. - self._seen_updates = ExpiringCache( + self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache( cache_name="device_update_edu", clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, iterable=True, - ) # type: ExpiringCache[str, Set[str]] + ) # Attempt to resync out of sync device lists every 30s. self._resync_retry_in_progress = False @@ -755,7 +755,7 @@ class DeviceListUpdater: """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ - seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str] + seen_updates: Set[str] = self._seen_updates.get(user_id, set()) extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 580b941595..679b47f081 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -203,7 +203,7 @@ class DeviceMessageHandler: log_kv({"number_of_to_device_messages": len(messages)}) set_tag("sender", sender_user_id) local_messages = {} - remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): # Ratelimit local cross-user key requests by the sending device. if ( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 06d7012bac..d487fee627 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -237,9 +237,9 @@ class DirectoryHandler(BaseHandler): async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result = await self.get_association_from_room_alias( - room_alias - ) # type: Optional[RoomAliasMapping] + result: Optional[ + RoomAliasMapping + ] = await self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 3972849d4d..d92370859f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -115,9 +115,9 @@ class E2eKeysHandler: the number of in-flight queries at a time. """ with await self._query_devices_linearizer.queue((from_user_id, from_device_id)): - device_keys_query = query_body.get( + device_keys_query: Dict[str, Iterable[str]] = query_body.get( "device_keys", {} - ) # type: Dict[str, Iterable[str]] + ) # separate users by domain. # make a map from domain to user_id to device_ids @@ -136,7 +136,7 @@ class E2eKeysHandler: # First get local devices. # A map of destination -> failure response. - failures = {} # type: Dict[str, JsonDict] + failures: Dict[str, JsonDict] = {} results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -151,11 +151,9 @@ class E2eKeysHandler: # Now attempt to get any remote devices from our local cache. # A map of destination -> user ID -> device IDs. - remote_queries_not_in_cache = ( - {} - ) # type: Dict[str, Dict[str, Iterable[str]]] + remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {} if remote_queries: - query_list = [] # type: List[Tuple[str, Optional[str]]] + query_list: List[Tuple[str, Optional[str]]] = [] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend( @@ -362,9 +360,9 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] # type: List[Tuple[str, Optional[str]]] + local_query: List[Tuple[str, Optional[str]]] = [] - result_dict = {} # type: Dict[str, Dict[str, dict]] + result_dict: Dict[str, Dict[str, dict]] = {} for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -402,9 +400,9 @@ class E2eKeysHandler: self, query_body: Dict[str, Dict[str, Optional[List[str]]]] ) -> JsonDict: """Handle a device key query from a federated server""" - device_keys_query = query_body.get( + device_keys_query: Dict[str, Optional[List[str]]] = query_body.get( "device_keys", {} - ) # type: Dict[str, Optional[List[str]]] + ) res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -421,8 +419,8 @@ class E2eKeysHandler: async def claim_one_time_keys( self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int ) -> JsonDict: - local_query = [] # type: List[Tuple[str, str, str]] - remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] + local_query: List[Tuple[str, str, str]] = [] + remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids @@ -439,8 +437,8 @@ class E2eKeysHandler: results = await self.store.claim_e2e_one_time_keys(local_query) # A map of user ID -> device ID -> key ID -> key. - json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] - failures = {} # type: Dict[str, JsonDict] + json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + failures: Dict[str, JsonDict] = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_str in keys.items(): @@ -768,8 +766,8 @@ class E2eKeysHandler: Raises: SynapseError: if the input is malformed """ - signature_list = [] # type: List[SignatureListItem] - failures = {} # type: Dict[str, Dict[str, JsonDict]] + signature_list: List["SignatureListItem"] = [] + failures: Dict[str, Dict[str, JsonDict]] = {} if not signatures: return signature_list, failures @@ -930,8 +928,8 @@ class E2eKeysHandler: Raises: SynapseError: if the input is malformed """ - signature_list = [] # type: List[SignatureListItem] - failures = {} # type: Dict[str, Dict[str, JsonDict]] + signature_list: List["SignatureListItem"] = [] + failures: Dict[str, Dict[str, JsonDict]] = {} if not signatures: return signature_list, failures @@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater: self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] + self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {} async def incoming_signing_key_update( self, origin: str, edu_content: JsonDict @@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater: # This can happen since we batch updates return - device_ids = [] # type: List[str] + device_ids: List[str] = [] logger.info("pending updates: %r", pending_updates) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f134f1e234..4b3f037072 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler): # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. - to_add = [] # type: List[JsonDict] + to_add: List[JsonDict] = [] for event in events: if not isinstance(event, EventBase): continue @@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = await self.store.get_users_in_room( + users: Iterable[str] = await self.store.get_users_in_room( event.room_id - ) # type: Iterable[str] + ) else: users = [event.state_key] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0209aee186..5c4463583e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -181,7 +181,7 @@ class FederationHandler(BaseHandler): # When joining a room we need to queue any events for that room up. # For each room, a list of (pdu, origin) tuples. - self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]] + self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} self._room_pdu_linearizer = Linearizer("fed_room_pdu") self._room_backfill = Linearizer("room_backfill") @@ -368,7 +368,7 @@ class FederationHandler(BaseHandler): ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - state_maps = list(ours.values()) # type: List[StateMap[str]] + state_maps: List[StateMap[str]] = list(ours.values()) # we don't need this any more, let's delete it. del ours @@ -845,7 +845,7 @@ class FederationHandler(BaseHandler): # exact key to expect. Otherwise check it matches any key we # have for that device. - current_keys = [] # type: Container[str] + current_keys: Container[str] = [] if device: keys = device.get("keys", {}).get("keys", {}) @@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler): if e_type == EventTypes.Member and event.membership == Membership.JOIN ] - joined_domains = {} # type: Dict[str, int] + joined_domains: Dict[str, int] = {} for u, d in joined_users: try: dom = get_domain_from_id(u) @@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version(room_id) - event_map = {} # type: Dict[str, EventBase] + event_map: Dict[str, EventBase] = {} async def get_event(event_id: str): with nested_logging_context(event_id): @@ -1596,7 +1596,7 @@ class FederationHandler(BaseHandler): # Ask the remote server to create a valid knock event for us. Once received, # we sign the event - params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]] + params: Dict[str, Iterable[str]] = {"ver": supported_room_versions} origin, event, event_format_version = await self._make_and_verify_event( target_hosts, room_id, knockee, Membership.KNOCK, content, params=params ) @@ -2453,14 +2453,14 @@ class FederationHandler(BaseHandler): state_sets_d = await self.state_store.get_state_groups( event.room_id, extrem_ids ) - state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]] + state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) state_sets.append(state) current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = { + current_state_ids: StateMap[str] = { k: e.event_id for k, e in current_states.items() - } # type: StateMap[str] + } else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2817,7 +2817,7 @@ class FederationHandler(BaseHandler): """ # exclude the state key of the new event from the current_state in the context. if event.is_state(): - event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]] + event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) else: event_key = None state_updates = { @@ -3156,7 +3156,7 @@ class FederationHandler(BaseHandler): logger.debug("Checking auth on event %r", event.content) - last_exception = None # type: Optional[Exception] + last_exception: Optional[Exception] = None # for each public key in the 3pid invite event for public_key_object in event_auth.get_public_keys(invite_event): diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 157f2ff218..1a6c5c64a2 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler: async def bulk_get_publicised_groups( self, user_ids: Iterable[str], proxy: bool = True ) -> JsonDict: - destinations = {} # type: Dict[str, Set[str]] + destinations: Dict[str, Set[str]] = {} local_users = set() for user_id in user_ids: @@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler: raise SynapseError(400, "Some user_ids are not local") results = {} - failed_results = [] # type: List[str] + failed_results: List[str] = [] for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 76242865ae..5d49640760 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = ResponseCache( - hs.get_clock(), "initial_sync_cache" - ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] + self.snapshot_cache: ResponseCache[ + Tuple[ + str, + Optional[StreamToken], + Optional[StreamToken], + str, + Optional[int], + bool, + bool, + ] + ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e06655f3d4..c7fe4ff89e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -81,7 +81,7 @@ class MessageHandler: # The scheduled call to self._expire_event. None if no call is currently # scheduled. - self._scheduled_expiry = None # type: Optional[IDelayedCall] + self._scheduled_expiry: Optional[IDelayedCall] = None if not hs.config.worker_app: run_as_background_process( @@ -196,9 +196,7 @@ class MessageHandler: room_state_events = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) - room_state = room_state_events[ - event.event_id - ] # type: Mapping[Any, EventBase] + room_state: Mapping[Any, EventBase] = room_state_events[event.event_id] else: raise AuthError( 403, @@ -421,9 +419,9 @@ class EventCreationHandler: self.action_generator = hs.get_action_generator() self.spam_checker = hs.get_spam_checker() - self.third_party_event_rules = ( + self.third_party_event_rules: "ThirdPartyEventRules" = ( self.hs.get_third_party_event_rules() - ) # type: ThirdPartyEventRules + ) self._block_events_without_consent_error = ( self.config.block_events_without_consent_error @@ -440,7 +438,7 @@ class EventCreationHandler: # # map from room id to time-of-last-attempt. # - self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] + self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {} # The number of forward extremeities before a dummy event is sent. self._dummy_events_threshold = hs.config.dummy_events_threshold @@ -465,9 +463,7 @@ class EventCreationHandler: # Stores the state groups we've recently added to the joined hosts # external cache. Note that the timeout must be significantly less than # the TTL on the external cache. - self._external_cache_joined_hosts_updates = ( - None - ) # type: Optional[ExpiringCache] + self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None if self._external_cache.is_enabled(): self._external_cache_joined_hosts_updates = ExpiringCache( "_external_cache_joined_hosts_updates", @@ -1299,7 +1295,7 @@ class EventCreationHandler: # Validate a newly added alias or newly added alt_aliases. original_alias = None - original_alt_aliases = [] # type: List[str] + original_alt_aliases: List[str] = [] original_event_id = event.unsigned.get("replaces_state") if original_event_id: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index ee6e41c0e4..a330c48fa7 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -105,9 +105,9 @@ class OidcHandler: assert provider_confs self._token_generator = OidcSessionTokenGenerator(hs) - self._providers = { + self._providers: Dict[str, "OidcProvider"] = { p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs - } # type: Dict[str, OidcProvider] + } async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. @@ -178,7 +178,7 @@ class OidcHandler: # are two. for cookie_name, _ in _SESSION_COOKIES: - session = request.getCookie(cookie_name) # type: Optional[bytes] + session: Optional[bytes] = request.getCookie(cookie_name) if session is not None: break else: @@ -277,7 +277,7 @@ class OidcProvider: self._token_generator = token_generator self._config = provider - self._callback_url = hs.config.oidc_callback_url # type: str + self._callback_url: str = hs.config.oidc_callback_url # Calculate the prefix for OIDC callback paths based on the public_baseurl. # We'll insert this into the Path= parameter of any session cookies we set. @@ -290,7 +290,7 @@ class OidcProvider: self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method - client_secret = None # type: Union[None, str, JwtClientSecret] + client_secret: Optional[Union[str, JwtClientSecret]] = None if provider.client_secret: client_secret = provider.client_secret elif provider.client_secret_jwt_key: @@ -305,7 +305,7 @@ class OidcProvider: provider.client_id, client_secret, provider.client_auth_method, - ) # type: ClientAuth + ) self._client_auth_method = provider.client_auth_method # cache of metadata for the identity provider (endpoint uris, mostly). This is @@ -324,7 +324,7 @@ class OidcProvider: self._allow_existing_users = provider.allow_existing_users self._http_client = hs.get_proxied_http_client() - self._server_name = hs.config.server_name # type: str + self._server_name: str = hs.config.server_name # identifier for the external_ids table self.idp_id = provider.idp_id @@ -1381,7 +1381,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if display_name == "": display_name = None - emails = [] # type: List[str] + emails: List[str] = [] email = render_template_field(self._config.email_template) if email: emails.append(email) @@ -1391,7 +1391,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): ) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: - extras = {} # type: Dict[str, str] + extras: Dict[str, str] = {} for key, template in self._config.extra_attributes.items(): try: extras[key] = template.render(user=userinfo).strip() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1e1186c29e..1dbafd253d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -81,9 +81,9 @@ class PaginationHandler: self._server_name = hs.hostname self.pagination_lock = ReadWriteLock() - self._purges_in_progress_by_room = set() # type: Set[str] + self._purges_in_progress_by_room: Set[str] = set() # map from purge id to PurgeStatus - self._purges_by_id = {} # type: Dict[str, PurgeStatus] + self._purges_by_id: Dict[str, PurgeStatus] = {} self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 44ed7a0712..016c5df2ca 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler): # The number of ongoing syncs on this process, by user id. # Empty if _presence_enabled is false. - self._user_to_num_current_syncs = {} # type: Dict[str, int] + self._user_to_num_current_syncs: Dict[str, int] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() # user_id -> last_sync_ms. Lists the users that have stopped syncing but # we haven't notified the presence writer of that yet - self.users_going_offline = {} # type: Dict[str, int] + self.users_going_offline: Dict[str, int] = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() # type: Set[str] + self.unpersisted_users_changes: Set[str] = set() hs.get_reactor().addSystemEventTrigger( "before", @@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} # type: Dict[str, int] + self.user_to_num_current_syncs: Dict[str, int] = {} # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler): # we assume that all the sync requests on that process have stopped. # Stored as a dict from process_id to set of user_id, and a dict of # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]] - self.external_process_last_updated_ms = {} # type: Dict[str, int] + self.external_process_to_current_syncs: Dict[str, Set[str]] = {} + self.external_process_last_updated_ms: Dict[str, int] = {} self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -1581,9 +1581,7 @@ class PresenceEventSource: # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users = ( - set() - ) # type: Union[Set[str], FrozenSet[str]] + interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set() if from_key: # First get all users that have had a presence update @@ -1950,8 +1948,8 @@ async def get_interested_parties( A 2-tuple of `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] - users_to_states = {} # type: Dict[str, List[UserPresenceState]] + room_ids_to_states: Dict[str, List[UserPresenceState]] = {} + users_to_states: Dict[str, List[UserPresenceState]] = {} for state in states: room_ids = await store.get_rooms_for_user(state.user_id) for room_id in room_ids: @@ -2063,12 +2061,12 @@ class PresenceFederationQueue: # stream_id, destinations, user_ids)`. We don't store the full states # for efficiency, and remote workers will already have the full states # cached. - self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]] + self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = [] self._next_id = 1 # Map from instance name to current token - self._current_tokens = {} # type: Dict[str, int] + self._current_tokens: Dict[str, int] = {} if self._queue_presence_updates: self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS) @@ -2168,7 +2166,7 @@ class PresenceFederationQueue: # handle the case where `from_token` stream ID has already been dropped. start_idx = max(from_token + 1 - self._next_id, -len(self._queue)) - to_send = [] # type: List[Tuple[int, Tuple[str, str]]] + to_send: List[Tuple[int, Tuple[str, str]]] = [] limited = False new_id = upto_token for _, stream_id, destinations, user_ids in self._queue[start_idx:]: @@ -2216,7 +2214,7 @@ class PresenceFederationQueue: if not self._federation: return - hosts_to_users = {} # type: Dict[str, Set[str]] + hosts_to_users: Dict[str, Set[str]] = {} for row in rows: hosts_to_users.setdefault(row.destination, set()).add(row.user_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 05b4a97b59..20a033d0ba 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler): 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - displayname_to_set = new_displayname # type: Optional[str] + displayname_to_set: Optional[str] = new_displayname if new_displayname == "": displayname_to_set = None @@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler): 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) - avatar_url_to_set = new_avatar_url # type: Optional[str] + avatar_url_to_set: Optional[str] = new_avatar_url if new_avatar_url == "": avatar_url_to_set = None diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 0059ad0f56..283483fc2c 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -98,8 +98,8 @@ class ReceiptsHandler(BaseHandler): async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: """Takes a list of receipts, stores them and informs the notifier.""" - min_batch_id = None # type: Optional[int] - max_batch_id = None # type: Optional[int] + min_batch_id: Optional[int] = None + max_batch_id: Optional[int] = None for receipt in receipts: res = await self.store.insert_receipt( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 579b1b93c5..64656fda22 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler): self.config = hs.config # Room state based off defined presets - self._presets_dict = { + self._presets_dict: Dict[str, Dict[str, Any]] = { RoomCreationPreset.PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, "history_visibility": HistoryVisibility.SHARED, @@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler): "guest_can_join": False, "power_level_content_override": {}, }, - } # type: Dict[str, Dict[str, Any]] + } # Modify presets to selectively enable encryption by default per homeserver config for preset_name, preset_config in self._presets_dict.items(): @@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler): # If a user tries to update the same room multiple times in quick # succession, only process the first attempt and return its result to # subsequent requests - self._upgrade_response_cache = ResponseCache( + self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS - ) # type: ResponseCache[Tuple[str, str]] + ) self._server_notices_mxid = hs.config.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() @@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler): if not await self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") - creation_content = { + creation_content: JsonDict = { "room_version": new_room_version.identifier, "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, - } # type: JsonDict + } # Check if old room was non-federatable @@ -936,7 +936,7 @@ class RoomCreationHandler(BaseHandler): etype=EventTypes.PowerLevels, content=pl_content ) else: - power_level_content = { + power_level_content: JsonDict = { "users": {creator_id: 100}, "users_default": 0, "events": { @@ -955,7 +955,7 @@ class RoomCreationHandler(BaseHandler): "kick": 50, "redact": 50, "invite": 50, - } # type: JsonDict + } if config["original_invitees_have_ops"]: for invitee in invite_list: diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index c6bfa5451f..6284bcdfbc 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -47,12 +47,12 @@ class RoomListHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search - self.response_cache = ResponseCache( - hs.get_clock(), "room_list" - ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]] - self.remote_response_cache = ResponseCache( - hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 - ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] + self.response_cache: ResponseCache[ + Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] + ] = ResponseCache(hs.get_clock(), "room_list") + self.remote_response_cache: ResponseCache[ + Tuple[str, Optional[int], Optional[str], bool, Optional[str]] + ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000) async def get_local_public_room_list( self, @@ -139,10 +139,10 @@ class RoomListHandler(BaseHandler): if since_token: batch_token = RoomListNextBatch.from_token(since_token) - bounds = ( + bounds: Optional[Tuple[int, str]] = ( batch_token.last_joined_members, batch_token.last_room_id, - ) # type: Optional[Tuple[int, str]] + ) forwards = batch_token.direction_is_forward has_batch_token = True else: @@ -182,7 +182,7 @@ class RoomListHandler(BaseHandler): results = [build_room_entry(r) for r in results] - response = {} # type: JsonDict + response: JsonDict = {} num_results = len(results) if limit is not None: more_to_come = num_results == probing_limit diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 80ba65b9e0..72f54c9403 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -83,7 +83,7 @@ class SamlHandler(BaseHandler): self.unstable_idp_brand = None # a map from saml session id to Saml2SessionData object - self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] + self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {} self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) @@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str: return username -MXID_MAPPER_MAP = { +MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = { "hexencode": map_username_to_mxid_localpart, "dotreplace": dot_replace_for_mxid, -} # type: Dict[str, Callable[[str], str]] +} @attr.s diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 4e718d3f63..8226d6f5a1 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -192,7 +192,7 @@ class SearchHandler(BaseHandler): # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: - historical_room_ids = [] # type: List[str] + historical_room_ids: List[str] = [] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist ids = await self.get_old_rooms_from_upgraded_room(room_id) @@ -216,9 +216,9 @@ class SearchHandler(BaseHandler): rank_map = {} # event_id -> rank of event allowed_events = [] # Holds result of grouping by room, if applicable - room_groups = {} # type: Dict[str, JsonDict] + room_groups: Dict[str, JsonDict] = {} # Holds result of grouping by sender, if applicable - sender_group = {} # type: Dict[str, JsonDict] + sender_group: Dict[str, JsonDict] = {} # Holds the next_batch for the entire result set if one of those exists global_next_batch = None @@ -262,7 +262,7 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - room_events = [] # type: List[EventBase] + room_events: List[EventBase] = [] i = 0 pagination_token = batch_token diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 366e6211e5..5f7d4602bd 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -90,14 +90,14 @@ class SpaceSummaryHandler: room_queue = deque((_RoomQueueEntry(room_id, ()),)) # rooms we have already processed - processed_rooms = set() # type: Set[str] + processed_rooms: Set[str] = set() # events we have already processed. We don't necessarily have their event ids, # so instead we key on (room id, state key) - processed_events = set() # type: Set[Tuple[str, str]] + processed_events: Set[Tuple[str, str]] = set() - rooms_result = [] # type: List[JsonDict] - events_result = [] # type: List[JsonDict] + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] while room_queue and len(rooms_result) < MAX_ROOMS: queue_entry = room_queue.popleft() @@ -272,10 +272,10 @@ class SpaceSummaryHandler: # the set of rooms that we should not walk further. Initialise it with the # excluded-rooms list; we will add other rooms as we process them so that # we do not loop. - processed_rooms = set(exclude_rooms) # type: Set[str] + processed_rooms: Set[str] = set(exclude_rooms) - rooms_result = [] # type: List[JsonDict] - events_result = [] # type: List[JsonDict] + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] while room_queue and len(rooms_result) < MAX_ROOMS: room_id = room_queue.popleft() @@ -353,7 +353,7 @@ class SpaceSummaryHandler: max_children = MAX_ROOMS_PER_SPACE now = self._clock.time_msec() - events_result = [] # type: List[JsonDict] + events_result: List[JsonDict] = [] for edge_event in itertools.islice(child_events, max_children): events_result.append( await self._event_serializer.serialize_event( diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 0b297e54c4..1b855a685c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -202,10 +202,10 @@ class SsoHandler: self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) # a map from session id to session data - self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {} # map from idp_id to SsoIdentityProvider - self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + self._identity_providers: Dict[str, SsoIdentityProvider] = {} self._consent_at_registration = hs.config.consent.user_consent_at_registration @@ -296,7 +296,7 @@ class SsoHandler: ) # if the client chose an IdP, use that - idp = None # type: Optional[SsoIdentityProvider] + idp: Optional[SsoIdentityProvider] = None if idp_id: idp = self._identity_providers.get(idp_id) if not idp: @@ -669,9 +669,9 @@ class SsoHandler: remote_user_id, ) - user_id_to_verify = await self._auth_handler.get_session_data( + user_id_to_verify: str = await self._auth_handler.get_session_data( ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID - ) # type: str + ) if not user_id: logger.warning( @@ -793,7 +793,7 @@ class SsoHandler: session.use_display_name = use_display_name emails_from_idp = set(session.emails) - filtered_emails = set() # type: Set[str] + filtered_emails: Set[str] = set() # we iterate through the list rather than just building a set conjunction, so # that we can log attempts to use unknown addresses diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 814d08efcb..3fd89af2a4 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -49,7 +49,7 @@ class StatsHandler: self.stats_enabled = hs.config.stats_enabled # The current position in the current_state_delta stream - self.pos = None # type: Optional[int] + self.pos: Optional[int] = None # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -131,10 +131,10 @@ class StatsHandler: mapping from room/user ID to changes in the various fields. """ - room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] - user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] + room_to_stats_deltas: Dict[str, CounterType[str]] = {} + user_to_stats_deltas: Dict[str, CounterType[str]] = {} - room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] + room_to_state_updates: Dict[str, Dict[str, Any]] = {} for delta in deltas: typ = delta["type"] @@ -164,7 +164,7 @@ class StatsHandler: ) continue - event_content = {} # type: JsonDict + event_content: JsonDict = {} if event_id is not None: event = await self.store.get_event(event_id, allow_none=True) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b9a0361059..722c4ae670 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -278,12 +278,14 @@ class SyncHandler: self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) - self.lazy_loaded_members_cache = ExpiringCache( + self.lazy_loaded_members_cache: ExpiringCache[ + Tuple[str, Optional[str]], LruCache[str, str] + ] = ExpiringCache( "lazy_loaded_members_cache", self.clock, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, - ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]] + ) async def wait_for_sync_for_user( self, @@ -440,7 +442,7 @@ class SyncHandler: ) now_token = now_token.copy_and_replace("typing_key", typing_key) - ephemeral_by_room = {} # type: JsonDict + ephemeral_by_room: JsonDict = {} for event in typing: # we want to exclude the room_id from the event, but modifying the @@ -502,7 +504,7 @@ class SyncHandler: # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to # ensure that we always include current state in the timeline - current_state_ids = frozenset() # type: FrozenSet[str] + current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): current_state_ids_map = await self.store.get_current_state_ids( room_id @@ -783,9 +785,9 @@ class SyncHandler: def get_lazy_loaded_members_cache( self, cache_key: Tuple[str, Optional[str]] ) -> LruCache[str, str]: - cache = self.lazy_loaded_members_cache.get( + cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get( cache_key - ) # type: Optional[LruCache[str, str]] + ) if cache is None: logger.debug("creating LruCache for %r", cache_key) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) @@ -984,7 +986,7 @@ class SyncHandler: if t[0] == EventTypes.Member: cache.set(t[1], event_id) - state = {} # type: Dict[str, EventBase] + state: Dict[str, EventBase] = {} if state_ids: state = await self.store.get_events(list(state_ids.values())) @@ -1088,8 +1090,8 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_key_counts = {} # type: JsonDict - unused_fallback_key_types = [] # type: List[str] + one_time_key_counts: JsonDict = {} + unused_fallback_key_types: List[str] = [] if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id @@ -1437,7 +1439,7 @@ class SyncHandler: ) if block_all_room_ephemeral: - ephemeral_by_room = {} # type: Dict[str, List[JsonDict]] + ephemeral_by_room: Dict[str, List[JsonDict]] = {} else: now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, @@ -1468,7 +1470,7 @@ class SyncHandler: # If there is ignored users account data and it matches the proper type, # then use it. - ignored_users = frozenset() # type: FrozenSet[str] + ignored_users: FrozenSet[str] = frozenset() if ignored_account_data: ignored_users_data = ignored_account_data.get("ignored_users", {}) if isinstance(ignored_users_data, dict): @@ -1586,7 +1588,7 @@ class SyncHandler: user_id, since_token.room_key, now_token.room_key ) - mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]] + mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} for event in rooms_changed: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) @@ -1722,7 +1724,7 @@ class SyncHandler: # This is all screaming out for a refactor, as the logic here is # subtle and the moving parts numerous. if leave_event.internal_metadata.is_out_of_band_membership(): - batch_events = [leave_event] # type: Optional[List[EventBase]] + batch_events: Optional[List[EventBase]] = [leave_event] else: batch_events = None @@ -1971,7 +1973,7 @@ class SyncHandler: room_id, batch, sync_config, since_token, now_token, full_state=full_state ) - summary = {} # type: Optional[JsonDict] + summary: Optional[JsonDict] = {} # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form @@ -1995,7 +1997,7 @@ class SyncHandler: ) if room_builder.rtype == "joined": - unread_notifications = {} # type: Dict[str, int] + unread_notifications: Dict[str, int] = {} room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c0a8364755..0cb651a400 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -68,11 +68,11 @@ class FollowerTypingHandler: ) # map room IDs to serial numbers - self._room_serials = {} # type: Dict[str, int] + self._room_serials: Dict[str, int] = {} # map room IDs to sets of users currently typing - self._room_typing = {} # type: Dict[str, Set[str]] + self._room_typing: Dict[str, Set[str]] = {} - self._member_last_federation_poke = {} # type: Dict[RoomMember, int] + self._member_last_federation_poke: Dict[RoomMember, int] = {} self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 @@ -217,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) # clock time we expect to stop - self._member_typing_until = {} # type: Dict[RoomMember, int] + self._member_typing_until: Dict[RoomMember, int] = {} # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( @@ -405,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler): if last_id == current_id: return [], current_id, False - changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( - last_id - ) # type: Optional[Iterable[str]] + changed_rooms: Optional[ + Iterable[str] + ] = self._typing_stream_change_cache.get_all_entities_changed(last_id) if changed_rooms is None: changed_rooms = self._room_serials diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index dacc4f3076..6edb1da50a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.search_all_users = hs.config.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream - self.pos = None # type: Optional[int] + self.pos: Optional[int] = None # Guard to ensure we only process deltas one at a time self._is_processing = False diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 3c51a742bf..40ee33646c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): # Get the room ID from the identifier. try: - remote_room_hosts = [ + remote_room_hosts: Optional[List[str]] = [ x.decode("ascii") for x in request.args[b"server_name"] - ] # type: Optional[List[str]] + ] except Exception: remote_room_hosts = None room_id, remote_room_hosts = await self.resolve_room_id( @@ -659,9 +659,7 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) else: event_filter = None diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 06e6ccee42..589e47fa47 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth_handler = hs.get_auth_handler() self.reactor = hs.get_reactor() - self.nonces = {} # type: Dict[str, int] + self.nonces: Dict[str, int] = {} self.hs = hs def _clear_old_nonces(self): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index cbcb60fe31..99d02cb355 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -121,7 +121,7 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow = { + sso_flow: JsonDict = { "type": LoginRestServlet.SSO_TYPE, "identity_providers": [ _get_auth_flow_dict_for_idp( @@ -129,7 +129,7 @@ class LoginRestServlet(RestServlet): ) for idp in self._sso_handler.get_identity_providers().values() ], - } # type: JsonDict + } if self._msc2858_enabled: # backwards-compatibility support for clients which don't @@ -447,7 +447,7 @@ def _get_auth_flow_dict_for_idp( use_unstable_brands: whether we should use brand identifiers suitable for the unstable API """ - e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} if idp.idp_icon: e["icon"] = idp.idp_icon if idp.idp_brand: @@ -561,7 +561,7 @@ class SsoRedirectServlet(RestServlet): finish_request(request) return - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) sso_url = await self._sso_handler.handle_redirect_request( request, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index ebf4e32230..31a1193cd3 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -783,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) - limit = int(content.get("limit", 100)) # type: Optional[int] + limit: Optional[int] = int(content.get("limit", 100)) since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -929,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -1044,9 +1042,7 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) else: event_filter = None diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index f8dcee603c..d537d811d8 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester, message_type, content["messages"] ) - response = (200, {}) # type: Tuple[int, dict] + response: Tuple[int, dict] = (200, {}) return response diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index e52570cd8e..4282e2b228 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource): has_consented = False public_version = username == "" if not public_version: - args = request.args # type: Dict[bytes, List[bytes]] + args: Dict[bytes, List[bytes]] = request.args userhmac_bytes = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac_bytes) @@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource): """ version = parse_string(request, "v", required=True) username = parse_string(request, "u", required=True) - args = request.args # type: Dict[bytes, List[bytes]] + args: Dict[bytes, List[bytes]] = request.args userhmac = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index d56a1ae482..63a40b1852 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource): async def _async_render_GET(self, request): if len(request.postpath) == 1: (server,) = request.postpath - query = {server.decode("ascii"): {}} # type: dict + query: dict = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") @@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() # Note that the value is unused. - cache_misses = {} # type: Dict[str, Dict[str, int]] + cache_misses: Dict[str, Dict[str, int]] = {} for (server_name, key_id, _), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 0fb4cd81f1..90364ebcf7 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: # The type on postpath seems incorrect in Twisted 21.2.0. - postpath = request.postpath # type: List[bytes] # type: ignore + postpath: List[bytes] = request.postpath # type: ignore assert postpath # This allows users to append e.g. /test.png to the URL. Useful for diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 21c43c340c..4f702f890c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -78,16 +78,16 @@ class MediaRepository: Thumbnailer.set_limits(self.max_image_pixels) - self.primary_base_path = hs.config.media_store_path # type: str - self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths + self.primary_base_path: str = hs.config.media_store_path + self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") - self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] - self.recently_accessed_locals = set() # type: Set[str] + self.recently_accessed_remotes: Set[Tuple[str, str]] = set() + self.recently_accessed_locals: Set[str] = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -711,7 +711,7 @@ class MediaRepository: # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. - thumbnails = {} # type: Dict[Tuple[int, int, str], str] + thumbnails: Dict[Tuple[int, int, str], str] = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index c7fd97c46c..56cdc1b4ed 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -191,7 +191,7 @@ class MediaStorage: for provider in self.storage_providers: for path in paths: - res = await provider.fetch(path, file_info) # type: Any + res: Any = await provider.fetch(path, file_info) if res: logger.debug("Streaming %s from %s", path, provider) return res @@ -233,7 +233,7 @@ class MediaStorage: os.makedirs(dirname) for provider in self.storage_providers: - res = await provider.fetch(path, file_info) # type: Any + res: Any = await provider.fetch(path, file_info) if res: with res: consumer = BackgroundFileConsumer( diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0adfb1a70f..8e7fead3a2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource): # memory cache mapping urls to an ObservableDeferred returning # JSON-encoded OG metadata - self._cache = ExpiringCache( + self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache( cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=ONE_HOUR, - ) # type: ExpiringCache[str, ObservableDeferred] + ) if self._worker_run_media_background_jobs: self._cleaner_loop = self.clock.looping_call( @@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) # If this URL can be accessed via oEmbed, use that instead. - url_to_download = url # type: Optional[str] + url_to_download: Optional[str] = url oembed_url = self._get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. @@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: # "og:video:height" : "720", # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - og = {} # type: Dict[str, Optional[str]] + og: Dict[str, Optional[str]] = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): if "content" in tag.attrib: # if we've got more than 50 tags, someone is taking the piss diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 62dc4aae2d..146adca8f1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource): errcode=Codes.TOO_LARGE, ) - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore upload_name_bytes = parse_bytes_from_args(args, "filename") if upload_name_bytes: try: - upload_name = upload_name_bytes.decode("utf8") # type: Optional[str] + upload_name: Optional[str] = upload_name_bytes.decode("utf8") except UnicodeDecodeError: raise SynapseError( msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 @@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource): # TODO(markjh): parse content-dispostion try: - content = request.content # type: IO # type: ignore + content: IO = request.content # type: ignore content_uri = await self.media_repo.create_content( media_type, upload_name, content, content_length, requester.user ) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 9b002cc15e..ab24ec0a8e 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource): use_display_name = parse_boolean(request, "use_display_name", default=False) try: - emails_to_use = [ + emails_to_use: List[str] = [ val.decode("utf-8") for val in request.args.get(b"use_email", []) - ] # type: List[str] + ] except ValueError: raise SynapseError(400, "Query parameter use_email must be utf-8") except SynapseError as e: -- cgit 1.4.1 From 95e47b2e782b5e7afa5fd2afd1d0ea7745eaac36 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 19 Jul 2021 16:28:05 +0200 Subject: [pyupgrade] `synapse/` (#10348) This PR is tantamount to running ``` pyupgrade --py36-plus --keep-percent-format `find synapse/ -type f -name "*.py"` ``` Part of #9744 --- changelog.d/10348.misc | 1 + synapse/app/generic_worker.py | 6 ++-- synapse/app/homeserver.py | 6 ++-- synapse/config/appservice.py | 2 +- synapse/config/tls.py | 6 ++-- synapse/handlers/cas.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 4 +-- synapse/handlers/oidc.py | 38 ++++++++++++++------------ synapse/handlers/register.py | 15 ++++------ synapse/handlers/saml.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/http/proxyagent.py | 2 +- synapse/http/site.py | 2 +- synapse/logging/opentracing.py | 2 +- synapse/metrics/_exposition.py | 26 ++++++++---------- synapse/metrics/background_process_metrics.py | 3 +- synapse/rest/client/v1/login.py | 25 ++++++----------- synapse/rest/media/v1/__init__.py | 4 +-- synapse/storage/database.py | 2 +- synapse/storage/databases/main/deviceinbox.py | 4 +-- synapse/storage/databases/main/group_server.py | 6 +++- synapse/storage/databases/main/roommember.py | 2 +- synapse/storage/prepare_database.py | 2 +- synapse/types.py | 4 +-- synapse/util/caches/lrucache.py | 3 +- synapse/util/caches/treecache.py | 3 +- synapse/util/daemonize.py | 8 +++--- synapse/visibility.py | 4 +-- 29 files changed, 86 insertions(+), 102 deletions(-) create mode 100644 changelog.d/10348.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/10348.misc b/changelog.d/10348.misc new file mode 100644 index 0000000000..b2275a1350 --- /dev/null +++ b/changelog.d/10348.misc @@ -0,0 +1 @@ +Run `pyupgrade` on the codebase. \ No newline at end of file diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index b43d858f59..c3d4992518 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -395,10 +395,8 @@ class GenericWorkerServer(HomeServer): elif listener.type == "metrics": if not self.config.enable_metrics: logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) + "Metrics listener configured, but " + "enable_metrics is not True!" ) else: _base.listen_metrics(listener.bind_addresses, listener.port) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7af56ac136..920b34d97b 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -305,10 +305,8 @@ class SynapseHomeServer(HomeServer): elif listener.type == "metrics": if not self.config.enable_metrics: logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) + "Metrics listener configured, but " + "enable_metrics is not True!" ) else: _base.listen_metrics(listener.bind_addresses, listener.port) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index a39d457c56..1ebea88db2 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -64,7 +64,7 @@ def load_appservices(hostname, config_files): for config_file in config_files: try: - with open(config_file, "r") as f: + with open(config_file) as f: appservice = _load_appservice(hostname, yaml.safe_load(f), config_file) if appservice.id in seen_ids: raise ConfigError( diff --git a/synapse/config/tls.py b/synapse/config/tls.py index fed05ac7be..5679f05e42 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -66,10 +66,8 @@ class TlsConfig(Config): if self.federation_client_minimum_tls_version == "1.3": if getattr(SSL, "OP_NO_TLSv1_3", None) is None: raise ConfigError( - ( - "federation_client_minimum_tls_version cannot be 1.3, " - "your OpenSSL does not support it" - ) + "federation_client_minimum_tls_version cannot be 1.3, " + "your OpenSSL does not support it" ) # Whitelist of domains to not verify certificates for diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index b681d208bc..0325f86e20 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -40,7 +40,7 @@ class CasError(Exception): def __str__(self): if self.error_description: - return "{}: {}".format(self.error, self.error_description) + return f"{self.error}: {self.error_description}" return self.error diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5c4463583e..cf389be3e4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -735,7 +735,7 @@ class FederationHandler(BaseHandler): # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - (await self.store.get_events(missing_desired_events, allow_rejected=True)) + await self.store.get_events(missing_desired_events, allow_rejected=True) ) # check for events which were in the wrong room. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 33d16fbf9c..0961dec5ab 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler): ) url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") + url_bytes = b"/_matrix/identity/api/v1/3pid/unbind" content = { "mxid": mxid, @@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler): return data["mxid"] except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - except IOError as e: + except OSError as e: logger.warning("Error from v1 identity server lookup: %s" % (e,)) return None diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index a330c48fa7..eca8f16040 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -72,26 +72,26 @@ _SESSION_COOKIES = [ (b"oidc_session_no_samesite", b"HttpOnly"), ] + #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and #: OpenID.Core sec 3.1.3.3. -Token = TypedDict( - "Token", - { - "access_token": str, - "token_type": str, - "id_token": Optional[str], - "refresh_token": Optional[str], - "expires_in": int, - "scope": Optional[str], - }, -) +class Token(TypedDict): + access_token: str + token_type: str + id_token: Optional[str] + refresh_token: Optional[str] + expires_in: int + scope: Optional[str] + #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but #: there is no real point of doing this in our case. JWK = Dict[str, str] + #: A JWK Set, as per RFC7517 sec 5. -JWKS = TypedDict("JWKS", {"keys": List[JWK]}) +class JWKS(TypedDict): + keys: List[JWK] class OidcHandler: @@ -255,7 +255,7 @@ class OidcError(Exception): def __str__(self): if self.error_description: - return "{}: {}".format(self.error, self.error_description) + return f"{self.error}: {self.error_description}" return self.error @@ -639,7 +639,7 @@ class OidcProvider: ) logger.warning(description) # Body was still valid JSON. Might be useful to log it for debugging. - logger.warning("Code exchange response: {resp!r}".format(resp=resp)) + logger.warning("Code exchange response: %r", resp) raise OidcError("server_error", description) return resp @@ -1217,10 +1217,12 @@ class OidcSessionData: ui_auth_session_id = attr.ib(type=str) -UserAttributeDict = TypedDict( - "UserAttributeDict", - {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, -) +class UserAttributeDict(TypedDict): + localpart: Optional[str] + display_name: Optional[str] + emails: List[str] + + C = TypeVar("C") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 056fe5e89f..8cf614136e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -55,15 +55,12 @@ login_counter = Counter( ["guest", "auth_provider"], ) -LoginDict = TypedDict( - "LoginDict", - { - "device_id": str, - "access_token": str, - "valid_until_ms": Optional[int], - "refresh_token": Optional[str], - }, -) + +class LoginDict(TypedDict): + device_id: str + access_token: str + valid_until_ms: Optional[int] + refresh_token: Optional[str] class RegistrationHandler(BaseHandler): diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 72f54c9403..e6e71e9729 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -372,7 +372,7 @@ class SamlHandler(BaseHandler): DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) + "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),) ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 722c4ae670..150a4f291e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1601,7 +1601,7 @@ class SyncHandler: logger.debug( "Membership changes in %s: [%s]", room_id, - ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)), + ", ".join("%s (%s)" % (e.event_id, e.membership) for e in events), ) non_joins = [e for e in events if e.membership != Membership.JOIN] diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7a6a1717de..f7193e60bd 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -172,7 +172,7 @@ class ProxyAgent(_AgentBase): """ uri = uri.strip() if not _VALID_URI.match(uri): - raise ValueError("Invalid URI {!r}".format(uri)) + raise ValueError(f"Invalid URI {uri!r}") parsed_uri = URI.fromBytes(uri) pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) diff --git a/synapse/http/site.py b/synapse/http/site.py index 3b0a38124e..190084e8aa 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -384,7 +384,7 @@ class SynapseRequest(Request): # authenticated (e.g. and admin is puppetting a user) then we log both. requester, authenticated_entity = self.get_authenticated_entity() if authenticated_entity: - requester = "{}.{}".format(authenticated_entity, requester) + requester = f"{authenticated_entity}.{requester}" self.site.access_logger.log( log_level, diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 185844f188..ecd51f1b4a 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -374,7 +374,7 @@ def init_tracer(hs: "HomeServer"): config = JaegerConfig( config=hs.config.jaeger_config, - service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + service_name=f"{hs.config.server_name} {hs.get_instance_name()}", scope_manager=LogContextScopeManager(hs.config), metrics_factory=PrometheusMetricsFactory(), ) diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 7e49d0d02c..bb9bcb5592 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -34,7 +34,7 @@ from twisted.web.resource import Resource from synapse.util import caches -CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" INF = float("inf") @@ -55,8 +55,8 @@ def floatToGoString(d): # Go switches to exponents sooner than Python. # We only need to care about positive values for le/quantile. if d > 0 and dot > 6: - mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.") - return "{0}e+0{1}".format(mantissa, dot - 1) + mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.") + return f"{mantissa}e+0{dot - 1}" return s @@ -65,7 +65,7 @@ def sample_line(line, name): labelstr = "{{{0}}}".format( ",".join( [ - '{0}="{1}"'.format( + '{}="{}"'.format( k, v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), ) @@ -78,10 +78,8 @@ def sample_line(line, name): timestamp = "" if line.timestamp is not None: # Convert to milliseconds. - timestamp = " {0:d}".format(int(float(line.timestamp) * 1000)) - return "{0}{1} {2}{3}\n".format( - name, labelstr, floatToGoString(line.value), timestamp - ) + timestamp = f" {int(float(line.timestamp) * 1000):d}" + return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) def generate_latest(registry, emit_help=False): @@ -118,12 +116,12 @@ def generate_latest(registry, emit_help=False): # Output in the old format for compatibility. if emit_help: output.append( - "# HELP {0} {1}\n".format( + "# HELP {} {}\n".format( mname, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0} {1}\n".format(mname, mtype)) + output.append(f"# TYPE {mname} {mtype}\n") om_samples: Dict[str, List[str]] = {} for s in metric.samples: @@ -143,13 +141,13 @@ def generate_latest(registry, emit_help=False): for suffix, lines in sorted(om_samples.items()): if emit_help: output.append( - "# HELP {0}{1} {2}\n".format( + "# HELP {}{} {}\n".format( metric.name, suffix, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix)) + output.append(f"# TYPE {metric.name}{suffix} gauge\n") output.extend(lines) # Get rid of the weird colon things while we're at it @@ -163,12 +161,12 @@ def generate_latest(registry, emit_help=False): # Also output in the new format, if it's different. if emit_help: output.append( - "# HELP {0} {1}\n".format( + "# HELP {} {}\n".format( mnewname, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0} {1}\n".format(mnewname, mtype)) + output.append(f"# TYPE {mnewname} {mtype}\n") for s in metric.samples: # Get rid of the OpenMetrics specific samples (we should already have diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 4455fa71a8..3a14260752 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -137,8 +137,7 @@ class _Collector: _background_process_db_txn_duration, _background_process_db_sched_duration, ): - for r in m.collect(): - yield r + yield from m.collect() REGISTRY.register(_Collector()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 99d02cb355..11567bf32c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,19 +44,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -LoginResponse = TypedDict( - "LoginResponse", - { - "user_id": str, - "access_token": str, - "home_server": str, - "expires_in_ms": Optional[int], - "refresh_token": Optional[str], - "device_id": str, - "well_known": Optional[Dict[str, Any]], - }, - total=False, -) +class LoginResponse(TypedDict, total=False): + user_id: str + access_token: str + home_server: str + expires_in_ms: Optional[int] + refresh_token: Optional[str] + device_id: str + well_known: Optional[Dict[str, Any]] class LoginRestServlet(RestServlet): @@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend( - ({"type": t} for t in self.auth_handler.get_supported_login_types()) - ) + flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) diff --git a/synapse/rest/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py index d20186bbd0..3dd16d4bb5 100644 --- a/synapse/rest/media/v1/__init__.py +++ b/synapse/rest/media/v1/__init__.py @@ -17,7 +17,7 @@ import PIL.Image # check for JPEG support. try: PIL.Image._getdecoder("rgb", "jpeg", None) -except IOError as e: +except OSError as e: if str(e).startswith("decoder jpeg not available"): raise Exception( "FATAL: jpeg codec not supported. Install pillow correctly! " @@ -32,7 +32,7 @@ except Exception: # check for PNG support. try: PIL.Image._getdecoder("rgb", "zip", None) -except IOError as e: +except OSError as e: if str(e).startswith("decoder zip not available"): raise Exception( "FATAL: zip codec not supported. Install pillow correctly! " diff --git a/synapse/storage/database.py b/synapse/storage/database.py index f80d822c12..ccf9ac51ef 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -907,7 +907,7 @@ class DatabasePool: # The sort is to ensure that we don't rely on dictionary iteration # order. keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i) ) for k in keys: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 50e7ddd735..c55508867d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -203,9 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): "delete_messages_for_device", delete_messages_for_device_txn ) - log_kv( - {"message": "deleted {} messages for device".format(count), "count": count} - ) + log_kv({"message": f"deleted {count} messages for device", "count": count}) # Update the cache, ensuring that we only ever increase the value last_deleted_stream_id = self._last_device_delete_cache.get( diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 66ad363bfb..e70d3649ff 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -27,8 +27,11 @@ from synapse.util import json_encoder _DEFAULT_CATEGORY_ID = "" _DEFAULT_ROLE_ID = "" + # A room in a group. -_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool}) +class _RoomInGroup(TypedDict): + room_id: str + is_public: bool class GroupServerWorkerStore(SQLBaseStore): @@ -92,6 +95,7 @@ class GroupServerWorkerStore(SQLBaseStore): "is_public": False # Whether this is a public room or not } """ + # TODO: Pagination def _get_rooms_in_group_txn(txn): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4d82c4c26d..68f1b40ea6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -649,7 +649,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): event_to_memberships = await self._get_joined_profiles_from_event_ids( missing_member_event_ids ) - users_in_room.update((row for row in event_to_memberships.values() if row)) + users_in_room.update(row for row in event_to_memberships.values() if row) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 82a7686df0..61392b9639 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -639,7 +639,7 @@ def get_statements(f: Iterable[str]) -> Generator[str, None, None]: def executescript(txn: Cursor, schema_path: str) -> None: - with open(schema_path, "r") as f: + with open(schema_path) as f: execute_statements_from_stream(txn, f) diff --git a/synapse/types.py b/synapse/types.py index fad23c8700..429bb013d2 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -577,10 +577,10 @@ class RoomStreamToken: entries = [] for name, pos in self.instance_map.items(): instance_id = await store.get_id_for_instance(name) - entries.append("{}.{}".format(instance_id, pos)) + entries.append(f"{instance_id}.{pos}") encoded_map = "~".join(entries) - return "m{}~{}".format(self.stream, encoded_map) + return f"m{self.stream}~{encoded_map}" else: return "s%d" % (self.stream,) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index efeba0cb96..5c65d187b6 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -90,8 +90,7 @@ def enumerate_leaves(node, depth): yield node else: for n in node.values(): - for m in enumerate_leaves(n, depth - 1): - yield m + yield from enumerate_leaves(n, depth - 1) P = TypeVar("P") diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index a6df81ebff..4138931e7b 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d): """ if isinstance(d, TreeCacheNode): for value_d in d.values(): - for value in iterate_tree_cache_entry(value_d): - yield value + yield from iterate_tree_cache_entry(value_d) else: yield d diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index 31b24dd188..d8532411c2 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -31,13 +31,13 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # If pidfile already exists, we should read pid from there; to overwrite it, if # locking will fail, because locking attempt somehow purges the file contents. if os.path.isfile(pid_file): - with open(pid_file, "r") as pid_fh: + with open(pid_file) as pid_fh: old_pid = pid_fh.read() # Create a lockfile so that only one instance of this daemon is running at any time. try: lock_fh = open(pid_file, "w") - except IOError: + except OSError: print("Unable to create the pidfile.") sys.exit(1) @@ -45,7 +45,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # Try to get an exclusive lock on the file. This will fail if another process # has the file locked. fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB) - except IOError: + except OSError: print("Unable to lock on the pidfile.") # We need to overwrite the pidfile if we got here. # @@ -113,7 +113,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - try: lock_fh.write("%s" % (os.getpid())) lock_fh.flush() - except IOError: + except OSError: logger.error("Unable to write pid to the pidfile.") print("Unable to write pid to the pidfile.") sys.exit(1) diff --git a/synapse/visibility.py b/synapse/visibility.py index 1dc6b90275..17532059e9 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -96,7 +96,7 @@ async def filter_events_for_client( if isinstance(ignored_users_dict, dict): ignore_list = frozenset(ignored_users_dict.keys()) - erased_senders = await storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased(e.sender for e in events) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -353,7 +353,7 @@ async def filter_events_for_server( ) if not check_history_visibility_only: - erased_senders = await storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased(e.sender for e in events) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. -- cgit 1.4.1