From 7695ca06187bb6742ed74c5ae060c48a08af99ce Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 15 Jul 2021 10:35:46 +0100 Subject: Fix a number of logged errors caused by remote servers being down. (#10400) --- synapse/handlers/federation.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) (limited to 'synapse/handlers/federation.py') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 991ec9919a..0209aee186 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1414,12 +1414,15 @@ class FederationHandler(BaseHandler): Invites must be signed by the invitee's server before distribution. """ - pdu = await self.federation_client.send_invite( - destination=target_host, - room_id=event.room_id, - event_id=event.event_id, - pdu=event, - ) + try: + pdu = await self.federation_client.send_invite( + destination=target_host, + room_id=event.room_id, + event_id=event.event_id, + pdu=event, + ) + except RequestSendFailed: + raise SynapseError(502, f"Can't connect to server {target_host}") return pdu @@ -3031,9 +3034,13 @@ class FederationHandler(BaseHandler): await member_handler.send_membership_event(None, event, context) else: destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} - await self.federation_client.forward_third_party_invite( - destinations, room_id, event_dict - ) + + try: + await self.federation_client.forward_third_party_invite( + destinations, room_id, event_dict + ) + except (RequestSendFailed, HttpResponseException): + raise SynapseError(502, "Failed to forward third party invite") async def on_exchange_third_party_invite_request( self, event_dict: JsonDict -- cgit 1.4.1 From 98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Fri, 16 Jul 2021 19:22:36 +0200 Subject: Use inline type hints in `handlers/` and `rest/`. (#10382) --- changelog.d/10382.misc | 1 + synapse/handlers/_base.py | 8 +++--- synapse/handlers/admin.py | 4 +-- synapse/handlers/appservice.py | 6 ++-- synapse/handlers/auth.py | 16 +++++------ synapse/handlers/cas.py | 2 +- synapse/handlers/device.py | 14 +++++----- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/directory.py | 6 ++-- synapse/handlers/e2e_keys.py | 40 +++++++++++++-------------- synapse/handlers/events.py | 6 ++-- synapse/handlers/federation.py | 22 +++++++-------- synapse/handlers/groups_local.py | 4 +-- synapse/handlers/initial_sync.py | 14 ++++++++-- synapse/handlers/message.py | 18 +++++------- synapse/handlers/oidc.py | 18 ++++++------ synapse/handlers/pagination.py | 4 +-- synapse/handlers/presence.py | 28 +++++++++---------- synapse/handlers/profile.py | 4 +-- synapse/handlers/receipts.py | 4 +-- synapse/handlers/room.py | 16 +++++------ synapse/handlers/room_list.py | 18 ++++++------ synapse/handlers/saml.py | 6 ++-- synapse/handlers/search.py | 8 +++--- synapse/handlers/space_summary.py | 16 +++++------ synapse/handlers/sso.py | 12 ++++---- synapse/handlers/stats.py | 10 +++---- synapse/handlers/sync.py | 32 +++++++++++---------- synapse/handlers/typing.py | 14 +++++----- synapse/handlers/user_directory.py | 2 +- synapse/rest/admin/rooms.py | 8 ++---- synapse/rest/admin/users.py | 2 +- synapse/rest/client/v1/login.py | 8 +++--- synapse/rest/client/v1/room.py | 10 ++----- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- synapse/rest/consent/consent_resource.py | 4 +-- synapse/rest/key/v2/remote_key_resource.py | 4 +-- synapse/rest/media/v1/_base.py | 2 +- synapse/rest/media/v1/media_repository.py | 10 +++---- synapse/rest/media/v1/media_storage.py | 4 +-- synapse/rest/media/v1/preview_url_resource.py | 8 +++--- synapse/rest/media/v1/upload_resource.py | 6 ++-- synapse/rest/synapse/client/pick_username.py | 4 +-- 43 files changed, 212 insertions(+), 215 deletions(-) create mode 100644 changelog.d/10382.misc (limited to 'synapse/handlers/federation.py') diff --git a/changelog.d/10382.misc b/changelog.d/10382.misc new file mode 100644 index 0000000000..eed2d8552a --- /dev/null +++ b/changelog.d/10382.misc @@ -0,0 +1 @@ +Convert internal type variable syntax to reflect wider ecosystem use. \ No newline at end of file diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d800e16912..525f3d39b1 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -38,10 +38,10 @@ class BaseHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() # type: synapse.storage.DataStore + self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler + self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() self.clock = hs.get_clock() self.hs = hs @@ -55,12 +55,12 @@ class BaseHandler: # Check whether ratelimiting room admin message redaction is enabled # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: - self.admin_redaction_ratelimiter = Ratelimiter( + self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, - ) # type: Optional[Ratelimiter] + ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d75a8b15c3..bfa7f2c545 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -139,7 +139,7 @@ class AdminHandler(BaseHandler): to_key = RoomStreamToken(None, stream_ordering) # Events that we've processed in this room - written_events = set() # type: Set[str] + written_events: Set[str] = set() # We need to track gaps in the events stream so that we can then # write out the state at those events. We do this by keeping track @@ -152,7 +152,7 @@ class AdminHandler(BaseHandler): # The reverse mapping to above, i.e. map from unseen event to events # that have the unseen event in their prev_events, i.e. the unseen # events "children". - unseen_to_child_events = {} # type: Dict[str, Set[str]] + unseen_to_child_events: Dict[str, Set[str]] = {} # We fetch events in the room the user could see by fetching *all* # events that we have and then filtering, this isn't the most diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 862638cc4f..21a17cd2e8 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -96,7 +96,7 @@ class ApplicationServicesHandler: self.current_max, limit ) - events_by_room = {} # type: Dict[str, List[EventBase]] + events_by_room: Dict[str, List[EventBase]] = {} for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -275,7 +275,7 @@ class ApplicationServicesHandler: async def _handle_presence( self, service: ApplicationService, users: Collection[Union[str, UserID]] ) -> List[JsonDict]: - events = [] # type: List[JsonDict] + events: List[JsonDict] = [] presence_source = self.event_sources.sources["presence"] from_key = await self.store.get_type_stream_id_for_appservice( service, "presence" @@ -375,7 +375,7 @@ class ApplicationServicesHandler: self, only_protocol: Optional[str] = None ) -> Dict[str, JsonDict]: services = self.store.get_app_services() - protocols = {} # type: Dict[str, List[JsonDict]] + protocols: Dict[str, List[JsonDict]] = {} # Collect up all the individual protocol responses out of the ASes for s in services: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e2ac595a62..22a8552241 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -191,7 +191,7 @@ class AuthHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] + self.checkers: Dict[str, UserInteractiveAuthChecker] = {} for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): @@ -296,7 +296,7 @@ class AuthHandler(BaseHandler): # A mapping of user ID to extra attributes to include in the login # response. - self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes] + self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {} async def validate_user_via_ui_auth( self, @@ -500,7 +500,7 @@ class AuthHandler(BaseHandler): all the stages in any of the permitted flows. """ - sid = None # type: Optional[str] + sid: Optional[str] = None authdict = clientdict.pop("auth", {}) if "session" in authdict: sid = authdict["session"] @@ -588,9 +588,9 @@ class AuthHandler(BaseHandler): ) # check auth type currently being presented - errordict = {} # type: Dict[str, Any] + errordict: Dict[str, Any] = {} if "type" in authdict: - login_type = authdict["type"] # type: str + login_type: str = authdict["type"] try: result = await self._check_auth_dict(authdict, clientip) if result: @@ -766,7 +766,7 @@ class AuthHandler(BaseHandler): LoginType.TERMS: self._get_params_terms, } - params = {} # type: Dict[str, Any] + params: Dict[str, Any] = {} for f in public_flows: for stage in f: @@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler): except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - user_id_to_verify = await self.get_session_data( + user_id_to_verify: str = await self.get_session_data( session_id, UIAuthSessionDataConstants.REQUEST_USER_ID - ) # type: str + ) idps = await self.hs.get_sso_handler().get_identity_providers_for_user( user_id_to_verify diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 7346ccfe93..b681d208bc 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -171,7 +171,7 @@ class CasHandler: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes = {} # type: Dict[str, List[Optional[str]]] + attributes: Dict[str, List[Optional[str]]] = {} for child in root[0]: if child.tag.endswith("user"): user = child.text diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 95bdc5902a..46ee834407 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler): user_id ) - hosts = set() # type: Set[str] + hosts: Set[str] = set() if self.hs.is_mine_id(user_id): hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.discard(self.server_name) @@ -613,20 +613,20 @@ class DeviceListUpdater: self._remote_edu_linearizer = Linearizer(name="remote_device_list") # user_id -> list of updates waiting to be handled. - self._pending_updates = ( - {} - ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]] + self._pending_updates: Dict[ + str, List[Tuple[str, str, Iterable[str], JsonDict]] + ] = {} # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious # resyncs. - self._seen_updates = ExpiringCache( + self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache( cache_name="device_update_edu", clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, iterable=True, - ) # type: ExpiringCache[str, Set[str]] + ) # Attempt to resync out of sync device lists every 30s. self._resync_retry_in_progress = False @@ -755,7 +755,7 @@ class DeviceListUpdater: """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ - seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str] + seen_updates: Set[str] = self._seen_updates.get(user_id, set()) extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 580b941595..679b47f081 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -203,7 +203,7 @@ class DeviceMessageHandler: log_kv({"number_of_to_device_messages": len(messages)}) set_tag("sender", sender_user_id) local_messages = {} - remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): # Ratelimit local cross-user key requests by the sending device. if ( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 06d7012bac..d487fee627 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -237,9 +237,9 @@ class DirectoryHandler(BaseHandler): async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result = await self.get_association_from_room_alias( - room_alias - ) # type: Optional[RoomAliasMapping] + result: Optional[ + RoomAliasMapping + ] = await self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 3972849d4d..d92370859f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -115,9 +115,9 @@ class E2eKeysHandler: the number of in-flight queries at a time. """ with await self._query_devices_linearizer.queue((from_user_id, from_device_id)): - device_keys_query = query_body.get( + device_keys_query: Dict[str, Iterable[str]] = query_body.get( "device_keys", {} - ) # type: Dict[str, Iterable[str]] + ) # separate users by domain. # make a map from domain to user_id to device_ids @@ -136,7 +136,7 @@ class E2eKeysHandler: # First get local devices. # A map of destination -> failure response. - failures = {} # type: Dict[str, JsonDict] + failures: Dict[str, JsonDict] = {} results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -151,11 +151,9 @@ class E2eKeysHandler: # Now attempt to get any remote devices from our local cache. # A map of destination -> user ID -> device IDs. - remote_queries_not_in_cache = ( - {} - ) # type: Dict[str, Dict[str, Iterable[str]]] + remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {} if remote_queries: - query_list = [] # type: List[Tuple[str, Optional[str]]] + query_list: List[Tuple[str, Optional[str]]] = [] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend( @@ -362,9 +360,9 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] # type: List[Tuple[str, Optional[str]]] + local_query: List[Tuple[str, Optional[str]]] = [] - result_dict = {} # type: Dict[str, Dict[str, dict]] + result_dict: Dict[str, Dict[str, dict]] = {} for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -402,9 +400,9 @@ class E2eKeysHandler: self, query_body: Dict[str, Dict[str, Optional[List[str]]]] ) -> JsonDict: """Handle a device key query from a federated server""" - device_keys_query = query_body.get( + device_keys_query: Dict[str, Optional[List[str]]] = query_body.get( "device_keys", {} - ) # type: Dict[str, Optional[List[str]]] + ) res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -421,8 +419,8 @@ class E2eKeysHandler: async def claim_one_time_keys( self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int ) -> JsonDict: - local_query = [] # type: List[Tuple[str, str, str]] - remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] + local_query: List[Tuple[str, str, str]] = [] + remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids @@ -439,8 +437,8 @@ class E2eKeysHandler: results = await self.store.claim_e2e_one_time_keys(local_query) # A map of user ID -> device ID -> key ID -> key. - json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] - failures = {} # type: Dict[str, JsonDict] + json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + failures: Dict[str, JsonDict] = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_str in keys.items(): @@ -768,8 +766,8 @@ class E2eKeysHandler: Raises: SynapseError: if the input is malformed """ - signature_list = [] # type: List[SignatureListItem] - failures = {} # type: Dict[str, Dict[str, JsonDict]] + signature_list: List["SignatureListItem"] = [] + failures: Dict[str, Dict[str, JsonDict]] = {} if not signatures: return signature_list, failures @@ -930,8 +928,8 @@ class E2eKeysHandler: Raises: SynapseError: if the input is malformed """ - signature_list = [] # type: List[SignatureListItem] - failures = {} # type: Dict[str, Dict[str, JsonDict]] + signature_list: List["SignatureListItem"] = [] + failures: Dict[str, Dict[str, JsonDict]] = {} if not signatures: return signature_list, failures @@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater: self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] + self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {} async def incoming_signing_key_update( self, origin: str, edu_content: JsonDict @@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater: # This can happen since we batch updates return - device_ids = [] # type: List[str] + device_ids: List[str] = [] logger.info("pending updates: %r", pending_updates) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f134f1e234..4b3f037072 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler): # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. - to_add = [] # type: List[JsonDict] + to_add: List[JsonDict] = [] for event in events: if not isinstance(event, EventBase): continue @@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = await self.store.get_users_in_room( + users: Iterable[str] = await self.store.get_users_in_room( event.room_id - ) # type: Iterable[str] + ) else: users = [event.state_key] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0209aee186..5c4463583e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -181,7 +181,7 @@ class FederationHandler(BaseHandler): # When joining a room we need to queue any events for that room up. # For each room, a list of (pdu, origin) tuples. - self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]] + self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} self._room_pdu_linearizer = Linearizer("fed_room_pdu") self._room_backfill = Linearizer("room_backfill") @@ -368,7 +368,7 @@ class FederationHandler(BaseHandler): ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - state_maps = list(ours.values()) # type: List[StateMap[str]] + state_maps: List[StateMap[str]] = list(ours.values()) # we don't need this any more, let's delete it. del ours @@ -845,7 +845,7 @@ class FederationHandler(BaseHandler): # exact key to expect. Otherwise check it matches any key we # have for that device. - current_keys = [] # type: Container[str] + current_keys: Container[str] = [] if device: keys = device.get("keys", {}).get("keys", {}) @@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler): if e_type == EventTypes.Member and event.membership == Membership.JOIN ] - joined_domains = {} # type: Dict[str, int] + joined_domains: Dict[str, int] = {} for u, d in joined_users: try: dom = get_domain_from_id(u) @@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version(room_id) - event_map = {} # type: Dict[str, EventBase] + event_map: Dict[str, EventBase] = {} async def get_event(event_id: str): with nested_logging_context(event_id): @@ -1596,7 +1596,7 @@ class FederationHandler(BaseHandler): # Ask the remote server to create a valid knock event for us. Once received, # we sign the event - params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]] + params: Dict[str, Iterable[str]] = {"ver": supported_room_versions} origin, event, event_format_version = await self._make_and_verify_event( target_hosts, room_id, knockee, Membership.KNOCK, content, params=params ) @@ -2453,14 +2453,14 @@ class FederationHandler(BaseHandler): state_sets_d = await self.state_store.get_state_groups( event.room_id, extrem_ids ) - state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]] + state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) state_sets.append(state) current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = { + current_state_ids: StateMap[str] = { k: e.event_id for k, e in current_states.items() - } # type: StateMap[str] + } else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2817,7 +2817,7 @@ class FederationHandler(BaseHandler): """ # exclude the state key of the new event from the current_state in the context. if event.is_state(): - event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]] + event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) else: event_key = None state_updates = { @@ -3156,7 +3156,7 @@ class FederationHandler(BaseHandler): logger.debug("Checking auth on event %r", event.content) - last_exception = None # type: Optional[Exception] + last_exception: Optional[Exception] = None # for each public key in the 3pid invite event for public_key_object in event_auth.get_public_keys(invite_event): diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 157f2ff218..1a6c5c64a2 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler: async def bulk_get_publicised_groups( self, user_ids: Iterable[str], proxy: bool = True ) -> JsonDict: - destinations = {} # type: Dict[str, Set[str]] + destinations: Dict[str, Set[str]] = {} local_users = set() for user_id in user_ids: @@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler: raise SynapseError(400, "Some user_ids are not local") results = {} - failed_results = [] # type: List[str] + failed_results: List[str] = [] for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 76242865ae..5d49640760 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = ResponseCache( - hs.get_clock(), "initial_sync_cache" - ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] + self.snapshot_cache: ResponseCache[ + Tuple[ + str, + Optional[StreamToken], + Optional[StreamToken], + str, + Optional[int], + bool, + bool, + ] + ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e06655f3d4..c7fe4ff89e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -81,7 +81,7 @@ class MessageHandler: # The scheduled call to self._expire_event. None if no call is currently # scheduled. - self._scheduled_expiry = None # type: Optional[IDelayedCall] + self._scheduled_expiry: Optional[IDelayedCall] = None if not hs.config.worker_app: run_as_background_process( @@ -196,9 +196,7 @@ class MessageHandler: room_state_events = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) - room_state = room_state_events[ - event.event_id - ] # type: Mapping[Any, EventBase] + room_state: Mapping[Any, EventBase] = room_state_events[event.event_id] else: raise AuthError( 403, @@ -421,9 +419,9 @@ class EventCreationHandler: self.action_generator = hs.get_action_generator() self.spam_checker = hs.get_spam_checker() - self.third_party_event_rules = ( + self.third_party_event_rules: "ThirdPartyEventRules" = ( self.hs.get_third_party_event_rules() - ) # type: ThirdPartyEventRules + ) self._block_events_without_consent_error = ( self.config.block_events_without_consent_error @@ -440,7 +438,7 @@ class EventCreationHandler: # # map from room id to time-of-last-attempt. # - self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] + self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {} # The number of forward extremeities before a dummy event is sent. self._dummy_events_threshold = hs.config.dummy_events_threshold @@ -465,9 +463,7 @@ class EventCreationHandler: # Stores the state groups we've recently added to the joined hosts # external cache. Note that the timeout must be significantly less than # the TTL on the external cache. - self._external_cache_joined_hosts_updates = ( - None - ) # type: Optional[ExpiringCache] + self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None if self._external_cache.is_enabled(): self._external_cache_joined_hosts_updates = ExpiringCache( "_external_cache_joined_hosts_updates", @@ -1299,7 +1295,7 @@ class EventCreationHandler: # Validate a newly added alias or newly added alt_aliases. original_alias = None - original_alt_aliases = [] # type: List[str] + original_alt_aliases: List[str] = [] original_event_id = event.unsigned.get("replaces_state") if original_event_id: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index ee6e41c0e4..a330c48fa7 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -105,9 +105,9 @@ class OidcHandler: assert provider_confs self._token_generator = OidcSessionTokenGenerator(hs) - self._providers = { + self._providers: Dict[str, "OidcProvider"] = { p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs - } # type: Dict[str, OidcProvider] + } async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. @@ -178,7 +178,7 @@ class OidcHandler: # are two. for cookie_name, _ in _SESSION_COOKIES: - session = request.getCookie(cookie_name) # type: Optional[bytes] + session: Optional[bytes] = request.getCookie(cookie_name) if session is not None: break else: @@ -277,7 +277,7 @@ class OidcProvider: self._token_generator = token_generator self._config = provider - self._callback_url = hs.config.oidc_callback_url # type: str + self._callback_url: str = hs.config.oidc_callback_url # Calculate the prefix for OIDC callback paths based on the public_baseurl. # We'll insert this into the Path= parameter of any session cookies we set. @@ -290,7 +290,7 @@ class OidcProvider: self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method - client_secret = None # type: Union[None, str, JwtClientSecret] + client_secret: Optional[Union[str, JwtClientSecret]] = None if provider.client_secret: client_secret = provider.client_secret elif provider.client_secret_jwt_key: @@ -305,7 +305,7 @@ class OidcProvider: provider.client_id, client_secret, provider.client_auth_method, - ) # type: ClientAuth + ) self._client_auth_method = provider.client_auth_method # cache of metadata for the identity provider (endpoint uris, mostly). This is @@ -324,7 +324,7 @@ class OidcProvider: self._allow_existing_users = provider.allow_existing_users self._http_client = hs.get_proxied_http_client() - self._server_name = hs.config.server_name # type: str + self._server_name: str = hs.config.server_name # identifier for the external_ids table self.idp_id = provider.idp_id @@ -1381,7 +1381,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if display_name == "": display_name = None - emails = [] # type: List[str] + emails: List[str] = [] email = render_template_field(self._config.email_template) if email: emails.append(email) @@ -1391,7 +1391,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): ) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: - extras = {} # type: Dict[str, str] + extras: Dict[str, str] = {} for key, template in self._config.extra_attributes.items(): try: extras[key] = template.render(user=userinfo).strip() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1e1186c29e..1dbafd253d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -81,9 +81,9 @@ class PaginationHandler: self._server_name = hs.hostname self.pagination_lock = ReadWriteLock() - self._purges_in_progress_by_room = set() # type: Set[str] + self._purges_in_progress_by_room: Set[str] = set() # map from purge id to PurgeStatus - self._purges_by_id = {} # type: Dict[str, PurgeStatus] + self._purges_by_id: Dict[str, PurgeStatus] = {} self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 44ed7a0712..016c5df2ca 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler): # The number of ongoing syncs on this process, by user id. # Empty if _presence_enabled is false. - self._user_to_num_current_syncs = {} # type: Dict[str, int] + self._user_to_num_current_syncs: Dict[str, int] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() # user_id -> last_sync_ms. Lists the users that have stopped syncing but # we haven't notified the presence writer of that yet - self.users_going_offline = {} # type: Dict[str, int] + self.users_going_offline: Dict[str, int] = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() # type: Set[str] + self.unpersisted_users_changes: Set[str] = set() hs.get_reactor().addSystemEventTrigger( "before", @@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} # type: Dict[str, int] + self.user_to_num_current_syncs: Dict[str, int] = {} # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler): # we assume that all the sync requests on that process have stopped. # Stored as a dict from process_id to set of user_id, and a dict of # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]] - self.external_process_last_updated_ms = {} # type: Dict[str, int] + self.external_process_to_current_syncs: Dict[str, Set[str]] = {} + self.external_process_last_updated_ms: Dict[str, int] = {} self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -1581,9 +1581,7 @@ class PresenceEventSource: # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users = ( - set() - ) # type: Union[Set[str], FrozenSet[str]] + interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set() if from_key: # First get all users that have had a presence update @@ -1950,8 +1948,8 @@ async def get_interested_parties( A 2-tuple of `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] - users_to_states = {} # type: Dict[str, List[UserPresenceState]] + room_ids_to_states: Dict[str, List[UserPresenceState]] = {} + users_to_states: Dict[str, List[UserPresenceState]] = {} for state in states: room_ids = await store.get_rooms_for_user(state.user_id) for room_id in room_ids: @@ -2063,12 +2061,12 @@ class PresenceFederationQueue: # stream_id, destinations, user_ids)`. We don't store the full states # for efficiency, and remote workers will already have the full states # cached. - self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]] + self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = [] self._next_id = 1 # Map from instance name to current token - self._current_tokens = {} # type: Dict[str, int] + self._current_tokens: Dict[str, int] = {} if self._queue_presence_updates: self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS) @@ -2168,7 +2166,7 @@ class PresenceFederationQueue: # handle the case where `from_token` stream ID has already been dropped. start_idx = max(from_token + 1 - self._next_id, -len(self._queue)) - to_send = [] # type: List[Tuple[int, Tuple[str, str]]] + to_send: List[Tuple[int, Tuple[str, str]]] = [] limited = False new_id = upto_token for _, stream_id, destinations, user_ids in self._queue[start_idx:]: @@ -2216,7 +2214,7 @@ class PresenceFederationQueue: if not self._federation: return - hosts_to_users = {} # type: Dict[str, Set[str]] + hosts_to_users: Dict[str, Set[str]] = {} for row in rows: hosts_to_users.setdefault(row.destination, set()).add(row.user_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 05b4a97b59..20a033d0ba 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler): 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - displayname_to_set = new_displayname # type: Optional[str] + displayname_to_set: Optional[str] = new_displayname if new_displayname == "": displayname_to_set = None @@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler): 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) - avatar_url_to_set = new_avatar_url # type: Optional[str] + avatar_url_to_set: Optional[str] = new_avatar_url if new_avatar_url == "": avatar_url_to_set = None diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 0059ad0f56..283483fc2c 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -98,8 +98,8 @@ class ReceiptsHandler(BaseHandler): async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: """Takes a list of receipts, stores them and informs the notifier.""" - min_batch_id = None # type: Optional[int] - max_batch_id = None # type: Optional[int] + min_batch_id: Optional[int] = None + max_batch_id: Optional[int] = None for receipt in receipts: res = await self.store.insert_receipt( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 579b1b93c5..64656fda22 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler): self.config = hs.config # Room state based off defined presets - self._presets_dict = { + self._presets_dict: Dict[str, Dict[str, Any]] = { RoomCreationPreset.PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, "history_visibility": HistoryVisibility.SHARED, @@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler): "guest_can_join": False, "power_level_content_override": {}, }, - } # type: Dict[str, Dict[str, Any]] + } # Modify presets to selectively enable encryption by default per homeserver config for preset_name, preset_config in self._presets_dict.items(): @@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler): # If a user tries to update the same room multiple times in quick # succession, only process the first attempt and return its result to # subsequent requests - self._upgrade_response_cache = ResponseCache( + self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS - ) # type: ResponseCache[Tuple[str, str]] + ) self._server_notices_mxid = hs.config.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() @@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler): if not await self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") - creation_content = { + creation_content: JsonDict = { "room_version": new_room_version.identifier, "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, - } # type: JsonDict + } # Check if old room was non-federatable @@ -936,7 +936,7 @@ class RoomCreationHandler(BaseHandler): etype=EventTypes.PowerLevels, content=pl_content ) else: - power_level_content = { + power_level_content: JsonDict = { "users": {creator_id: 100}, "users_default": 0, "events": { @@ -955,7 +955,7 @@ class RoomCreationHandler(BaseHandler): "kick": 50, "redact": 50, "invite": 50, - } # type: JsonDict + } if config["original_invitees_have_ops"]: for invitee in invite_list: diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index c6bfa5451f..6284bcdfbc 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -47,12 +47,12 @@ class RoomListHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search - self.response_cache = ResponseCache( - hs.get_clock(), "room_list" - ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]] - self.remote_response_cache = ResponseCache( - hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 - ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] + self.response_cache: ResponseCache[ + Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] + ] = ResponseCache(hs.get_clock(), "room_list") + self.remote_response_cache: ResponseCache[ + Tuple[str, Optional[int], Optional[str], bool, Optional[str]] + ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000) async def get_local_public_room_list( self, @@ -139,10 +139,10 @@ class RoomListHandler(BaseHandler): if since_token: batch_token = RoomListNextBatch.from_token(since_token) - bounds = ( + bounds: Optional[Tuple[int, str]] = ( batch_token.last_joined_members, batch_token.last_room_id, - ) # type: Optional[Tuple[int, str]] + ) forwards = batch_token.direction_is_forward has_batch_token = True else: @@ -182,7 +182,7 @@ class RoomListHandler(BaseHandler): results = [build_room_entry(r) for r in results] - response = {} # type: JsonDict + response: JsonDict = {} num_results = len(results) if limit is not None: more_to_come = num_results == probing_limit diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 80ba65b9e0..72f54c9403 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -83,7 +83,7 @@ class SamlHandler(BaseHandler): self.unstable_idp_brand = None # a map from saml session id to Saml2SessionData object - self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] + self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {} self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) @@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str: return username -MXID_MAPPER_MAP = { +MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = { "hexencode": map_username_to_mxid_localpart, "dotreplace": dot_replace_for_mxid, -} # type: Dict[str, Callable[[str], str]] +} @attr.s diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 4e718d3f63..8226d6f5a1 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -192,7 +192,7 @@ class SearchHandler(BaseHandler): # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: - historical_room_ids = [] # type: List[str] + historical_room_ids: List[str] = [] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist ids = await self.get_old_rooms_from_upgraded_room(room_id) @@ -216,9 +216,9 @@ class SearchHandler(BaseHandler): rank_map = {} # event_id -> rank of event allowed_events = [] # Holds result of grouping by room, if applicable - room_groups = {} # type: Dict[str, JsonDict] + room_groups: Dict[str, JsonDict] = {} # Holds result of grouping by sender, if applicable - sender_group = {} # type: Dict[str, JsonDict] + sender_group: Dict[str, JsonDict] = {} # Holds the next_batch for the entire result set if one of those exists global_next_batch = None @@ -262,7 +262,7 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - room_events = [] # type: List[EventBase] + room_events: List[EventBase] = [] i = 0 pagination_token = batch_token diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 366e6211e5..5f7d4602bd 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -90,14 +90,14 @@ class SpaceSummaryHandler: room_queue = deque((_RoomQueueEntry(room_id, ()),)) # rooms we have already processed - processed_rooms = set() # type: Set[str] + processed_rooms: Set[str] = set() # events we have already processed. We don't necessarily have their event ids, # so instead we key on (room id, state key) - processed_events = set() # type: Set[Tuple[str, str]] + processed_events: Set[Tuple[str, str]] = set() - rooms_result = [] # type: List[JsonDict] - events_result = [] # type: List[JsonDict] + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] while room_queue and len(rooms_result) < MAX_ROOMS: queue_entry = room_queue.popleft() @@ -272,10 +272,10 @@ class SpaceSummaryHandler: # the set of rooms that we should not walk further. Initialise it with the # excluded-rooms list; we will add other rooms as we process them so that # we do not loop. - processed_rooms = set(exclude_rooms) # type: Set[str] + processed_rooms: Set[str] = set(exclude_rooms) - rooms_result = [] # type: List[JsonDict] - events_result = [] # type: List[JsonDict] + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] while room_queue and len(rooms_result) < MAX_ROOMS: room_id = room_queue.popleft() @@ -353,7 +353,7 @@ class SpaceSummaryHandler: max_children = MAX_ROOMS_PER_SPACE now = self._clock.time_msec() - events_result = [] # type: List[JsonDict] + events_result: List[JsonDict] = [] for edge_event in itertools.islice(child_events, max_children): events_result.append( await self._event_serializer.serialize_event( diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 0b297e54c4..1b855a685c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -202,10 +202,10 @@ class SsoHandler: self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) # a map from session id to session data - self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {} # map from idp_id to SsoIdentityProvider - self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + self._identity_providers: Dict[str, SsoIdentityProvider] = {} self._consent_at_registration = hs.config.consent.user_consent_at_registration @@ -296,7 +296,7 @@ class SsoHandler: ) # if the client chose an IdP, use that - idp = None # type: Optional[SsoIdentityProvider] + idp: Optional[SsoIdentityProvider] = None if idp_id: idp = self._identity_providers.get(idp_id) if not idp: @@ -669,9 +669,9 @@ class SsoHandler: remote_user_id, ) - user_id_to_verify = await self._auth_handler.get_session_data( + user_id_to_verify: str = await self._auth_handler.get_session_data( ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID - ) # type: str + ) if not user_id: logger.warning( @@ -793,7 +793,7 @@ class SsoHandler: session.use_display_name = use_display_name emails_from_idp = set(session.emails) - filtered_emails = set() # type: Set[str] + filtered_emails: Set[str] = set() # we iterate through the list rather than just building a set conjunction, so # that we can log attempts to use unknown addresses diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 814d08efcb..3fd89af2a4 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -49,7 +49,7 @@ class StatsHandler: self.stats_enabled = hs.config.stats_enabled # The current position in the current_state_delta stream - self.pos = None # type: Optional[int] + self.pos: Optional[int] = None # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -131,10 +131,10 @@ class StatsHandler: mapping from room/user ID to changes in the various fields. """ - room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] - user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] + room_to_stats_deltas: Dict[str, CounterType[str]] = {} + user_to_stats_deltas: Dict[str, CounterType[str]] = {} - room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] + room_to_state_updates: Dict[str, Dict[str, Any]] = {} for delta in deltas: typ = delta["type"] @@ -164,7 +164,7 @@ class StatsHandler: ) continue - event_content = {} # type: JsonDict + event_content: JsonDict = {} if event_id is not None: event = await self.store.get_event(event_id, allow_none=True) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b9a0361059..722c4ae670 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -278,12 +278,14 @@ class SyncHandler: self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) - self.lazy_loaded_members_cache = ExpiringCache( + self.lazy_loaded_members_cache: ExpiringCache[ + Tuple[str, Optional[str]], LruCache[str, str] + ] = ExpiringCache( "lazy_loaded_members_cache", self.clock, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, - ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]] + ) async def wait_for_sync_for_user( self, @@ -440,7 +442,7 @@ class SyncHandler: ) now_token = now_token.copy_and_replace("typing_key", typing_key) - ephemeral_by_room = {} # type: JsonDict + ephemeral_by_room: JsonDict = {} for event in typing: # we want to exclude the room_id from the event, but modifying the @@ -502,7 +504,7 @@ class SyncHandler: # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to # ensure that we always include current state in the timeline - current_state_ids = frozenset() # type: FrozenSet[str] + current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): current_state_ids_map = await self.store.get_current_state_ids( room_id @@ -783,9 +785,9 @@ class SyncHandler: def get_lazy_loaded_members_cache( self, cache_key: Tuple[str, Optional[str]] ) -> LruCache[str, str]: - cache = self.lazy_loaded_members_cache.get( + cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get( cache_key - ) # type: Optional[LruCache[str, str]] + ) if cache is None: logger.debug("creating LruCache for %r", cache_key) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) @@ -984,7 +986,7 @@ class SyncHandler: if t[0] == EventTypes.Member: cache.set(t[1], event_id) - state = {} # type: Dict[str, EventBase] + state: Dict[str, EventBase] = {} if state_ids: state = await self.store.get_events(list(state_ids.values())) @@ -1088,8 +1090,8 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_key_counts = {} # type: JsonDict - unused_fallback_key_types = [] # type: List[str] + one_time_key_counts: JsonDict = {} + unused_fallback_key_types: List[str] = [] if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id @@ -1437,7 +1439,7 @@ class SyncHandler: ) if block_all_room_ephemeral: - ephemeral_by_room = {} # type: Dict[str, List[JsonDict]] + ephemeral_by_room: Dict[str, List[JsonDict]] = {} else: now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, @@ -1468,7 +1470,7 @@ class SyncHandler: # If there is ignored users account data and it matches the proper type, # then use it. - ignored_users = frozenset() # type: FrozenSet[str] + ignored_users: FrozenSet[str] = frozenset() if ignored_account_data: ignored_users_data = ignored_account_data.get("ignored_users", {}) if isinstance(ignored_users_data, dict): @@ -1586,7 +1588,7 @@ class SyncHandler: user_id, since_token.room_key, now_token.room_key ) - mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]] + mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} for event in rooms_changed: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) @@ -1722,7 +1724,7 @@ class SyncHandler: # This is all screaming out for a refactor, as the logic here is # subtle and the moving parts numerous. if leave_event.internal_metadata.is_out_of_band_membership(): - batch_events = [leave_event] # type: Optional[List[EventBase]] + batch_events: Optional[List[EventBase]] = [leave_event] else: batch_events = None @@ -1971,7 +1973,7 @@ class SyncHandler: room_id, batch, sync_config, since_token, now_token, full_state=full_state ) - summary = {} # type: Optional[JsonDict] + summary: Optional[JsonDict] = {} # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form @@ -1995,7 +1997,7 @@ class SyncHandler: ) if room_builder.rtype == "joined": - unread_notifications = {} # type: Dict[str, int] + unread_notifications: Dict[str, int] = {} room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c0a8364755..0cb651a400 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -68,11 +68,11 @@ class FollowerTypingHandler: ) # map room IDs to serial numbers - self._room_serials = {} # type: Dict[str, int] + self._room_serials: Dict[str, int] = {} # map room IDs to sets of users currently typing - self._room_typing = {} # type: Dict[str, Set[str]] + self._room_typing: Dict[str, Set[str]] = {} - self._member_last_federation_poke = {} # type: Dict[RoomMember, int] + self._member_last_federation_poke: Dict[RoomMember, int] = {} self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 @@ -217,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) # clock time we expect to stop - self._member_typing_until = {} # type: Dict[RoomMember, int] + self._member_typing_until: Dict[RoomMember, int] = {} # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( @@ -405,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler): if last_id == current_id: return [], current_id, False - changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( - last_id - ) # type: Optional[Iterable[str]] + changed_rooms: Optional[ + Iterable[str] + ] = self._typing_stream_change_cache.get_all_entities_changed(last_id) if changed_rooms is None: changed_rooms = self._room_serials diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index dacc4f3076..6edb1da50a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.search_all_users = hs.config.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream - self.pos = None # type: Optional[int] + self.pos: Optional[int] = None # Guard to ensure we only process deltas one at a time self._is_processing = False diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 3c51a742bf..40ee33646c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): # Get the room ID from the identifier. try: - remote_room_hosts = [ + remote_room_hosts: Optional[List[str]] = [ x.decode("ascii") for x in request.args[b"server_name"] - ] # type: Optional[List[str]] + ] except Exception: remote_room_hosts = None room_id, remote_room_hosts = await self.resolve_room_id( @@ -659,9 +659,7 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) else: event_filter = None diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 06e6ccee42..589e47fa47 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth_handler = hs.get_auth_handler() self.reactor = hs.get_reactor() - self.nonces = {} # type: Dict[str, int] + self.nonces: Dict[str, int] = {} self.hs = hs def _clear_old_nonces(self): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index cbcb60fe31..99d02cb355 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -121,7 +121,7 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow = { + sso_flow: JsonDict = { "type": LoginRestServlet.SSO_TYPE, "identity_providers": [ _get_auth_flow_dict_for_idp( @@ -129,7 +129,7 @@ class LoginRestServlet(RestServlet): ) for idp in self._sso_handler.get_identity_providers().values() ], - } # type: JsonDict + } if self._msc2858_enabled: # backwards-compatibility support for clients which don't @@ -447,7 +447,7 @@ def _get_auth_flow_dict_for_idp( use_unstable_brands: whether we should use brand identifiers suitable for the unstable API """ - e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} if idp.idp_icon: e["icon"] = idp.idp_icon if idp.idp_brand: @@ -561,7 +561,7 @@ class SsoRedirectServlet(RestServlet): finish_request(request) return - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) sso_url = await self._sso_handler.handle_redirect_request( request, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index ebf4e32230..31a1193cd3 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -783,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) - limit = int(content.get("limit", 100)) # type: Optional[int] + limit: Optional[int] = int(content.get("limit", 100)) since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -929,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -1044,9 +1042,7 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) else: event_filter = None diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index f8dcee603c..d537d811d8 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester, message_type, content["messages"] ) - response = (200, {}) # type: Tuple[int, dict] + response: Tuple[int, dict] = (200, {}) return response diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index e52570cd8e..4282e2b228 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource): has_consented = False public_version = username == "" if not public_version: - args = request.args # type: Dict[bytes, List[bytes]] + args: Dict[bytes, List[bytes]] = request.args userhmac_bytes = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac_bytes) @@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource): """ version = parse_string(request, "v", required=True) username = parse_string(request, "u", required=True) - args = request.args # type: Dict[bytes, List[bytes]] + args: Dict[bytes, List[bytes]] = request.args userhmac = parse_bytes_from_args(args, "h", required=True) self._check_hash(username, userhmac) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index d56a1ae482..63a40b1852 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource): async def _async_render_GET(self, request): if len(request.postpath) == 1: (server,) = request.postpath - query = {server.decode("ascii"): {}} # type: dict + query: dict = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") @@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() # Note that the value is unused. - cache_misses = {} # type: Dict[str, Dict[str, int]] + cache_misses: Dict[str, Dict[str, int]] = {} for (server_name, key_id, _), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 0fb4cd81f1..90364ebcf7 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: # The type on postpath seems incorrect in Twisted 21.2.0. - postpath = request.postpath # type: List[bytes] # type: ignore + postpath: List[bytes] = request.postpath # type: ignore assert postpath # This allows users to append e.g. /test.png to the URL. Useful for diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 21c43c340c..4f702f890c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -78,16 +78,16 @@ class MediaRepository: Thumbnailer.set_limits(self.max_image_pixels) - self.primary_base_path = hs.config.media_store_path # type: str - self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths + self.primary_base_path: str = hs.config.media_store_path + self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") - self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] - self.recently_accessed_locals = set() # type: Set[str] + self.recently_accessed_remotes: Set[Tuple[str, str]] = set() + self.recently_accessed_locals: Set[str] = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -711,7 +711,7 @@ class MediaRepository: # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. - thumbnails = {} # type: Dict[Tuple[int, int, str], str] + thumbnails: Dict[Tuple[int, int, str], str] = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index c7fd97c46c..56cdc1b4ed 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -191,7 +191,7 @@ class MediaStorage: for provider in self.storage_providers: for path in paths: - res = await provider.fetch(path, file_info) # type: Any + res: Any = await provider.fetch(path, file_info) if res: logger.debug("Streaming %s from %s", path, provider) return res @@ -233,7 +233,7 @@ class MediaStorage: os.makedirs(dirname) for provider in self.storage_providers: - res = await provider.fetch(path, file_info) # type: Any + res: Any = await provider.fetch(path, file_info) if res: with res: consumer = BackgroundFileConsumer( diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0adfb1a70f..8e7fead3a2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource): # memory cache mapping urls to an ObservableDeferred returning # JSON-encoded OG metadata - self._cache = ExpiringCache( + self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache( cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=ONE_HOUR, - ) # type: ExpiringCache[str, ObservableDeferred] + ) if self._worker_run_media_background_jobs: self._cleaner_loop = self.clock.looping_call( @@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) # If this URL can be accessed via oEmbed, use that instead. - url_to_download = url # type: Optional[str] + url_to_download: Optional[str] = url oembed_url = self._get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. @@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: # "og:video:height" : "720", # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - og = {} # type: Dict[str, Optional[str]] + og: Dict[str, Optional[str]] = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): if "content" in tag.attrib: # if we've got more than 50 tags, someone is taking the piss diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 62dc4aae2d..146adca8f1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource): errcode=Codes.TOO_LARGE, ) - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore upload_name_bytes = parse_bytes_from_args(args, "filename") if upload_name_bytes: try: - upload_name = upload_name_bytes.decode("utf8") # type: Optional[str] + upload_name: Optional[str] = upload_name_bytes.decode("utf8") except UnicodeDecodeError: raise SynapseError( msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 @@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource): # TODO(markjh): parse content-dispostion try: - content = request.content # type: IO # type: ignore + content: IO = request.content # type: ignore content_uri = await self.media_repo.create_content( media_type, upload_name, content, content_length, requester.user ) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 9b002cc15e..ab24ec0a8e 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource): use_display_name = parse_boolean(request, "use_display_name", default=False) try: - emails_to_use = [ + emails_to_use: List[str] = [ val.decode("utf-8") for val in request.args.get(b"use_email", []) - ] # type: List[str] + ] except ValueError: raise SynapseError(400, "Query parameter use_email must be utf-8") except SynapseError as e: -- cgit 1.4.1 From 95e47b2e782b5e7afa5fd2afd1d0ea7745eaac36 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 19 Jul 2021 16:28:05 +0200 Subject: [pyupgrade] `synapse/` (#10348) This PR is tantamount to running ``` pyupgrade --py36-plus --keep-percent-format `find synapse/ -type f -name "*.py"` ``` Part of #9744 --- changelog.d/10348.misc | 1 + synapse/app/generic_worker.py | 6 ++-- synapse/app/homeserver.py | 6 ++-- synapse/config/appservice.py | 2 +- synapse/config/tls.py | 6 ++-- synapse/handlers/cas.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 4 +-- synapse/handlers/oidc.py | 38 ++++++++++++++------------ synapse/handlers/register.py | 15 ++++------ synapse/handlers/saml.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/http/proxyagent.py | 2 +- synapse/http/site.py | 2 +- synapse/logging/opentracing.py | 2 +- synapse/metrics/_exposition.py | 26 ++++++++---------- synapse/metrics/background_process_metrics.py | 3 +- synapse/rest/client/v1/login.py | 25 ++++++----------- synapse/rest/media/v1/__init__.py | 4 +-- synapse/storage/database.py | 2 +- synapse/storage/databases/main/deviceinbox.py | 4 +-- synapse/storage/databases/main/group_server.py | 6 +++- synapse/storage/databases/main/roommember.py | 2 +- synapse/storage/prepare_database.py | 2 +- synapse/types.py | 4 +-- synapse/util/caches/lrucache.py | 3 +- synapse/util/caches/treecache.py | 3 +- synapse/util/daemonize.py | 8 +++--- synapse/visibility.py | 4 +-- 29 files changed, 86 insertions(+), 102 deletions(-) create mode 100644 changelog.d/10348.misc (limited to 'synapse/handlers/federation.py') diff --git a/changelog.d/10348.misc b/changelog.d/10348.misc new file mode 100644 index 0000000000..b2275a1350 --- /dev/null +++ b/changelog.d/10348.misc @@ -0,0 +1 @@ +Run `pyupgrade` on the codebase. \ No newline at end of file diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index b43d858f59..c3d4992518 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -395,10 +395,8 @@ class GenericWorkerServer(HomeServer): elif listener.type == "metrics": if not self.config.enable_metrics: logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) + "Metrics listener configured, but " + "enable_metrics is not True!" ) else: _base.listen_metrics(listener.bind_addresses, listener.port) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7af56ac136..920b34d97b 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -305,10 +305,8 @@ class SynapseHomeServer(HomeServer): elif listener.type == "metrics": if not self.config.enable_metrics: logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) + "Metrics listener configured, but " + "enable_metrics is not True!" ) else: _base.listen_metrics(listener.bind_addresses, listener.port) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index a39d457c56..1ebea88db2 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -64,7 +64,7 @@ def load_appservices(hostname, config_files): for config_file in config_files: try: - with open(config_file, "r") as f: + with open(config_file) as f: appservice = _load_appservice(hostname, yaml.safe_load(f), config_file) if appservice.id in seen_ids: raise ConfigError( diff --git a/synapse/config/tls.py b/synapse/config/tls.py index fed05ac7be..5679f05e42 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -66,10 +66,8 @@ class TlsConfig(Config): if self.federation_client_minimum_tls_version == "1.3": if getattr(SSL, "OP_NO_TLSv1_3", None) is None: raise ConfigError( - ( - "federation_client_minimum_tls_version cannot be 1.3, " - "your OpenSSL does not support it" - ) + "federation_client_minimum_tls_version cannot be 1.3, " + "your OpenSSL does not support it" ) # Whitelist of domains to not verify certificates for diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index b681d208bc..0325f86e20 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -40,7 +40,7 @@ class CasError(Exception): def __str__(self): if self.error_description: - return "{}: {}".format(self.error, self.error_description) + return f"{self.error}: {self.error_description}" return self.error diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5c4463583e..cf389be3e4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -735,7 +735,7 @@ class FederationHandler(BaseHandler): # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - (await self.store.get_events(missing_desired_events, allow_rejected=True)) + await self.store.get_events(missing_desired_events, allow_rejected=True) ) # check for events which were in the wrong room. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 33d16fbf9c..0961dec5ab 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler): ) url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") + url_bytes = b"/_matrix/identity/api/v1/3pid/unbind" content = { "mxid": mxid, @@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler): return data["mxid"] except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - except IOError as e: + except OSError as e: logger.warning("Error from v1 identity server lookup: %s" % (e,)) return None diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index a330c48fa7..eca8f16040 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -72,26 +72,26 @@ _SESSION_COOKIES = [ (b"oidc_session_no_samesite", b"HttpOnly"), ] + #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and #: OpenID.Core sec 3.1.3.3. -Token = TypedDict( - "Token", - { - "access_token": str, - "token_type": str, - "id_token": Optional[str], - "refresh_token": Optional[str], - "expires_in": int, - "scope": Optional[str], - }, -) +class Token(TypedDict): + access_token: str + token_type: str + id_token: Optional[str] + refresh_token: Optional[str] + expires_in: int + scope: Optional[str] + #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but #: there is no real point of doing this in our case. JWK = Dict[str, str] + #: A JWK Set, as per RFC7517 sec 5. -JWKS = TypedDict("JWKS", {"keys": List[JWK]}) +class JWKS(TypedDict): + keys: List[JWK] class OidcHandler: @@ -255,7 +255,7 @@ class OidcError(Exception): def __str__(self): if self.error_description: - return "{}: {}".format(self.error, self.error_description) + return f"{self.error}: {self.error_description}" return self.error @@ -639,7 +639,7 @@ class OidcProvider: ) logger.warning(description) # Body was still valid JSON. Might be useful to log it for debugging. - logger.warning("Code exchange response: {resp!r}".format(resp=resp)) + logger.warning("Code exchange response: %r", resp) raise OidcError("server_error", description) return resp @@ -1217,10 +1217,12 @@ class OidcSessionData: ui_auth_session_id = attr.ib(type=str) -UserAttributeDict = TypedDict( - "UserAttributeDict", - {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, -) +class UserAttributeDict(TypedDict): + localpart: Optional[str] + display_name: Optional[str] + emails: List[str] + + C = TypeVar("C") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 056fe5e89f..8cf614136e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -55,15 +55,12 @@ login_counter = Counter( ["guest", "auth_provider"], ) -LoginDict = TypedDict( - "LoginDict", - { - "device_id": str, - "access_token": str, - "valid_until_ms": Optional[int], - "refresh_token": Optional[str], - }, -) + +class LoginDict(TypedDict): + device_id: str + access_token: str + valid_until_ms: Optional[int] + refresh_token: Optional[str] class RegistrationHandler(BaseHandler): diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 72f54c9403..e6e71e9729 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -372,7 +372,7 @@ class SamlHandler(BaseHandler): DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) + "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),) ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 722c4ae670..150a4f291e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1601,7 +1601,7 @@ class SyncHandler: logger.debug( "Membership changes in %s: [%s]", room_id, - ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)), + ", ".join("%s (%s)" % (e.event_id, e.membership) for e in events), ) non_joins = [e for e in events if e.membership != Membership.JOIN] diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7a6a1717de..f7193e60bd 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -172,7 +172,7 @@ class ProxyAgent(_AgentBase): """ uri = uri.strip() if not _VALID_URI.match(uri): - raise ValueError("Invalid URI {!r}".format(uri)) + raise ValueError(f"Invalid URI {uri!r}") parsed_uri = URI.fromBytes(uri) pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) diff --git a/synapse/http/site.py b/synapse/http/site.py index 3b0a38124e..190084e8aa 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -384,7 +384,7 @@ class SynapseRequest(Request): # authenticated (e.g. and admin is puppetting a user) then we log both. requester, authenticated_entity = self.get_authenticated_entity() if authenticated_entity: - requester = "{}.{}".format(authenticated_entity, requester) + requester = f"{authenticated_entity}.{requester}" self.site.access_logger.log( log_level, diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 185844f188..ecd51f1b4a 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -374,7 +374,7 @@ def init_tracer(hs: "HomeServer"): config = JaegerConfig( config=hs.config.jaeger_config, - service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + service_name=f"{hs.config.server_name} {hs.get_instance_name()}", scope_manager=LogContextScopeManager(hs.config), metrics_factory=PrometheusMetricsFactory(), ) diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 7e49d0d02c..bb9bcb5592 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -34,7 +34,7 @@ from twisted.web.resource import Resource from synapse.util import caches -CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" INF = float("inf") @@ -55,8 +55,8 @@ def floatToGoString(d): # Go switches to exponents sooner than Python. # We only need to care about positive values for le/quantile. if d > 0 and dot > 6: - mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.") - return "{0}e+0{1}".format(mantissa, dot - 1) + mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.") + return f"{mantissa}e+0{dot - 1}" return s @@ -65,7 +65,7 @@ def sample_line(line, name): labelstr = "{{{0}}}".format( ",".join( [ - '{0}="{1}"'.format( + '{}="{}"'.format( k, v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), ) @@ -78,10 +78,8 @@ def sample_line(line, name): timestamp = "" if line.timestamp is not None: # Convert to milliseconds. - timestamp = " {0:d}".format(int(float(line.timestamp) * 1000)) - return "{0}{1} {2}{3}\n".format( - name, labelstr, floatToGoString(line.value), timestamp - ) + timestamp = f" {int(float(line.timestamp) * 1000):d}" + return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) def generate_latest(registry, emit_help=False): @@ -118,12 +116,12 @@ def generate_latest(registry, emit_help=False): # Output in the old format for compatibility. if emit_help: output.append( - "# HELP {0} {1}\n".format( + "# HELP {} {}\n".format( mname, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0} {1}\n".format(mname, mtype)) + output.append(f"# TYPE {mname} {mtype}\n") om_samples: Dict[str, List[str]] = {} for s in metric.samples: @@ -143,13 +141,13 @@ def generate_latest(registry, emit_help=False): for suffix, lines in sorted(om_samples.items()): if emit_help: output.append( - "# HELP {0}{1} {2}\n".format( + "# HELP {}{} {}\n".format( metric.name, suffix, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix)) + output.append(f"# TYPE {metric.name}{suffix} gauge\n") output.extend(lines) # Get rid of the weird colon things while we're at it @@ -163,12 +161,12 @@ def generate_latest(registry, emit_help=False): # Also output in the new format, if it's different. if emit_help: output.append( - "# HELP {0} {1}\n".format( + "# HELP {} {}\n".format( mnewname, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0} {1}\n".format(mnewname, mtype)) + output.append(f"# TYPE {mnewname} {mtype}\n") for s in metric.samples: # Get rid of the OpenMetrics specific samples (we should already have diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 4455fa71a8..3a14260752 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -137,8 +137,7 @@ class _Collector: _background_process_db_txn_duration, _background_process_db_sched_duration, ): - for r in m.collect(): - yield r + yield from m.collect() REGISTRY.register(_Collector()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 99d02cb355..11567bf32c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,19 +44,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -LoginResponse = TypedDict( - "LoginResponse", - { - "user_id": str, - "access_token": str, - "home_server": str, - "expires_in_ms": Optional[int], - "refresh_token": Optional[str], - "device_id": str, - "well_known": Optional[Dict[str, Any]], - }, - total=False, -) +class LoginResponse(TypedDict, total=False): + user_id: str + access_token: str + home_server: str + expires_in_ms: Optional[int] + refresh_token: Optional[str] + device_id: str + well_known: Optional[Dict[str, Any]] class LoginRestServlet(RestServlet): @@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend( - ({"type": t} for t in self.auth_handler.get_supported_login_types()) - ) + flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) diff --git a/synapse/rest/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py index d20186bbd0..3dd16d4bb5 100644 --- a/synapse/rest/media/v1/__init__.py +++ b/synapse/rest/media/v1/__init__.py @@ -17,7 +17,7 @@ import PIL.Image # check for JPEG support. try: PIL.Image._getdecoder("rgb", "jpeg", None) -except IOError as e: +except OSError as e: if str(e).startswith("decoder jpeg not available"): raise Exception( "FATAL: jpeg codec not supported. Install pillow correctly! " @@ -32,7 +32,7 @@ except Exception: # check for PNG support. try: PIL.Image._getdecoder("rgb", "zip", None) -except IOError as e: +except OSError as e: if str(e).startswith("decoder zip not available"): raise Exception( "FATAL: zip codec not supported. Install pillow correctly! " diff --git a/synapse/storage/database.py b/synapse/storage/database.py index f80d822c12..ccf9ac51ef 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -907,7 +907,7 @@ class DatabasePool: # The sort is to ensure that we don't rely on dictionary iteration # order. keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i) ) for k in keys: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 50e7ddd735..c55508867d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -203,9 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): "delete_messages_for_device", delete_messages_for_device_txn ) - log_kv( - {"message": "deleted {} messages for device".format(count), "count": count} - ) + log_kv({"message": f"deleted {count} messages for device", "count": count}) # Update the cache, ensuring that we only ever increase the value last_deleted_stream_id = self._last_device_delete_cache.get( diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 66ad363bfb..e70d3649ff 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -27,8 +27,11 @@ from synapse.util import json_encoder _DEFAULT_CATEGORY_ID = "" _DEFAULT_ROLE_ID = "" + # A room in a group. -_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool}) +class _RoomInGroup(TypedDict): + room_id: str + is_public: bool class GroupServerWorkerStore(SQLBaseStore): @@ -92,6 +95,7 @@ class GroupServerWorkerStore(SQLBaseStore): "is_public": False # Whether this is a public room or not } """ + # TODO: Pagination def _get_rooms_in_group_txn(txn): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4d82c4c26d..68f1b40ea6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -649,7 +649,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): event_to_memberships = await self._get_joined_profiles_from_event_ids( missing_member_event_ids ) - users_in_room.update((row for row in event_to_memberships.values() if row)) + users_in_room.update(row for row in event_to_memberships.values() if row) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 82a7686df0..61392b9639 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -639,7 +639,7 @@ def get_statements(f: Iterable[str]) -> Generator[str, None, None]: def executescript(txn: Cursor, schema_path: str) -> None: - with open(schema_path, "r") as f: + with open(schema_path) as f: execute_statements_from_stream(txn, f) diff --git a/synapse/types.py b/synapse/types.py index fad23c8700..429bb013d2 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -577,10 +577,10 @@ class RoomStreamToken: entries = [] for name, pos in self.instance_map.items(): instance_id = await store.get_id_for_instance(name) - entries.append("{}.{}".format(instance_id, pos)) + entries.append(f"{instance_id}.{pos}") encoded_map = "~".join(entries) - return "m{}~{}".format(self.stream, encoded_map) + return f"m{self.stream}~{encoded_map}" else: return "s%d" % (self.stream,) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index efeba0cb96..5c65d187b6 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -90,8 +90,7 @@ def enumerate_leaves(node, depth): yield node else: for n in node.values(): - for m in enumerate_leaves(n, depth - 1): - yield m + yield from enumerate_leaves(n, depth - 1) P = TypeVar("P") diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index a6df81ebff..4138931e7b 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d): """ if isinstance(d, TreeCacheNode): for value_d in d.values(): - for value in iterate_tree_cache_entry(value_d): - yield value + yield from iterate_tree_cache_entry(value_d) else: yield d diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index 31b24dd188..d8532411c2 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -31,13 +31,13 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # If pidfile already exists, we should read pid from there; to overwrite it, if # locking will fail, because locking attempt somehow purges the file contents. if os.path.isfile(pid_file): - with open(pid_file, "r") as pid_fh: + with open(pid_file) as pid_fh: old_pid = pid_fh.read() # Create a lockfile so that only one instance of this daemon is running at any time. try: lock_fh = open(pid_file, "w") - except IOError: + except OSError: print("Unable to create the pidfile.") sys.exit(1) @@ -45,7 +45,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # Try to get an exclusive lock on the file. This will fail if another process # has the file locked. fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB) - except IOError: + except OSError: print("Unable to lock on the pidfile.") # We need to overwrite the pidfile if we got here. # @@ -113,7 +113,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - try: lock_fh.write("%s" % (os.getpid())) lock_fh.flush() - except IOError: + except OSError: logger.error("Unable to write pid to the pidfile.") print("Unable to write pid to the pidfile.") sys.exit(1) diff --git a/synapse/visibility.py b/synapse/visibility.py index 1dc6b90275..17532059e9 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -96,7 +96,7 @@ async def filter_events_for_client( if isinstance(ignored_users_dict, dict): ignore_list = frozenset(ignored_users_dict.keys()) - erased_senders = await storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased(e.sender for e in events) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -353,7 +353,7 @@ async def filter_events_for_server( ) if not check_history_visibility_only: - erased_senders = await storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased(e.sender for e in events) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. -- cgit 1.4.1 From a743bf46949e851c9a10d8e01a138659f3af2484 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 20 Jul 2021 12:39:46 +0200 Subject: Port the ThirdPartyEventRules module interface to the new generic interface (#10386) Port the third-party event rules interface to the generic module interface introduced in v1.37.0 --- changelog.d/10386.removal | 1 + docs/modules.md | 62 ++++++- docs/sample_config.yaml | 13 -- docs/upgrade.md | 13 ++ synapse/app/_base.py | 2 + synapse/config/third_party_event_rules.py | 15 -- synapse/events/third_party_rules.py | 245 +++++++++++++++++++++++----- synapse/handlers/federation.py | 4 +- synapse/handlers/message.py | 8 +- synapse/handlers/room.py | 10 +- synapse/module_api/__init__.py | 6 + tests/rest/client/test_third_party_rules.py | 132 ++++++++++++--- 12 files changed, 403 insertions(+), 108 deletions(-) create mode 100644 changelog.d/10386.removal (limited to 'synapse/handlers/federation.py') diff --git a/changelog.d/10386.removal b/changelog.d/10386.removal new file mode 100644 index 0000000000..800a6143d7 --- /dev/null +++ b/changelog.d/10386.removal @@ -0,0 +1 @@ +The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information. diff --git a/docs/modules.md b/docs/modules.md index c4cb7018f7..9a430390a4 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -186,7 +186,7 @@ The arguments passed to this callback are: ```python async def check_media_file_for_spam( file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper", - file_info: "synapse.rest.media.v1._base.FileInfo" + file_info: "synapse.rest.media.v1._base.FileInfo", ) -> bool ``` @@ -223,6 +223,66 @@ Called after successfully registering a user, in case the module needs to perfor operations to keep track of them. (e.g. add them to a database table). The user is represented by their Matrix user ID. +#### Third party rules callbacks + +Third party rules callbacks allow module developers to add extra checks to verify the +validity of incoming events. Third party event rules callbacks can be registered using +the module API's `register_third_party_rules_callbacks` method. + +The available third party rules callbacks are: + +```python +async def check_event_allowed( + event: "synapse.events.EventBase", + state_events: "synapse.types.StateMap", +) -> Tuple[bool, Optional[dict]] +``` + +** +This callback is very experimental and can and will break without notice. Module developers +are encouraged to implement `check_event_for_spam` from the spam checker category instead. +** + +Called when processing any incoming event, with the event and a `StateMap` +representing the current state of the room the event is being sent into. A `StateMap` is +a dictionary that maps tuples containing an event type and a state key to the +corresponding state event. For example retrieving the room's `m.room.create` event from +the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`. +The module must return a boolean indicating whether the event can be allowed. + +Note that this callback function processes incoming events coming via federation +traffic (on top of client traffic). This means denying an event might cause the local +copy of the room's history to diverge from that of remote servers. This may cause +federation issues in the room. It is strongly recommended to only deny events using this +callback function if the sender is a local user, or in a private federation in which all +servers are using the same module, with the same configuration. + +If the boolean returned by the module is `True`, it may also tell Synapse to replace the +event with new data by returning the new event's data as a dictionary. In order to do +that, it is recommended the module calls `event.get_dict()` to get the current event as a +dictionary, and modify the returned dictionary accordingly. + +Note that replacing the event only works for events sent by local users, not for events +received over federation. + +```python +async def on_create_room( + requester: "synapse.types.Requester", + request_content: dict, + is_requester_admin: bool, +) -> None +``` + +Called when processing a room creation request, with the `Requester` object for the user +performing the request, a dictionary representing the room creation request's JSON body +(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom) +for a list of possible parameters), and a boolean indicating whether the user performing +the request is a server admin. + +Modules can modify the `request_content` (by e.g. adding events to its `initial_state`), +or deny the room's creation by raising a `module_api.errors.SynapseError`. + + ### Porting an existing module that uses the old interface In order to port a module that uses Synapse's old module interface, its author needs to: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index f4845a5841..853c2f6899 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2654,19 +2654,6 @@ stats: # action: allow -# Server admins can define a Python module that implements extra rules for -# allowing or denying incoming events. In order to work, this module needs to -# override the methods defined in synapse/events/third_party_rules.py. -# -# This feature is designed to be used in closed federations only, where each -# participating server enforces the same rules. -# -#third_party_event_rules: -# module: "my_custom_project.SuperRulesSet" -# config: -# example_option: 'things' - - ## Opentracing ## # These settings enable opentracing, which implements distributed tracing. diff --git a/docs/upgrade.md b/docs/upgrade.md index db0450f563..c8f4a2c171 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -86,6 +86,19 @@ process, for example: ``` +# Upgrading to v1.39.0 + +## Deprecation of the current third-party rules module interface + +The current third-party rules module interface is deprecated in favour of the new generic +modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer +to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface) +to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules) +to update their configuration once the modules they are using have been updated. + +We plan to remove support for the current third-party rules interface in September 2021. + + # Upgrading to v1.38.0 ## Re-indexing of `events` table on Postgres databases diff --git a/synapse/app/_base.py b/synapse/app/_base.py index b30571fe49..50a02f51f5 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.events.spamcheck import load_legacy_spam_checkers +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -368,6 +369,7 @@ async def start(hs: "HomeServer"): module(config=config, api=module_api) load_legacy_spam_checkers(hs) + load_legacy_third_party_event_rules(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index f502ff539e..a3fae02420 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config): self.third_party_event_rules = load_module( provider, ("third_party_event_rules",) ) - - def generate_config_section(self, **kwargs): - return """\ - # Server admins can define a Python module that implements extra rules for - # allowing or denying incoming events. In order to work, this module needs to - # override the methods defined in synapse/events/third_party_rules.py. - # - # This feature is designed to be used in closed federations only, where each - # participating server enforces the same rules. - # - #third_party_event_rules: - # module: "my_custom_project.SuperRulesSet" - # config: - # example_option: 'things' - """ diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index f7944fd834..7a6eb3e516 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -11,16 +11,124 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple -from typing import TYPE_CHECKING, Union - +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Requester, StateMap +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + + +CHECK_EVENT_ALLOWED_CALLBACK = Callable[ + [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]] +] +ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable] +CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ + [str, str, StateMap[EventBase]], Awaitable[bool] +] +CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ + [str, StateMap[EventBase], str], Awaitable[bool] +] + + +def load_legacy_third_party_event_rules(hs: "HomeServer"): + """Wrapper that loads a third party event rules module configured using the old + configuration, and registers the hooks they implement. + """ + if hs.config.third_party_event_rules is None: + return + + module, config = hs.config.third_party_event_rules + + api = hs.get_module_api() + third_party_rules = module(config=config, module_api=api) + + # The known hooks. If a module implements a method which name appears in this set, + # we'll want to register it. + third_party_event_rules_methods = { + "check_event_allowed", + "on_create_room", + "check_threepid_can_be_invited", + "check_visibility_can_be_modified", + } + + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + # We return a separate wrapper for these methods because, in order to wrap them + # correctly, we need to await its result. Therefore it doesn't make a lot of + # sense to make it go through the run() wrapper. + if f.__name__ == "check_event_allowed": + + # We need to wrap check_event_allowed because its old form would return either + # a boolean or a dict, but now we want to return the dict separately from the + # boolean. + async def wrap_check_event_allowed( + event: EventBase, + state_events: StateMap[EventBase], + ) -> Tuple[bool, Optional[dict]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + res = await f(event, state_events) + if isinstance(res, dict): + return True, res + else: + return res, None + + return wrap_check_event_allowed + + if f.__name__ == "on_create_room": + + # We need to wrap on_create_room because its old form would return a boolean + # if the room creation is denied, but now we just want it to raise an + # exception. + async def wrap_on_create_room( + requester: Requester, config: dict, is_requester_admin: bool + ) -> None: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + res = await f(requester, config, is_requester_admin) + if res is False: + raise SynapseError( + 403, + "Room creation forbidden with these parameters", + ) + + return wrap_on_create_room + + def run(*args, **kwargs): + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None + + return maybe_awaitable(f(*args, **kwargs)) + + return run + + # Register the hooks through the module API. + hooks = { + hook: async_wrapper(getattr(third_party_rules, hook, None)) + for hook in third_party_event_rules_methods + } + + api.register_third_party_rules_callbacks(**hooks) + class ThirdPartyEventRules: """Allows server admins to provide a Python module implementing an extra @@ -35,36 +143,65 @@ class ThirdPartyEventRules: self.store = hs.get_datastore() - module = None - config = None - if hs.config.third_party_event_rules: - module, config = hs.config.third_party_event_rules + self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] + self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] + self._check_threepid_can_be_invited_callbacks: List[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = [] + self._check_visibility_can_be_modified_callbacks: List[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = [] + + def register_third_party_rules_callbacks( + self, + check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, + on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, + check_threepid_can_be_invited: Optional[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = None, + check_visibility_can_be_modified: Optional[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = None, + ): + """Register callbacks from modules for each hook.""" + if check_event_allowed is not None: + self._check_event_allowed_callbacks.append(check_event_allowed) + + if on_create_room is not None: + self._on_create_room_callbacks.append(on_create_room) + + if check_threepid_can_be_invited is not None: + self._check_threepid_can_be_invited_callbacks.append( + check_threepid_can_be_invited, + ) - if module is not None: - self.third_party_rules = module( - config=config, - module_api=hs.get_module_api(), + if check_visibility_can_be_modified is not None: + self._check_visibility_can_be_modified_callbacks.append( + check_visibility_can_be_modified, ) async def check_event_allowed( self, event: EventBase, context: EventContext - ) -> Union[bool, dict]: + ) -> Tuple[bool, Optional[dict]]: """Check if a provided event should be allowed in the given context. The module can return: * True: the event is allowed. * False: the event is not allowed, and should be rejected with M_FORBIDDEN. - * a dict: replacement event data. + + If the event is allowed, the module can also return a dictionary to use as a + replacement for the event. Args: event: The event to be checked. context: The context of the event. Returns: - The result from the ThirdPartyRules module, as above + The result from the ThirdPartyRules module, as above. """ - if self.third_party_rules is None: - return True + # Bail out early without hitting the store if we don't have any callbacks to run. + if len(self._check_event_allowed_callbacks) == 0: + return True, None prev_state_ids = await context.get_prev_state_ids() @@ -77,29 +214,46 @@ class ThirdPartyEventRules: # the hashes and signatures. event.freeze() - return await self.third_party_rules.check_event_allowed(event, state_events) + for callback in self._check_event_allowed_callbacks: + try: + res, replacement_data = await callback(event, state_events) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue + + # Return if the event shouldn't be allowed or if the module came up with a + # replacement dict for the event. + if res is False: + return res, None + elif isinstance(replacement_data, dict): + return True, replacement_data + + return True, None async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ) -> bool: - """Intercept requests to create room to allow, deny or update the - request config. + ) -> None: + """Intercept requests to create room to maybe deny it (via an exception) or + update the request config. Args: requester config: The creation config from the client. is_requester_admin: If the requester is an admin - - Returns: - Whether room creation is allowed or denied. """ - - if self.third_party_rules is None: - return True - - return await self.third_party_rules.on_create_room( - requester, config, is_requester_admin - ) + for callback in self._on_create_room_callbacks: + try: + await callback(requester, config, is_requester_admin) + except Exception as e: + # Don't silence the errors raised by this callback since we expect it to + # raise an exception to deny the creation of the room; instead make sure + # it's a SynapseError we can send to clients. + if not isinstance(e, SynapseError): + e = SynapseError( + 403, "Room creation forbidden with these parameters" + ) + + raise e async def check_threepid_can_be_invited( self, medium: str, address: str, room_id: str @@ -114,15 +268,20 @@ class ThirdPartyEventRules: Returns: True if the 3PID can be invited, False if not. """ - - if self.third_party_rules is None: + # Bail out early without hitting the store if we don't have any callbacks to run. + if len(self._check_threepid_can_be_invited_callbacks) == 0: return True state_events = await self._get_state_map_for_room(room_id) - return await self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events - ) + for callback in self._check_threepid_can_be_invited_callbacks: + try: + if await callback(medium, address, state_events) is False: + return False + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + + return True async def check_visibility_can_be_modified( self, room_id: str, new_visibility: str @@ -137,18 +296,20 @@ class ThirdPartyEventRules: Returns: True if the room's visibility can be modified, False if not. """ - if self.third_party_rules is None: - return True - - check_func = getattr( - self.third_party_rules, "check_visibility_can_be_modified", None - ) - if not check_func or not callable(check_func): + # Bail out early without hitting the store if we don't have any callback + if len(self._check_visibility_can_be_modified_callbacks) == 0: return True state_events = await self._get_state_map_for_room(room_id) - return await check_func(room_id, state_events, new_visibility) + for callback in self._check_visibility_can_be_modified_callbacks: + try: + if await callback(room_id, state_events, new_visibility) is False: + return False + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + + return True async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: """Given a room ID, return the state events of that room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index cf389be3e4..5728719909 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1934,7 +1934,7 @@ class FederationHandler(BaseHandler): builder=builder ) - event_allowed = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2026,7 +2026,7 @@ class FederationHandler(BaseHandler): # for knock events, we run the third-party event rules. It's not entirely clear # why we don't do this for other sorts of membership events. if event.membership == Membership.KNOCK: - event_allowed = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c7fe4ff89e..8a0024ce84 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -949,10 +949,10 @@ class EventCreationHandler: if requester: context.app_service = requester.app_service - third_party_result = await self.third_party_event_rules.check_event_allowed( + res, new_content = await self.third_party_event_rules.check_event_allowed( event, context ) - if not third_party_result: + if res is False: logger.info( "Event %s forbidden by third-party rules", event, @@ -960,11 +960,11 @@ class EventCreationHandler: raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - elif isinstance(third_party_result, dict): + elif new_content is not None: # the third-party rules want to replace the event. We'll need to build a new # event. event, context = await self._rebuild_event_after_third_party_rules( - third_party_result, event + new_content, event ) self.validator.validate_new(event, self.config) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 64656fda22..370561e549 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler): else: is_requester_admin = await self.auth.is_server_admin(requester.user) - # Check whether the third party rules allows/changes the room create - # request. - event_allowed = await self.third_party_event_rules.on_create_room( + # Let the third party rules modify the room creation config if needed, or abort + # the room creation entirely with an exception. + await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) - if not event_allowed: - raise SynapseError( - 403, "You are not permitted to create rooms", Codes.FORBIDDEN - ) if not is_requester_admin and not await self.spam_checker.user_may_create_room( user_id diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 5df9349134..1259fc2d90 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -110,6 +110,7 @@ class ModuleApi: self._spam_checker = hs.get_spam_checker() self._account_validity_handler = hs.get_account_validity_handler() + self._third_party_event_rules = hs.get_third_party_event_rules() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -124,6 +125,11 @@ class ModuleApi: """Registers callbacks for account validity capabilities.""" return self._account_validity_handler.register_account_validity_callbacks + @property + def register_third_party_rules_callbacks(self): + """Registers callbacks for third party event rules capabilities.""" + return self._third_party_event_rules.register_third_party_rules_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index c5e1c5458b..28dd47a28b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -16,17 +16,19 @@ from typing import Dict from unittest.mock import Mock from synapse.events import EventBase +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import Requester, StateMap +from synapse.util.frozenutils import unfreeze from tests import unittest thread_local = threading.local() -class ThirdPartyRulesTestModule: +class LegacyThirdPartyRulesTestModule: def __init__(self, config: Dict, module_api: ModuleApi): # keep a record of the "current" rules module, so that the test can patch # it if desired. @@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule: return config -def current_rules_module() -> ThirdPartyRulesTestModule: - return thread_local.rules_module +class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ): + return False + + +class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + d = event.get_dict() + content = unfreeze(event.content) + content["foo"] = "bar" + d["content"] = content + return d class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): @@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def default_config(self): - config = super().default_config() - config["third_party_event_rules"] = { - "module": __name__ + ".ThirdPartyRulesTestModule", - "config": {}, - } - return config + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + load_legacy_third_party_event_rules(hs) + + return hs def prepare(self, reactor, clock, homeserver): # Create a user and room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.tok = self.login("kermit", "monkey") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + # Some tests might prevent room creation on purpose. + try: + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + except Exception: + pass def test_third_party_rules(self): """Tests that a forbidden event is forbidden from being sent, but an allowed one @@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): # patch the rules module with a Mock which will return False for some event # types async def check(ev, state): - return ev.type != "foo.bar.forbidden" + return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) - current_rules_module().check_event_allowed = callback + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [ + callback + ] channel = self.make_request( "PUT", @@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): # first patch the event checker so that it will try to modify the event async def check(ev: EventBase, state): ev.content = {"x": "y"} - return True + return True, None - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"500", channel.result) + # check_event_allowed has some error handling, so it shouldn't 500 just because a + # module did something bad. + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["x"], "x") def test_modify_event(self): """The module can return a modified version of the event""" @@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): async def check(ev: EventBase, state): d = ev.get_dict() d["content"] = {"x": "y"} - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "msgtype": "m.text", "body": d["content"]["body"].upper(), } - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # Send an event, then edit it. channel = self.make_request( @@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.assertEqual(ev["content"]["body"], "EDITED BODY") def test_send_event(self): - """Tests that the module can send an event into a room via the module api""" + """Tests that a module can send an event into a room via the module api""" content = { "msgtype": "m.text", "body": "Hello!", @@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "sender": self.user_id, } event: EventBase = self.get_success( - current_rules_module().module_api.create_and_send_event_into_room( - event_dict - ) + self.hs.get_module_api().create_and_send_event_into_room(event_dict) ) self.assertEquals(event.sender, self.user_id) self.assertEquals(event.room_id, self.room_id) self.assertEquals(event.type, "m.room.message") self.assertEquals(event.content, content) + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyChangeEvents", + "config": {}, + } + } + ) + def test_legacy_check_event_allowed(self): + """Tests that the wrapper for legacy check_event_allowed callbacks works + correctly. + """ + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id, + { + "msgtype": "m.text", + "body": "Original body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + self.assertIn("foo", channel.json_body["content"].keys()) + self.assertEqual(channel.json_body["content"]["foo"], "bar") + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyDenyNewRooms", + "config": {}, + } + } + ) + def test_legacy_on_create_room(self): + """Tests that the wrapper for legacy on_create_room callbacks works + correctly. + """ + self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) -- cgit 1.4.1