From 7695ca06187bb6742ed74c5ae060c48a08af99ce Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 15 Jul 2021 10:35:46 +0100 Subject: Fix a number of logged errors caused by remote servers being down. (#10400) --- synapse/http/matrixfederationclient.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) (limited to 'synapse/http') diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b8849c0150..3bace2c965 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -43,6 +43,7 @@ from twisted.internet import defer from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IReactorTime from twisted.internet.task import _EPSILON, Cooperator +from twisted.web.client import ResponseFailed from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer, IResponse @@ -262,6 +263,15 @@ async def _handle_response( request.uri.decode("ascii"), ) raise RequestSendFailed(e, can_retry=True) from e + except ResponseFailed as e: + logger.warning( + "{%s} [%s] Failed to read response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e except Exception as e: logger.warning( "{%s} [%s] Error reading response %s %s: %s", @@ -1137,6 +1147,24 @@ class MatrixFederationHttpClient: msg, ) raise SynapseError(502, msg, Codes.TOO_LARGE) + except defer.TimeoutError as e: + logger.warning( + "{%s} [%s] Timed out reading response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e + except ResponseFailed as e: + logger.warning( + "{%s} [%s] Failed to read response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", -- cgit 1.5.1 From c7603af1d06d65932c420ae76002b6ed94dbf23c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 15 Jul 2021 11:37:08 +0200 Subject: Allow providing credentials to `http_proxy` (#10360) --- changelog.d/10360.feature | 1 + synapse/http/proxyagent.py | 12 +++++++- tests/http/test_proxyagent.py | 65 ++++++++++++++++++++++++++++++++++--------- 3 files changed, 64 insertions(+), 14 deletions(-) create mode 100644 changelog.d/10360.feature (limited to 'synapse/http') diff --git a/changelog.d/10360.feature b/changelog.d/10360.feature new file mode 100644 index 0000000000..904221cb6d --- /dev/null +++ b/changelog.d/10360.feature @@ -0,0 +1 @@ +Allow providing credentials to `http_proxy`. \ No newline at end of file diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7dfae8b786..7a6a1717de 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -117,7 +117,8 @@ class ProxyAgent(_AgentBase): https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None - # Parse credentials from https proxy connection string if present + # Parse credentials from http and https proxy connection string if present + self.http_proxy_creds, http_proxy = parse_username_password(http_proxy) self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) self.http_proxy_endpoint = _http_proxy_endpoint( @@ -189,6 +190,15 @@ class ProxyAgent(_AgentBase): and self.http_proxy_endpoint and not should_skip_proxy ): + # Determine whether we need to set Proxy-Authorization headers + if self.http_proxy_creds: + # Set a Proxy-Authorization header + if headers is None: + headers = Headers() + headers.addRawHeader( + b"Proxy-Authorization", + self.http_proxy_creds.as_proxy_authorization_value(), + ) # Cache *all* connections under the same key, since we are only # connecting to a single destination, the proxy: pool_key = ("http-proxy", self.http_proxy_endpoint) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index fefc8099c9..437113929a 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -205,6 +205,41 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"}) def test_http_request_via_proxy(self): + """ + Tests that requests can be made through a proxy. + """ + self._do_http_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, + ) + def test_http_request_via_proxy_with_auth(self): + """ + Tests that authenticated requests can be made through a proxy. + """ + self._do_http_request_via_proxy(auth_credentials="bob:pinkponies") + + @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) + def test_https_request_via_proxy(self): + """Tests that TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, + ) + def test_https_request_via_proxy_with_auth(self): + """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") + + def _do_http_request_via_proxy( + self, + auth_credentials: Optional[str] = None, + ): + """ + Tests that requests can be made through a proxy. + """ agent = ProxyAgent(self.reactor, use_proxy=True) self.reactor.lookups["proxy.com"] = "1.2.3.5" @@ -229,6 +264,23 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] + + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(b"bob:pinkponies") + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) @@ -241,19 +293,6 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") - @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) - def test_https_request_via_proxy(self): - """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials=None) - - @patch.dict( - os.environ, - {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, - ) - def test_https_request_via_proxy_with_auth(self): - """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") - def _do_https_request_via_proxy( self, auth_credentials: Optional[str] = None, -- cgit 1.5.1 From bf72d10dbf506f5ea486d67094b6003947d38fb7 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 15 Jul 2021 12:02:43 +0200 Subject: Use inline type hints in various other places (in `synapse/`) (#10380) --- changelog.d/10380.misc | 1 + synapse/api/auth.py | 4 +-- synapse/api/errors.py | 4 +-- synapse/api/filtering.py | 2 +- synapse/api/ratelimiting.py | 4 +-- synapse/api/room_versions.py | 4 +-- synapse/app/generic_worker.py | 2 +- synapse/appservice/api.py | 4 +-- synapse/config/appservice.py | 4 +-- synapse/config/cache.py | 4 +-- synapse/config/emailconfig.py | 4 +-- synapse/config/experimental.py | 6 ++-- synapse/config/federation.py | 2 +- synapse/config/oidc.py | 2 +- synapse/config/password_auth_providers.py | 2 +- synapse/config/repository.py | 4 +-- synapse/config/server.py | 16 +++++----- synapse/config/spam_checker.py | 2 +- synapse/config/sso.py | 2 +- synapse/config/tls.py | 6 ++-- synapse/crypto/keyring.py | 20 +++++++------ synapse/event_auth.py | 8 ++--- synapse/events/__init__.py | 26 ++++++++--------- synapse/events/builder.py | 16 +++++----- synapse/events/spamcheck.py | 4 +-- synapse/federation/federation_client.py | 10 +++---- synapse/federation/federation_server.py | 34 ++++++++++------------ synapse/federation/send_queue.py | 26 ++++++++--------- synapse/federation/sender/__init__.py | 14 ++++----- synapse/federation/sender/per_destination_queue.py | 18 ++++++------ synapse/federation/transport/client.py | 8 ++--- synapse/federation/transport/server.py | 24 +++++++-------- synapse/groups/groups_server.py | 12 ++++---- synapse/http/__init__.py | 2 +- synapse/http/client.py | 18 ++++++------ synapse/http/matrixfederationclient.py | 12 ++++---- synapse/http/server.py | 8 ++--- synapse/http/servlet.py | 2 +- synapse/http/site.py | 14 ++++----- synapse/logging/_remote.py | 14 ++++----- synapse/logging/_structured.py | 2 +- synapse/logging/context.py | 16 +++++----- synapse/logging/opentracing.py | 10 +++---- synapse/metrics/__init__.py | 6 ++-- synapse/metrics/_exposition.py | 2 +- synapse/metrics/background_process_metrics.py | 4 +-- synapse/module_api/__init__.py | 2 +- synapse/notifier.py | 18 ++++++------ synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/push/clientformat.py | 4 +-- synapse/push/emailpusher.py | 6 ++-- synapse/push/httppusher.py | 2 +- synapse/push/mailer.py | 12 ++++---- synapse/push/presentable_names.py | 2 +- synapse/push/push_rule_evaluator.py | 4 +-- synapse/push/pusher.py | 6 ++-- synapse/push/pusherpool.py | 2 +- synapse/python_dependencies.py | 4 +-- synapse/replication/http/_base.py | 10 +++---- synapse/replication/slave/storage/_base.py | 6 ++-- synapse/replication/slave/storage/client_ips.py | 4 +-- synapse/replication/tcp/client.py | 10 +++---- synapse/replication/tcp/commands.py | 6 ++-- synapse/replication/tcp/handler.py | 16 +++++----- synapse/replication/tcp/protocol.py | 14 ++++----- synapse/replication/tcp/redis.py | 8 ++--- synapse/replication/tcp/streams/_base.py | 14 ++++----- synapse/replication/tcp/streams/events.py | 28 +++++++++--------- synapse/replication/tcp/streams/federation.py | 6 ++-- synapse/server.py | 6 ++-- synapse/server_notices/consent_server_notices.py | 2 +- .../resource_limits_server_notices.py | 2 +- synapse/server_notices/server_notices_sender.py | 6 ++-- synapse/state/__init__.py | 20 +++++++------ synapse/state/v1.py | 2 +- synapse/state/v2.py | 18 ++++++------ synapse/streams/events.py | 4 +-- synapse/types.py | 6 ++-- synapse/visibility.py | 2 +- 79 files changed, 329 insertions(+), 336 deletions(-) create mode 100644 changelog.d/10380.misc (limited to 'synapse/http') diff --git a/changelog.d/10380.misc b/changelog.d/10380.misc new file mode 100644 index 0000000000..eed2d8552a --- /dev/null +++ b/changelog.d/10380.misc @@ -0,0 +1 @@ +Convert internal type variable syntax to reflect wider ecosystem use. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 42476a18e5..8916e6fa2f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -63,9 +63,9 @@ class Auth: self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache( + self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( 10000, "token_cache" - ) # type: LruCache[str, Tuple[str, bool]] + ) self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 4cb8bbaf70..054ab14ab6 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -118,7 +118,7 @@ class RedirectException(CodeMessageException): super().__init__(code=http_code, msg=msg) self.location = location - self.cookies = [] # type: List[bytes] + self.cookies: List[bytes] = [] class SynapseError(CodeMessageException): @@ -160,7 +160,7 @@ class ProxiedRequestError(SynapseError): ): super().__init__(code, msg, errcode) if additional_fields is None: - self._additional_fields = {} # type: Dict + self._additional_fields: Dict = {} else: self._additional_fields = dict(additional_fields) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index ce49a0ad58..ad1ff6a9df 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -289,7 +289,7 @@ class Filter: room_id = None ev_type = "m.presence" contains_url = False - labels = [] # type: List[str] + labels: List[str] = [] else: sender = event.get("sender", None) if not sender: diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index b9a10283f4..3e3d09bbd2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -46,9 +46,7 @@ class Ratelimiter: # * How many times an action has occurred since a point in time # * The point in time # * The rate_hz of this particular entry. This can vary per request - self.actions = ( - OrderedDict() - ) # type: OrderedDict[Hashable, Tuple[float, int, float]] + self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() async def can_do_action( self, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index f6c1c97b40..a20abc5a65 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -195,7 +195,7 @@ class RoomVersions: ) -KNOWN_ROOM_VERSIONS = { +KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { v.identifier: v for v in ( RoomVersions.V1, @@ -209,4 +209,4 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V7, ) # Note that we do not include MSC2043 here unless it is enabled in the config. -} # type: Dict[str, RoomVersion] +} diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 5b041fcaad..b43d858f59 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -270,7 +270,7 @@ class GenericWorkerServer(HomeServer): site_tag = port # We always include a health resource. - resources = {"/health": HealthResource()} # type: Dict[str, IResource] + resources: Dict[str, IResource] = {"/health": HealthResource()} for res in listener_config.http_options.resources: for name in res.names: diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 61152b2c46..935f24263c 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -88,9 +88,9 @@ class ApplicationServiceApi(SimpleHttpClient): super().__init__(hs) self.clock = hs.get_clock() - self.protocol_meta_cache = ResponseCache( + self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS - ) # type: ResponseCache[Tuple[str, str]] + ) async def query_user(self, service, user_id): if service.url is None: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 746fc3cc02..a39d457c56 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -57,8 +57,8 @@ def load_appservices(hostname, config_files): return [] # Dicts of value -> filename - seen_as_tokens = {} # type: Dict[str, str] - seen_ids = {} # type: Dict[str, str] + seen_as_tokens: Dict[str, str] = {} + seen_ids: Dict[str, str] = {} appservices = [] diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 7789b40323..8d5f38b5d9 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -25,7 +25,7 @@ from ._base import Config, ConfigError _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" # Map from canonicalised cache name to cache. -_CACHES = {} # type: Dict[str, Callable[[float], None]] +_CACHES: Dict[str, Callable[[float], None]] = {} # a lock on the contents of _CACHES _CACHES_LOCK = threading.Lock() @@ -157,7 +157,7 @@ class CacheConfig(Config): self.event_cache_size = self.parse_size( config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) ) - self.cache_factors = {} # type: Dict[str, float] + self.cache_factors: Dict[str, float] = {} cache_config = config.get("caches") or {} self.global_factor = cache_config.get( diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 5564d7d097..bcecbfec03 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -134,9 +134,9 @@ class EmailConfig(Config): # trusted_third_party_id_servers does not contain a scheme whereas # account_threepid_delegate_email is expected to. Presume https - self.account_threepid_delegate_email = ( + self.account_threepid_delegate_email: Optional[str] = ( "https://" + first_trusted_identity_server - ) # type: Optional[str] + ) self.using_identity_server_from_trusted_list = True else: raise ConfigError( diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 7fb1f7021f..e25ccba9ac 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -25,10 +25,10 @@ class ExperimentalConfig(Config): experimental = config.get("experimental_features") or {} # MSC2858 (multiple SSO identity providers) - self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool + self.msc2858_enabled: bool = experimental.get("msc2858_enabled", False) # MSC3026 (busy presence state) - self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool + self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) # MSC2716 (backfill existing history) - self.msc2716_enabled = experimental.get("msc2716_enabled", False) # type: bool + self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) diff --git a/synapse/config/federation.py b/synapse/config/federation.py index cdd7a1ef05..7d64993e22 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -22,7 +22,7 @@ class FederationConfig(Config): def read_config(self, config, **kwargs): # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None # type: Optional[dict] + self.federation_domain_whitelist: Optional[dict] = None federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 942e2672a9..ba89d11cf0 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -460,7 +460,7 @@ def _parse_oidc_config_dict( ) from e client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") - client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey] + client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None if client_secret_jwt_key_config is not None: keyfile = client_secret_jwt_key_config.get("key_file") if keyfile: diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index fd90b79772..0f5b2b3977 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config): section = "authproviders" def read_config(self, config, **kwargs): - self.password_providers = [] # type: List[Any] + self.password_providers: List[Any] = [] providers = [] # We want to be backwards compatible with the old `ldap_config` diff --git a/synapse/config/repository.py b/synapse/config/repository.py index a7a82742ac..0dfb3a227a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -62,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): Dictionary mapping from media type string to list of ThumbnailRequirement tuples. """ - requirements = {} # type: Dict[str, List] + requirements: Dict[str, List] = {} for size in thumbnail_sizes: width = size["width"] height = size["height"] @@ -141,7 +141,7 @@ class ContentRepositoryConfig(Config): # # We don't create the storage providers here as not all workers need # them to be started. - self.media_storage_providers = [] # type: List[tuple] + self.media_storage_providers: List[tuple] = [] for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to diff --git a/synapse/config/server.py b/synapse/config/server.py index 6bff715230..b9e0c0b300 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -505,7 +505,7 @@ class ServerConfig(Config): " greater than 'allowed_lifetime_max'" ) - self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] + self.retention_purge_jobs: List[Dict[str, Optional[int]]] = [] for purge_job_config in retention_config.get("purge_jobs", []): interval_config = purge_job_config.get("interval") @@ -688,23 +688,21 @@ class ServerConfig(Config): # not included in the sample configuration file on purpose as it's a temporary # hack, so that some users can trial the new defaults without impacting every # user on the homeserver. - users_new_default_push_rules = ( + users_new_default_push_rules: list = ( config.get("users_new_default_push_rules") or [] - ) # type: list + ) if not isinstance(users_new_default_push_rules, list): raise ConfigError("'users_new_default_push_rules' must be a list") # Turn the list into a set to improve lookup speed. - self.users_new_default_push_rules = set( - users_new_default_push_rules - ) # type: set + self.users_new_default_push_rules: set = set(users_new_default_push_rules) # Whitelist of domain names that given next_link parameters must have - next_link_domain_whitelist = config.get( + next_link_domain_whitelist: Optional[List[str]] = config.get( "next_link_domain_whitelist" - ) # type: Optional[List[str]] + ) - self.next_link_domain_whitelist = None # type: Optional[Set[str]] + self.next_link_domain_whitelist: Optional[Set[str]] = None if next_link_domain_whitelist is not None: if not isinstance(next_link_domain_whitelist, list): raise ConfigError("'next_link_domain_whitelist' must be a list") diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index cb7716c837..a233a9ce03 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -34,7 +34,7 @@ class SpamCheckerConfig(Config): section = "spamchecker" def read_config(self, config, **kwargs): - self.spam_checkers = [] # type: List[Tuple[Any, Dict]] + self.spam_checkers: List[Tuple[Any, Dict]] = [] spam_checkers = config.get("spam_checker") or [] if isinstance(spam_checkers, dict): diff --git a/synapse/config/sso.py b/synapse/config/sso.py index e4346e02aa..d0f04cf8e6 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -39,7 +39,7 @@ class SSOConfig(Config): section = "sso" def read_config(self, config, **kwargs): - sso_config = config.get("sso") or {} # type: Dict[str, Any] + sso_config: Dict[str, Any] = config.get("sso") or {} # The sso-specific template_dir self.sso_template_dir = sso_config.get("template_dir") diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 9a16a8fbae..fed05ac7be 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -80,7 +80,7 @@ class TlsConfig(Config): fed_whitelist_entries = [] # Support globs (*) in whitelist values - self.federation_certificate_verification_whitelist = [] # type: List[Pattern] + self.federation_certificate_verification_whitelist: List[Pattern] = [] for entry in fed_whitelist_entries: try: entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) @@ -132,8 +132,8 @@ class TlsConfig(Config): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - self.tls_certificate = None # type: Optional[crypto.X509] - self.tls_private_key = None # type: Optional[crypto.PKey] + self.tls_certificate: Optional[crypto.X509] = None + self.tls_private_key: Optional[crypto.PKey] = None def is_disk_cert_valid(self, allow_self_signed=True): """ diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index e5a4685ed4..9e9b1c1c86 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -170,11 +170,13 @@ class Keyring: ) self._key_fetchers = key_fetchers - self._server_queue = BatchingQueue( + self._server_queue: BatchingQueue[ + _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] + ] = BatchingQueue( "keyring_server", clock=hs.get_clock(), process_batch_callback=self._inner_fetch_key_requests, - ) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]] + ) async def verify_json_for_server( self, @@ -330,7 +332,7 @@ class Keyring: # First we need to deduplicate requests for the same key. We do this by # taking the *maximum* requested `minimum_valid_until_ts` for each pair # of server name/key ID. - server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]] + server_to_key_to_ts: Dict[str, Dict[str, int]] = {} for request in requests: by_server = server_to_key_to_ts.setdefault(request.server_name, {}) for key_id in request.key_ids: @@ -355,7 +357,7 @@ class Keyring: # We now convert the returned list of results into a map from server # name to key ID to FetchKeyResult, to return. - to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]] + to_return: Dict[str, Dict[str, FetchKeyResult]] = {} for (request, results) in zip(deduped_requests, results_per_request): to_return_by_server = to_return.setdefault(request.server_name, {}) for key_id, key_result in results.items(): @@ -455,7 +457,7 @@ class StoreKeyFetcher(KeyFetcher): ) res = await self.store.get_server_verify_keys(key_ids_to_fetch) - keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] + keys: Dict[str, Dict[str, FetchKeyResult]] = {} for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key return keys @@ -603,7 +605,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ).addErrback(unwrapFirstError) ) - union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] + union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {} for result in results: for server_name, keys in result.items(): union_of_keys.setdefault(server_name, {}).update(keys) @@ -656,8 +658,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) - keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] - added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]] + keys: Dict[str, Dict[str, FetchKeyResult]] = {} + added_keys: List[Tuple[str, str, FetchKeyResult]] = [] time_now_ms = self.clock.time_msec() @@ -805,7 +807,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): Raises: KeyLookupError if there was a problem making the lookup """ - keys = {} # type: Dict[str, FetchKeyResult] + keys: Dict[str, FetchKeyResult] = {} for requested_key_id in key_ids: # we may have found this key as a side-effect of asking for another. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index a3df6cfcc1..137dff2513 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -531,7 +531,7 @@ def _check_power_levels( user_level = get_user_power_level(event.user_id, auth_events) # Check other levels: - levels_to_check = [ + levels_to_check: List[Tuple[str, Optional[str]]] = [ ("users_default", None), ("events_default", None), ("state_default", None), @@ -539,7 +539,7 @@ def _check_power_levels( ("redact", None), ("kick", None), ("invite", None), - ] # type: List[Tuple[str, Optional[str]]] + ] old_list = current_state.content.get("users", {}) for user in set(list(old_list) + list(user_list)): @@ -569,12 +569,12 @@ def _check_power_levels( new_loc = new_loc.get(dir, {}) if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) # type: Optional[int] + old_level: Optional[int] = int(old_loc[level_to_check]) else: old_level = None if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) # type: Optional[int] + new_level: Optional[int] = int(new_loc[level_to_check]) else: new_level = None diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 6286ad999a..65dc7a4ed0 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -105,28 +105,28 @@ class _EventInternalMetadata: self._dict = dict(internal_metadata_dict) # the stream ordering of this event. None, until it has been persisted. - self.stream_ordering = None # type: Optional[int] + self.stream_ordering: Optional[int] = None # whether this event is an outlier (ie, whether we have the state at that point # in the DAG) self.outlier = False - out_of_band_membership = DictProperty("out_of_band_membership") # type: bool - send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str - recheck_redaction = DictProperty("recheck_redaction") # type: bool - soft_failed = DictProperty("soft_failed") # type: bool - proactively_send = DictProperty("proactively_send") # type: bool - redacted = DictProperty("redacted") # type: bool - txn_id = DictProperty("txn_id") # type: str - token_id = DictProperty("token_id") # type: int - historical = DictProperty("historical") # type: bool + out_of_band_membership: bool = DictProperty("out_of_band_membership") + send_on_behalf_of: str = DictProperty("send_on_behalf_of") + recheck_redaction: bool = DictProperty("recheck_redaction") + soft_failed: bool = DictProperty("soft_failed") + proactively_send: bool = DictProperty("proactively_send") + redacted: bool = DictProperty("redacted") + txn_id: str = DictProperty("txn_id") + token_id: int = DictProperty("token_id") + historical: bool = DictProperty("historical") # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't # be here - before = DictProperty("before") # type: RoomStreamToken - after = DictProperty("after") # type: RoomStreamToken - order = DictProperty("order") # type: Tuple[int, int] + before: RoomStreamToken = DictProperty("before") + after: RoomStreamToken = DictProperty("after") + order: Tuple[int, int] = DictProperty("order") def get_dict(self) -> JsonDict: return dict(self._dict) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 26e3950859..87e2bb123b 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -132,12 +132,12 @@ class EventBuilder: format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: # The types of auth/prev events changes between event versions. - auth_events = await self._store.add_event_hashes( - auth_event_ids - ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] - prev_events = await self._store.add_event_hashes( - prev_event_ids - ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + auth_events: Union[ + List[str], List[Tuple[str, Dict[str, str]]] + ] = await self._store.add_event_hashes(auth_event_ids) + prev_events: Union[ + List[str], List[Tuple[str, Dict[str, str]]] + ] = await self._store.add_event_hashes(prev_event_ids) else: auth_events = auth_event_ids prev_events = prev_event_ids @@ -156,7 +156,7 @@ class EventBuilder: # the db) depth = min(depth, MAX_DEPTH) - event_dict = { + event_dict: Dict[str, Any] = { "auth_events": auth_events, "prev_events": prev_events, "type": self.type, @@ -166,7 +166,7 @@ class EventBuilder: "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } # type: Dict[str, Any] + } if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index efec16c226..57f1d53fa8 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -76,7 +76,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): """Wrapper that loads spam checkers configured using the old configuration, and registers the spam checker hooks they implement. """ - spam_checkers = [] # type: List[Any] + spam_checkers: List[Any] = [] api = hs.get_module_api() for module, config in hs.config.spam_checkers: # Older spam checkers don't accept the `api` argument, so we @@ -239,7 +239,7 @@ class SpamChecker: will be used as the error message returned to the user. """ for callback in self._check_event_for_spam_callbacks: - res = await callback(event) # type: Union[bool, str] + res: Union[bool, str] = await callback(event) if res: return res diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index ed09c6af1f..c767d30627 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -86,7 +86,7 @@ class FederationClient(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] + self.pdu_destination_tried: Dict[str, Dict[str, int]] = {} self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -94,13 +94,13 @@ class FederationClient(FederationBase): self.hostname = hs.hostname self.signing_key = hs.signing_key - self._get_pdu_cache = ExpiringCache( + self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache( cache_name="get_pdu_cache", clock=self._clock, max_len=1000, expiry_ms=120 * 1000, reset_expiry_on_get=False, - ) # type: ExpiringCache[str, EventBase] + ) def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" @@ -293,10 +293,10 @@ class FederationClient(FederationBase): transaction_data, ) - pdu_list = [ + pdu_list: List[EventBase] = [ event_from_pdu_json(p, room_version, outlier=outlier) for p in transaction_data["pdus"] - ] # type: List[EventBase] + ] if pdu_list and pdu_list[0]: pdu = pdu_list[0] diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ac0f2ccfb3..d91f0ff32f 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -122,12 +122,12 @@ class FederationServer(FederationBase): # origins that we are currently processing a transaction from. # a dict from origin to txn id. - self._active_transactions = {} # type: Dict[str, str] + self._active_transactions: Dict[str, str] = {} # We cache results for transaction with the same ID - self._transaction_resp_cache = ResponseCache( + self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "fed_txn_handler", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, str]] + ) self.transaction_actions = TransactionActions(self.store) @@ -135,12 +135,12 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache = ResponseCache( - hs.get_clock(), "state_resp", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, Optional[str]]] - self._state_ids_resp_cache = ResponseCache( + self._state_resp_cache: ResponseCache[ + Tuple[str, Optional[str]] + ] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000) + self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "state_ids_resp", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, str]] + ) self._federation_metrics_domains = ( hs.config.federation.federation_metrics_domains @@ -337,7 +337,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) - pdus_by_room = {} # type: Dict[str, List[EventBase]] + pdus_by_room: Dict[str, List[EventBase]] = {} newest_pdu_ts = 0 @@ -516,9 +516,9 @@ class FederationServer(FederationBase): self, room_id: str, event_id: Optional[str] ) -> Dict[str, list]: if event_id: - pdus = await self.handler.get_state_for_pdu( + pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu( room_id, event_id - ) # type: Iterable[EventBase] + ) else: pdus = (await self.state.get_current_state(room_id)).values() @@ -791,7 +791,7 @@ class FederationServer(FederationBase): log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) results = await self.store.claim_e2e_one_time_keys(query) - json_result = {} # type: Dict[str, Dict[str, dict]] + json_result: Dict[str, Dict[str, dict]] = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_str in keys.items(): @@ -1119,17 +1119,13 @@ class FederationHandlerRegistry: self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) - self.edu_handlers = ( - {} - ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] - self.query_handlers = ( - {} - ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]] + self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} # Map from type to instance names that we should route EDU handling to. # We randomly choose one instance from the list to route to for each new # EDU received. - self._edu_type_to_instance = {} # type: Dict[str, List[str]] + self._edu_type_to_instance: Dict[str, List[str]] = {} def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 65d76ea974..1fbf325fdc 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -71,34 +71,32 @@ class FederationRemoteSendQueue(AbstractFederationSender): # We may have multiple federation sender instances, so we need to track # their positions separately. self._sender_instances = hs.config.worker.federation_shard_config.instances - self._sender_positions = {} # type: Dict[str, int] + self._sender_positions: Dict[str, int] = {} # Pending presence map user_id -> UserPresenceState - self.presence_map = {} # type: Dict[str, UserPresenceState] + self.presence_map: Dict[str, UserPresenceState] = {} # Stores the destinations we need to explicitly send presence to about a # given user. # Stream position -> (user_id, destinations) - self.presence_destinations = ( - SortedDict() - ) # type: SortedDict[int, Tuple[str, Iterable[str]]] + self.presence_destinations: SortedDict[ + int, Tuple[str, Iterable[str]] + ] = SortedDict() # (destination, key) -> EDU - self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] + self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {} # stream position -> (destination, key) - self.keyed_edu_changed = ( - SortedDict() - ) # type: SortedDict[int, Tuple[str, tuple]] + self.keyed_edu_changed: SortedDict[int, Tuple[str, tuple]] = SortedDict() - self.edus = SortedDict() # type: SortedDict[int, Edu] + self.edus: SortedDict[int, Edu] = SortedDict() # stream ID for the next entry into keyed_edu_changed/edus. self.pos = 1 # map from stream ID to the time that stream entry was generated, so that we # can clear out entries after a while - self.pos_time = SortedDict() # type: SortedDict[int, int] + self.pos_time: SortedDict[int, int] = SortedDict() # EVERYTHING IS SAD. In particular, python only makes new scopes when # we make a new function, so we need to make a new function so the inner @@ -291,7 +289,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): # list of tuple(int, BaseFederationRow), where the first is the position # of the federation stream. - rows = [] # type: List[Tuple[int, BaseFederationRow]] + rows: List[Tuple[int, BaseFederationRow]] = [] # Fetch presence to send to destinations i = self.presence_destinations.bisect_right(from_token) @@ -445,11 +443,11 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu buff.edus.setdefault(self.edu.destination, []).append(self.edu) -_rowtypes = ( +_rowtypes: Tuple[Type[BaseFederationRow], ...] = ( PresenceDestinationsRow, KeyedEduRow, EduRow, -) # type: Tuple[Type[BaseFederationRow], ...] +) TypeToRow = {Row.TypeId: Row for Row in _rowtypes} diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index deb40f4610..0960f033bc 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -148,14 +148,14 @@ class FederationSender(AbstractFederationSender): self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id - self._presence_router = None # type: Optional[PresenceRouter] + self._presence_router: Optional["PresenceRouter"] = None self._transaction_manager = TransactionManager(hs) self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] + self._per_destination_queues: Dict[str, PerDestinationQueue] = {} LaterGauge( "synapse_federation_transaction_queue_pending_destinations", @@ -192,9 +192,7 @@ class FederationSender(AbstractFederationSender): # awaiting a call to flush_read_receipts_for_room. The presence of an entry # here for a given room means that we are rate-limiting RR flushes to that room, # and that there is a pending call to _flush_rrs_for_room in the system. - self._queues_awaiting_rr_flush_by_room = ( - {} - ) # type: Dict[str, Set[PerDestinationQueue]] + self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {} self._rr_txn_interval_per_room_ms = ( 1000.0 / hs.config.federation_rr_transactions_per_room_per_second @@ -265,7 +263,7 @@ class FederationSender(AbstractFederationSender): if not event.internal_metadata.should_proactively_send(): return - destinations = None # type: Optional[Set[str]] + destinations: Optional[Set[str]] = None if not event.prev_event_ids(): # If there are no prev event IDs then the state is empty # and so no remote servers in the room @@ -331,7 +329,7 @@ class FederationSender(AbstractFederationSender): for event in events: await handle_event(event) - events_by_room = {} # type: Dict[str, List[EventBase]] + events_by_room: Dict[str, List[EventBase]] = {} for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -628,7 +626,7 @@ class FederationSender(AbstractFederationSender): In order to reduce load spikes, adds a delay between each destination. """ - last_processed = None # type: Optional[str] + last_processed: Optional[str] = None while True: destinations_to_wake = ( diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 3a2efd56ee..d06a3aff19 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -105,34 +105,34 @@ class PerDestinationQueue: # catch-up at startup. # New events will only be sent once this is finished, at which point # _catching_up is flipped to False. - self._catching_up = True # type: bool + self._catching_up: bool = True # The stream_ordering of the most recent PDU that was discarded due to # being in catch-up mode. - self._catchup_last_skipped = 0 # type: int + self._catchup_last_skipped: int = 0 # Cache of the last successfully-transmitted stream ordering for this # destination (we are the only updater so this is safe) - self._last_successful_stream_ordering = None # type: Optional[int] + self._last_successful_stream_ordering: Optional[int] = None # a queue of pending PDUs - self._pending_pdus = [] # type: List[EventBase] + self._pending_pdus: List[EventBase] = [] # XXX this is never actually used: see # https://github.com/matrix-org/synapse/issues/7549 - self._pending_edus = [] # type: List[Edu] + self._pending_edus: List[Edu] = [] # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu] + self._pending_edus_keyed: Dict[Tuple[str, Hashable], Edu] = {} # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence = {} # type: Dict[str, UserPresenceState] + self._pending_presence: Dict[str, UserPresenceState] = {} # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]] + self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {} self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -243,7 +243,7 @@ class PerDestinationQueue: ) async def _transaction_transmission_loop(self) -> None: - pending_pdus = [] # type: List[EventBase] + pending_pdus: List[EventBase] = [] try: self.transmission_loop_running = True diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index c9e7c57461..98b1bf77fd 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -395,9 +395,9 @@ class TransportLayerClient: # this uses MSC2197 (Search Filtering over Federation) path = _create_v1_path("/publicRooms") - data = { + data: Dict[str, Any] = { "include_all_networks": "true" if include_all_networks else "false" - } # type: Dict[str, Any] + } if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -423,9 +423,9 @@ class TransportLayerClient: else: path = _create_v1_path("/publicRooms") - args = { + args: Dict[str, Any] = { "include_all_networks": "true" if include_all_networks else "false" - } # type: Dict[str, Any] + } if third_party_instance_id: args["third_party_instance_id"] = (third_party_instance_id,) if limit: diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 0b21b375ee..2974d4d0cc 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1013,7 +1013,7 @@ class PublicRoomList(BaseFederationServlet): if not self.allow_access: raise FederationDeniedError(origin) - limit = int(content.get("limit", 100)) # type: Optional[int] + limit: Optional[int] = int(content.get("limit", 100)) since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -1991,7 +1991,7 @@ class RoomComplexityServlet(BaseFederationServlet): return 200, complexity -FEDERATION_SERVLET_CLASSES = ( +FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationSendServlet, FederationEventServlet, FederationStateV1Servlet, @@ -2019,15 +2019,13 @@ FEDERATION_SERVLET_CLASSES = ( FederationSpaceSummaryServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, -) # type: Tuple[Type[BaseFederationServlet], ...] +) -OPENID_SERVLET_CLASSES = ( - OpenIdUserInfo, -) # type: Tuple[Type[BaseFederationServlet], ...] +OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,) -ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...] +ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,) -GROUP_SERVER_SERVLET_CLASSES = ( +GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationGroupsProfileServlet, FederationGroupsSummaryServlet, FederationGroupsRoomsServlet, @@ -2046,19 +2044,19 @@ GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsAddRoomsServlet, FederationGroupsAddRoomsConfigServlet, FederationGroupsSettingJoinPolicyServlet, -) # type: Tuple[Type[BaseFederationServlet], ...] +) -GROUP_LOCAL_SERVLET_CLASSES = ( +GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationGroupsLocalInviteServlet, FederationGroupsRemoveLocalUserServlet, FederationGroupsBulkPublicisedServlet, -) # type: Tuple[Type[BaseFederationServlet], ...] +) -GROUP_ATTESTATION_SERVLET_CLASSES = ( +GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationGroupsRenewAttestaionServlet, -) # type: Tuple[Type[BaseFederationServlet], ...] +) DEFAULT_SERVLET_GROUPS = ( diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index a06d060ebf..3dc55ab861 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -707,9 +707,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): See accept_invite, join_group. """ if not self.hs.is_mine_id(user_id): - local_attestation = self.attestations.create_attestation( - group_id, user_id - ) # type: Optional[JsonDict] + local_attestation: Optional[ + JsonDict + ] = self.attestations.create_attestation(group_id, user_id) remote_attestation = content["attestation"] @@ -868,9 +868,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): remote_attestation, user_id=requester_user_id, group_id=group_id ) - local_attestation = self.attestations.create_attestation( - group_id, requester_user_id - ) # type: Optional[JsonDict] + local_attestation: Optional[ + JsonDict + ] = self.attestations.create_attestation(group_id, requester_user_id) else: local_attestation = None remote_attestation = None diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index ed4671b7de..578fc48ef4 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -69,7 +69,7 @@ def _get_requested_host(request: IRequest) -> bytes: return hostname # no Host header, use the address/port that the request arrived on - host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address] + host: Union[address.IPv4Address, address.IPv6Address] = request.getHost() hostname = host.host.encode("ascii") diff --git a/synapse/http/client.py b/synapse/http/client.py index 1ca6624fd5..2ac76b15c2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -160,7 +160,7 @@ class _IPBlacklistingResolver: def resolveHostName( self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 ) -> IResolutionReceiver: - addresses = [] # type: List[IAddress] + addresses: List[IAddress] = [] def _callback() -> None: has_bad_ip = False @@ -333,9 +333,9 @@ class SimpleHttpClient: if self._ip_blacklist: # If we have an IP blacklist, we need to use a DNS resolver which # filters out blacklisted IP addresses, to prevent DNS rebinding. - self.reactor = BlacklistingReactorWrapper( + self.reactor: ISynapseReactor = BlacklistingReactorWrapper( hs.get_reactor(), self._ip_whitelist, self._ip_blacklist - ) # type: ISynapseReactor + ) else: self.reactor = hs.get_reactor() @@ -349,14 +349,14 @@ class SimpleHttpClient: pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.cachedConnectionTimeout = 2 * 60 - self.agent = ProxyAgent( + self.agent: IAgent = ProxyAgent( self.reactor, hs.get_reactor(), connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, use_proxy=use_proxy, - ) # type: IAgent + ) if self._ip_blacklist: # If we have an IP blacklist, we then install the blacklisting Agent @@ -411,7 +411,7 @@ class SimpleHttpClient: cooperator=self._cooperator, ) - request_deferred = treq.request( + request_deferred: defer.Deferred = treq.request( method, uri, agent=self.agent, @@ -421,7 +421,7 @@ class SimpleHttpClient: # response bodies. unbuffered=True, **self._extra_treq_args, - ) # type: defer.Deferred + ) # we use our own timeout mechanism rather than treq's as a workaround # for https://twistedmatrix.com/trac/ticket/9534. @@ -772,7 +772,7 @@ class BodyExceededMaxSize(Exception): class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which immediately errors upon receiving data.""" - transport = None # type: Optional[ITCPTransport] + transport: Optional[ITCPTransport] = None def __init__(self, deferred: defer.Deferred): self.deferred = deferred @@ -798,7 +798,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" - transport = None # type: Optional[ITCPTransport] + transport: Optional[ITCPTransport] = None def __init__( self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int] diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3bace2c965..2efa15bf04 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -106,7 +106,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC): the parsed data. """ - CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore + CONTENT_TYPE: str = abc.abstractproperty() # type: ignore """The expected content type of the response, e.g. `application/json`. If the content type doesn't match we fail the request. """ @@ -327,11 +327,11 @@ class MatrixFederationHttpClient: # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. - self.reactor = BlacklistingReactorWrapper( + self.reactor: ISynapseReactor = BlacklistingReactorWrapper( hs.get_reactor(), hs.config.federation_ip_range_whitelist, hs.config.federation_ip_range_blacklist, - ) # type: ISynapseReactor + ) user_agent = hs.version_string if hs.config.user_agent_suffix: @@ -504,7 +504,7 @@ class MatrixFederationHttpClient: ) # Inject the span into the headers - headers_dict = {} # type: Dict[bytes, List[bytes]] + headers_dict: Dict[bytes, List[bytes]] = {} opentracing.inject_header_dict(headers_dict, request.destination) headers_dict[b"User-Agent"] = [self.version_string_bytes] @@ -533,9 +533,9 @@ class MatrixFederationHttpClient: destination_bytes, method_bytes, url_to_sign_bytes, json ) data = encode_canonical_json(json) - producer = QuieterFileBodyProducer( + producer: Optional[IBodyProducer] = QuieterFileBodyProducer( BytesIO(data), cooperator=self._cooperator - ) # type: Optional[IBodyProducer] + ) else: producer = None auth_headers = self.build_auth_headers( diff --git a/synapse/http/server.py b/synapse/http/server.py index efbc6d5b25..b79fa722e9 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -81,7 +81,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: if f.check(SynapseError): # mypy doesn't understand that f.check asserts the type. - exc = f.value # type: SynapseError # type: ignore + exc: SynapseError = f.value # type: ignore error_code = exc.code error_dict = exc.error_dict() @@ -132,7 +132,7 @@ def return_html_error( """ if f.check(CodeMessageException): # mypy doesn't understand that f.check asserts the type. - cme = f.value # type: CodeMessageException # type: ignore + cme: CodeMessageException = f.value # type: ignore code = cme.code msg = cme.msg @@ -404,7 +404,7 @@ class JsonResource(DirectServeJsonResource): key word arguments to pass to the callback """ # At this point the path must be bytes. - request_path_bytes = request.path # type: bytes # type: ignore + request_path_bytes: bytes = request.path # type: ignore request_path = request_path_bytes.decode("ascii") # Treat HEAD requests as GET requests. request_method = request.method @@ -557,7 +557,7 @@ class _ByteProducer: request: Request, iterator: Iterator[bytes], ): - self._request = request # type: Optional[Request] + self._request: Optional[Request] = request self._iterator = iterator self._paused = False diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 6ba2ce1e53..04560fb589 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -205,7 +205,7 @@ def parse_string( parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore return parse_string_from_args( args, name, diff --git a/synapse/http/site.py b/synapse/http/site.py index 40754b7bea..3b0a38124e 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -64,16 +64,16 @@ class SynapseRequest(Request): def __init__(self, channel, *args, max_request_body_size=1024, **kw): Request.__init__(self, channel, *args, **kw) self._max_request_body_size = max_request_body_size - self.site = channel.site # type: SynapseSite + self.site: SynapseSite = channel.site self._channel = channel # this is used by the tests self.start_time = 0.0 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. - self._requester = None # type: Optional[Union[Requester, str]] + self._requester: Optional[Union[Requester, str]] = None # we can't yet create the logcontext, as we don't know the method. - self.logcontext = None # type: Optional[LoggingContext] + self.logcontext: Optional[LoggingContext] = None global _next_request_seq self.request_seq = _next_request_seq @@ -152,7 +152,7 @@ class SynapseRequest(Request): Returns: The redacted URI as a string. """ - uri = self.uri # type: Union[bytes, str] + uri: Union[bytes, str] = self.uri if isinstance(uri, bytes): uri = uri.decode("ascii", errors="replace") return redact_uri(uri) @@ -167,7 +167,7 @@ class SynapseRequest(Request): Returns: The request method as a string. """ - method = self.method # type: Union[bytes, str] + method: Union[bytes, str] = self.method if isinstance(method, bytes): return self.method.decode("ascii") return method @@ -434,8 +434,8 @@ class XForwardedForRequest(SynapseRequest): """ # the client IP and ssl flag, as extracted from the headers. - _forwarded_for = None # type: Optional[_XForwardedForAddress] - _forwarded_https = False # type: bool + _forwarded_for: "Optional[_XForwardedForAddress]" = None + _forwarded_https: bool = False def requestReceived(self, command, path, version): # this method is called by the Channel once the full request has been diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index c515690b38..8202d0494d 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -110,9 +110,9 @@ class RemoteHandler(logging.Handler): self.port = port self.maximum_buffer = maximum_buffer - self._buffer = deque() # type: Deque[logging.LogRecord] - self._connection_waiter = None # type: Optional[Deferred] - self._producer = None # type: Optional[LogProducer] + self._buffer: Deque[logging.LogRecord] = deque() + self._connection_waiter: Optional[Deferred] = None + self._producer: Optional[LogProducer] = None # Connect without DNS lookups if it's a direct IP. if _reactor is None: @@ -123,9 +123,9 @@ class RemoteHandler(logging.Handler): try: ip = ip_address(self.host) if isinstance(ip, IPv4Address): - endpoint = TCP4ClientEndpoint( + endpoint: IStreamClientEndpoint = TCP4ClientEndpoint( _reactor, self.host, self.port - ) # type: IStreamClientEndpoint + ) elif isinstance(ip, IPv6Address): endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) else: @@ -165,7 +165,7 @@ class RemoteHandler(logging.Handler): def writer(result: Protocol) -> None: # Force recognising transport as a Connection and not the more # generic ITransport. - transport = result.transport # type: Connection # type: ignore + transport: Connection = result.transport # type: ignore # We have a connection. If we already have a producer, and its # transport is the same, just trigger a resumeProducing. @@ -188,7 +188,7 @@ class RemoteHandler(logging.Handler): self._producer.resumeProducing() self._connection_waiter = None - deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred + deferred: Deferred = self._service.whenConnected(failAfterFailures=1) deferred.addCallbacks(writer, fail) self._connection_waiter = deferred diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index c7a971a9d6..b9933a1528 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -63,7 +63,7 @@ def parse_drain_configs( DrainType.CONSOLE_JSON, DrainType.FILE_JSON, ): - formatter = "json" # type: Optional[str] + formatter: Optional[str] = "json" elif logging_type in ( DrainType.CONSOLE_JSON_TERSE, DrainType.NETWORK_JSON_TERSE, diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 7fc11a9ac2..18ac507802 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -113,13 +113,13 @@ class ContextResourceUsage: self.reset() else: # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now - self.ru_utime = copy_from.ru_utime # type: float - self.ru_stime = copy_from.ru_stime # type: float - self.db_txn_count = copy_from.db_txn_count # type: int + self.ru_utime: float = copy_from.ru_utime + self.ru_stime: float = copy_from.ru_stime + self.db_txn_count: int = copy_from.db_txn_count - self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float - self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float - self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int + self.db_txn_duration_sec: float = copy_from.db_txn_duration_sec + self.db_sched_duration_sec: float = copy_from.db_sched_duration_sec + self.evt_db_fetch_count: int = copy_from.evt_db_fetch_count def copy(self) -> "ContextResourceUsage": return ContextResourceUsage(copy_from=self) @@ -289,12 +289,12 @@ class LoggingContext: # The thread resource usage when the logcontext became active. None # if the context is not currently active. - self.usage_start = None # type: Optional[resource._RUsage] + self.usage_start: Optional[resource._RUsage] = None self.main_thread = get_thread_id() self.request = None self.tag = "" - self.scope = None # type: Optional[_LogContextScope] + self.scope: Optional["_LogContextScope"] = None # keep track of whether we have hit the __exit__ block for this context # (suggesting that the the thing that created the context thinks it should diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 140ed711e3..185844f188 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -251,7 +251,7 @@ try: except Exception: logger.exception("Failed to report span") - RustReporter = _WrappedRustReporter # type: Optional[Type[_WrappedRustReporter]] + RustReporter: Optional[Type[_WrappedRustReporter]] = _WrappedRustReporter except ImportError: RustReporter = None @@ -286,7 +286,7 @@ class SynapseBaggage: # Block everything by default # A regex which matches the server_names to expose traces for. # None means 'block everything'. -_homeserver_whitelist = None # type: Optional[Pattern[str]] +_homeserver_whitelist: Optional[Pattern[str]] = None # Util methods @@ -662,7 +662,7 @@ def inject_header_dict( span = opentracing.tracer.active_span - carrier = {} # type: Dict[str, str] + carrier: Dict[str, str] = {} opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier) for key, value in carrier.items(): @@ -704,7 +704,7 @@ def get_active_span_text_map(destination=None): if destination and not whitelisted_homeserver(destination): return {} - carrier = {} # type: Dict[str, str] + carrier: Dict[str, str] = {} opentracing.tracer.inject( opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier ) @@ -718,7 +718,7 @@ def active_span_context_as_string(): Returns: The active span context encoded as a string. """ - carrier = {} # type: Dict[str, str] + carrier: Dict[str, str] = {} if opentracing: opentracing.tracer.inject( opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index fef2846669..f237b8a236 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) METRICS_PREFIX = "/_synapse/metrics" running_on_pypy = platform.python_implementation() == "PyPy" -all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]] +all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {} HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") @@ -130,7 +130,7 @@ class InFlightGauge: ) # Counts number of in flight blocks for a given set of label values - self._registrations = {} # type: Dict + self._registrations: Dict = {} # Protects access to _registrations self._lock = threading.Lock() @@ -248,7 +248,7 @@ class GaugeBucketCollector: # We initially set this to None. We won't report metrics until # this has been initialised after a successful data update - self._metric = None # type: Optional[GaugeHistogramMetricFamily] + self._metric: Optional[GaugeHistogramMetricFamily] = None registry.register(self) diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 8002be56e0..7e49d0d02c 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -125,7 +125,7 @@ def generate_latest(registry, emit_help=False): ) output.append("# TYPE {0} {1}\n".format(mname, mtype)) - om_samples = {} # type: Dict[str, List[str]] + om_samples: Dict[str, List[str]] = {} for s in metric.samples: for suffix in ["_created", "_gsum", "_gcount"]: if s.name == metric.name + suffix: diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index de96ca0821..4455fa71a8 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -93,7 +93,7 @@ _background_process_db_sched_duration = Counter( # map from description to a counter, so that we can name our logcontexts # incrementally. (It actually duplicates _background_process_start_count, but # it's much simpler to do so than to try to combine them.) -_background_process_counts = {} # type: Dict[str, int] +_background_process_counts: Dict[str, int] = {} # Set of all running background processes that became active active since the # last time metrics were scraped (i.e. background processes that performed some @@ -103,7 +103,7 @@ _background_process_counts = {} # type: Dict[str, int] # background processes stacking up behind a lock or linearizer, where we then # only need to iterate over and update metrics for the process that have # actually been active and can ignore the idle ones. -_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess] +_background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set() # A lock that covers the above set and dict _bg_metrics_lock = threading.Lock() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 721c45abac..308f045700 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -54,7 +54,7 @@ class ModuleApi: self._state = hs.get_state_handler() # We expose these as properties below in order to attach a helpful docstring. - self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient + self._http_client: SimpleHttpClient = hs.get_simple_http_client() self._public_room_list_manager = PublicRoomListManager(hs) self._spam_checker = hs.get_spam_checker() diff --git a/synapse/notifier.py b/synapse/notifier.py index 3c3cc47631..c5fbebc17d 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -203,21 +203,21 @@ class Notifier: UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs: "synapse.server.HomeServer"): - self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream] - self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]] + self.user_to_user_stream: Dict[str, _NotifierUserStream] = {} + self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry] + self.pending_new_room_events: List[_PendingRoomEventEntry] = [] # Called when there are new things to stream over replication - self.replication_callbacks = [] # type: List[Callable[[], None]] + self.replication_callbacks: List[Callable[[], None]] = [] # Called when remote servers have come back online after having been # down. - self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] + self.remote_server_up_callbacks: List[Callable[[str], None]] = [] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -237,7 +237,7 @@ class Notifier: # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): - all_user_streams = set() # type: Set[_NotifierUserStream] + all_user_streams: Set[_NotifierUserStream] = set() for streams in list(self.room_to_user_streams.values()): all_user_streams |= streams @@ -329,8 +329,8 @@ class Notifier: pending = self.pending_new_room_events self.pending_new_room_events = [] - users = set() # type: Set[UserID] - rooms = set() # type: Set[str] + users: Set[UserID] = set() + rooms: Set[str] = set() for entry in pending: if entry.event_pos.persisted_after(max_room_stream_token): @@ -580,7 +580,7 @@ class Notifier: if after_token == before_token: return EventStreamResult([], (from_token, from_token)) - events = [] # type: List[EventBase] + events: List[EventBase] = [] end_token = from_token for name, source in self.event_sources.sources.items(): diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 669ea462e2..c337e530d3 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -194,7 +194,7 @@ class BulkPushRuleEvaluator: count_as_unread = _should_count_as_unread(event, context) rules_by_user = await self._get_rules_for_event(event, context) - actions_by_user = {} # type: Dict[str, List[Union[dict, str]]] + actions_by_user: Dict[str, List[Union[dict, str]]] = {} room_members = await self.store.get_joined_users_from_context(event, context) @@ -207,7 +207,7 @@ class BulkPushRuleEvaluator: event, len(room_members), sender_power_level, power_levels ) - condition_cache = {} # type: Dict[str, bool] + condition_cache: Dict[str, bool] = {} # If the event is not a state event check if any users ignore the sender. if not event.is_state(): diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 2ee0ccd58a..1fc9716a34 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -26,10 +26,10 @@ def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, l # We're going to be mutating this a lot, so do a deep copy ruleslist = copy.deepcopy(ruleslist) - rules = { + rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = { "global": {}, "device": {}, - } # type: Dict[str, Dict[str, List[Dict[str, Any]]]] + } rules["global"] = _add_empty_priority_class_arrays(rules["global"]) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 99a18874d1..e08e125cb8 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -66,8 +66,8 @@ class EmailPusher(Pusher): self.store = self.hs.get_datastore() self.email = pusher_config.pushkey - self.timed_call = None # type: Optional[IDelayedCall] - self.throttle_params = {} # type: Dict[str, ThrottleParams] + self.timed_call: Optional[IDelayedCall] = None + self.throttle_params: Dict[str, ThrottleParams] = {} self._inited = False self._is_processing = False @@ -168,7 +168,7 @@ class EmailPusher(Pusher): ) ) - soonest_due_at = None # type: Optional[int] + soonest_due_at: Optional[int] = None if not unprocessed: await self.save_last_stream_ordering_and_success(self.max_stream_ordering) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 06bf5f8ada..36aabd8422 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -71,7 +71,7 @@ class HttpPusher(Pusher): self.data = pusher_config.data self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.failing_since = pusher_config.failing_since - self.timed_call = None # type: Optional[IDelayedCall] + self.timed_call: Optional[IDelayedCall] = None self._is_processing = False self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._pusherpool = hs.get_pusherpool() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5f9ea5003a..7be5fe1e9b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -110,7 +110,7 @@ class Mailer: self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() self.app_name = app_name - self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig + self.email_subjects: EmailSubjectConfig = hs.config.email_subjects logger.info("Created Mailer for app_name %s" % app_name) @@ -230,7 +230,7 @@ class Mailer: [pa["event_id"] for pa in push_actions] ) - notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]] + notifs_by_room: Dict[str, List[Dict[str, Any]]] = {} for pa in push_actions: notifs_by_room.setdefault(pa["room_id"], []).append(pa) @@ -356,13 +356,13 @@ class Mailer: room_name = await calculate_room_name(self.store, room_state_ids, user_id) - room_vars = { + room_vars: Dict[str, Any] = { "title": room_name, "hash": string_ordinal_total(room_id), # See sender avatar hash "notifs": [], "invite": is_invite, "link": self._make_room_link(room_id), - } # type: Dict[str, Any] + } if not is_invite: for n in notifs: @@ -460,9 +460,9 @@ class Mailer: type_state_key = ("m.room.member", event.sender) sender_state_event_id = room_state_ids.get(type_state_key) if sender_state_event_id: - sender_state_event = await self.store.get_event( + sender_state_event: Optional[EventBase] = await self.store.get_event( sender_state_event_id - ) # type: Optional[EventBase] + ) else: # Attempt to check the historical state for the room. historical_state = await self.state_store.get_state_for_event( diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 412941393f..0510c1cbd5 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -199,7 +199,7 @@ def name_from_member_event(member_event: EventBase) -> str: def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]: - ret = {} # type: Dict[str, Dict[str, str]] + ret: Dict[str, Dict[str, str]] = {} for k, v in state.items(): ret.setdefault(k[0], {})[k[1]] = v return ret diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 98b90a4f51..7a8dc63976 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -195,9 +195,9 @@ class PushRuleEvaluatorForEvent: # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache( +regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( 50000, "regex_push_cache" -) # type: LruCache[Tuple[str, bool, bool], Pattern] +) def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index c51938b8cf..021275437c 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -31,13 +31,13 @@ class PusherFactory: self.hs = hs self.config = hs.config - self.pusher_types = { + self.pusher_types: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] = { "http": HttpPusher - } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] + } logger.info("email enable notifs: %r", hs.config.email_enable_notifs) if hs.config.email_enable_notifs: - self.mailers = {} # type: Dict[str, Mailer] + self.mailers: Dict[str, Mailer] = {} self._notif_template_html = hs.config.email_notif_template_html self._notif_template_text = hs.config.email_notif_template_text diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 579fcdf472..2519ad76db 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -87,7 +87,7 @@ class PusherPool: self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() # map from user id to app_id:pushkey to pusher - self.pushers = {} # type: Dict[str, Dict[str, Pusher]] + self.pushers: Dict[str, Dict[str, Pusher]] = {} def start(self) -> None: """Starts the pushers off in a background process.""" diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 271c17c226..cdcbdd772b 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -115,7 +115,7 @@ CONDITIONAL_REQUIREMENTS = { "cache_memory": ["pympler"], } -ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] +ALL_OPTIONAL_REQUIREMENTS: Set[str] = set() for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): # Exclude systemd as it's a system-based requirement. @@ -193,7 +193,7 @@ def check_requirements(for_feature=None): if not for_feature: # Check the optional dependencies are up to date. We allow them to not be # installed. - OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str] + OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), []) for dependency in OPTS: try: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index f13a7c23b4..25589b0042 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -85,17 +85,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): is received. """ - NAME = abc.abstractproperty() # type: str # type: ignore - PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore + NAME: str = abc.abstractproperty() # type: ignore + PATH_ARGS: Tuple[str, ...] = abc.abstractproperty() # type: ignore METHOD = "POST" CACHE = True RETRY_ON_TIMEOUT = True def __init__(self, hs: "HomeServer"): if self.CACHE: - self.response_cache = ResponseCache( + self.response_cache: ResponseCache[str] = ResponseCache( hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 - ) # type: ResponseCache[str] + ) # We reserve `instance_name` as a parameter to sending requests, so we # assert here that sub classes don't try and use the name. @@ -232,7 +232,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # have a good idea that the request has either succeeded or failed on # the master, and so whether we should clean up or not. while True: - headers = {} # type: Dict[bytes, List[bytes]] + headers: Dict[bytes, List[bytes]] = {} # Add an authorization header, if configured. if replication_secret: headers[b"Authorization"] = [b"Bearer " + replication_secret] diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index faa99387a7..e460dd85cd 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -27,7 +27,9 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = MultiWriterIdGenerator( + self._cache_id_gen: Optional[ + MultiWriterIdGenerator + ] = MultiWriterIdGenerator( db_conn, database, stream_name="caches", @@ -41,7 +43,7 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): ], sequence_name="cache_invalidation_stream_seq", writers=[], - ) # type: Optional[MultiWriterIdGenerator] + ) else: self._cache_id_gen = None diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 13ed87adc4..436d39c320 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -23,9 +23,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = LruCache( + self.client_ip_last_seen: LruCache[tuple, int] = LruCache( cache_name="client_ip_last_seen", max_size=50000 - ) # type: LruCache[tuple, int] + ) async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 62d7809175..9d4859798b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -121,13 +121,13 @@ class ReplicationDataHandler: self._pusher_pool = hs.get_pusherpool() self._presence_handler = hs.get_presence_handler() - self.send_handler = None # type: Optional[FederationSenderHandler] + self.send_handler: Optional[FederationSenderHandler] = None if hs.should_send_federation(): self.send_handler = FederationSenderHandler(hs) # Map from stream to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. - self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]] + self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {} async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -173,7 +173,7 @@ class ReplicationDataHandler: if entities: self.notifier.on_new_event("to_device_key", token, users=entities) elif stream_name == DeviceListsStream.NAME: - all_room_ids = set() # type: Set[str] + all_room_ids: Set[str] = set() for row in rows: if row.entity.startswith("@"): room_ids = await self.store.get_rooms_for_user(row.entity) @@ -201,7 +201,7 @@ class ReplicationDataHandler: if row.data.rejected: continue - extra_users = () # type: Tuple[UserID, ...] + extra_users: Tuple[UserID, ...] = () if row.data.type == EventTypes.Member and row.data.state_key: extra_users = (UserID.from_string(row.data.state_key),) @@ -348,7 +348,7 @@ class FederationSenderHandler: # Stores the latest position in the federation stream we've gotten up # to. This is always set before we use it. - self.federation_position = None # type: Optional[int] + self.federation_position: Optional[int] = None self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 505d450e19..1311b013da 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -34,7 +34,7 @@ class Command(metaclass=abc.ABCMeta): A full command line on the wire is constructed from `NAME + " " + to_line()` """ - NAME = None # type: str + NAME: str @classmethod @abc.abstractmethod @@ -380,7 +380,7 @@ class RemoteServerUpCommand(_SimpleCommand): NAME = "REMOTE_SERVER_UP" -_COMMANDS = ( +_COMMANDS: Tuple[Type[Command], ...] = ( ServerCommand, RdataCommand, PositionCommand, @@ -393,7 +393,7 @@ _COMMANDS = ( UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, -) # type: Tuple[Type[Command], ...] +) # Map of command name to command type. COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 2ad7a200bb..eae4515363 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -105,12 +105,12 @@ class ReplicationCommandHandler: hs.get_instance_name() in hs.config.worker.writers.presence ) - self._streams = { + self._streams: Dict[str, Stream] = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() - } # type: Dict[str, Stream] + } # List of streams that this instance is the source of - self._streams_to_replicate = [] # type: List[Stream] + self._streams_to_replicate: List[Stream] = [] for stream in self._streams.values(): if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: @@ -180,14 +180,14 @@ class ReplicationCommandHandler: # Map of stream name to batched updates. See RdataCommand for info on # how batching works. - self._pending_batches = {} # type: Dict[str, List[Any]] + self._pending_batches: Dict[str, List[Any]] = {} # The factory used to create connections. - self._factory = None # type: Optional[ReconnectingClientFactory] + self._factory: Optional[ReconnectingClientFactory] = None # The currently connected connections. (The list of places we need to send # outgoing replication commands to.) - self._connections = [] # type: List[IReplicationConnection] + self._connections: List[IReplicationConnection] = [] LaterGauge( "synapse_replication_tcp_resource_total_connections", @@ -200,7 +200,7 @@ class ReplicationCommandHandler: # them in order in a separate background process. # the streams which are currently being processed by _unsafe_process_queue - self._processing_streams = set() # type: Set[str] + self._processing_streams: Set[str] = set() # for each stream, a queue of commands that are awaiting processing, and the # connection that they arrived on. @@ -210,7 +210,7 @@ class ReplicationCommandHandler: # For each connection, the incoming stream names that have received a POSITION # from that connection. - self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]] + self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {} LaterGauge( "synapse_replication_tcp_command_queue", diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 6e3705364f..8c80153ab6 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -102,7 +102,7 @@ tcp_outbound_commands_counter = Counter( # A list of all connected protocols. This allows us to send metrics about the # connections. -connected_connections = [] # type: List[BaseReplicationStreamProtocol] +connected_connections: "List[BaseReplicationStreamProtocol]" = [] logger = logging.getLogger(__name__) @@ -146,15 +146,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The transport is going to be an ITCPTransport, but that doesn't have the # (un)registerProducer methods, those are only on the implementation. - transport = None # type: Connection + transport: Connection delimiter = b"\n" # Valid commands we expect to receive - VALID_INBOUND_COMMANDS = [] # type: Collection[str] + VALID_INBOUND_COMMANDS: Collection[str] = [] # Valid commands we can send - VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] + VALID_OUTBOUND_COMMANDS: Collection[str] = [] max_line_buffer = 10000 @@ -165,7 +165,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 # When we requested the connection be closed - self.time_we_closed = None # type: Optional[int] + self.time_we_closed: Optional[int] = None self.received_ping = False # Have we received a ping from the other side @@ -175,10 +175,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.conn_id = random_string(5) # To dedupe in case of name clashes. # List of pending commands to send once we've established the connection - self.pending_commands = [] # type: List[Command] + self.pending_commands: List[Command] = [] # The LoopingCall for sending pings. - self._send_ping_loop = None # type: Optional[task.LoopingCall] + self._send_ping_loop: Optional[task.LoopingCall] = None # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 6a2c2655e4..8c0df627c8 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -57,7 +57,7 @@ class ConstantProperty(Generic[T, V]): it. """ - constant = attr.ib() # type: V + constant: V = attr.ib() def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: return self.constant @@ -91,9 +91,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): commands. """ - synapse_handler = None # type: ReplicationCommandHandler - synapse_stream_name = None # type: str - synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol + synapse_handler: "ReplicationCommandHandler" + synapse_stream_name: str + synapse_outbound_redis_connection: txredisapi.RedisProtocol def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b03824925a..3716c41bea 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -85,9 +85,9 @@ class Stream: time it was called. """ - NAME = None # type: str # The name of the stream + NAME: str # The name of the stream # The type of the row. Used by the default impl of parse_row. - ROW_TYPE = None # type: Any + ROW_TYPE: Any = None @classmethod def parse_row(cls, row: StreamRow): @@ -283,9 +283,7 @@ class PresenceStream(Stream): assert isinstance(presence_handler, PresenceHandler) - update_function = ( - presence_handler.get_all_presence_updates - ) # type: UpdateFunction + update_function: UpdateFunction = presence_handler.get_all_presence_updates else: # Query presence writer process update_function = make_http_update_function(hs, self.NAME) @@ -334,9 +332,9 @@ class TypingStream(Stream): if writer_instance == hs.get_instance_name(): # On the writer, query the typing handler typing_writer_handler = hs.get_typing_writer_handler() - update_function = ( - typing_writer_handler.get_all_typing_updates - ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] + update_function: Callable[ + [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] + ] = typing_writer_handler.get_all_typing_updates current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index e7e87bac92..a030e9299e 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -65,7 +65,7 @@ class BaseEventsStreamRow: """ # Unique string that ids the type. Must be overridden in sub classes. - TypeId = None # type: str + TypeId: str @classmethod def from_data(cls, data): @@ -103,10 +103,10 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): event_id = attr.ib() # str, optional -_EventRows = ( +_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = ( EventsStreamEventRow, EventsStreamCurrentStateRow, -) # type: Tuple[Type[BaseEventsStreamRow], ...] +) TypeToRow = {Row.TypeId: Row for Row in _EventRows} @@ -157,9 +157,9 @@ class EventsStream(Stream): # now we fetch up to that many rows from the events table - event_rows = await self._store.get_all_new_forward_event_rows( + event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows( instance_name, from_token, current_token, target_row_count - ) # type: List[Tuple] + ) # we rely on get_all_new_forward_event_rows strictly honouring the limit, so # that we know it is safe to just take upper_limit = event_rows[-1][0]. @@ -172,7 +172,7 @@ class EventsStream(Stream): if len(event_rows) == target_row_count: limited = True - upper_limit = event_rows[-1][0] # type: int + upper_limit: int = event_rows[-1][0] else: limited = False upper_limit = current_token @@ -191,30 +191,30 @@ class EventsStream(Stream): # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. - ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( + ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows( instance_name, from_token, upper_limit - ) # type: List[Tuple] + ) # we now need to turn the raw database rows returned into tuples suitable # for the replication protocol (basically, we add an identifier to # distinguish the row type). At the same time, we can limit the event_rows # to the max stream_id from state_rows. - event_updates = ( + event_updates: Iterable[Tuple[int, Tuple]] = ( (stream_id, (EventsStreamEventRow.TypeId, rest)) for (stream_id, *rest) in event_rows if stream_id <= upper_limit - ) # type: Iterable[Tuple[int, Tuple]] + ) - state_updates = ( + state_updates: Iterable[Tuple[int, Tuple]] = ( (stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) for (stream_id, *rest) in state_rows - ) # type: Iterable[Tuple[int, Tuple]] + ) - ex_outliers_updates = ( + ex_outliers_updates: Iterable[Tuple[int, Tuple]] = ( (stream_id, (EventsStreamEventRow.TypeId, rest)) for (stream_id, *rest) in ex_outliers_rows - ) # type: Iterable[Tuple[int, Tuple]] + ) # we need to return a sorted list, so merge them together. updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 096a85d363..c445af9bd9 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -51,9 +51,9 @@ class FederationStream(Stream): current_token = current_token_without_instance( federation_sender.get_current_token ) - update_function = ( - federation_sender.get_replication_rows - ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] + update_function: Callable[ + [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] + ] = federation_sender.get_replication_rows elif hs.should_send_federation(): # federation sender: Query master process diff --git a/synapse/server.py b/synapse/server.py index 2c27d2a7e8..095dba9ad0 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -247,15 +247,15 @@ class HomeServer(metaclass=abc.ABCMeta): # the key we use to sign events and requests self.signing_key = config.key.signing_key[0] self.config = config - self._listening_services = [] # type: List[twisted.internet.tcp.Port] - self.start_time = None # type: Optional[int] + self._listening_services: List[twisted.internet.tcp.Port] = [] + self.start_time: Optional[int] = None self._instance_id = random_string(5) self._instance_name = config.worker.instance_name self.version_string = version_string - self.datastores = None # type: Optional[Databases] + self.datastores: Optional[Databases] = None self._module_web_resources: Dict[str, IResource] = {} self._module_web_resources_consumed = False diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e65f6f88fe..4e0f814035 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -34,7 +34,7 @@ class ConsentServerNotices: self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastore() - self._users_in_progress = set() # type: Set[str] + self._users_in_progress: Set[str] = set() self._current_consent_version = hs.config.user_consent_version self._server_notice_content = hs.config.user_consent_server_notice_content diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index e4b0bc5c72..073b0d754f 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -205,7 +205,7 @@ class ResourceLimitsServerNotices: # The user has yet to join the server notices room pass - referenced_events = [] # type: List[str] + referenced_events: List[str] = [] if pinned_state_event is not None: referenced_events = list(pinned_state_event.content.get("pinned", [])) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index c875b15b32..cdf0973d05 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -32,10 +32,12 @@ class ServerNoticesSender(WorkerServerNoticesSender): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self._server_notices = ( + self._server_notices: Iterable[ + Union[ConsentServerNotices, ResourceLimitsServerNotices] + ] = ( ConsentServerNotices(hs), ResourceLimitsServerNotices(hs), - ) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]] + ) async def on_user_syncing(self, user_id: str) -> None: """Called when the user performs a sync operation. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1770f620e..6223daf522 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -309,9 +309,9 @@ class StateHandler: if old_state: # if we're given the state before the event, then we use that - state_ids_before_event = { + state_ids_before_event: StateMap[str] = { (s.type, s.state_key): s.event_id for s in old_state - } # type: StateMap[str] + } state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None @@ -513,23 +513,25 @@ class StateResolutionHandler: self.resolve_linearizer = Linearizer(name="state_resolve_lock") # dict of set of event_ids -> _StateCacheEntry. - self._state_cache = ExpiringCache( + self._state_cache: ExpiringCache[ + FrozenSet[int], _StateCacheEntry + ] = ExpiringCache( cache_name="state_cache", clock=self.clock, max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, - ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry] + ) # # stuff for tracking time spent on state-res by room # # tracks the amount of work done on state res per room - self._state_res_metrics = defaultdict( + self._state_res_metrics: DefaultDict[str, _StateResMetrics] = defaultdict( _StateResMetrics - ) # type: DefaultDict[str, _StateResMetrics] + ) self.clock.looping_call(self._report_metrics, 120 * 1000) @@ -700,9 +702,9 @@ class StateResolutionHandler: items = self._state_res_metrics.items() # log the N biggest rooms - biggest = heapq.nlargest( + biggest: List[Tuple[str, _StateResMetrics]] = heapq.nlargest( n_to_log, items, key=lambda i: extract_key(i[1]) - ) # type: List[Tuple[str, _StateResMetrics]] + ) metrics_logger.debug( "%i biggest rooms for state-res by %s: %s", len(biggest), @@ -754,7 +756,7 @@ def _make_state_cache_entry( # failing that, look for the closest match. prev_group = None - delta_ids = None # type: Optional[StateMap[str]] + delta_ids: Optional[StateMap[str]] = None for old_group, old_state in state_groups_ids.items(): n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 318e998813..267193cedf 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -159,7 +159,7 @@ def _seperate( """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} # type: MutableStateMap[Set[str]] + conflicted_state: MutableStateMap[Set[str]] = {} for state_set in state_set_iterator: for key, value in state_set.items(): diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 008644cd98..e66e6571c8 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -276,7 +276,7 @@ async def _get_auth_chain_difference( # event IDs if they appear in the `event_map`. This is the intersection of # the event's auth chain with the events in the `event_map` *plus* their # auth event IDs. - events_to_auth_chain = {} # type: Dict[str, Set[str]] + events_to_auth_chain: Dict[str, Set[str]] = {} for event in event_map.values(): chain = {event.event_id} events_to_auth_chain[event.event_id] = chain @@ -301,17 +301,17 @@ async def _get_auth_chain_difference( # ((type, state_key)->event_id) mappings; and (b) we have stripped out # unpersisted events and replaced them with the persisted events in # their auth chain. - state_sets_ids = [] # type: List[Set[str]] + state_sets_ids: List[Set[str]] = [] # For each state set, the unpersisted event IDs reachable (by their auth # chain) from the events in that set. - unpersisted_set_ids = [] # type: List[Set[str]] + unpersisted_set_ids: List[Set[str]] = [] for state_set in state_sets: - set_ids = set() # type: Set[str] + set_ids: Set[str] = set() state_sets_ids.append(set_ids) - unpersisted_ids = set() # type: Set[str] + unpersisted_ids: Set[str] = set() unpersisted_set_ids.append(unpersisted_ids) for event_id in state_set.values(): @@ -334,7 +334,7 @@ async def _get_auth_chain_difference( union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) - difference_from_event_map = union - intersection # type: Collection[str] + difference_from_event_map: Collection[str] = union - intersection else: difference_from_event_map = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] @@ -458,7 +458,7 @@ async def _reverse_topological_power_sort( The sorted list """ - graph = {} # type: Dict[str, Set[str]] + graph: Dict[str, Set[str]] = {} for idx, event_id in enumerate(event_ids, start=1): await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff @@ -657,7 +657,7 @@ async def _get_mainline_depth_for_event( """ room_id = event.room_id - tmp_event = event # type: Optional[EventBase] + tmp_event: Optional[EventBase] = event # We do an iterative search, replacing `event with the power level in its # auth events (if any) @@ -767,7 +767,7 @@ def lexicographical_topological_sort( # outgoing edges, c.f. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm outdegree_map = graph - reverse_graph = {} # type: Dict[str, Set[str]] + reverse_graph: Dict[str, Set[str]] = {} # Lists of nodes with zero out degree. Is actually a tuple of # `(key(node), node)` so that sorting does the right thing diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 20fceaa935..99b0aac2fb 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -32,9 +32,9 @@ class EventSources: } def __init__(self, hs): - self.sources = { + self.sources: Dict[str, Any] = { name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() - } # type: Dict[str, Any] + } self.store = hs.get_datastore() def get_current_token(self) -> StreamToken: diff --git a/synapse/types.py b/synapse/types.py index 64c442bd0f..fad23c8700 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -210,7 +210,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): 'domain' : The domain part of the name """ - SIGIL = abc.abstractproperty() # type: str # type: ignore + SIGIL: str = abc.abstractproperty() # type: ignore localpart = attr.ib(type=str) domain = attr.ib(type=str) @@ -304,7 +304,7 @@ class GroupID(DomainSpecificString): @classmethod def from_string(cls: Type[DS], s: str) -> DS: - group_id = super().from_string(s) # type: DS # type: ignore + group_id: DS = super().from_string(s) # type: ignore if not group_id.localpart: raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) @@ -600,7 +600,7 @@ class StreamToken: groups_key = attr.ib(type=int) _SEPARATOR = "_" - START = None # type: StreamToken + START: "StreamToken" @classmethod async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": diff --git a/synapse/visibility.py b/synapse/visibility.py index 490fb26e81..1dc6b90275 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -90,7 +90,7 @@ async def filter_events_for_client( AccountDataTypes.IGNORED_USER_LIST, user_id ) - ignore_list = frozenset() # type: FrozenSet[str] + ignore_list: FrozenSet[str] = frozenset() if ignore_dict_content: ignored_users_dict = ignore_dict_content.get("ignored_users", {}) if isinstance(ignored_users_dict, dict): -- cgit 1.5.1 From bdfde6dca11a9468372b3c9b327ad3327cbdbe4a Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 15 Jul 2021 18:46:54 +0200 Subject: Use inline type hints in `http/federation/`, `storage/` and `util/` (#10381) --- changelog.d/10381.misc | 1 + synapse/http/federation/well_known_resolver.py | 13 ++++---- synapse/storage/background_updates.py | 16 ++++----- synapse/storage/database.py | 14 ++++---- synapse/storage/databases/main/appservice.py | 4 +-- synapse/storage/databases/main/end_to_end_keys.py | 2 +- synapse/storage/databases/main/event_federation.py | 26 +++++++-------- .../storage/databases/main/event_push_actions.py | 2 +- synapse/storage/databases/main/events.py | 38 ++++++++++------------ .../storage/databases/main/events_bg_updates.py | 8 ++--- synapse/storage/databases/main/events_worker.py | 6 ++-- synapse/storage/databases/main/purge_events.py | 2 +- synapse/storage/databases/main/push_rule.py | 6 ++-- synapse/storage/databases/main/registration.py | 2 +- synapse/storage/databases/main/stream.py | 6 ++-- synapse/storage/databases/main/tags.py | 2 +- synapse/storage/databases/main/ui_auth.py | 4 +-- synapse/storage/persist_events.py | 16 ++++----- synapse/storage/prepare_database.py | 6 ++-- synapse/storage/state.py | 4 +-- synapse/storage/util/id_generators.py | 12 +++---- synapse/storage/util/sequence.py | 6 ++-- synapse/util/async_helpers.py | 8 ++--- synapse/util/batching_queue.py | 8 ++--- synapse/util/caches/__init__.py | 4 +-- synapse/util/caches/cached_call.py | 6 ++-- synapse/util/caches/deferred_cache.py | 12 +++---- synapse/util/caches/descriptors.py | 36 ++++++++++---------- synapse/util/caches/dictionary_cache.py | 6 ++-- synapse/util/caches/expiringcache.py | 4 +-- synapse/util/caches/lrucache.py | 8 ++--- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/stream_change_cache.py | 6 ++-- synapse/util/caches/ttlcache.py | 6 ++-- synapse/util/iterutils.py | 2 +- synapse/util/macaroons.py | 2 +- synapse/util/metrics.py | 2 +- synapse/util/patch_inline_callbacks.py | 4 +-- 38 files changed, 150 insertions(+), 162 deletions(-) create mode 100644 changelog.d/10381.misc (limited to 'synapse/http') diff --git a/changelog.d/10381.misc b/changelog.d/10381.misc new file mode 100644 index 0000000000..eed2d8552a --- /dev/null +++ b/changelog.d/10381.misc @@ -0,0 +1 @@ +Convert internal type variable syntax to reflect wider ecosystem use. \ No newline at end of file diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 20d39a4ea6..43f2140429 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -70,10 +70,8 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3 logger = logging.getLogger(__name__) -_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]] -_had_valid_well_known_cache = TTLCache( - "had-valid-well-known" -) # type: TTLCache[bytes, bool] +_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known") +_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known") @attr.s(slots=True, frozen=True) @@ -130,9 +128,10 @@ class WellKnownResolver: # requests for the same server in parallel? try: with Measure(self._clock, "get_well_known"): - result, cache_period = await self._fetch_well_known( - server_name - ) # type: Optional[bytes], float + result: Optional[bytes] + cache_period: float + + result, cache_period = await self._fetch_well_known(server_name) except _FetchWellKnownFailure as e: if prev_result and e.temporary: diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 142787fdfd..82b31d24f1 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -92,14 +92,12 @@ class BackgroundUpdater: self.db_pool = database # if a background update is currently running, its name. - self._current_background_update = None # type: Optional[str] - - self._background_update_performance = ( - {} - ) # type: Dict[str, BackgroundUpdatePerformance] - self._background_update_handlers = ( - {} - ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]] + self._current_background_update: Optional[str] = None + + self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {} + self._background_update_handlers: Dict[ + str, Callable[[JsonDict, int], Awaitable[int]] + ] = {} self._all_done = False def start_doing_background_updates(self) -> None: @@ -411,7 +409,7 @@ class BackgroundUpdater: c.execute(sql) if isinstance(self.db_pool.engine, engines.PostgresEngine): - runner = create_index_psql # type: Optional[Callable[[Connection], None]] + runner: Optional[Callable[[Connection], None]] = create_index_psql elif psql_only: runner = None else: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 33c42cf95a..f80d822c12 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -670,8 +670,8 @@ class DatabasePool: Returns: The result of func """ - after_callbacks = [] # type: List[_CallbackListEntry] - exception_callbacks = [] # type: List[_CallbackListEntry] + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] if not current_context(): logger.warning("Starting db txn '%s' from sentinel context", desc) @@ -1090,7 +1090,7 @@ class DatabasePool: return False # We didn't find any existing rows, so insert a new one - allvalues = {} # type: Dict[str, Any] + allvalues: Dict[str, Any] = {} allvalues.update(keyvalues) allvalues.update(values) allvalues.update(insertion_values) @@ -1121,7 +1121,7 @@ class DatabasePool: values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting """ - allvalues = {} # type: Dict[str, Any] + allvalues: Dict[str, Any] = {} allvalues.update(keyvalues) allvalues.update(insertion_values or {}) @@ -1257,7 +1257,7 @@ class DatabasePool: value_values: A list of each row's value column values. Ignored if value_names is empty. """ - allnames = [] # type: List[str] + allnames: List[str] = [] allnames.extend(key_names) allnames.extend(value_names) @@ -1566,7 +1566,7 @@ class DatabasePool: """ keyvalues = keyvalues or {} - results = [] # type: List[Dict[str, Any]] + results: List[Dict[str, Any]] = [] if not iterable: return results @@ -1978,7 +1978,7 @@ class DatabasePool: raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else "" - arg_list = [] # type: List[Any] + arg_list: List[Any] = [] if filters: where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) arg_list += list(filters.values()) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 9f182c2a89..e2d1b758bd 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -48,9 +48,7 @@ def _make_exclusive_regex( ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_pattern = re.compile( - exclusive_user_regex - ) # type: Optional[Pattern] + exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex) else: # We handle this case specially otherwise the constructed regex # will always match diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 0e3dd4e9ca..78ae68ec68 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -247,7 +247,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): txn.execute(sql, query_params) - result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] + result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {} for (user_id, device_id, display_name, key_json) in txn: if include_deleted_devices: deleted_devices.remove((user_id, device_id)) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4e06938849..d39368c20e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -62,9 +62,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) # Cache of event ID to list of auth event IDs and their depths. - self._event_auth_cache = LruCache( + self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache( 500000, "_event_auth_cache", size_callback=len - ) # type: LruCache[str, List[Tuple[str, int]]] + ) self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) @@ -137,10 +137,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas initial_events = set(event_ids) # All the events that we've found that are reachable from the events. - seen_events = set() # type: Set[str] + seen_events: Set[str] = set() # A map from chain ID to max sequence number of the given events. - event_chains = {} # type: Dict[int, int] + event_chains: Dict[int, int] = {} sql = """ SELECT event_id, chain_id, sequence_number @@ -182,7 +182,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ # A map from chain ID to max sequence number *reachable* from any event ID. - chains = {} # type: Dict[int, int] + chains: Dict[int, int] = {} # Add all linked chains reachable from initial set of chains. for batch in batch_iter(event_chains, 1000): @@ -353,14 +353,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas initial_events = set(state_sets[0]).union(*state_sets[1:]) # Map from event_id -> (chain ID, seq no) - chain_info = {} # type: Dict[str, Tuple[int, int]] + chain_info: Dict[str, Tuple[int, int]] = {} # Map from chain ID -> seq no -> event Id - chain_to_event = {} # type: Dict[int, Dict[int, str]] + chain_to_event: Dict[int, Dict[int, str]] = {} # All the chains that we've found that are reachable from the state # sets. - seen_chains = set() # type: Set[int] + seen_chains: Set[int] = set() sql = """ SELECT event_id, chain_id, sequence_number @@ -392,9 +392,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Corresponds to `state_sets`, except as a map from chain ID to max # sequence number reachable from the state set. - set_to_chain = [] # type: List[Dict[int, int]] + set_to_chain: List[Dict[int, int]] = [] for state_set in state_sets: - chains = {} # type: Dict[int, int] + chains: Dict[int, int] = {} set_to_chain.append(chains) for event_id in state_set: @@ -446,7 +446,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Mapping from chain ID to the range of sequence numbers that should be # pulled from the database. - chain_to_gap = {} # type: Dict[int, Tuple[int, int]] + chain_to_gap: Dict[int, Tuple[int, int]] = {} for chain_id in seen_chains: min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) @@ -555,7 +555,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas } # The sorted list of events whose auth chains we should walk. - search = [] # type: List[Tuple[int, str]] + search: List[Tuple[int, str]] = [] # We need to get the depth of the initial events for sorting purposes. sql = """ @@ -578,7 +578,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas search.sort() # Map from event to its auth events - event_to_auth_events = {} # type: Dict[str, Set[str]] + event_to_auth_events: Dict[str, Set[str]] = {} base_sql = """ SELECT a.event_id, auth_id, depth diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index d1237c65cc..55caa6bbe7 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -759,7 +759,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # object because we might not have the same amount of rows in each of them. To do # this, we use a dict indexed on the user ID and room ID to make it easier to # populate. - summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary] + summaries: Dict[Tuple[str, str], _EventPushSummary] = {} for row in txn: summaries[(row[0], row[1])] = _EventPushSummary( unread_count=row[2], diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 08c580b0dc..ec8579b9ad 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -109,10 +109,8 @@ class PersistEventsStore: # Ideally we'd move these ID gens here, unfortunately some other ID # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen = ( - self.store._backfill_id_gen - ) # type: MultiWriterIdGenerator - self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator + self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen # This should only exist on instances that are configured to write assert ( @@ -221,7 +219,7 @@ class PersistEventsStore: Returns: Filtered event ids """ - results = [] # type: List[str] + results: List[str] = [] def _get_events_which_are_prevs_txn(txn, batch): sql = """ @@ -508,7 +506,7 @@ class PersistEventsStore: """ # Map from event ID to chain ID/sequence number. - chain_map = {} # type: Dict[str, Tuple[int, int]] + chain_map: Dict[str, Tuple[int, int]] = {} # Set of event IDs to calculate chain ID/seq numbers for. events_to_calc_chain_id_for = set(event_to_room_id) @@ -817,8 +815,8 @@ class PersistEventsStore: # new chain if the sequence number has already been allocated. # - existing_chains = set() # type: Set[int] - tree = [] # type: List[Tuple[str, Optional[str]]] + existing_chains: Set[int] = set() + tree: List[Tuple[str, Optional[str]]] = [] # We need to do this in a topologically sorted order as we want to # generate chain IDs/sequence numbers of an event's auth events before @@ -848,7 +846,7 @@ class PersistEventsStore: ) txn.execute(sql % (clause,), args) - chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int] + chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn} # Allocate the new events chain ID/sequence numbers. # @@ -858,8 +856,8 @@ class PersistEventsStore: # number of new chain IDs in one call, replacing all temporary # objects with real allocated chain IDs. - unallocated_chain_ids = set() # type: Set[object] - new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]] + unallocated_chain_ids: Set[object] = set() + new_chain_tuples: Dict[str, Tuple[Any, int]] = {} for event_id, auth_event_id in tree: # If we reference an auth_event_id we fetch the allocated chain ID, # either from the existing `chain_map` or the newly generated @@ -870,7 +868,7 @@ class PersistEventsStore: if not existing_chain_id: existing_chain_id = chain_map[auth_event_id] - new_chain_tuple = None # type: Optional[Tuple[Any, int]] + new_chain_tuple: Optional[Tuple[Any, int]] = None if existing_chain_id: # We found a chain ID/sequence number candidate, check its # not already taken. @@ -897,9 +895,9 @@ class PersistEventsStore: ) # Map from potentially temporary chain ID to real chain ID - chain_id_to_allocated_map = dict( + chain_id_to_allocated_map: Dict[Any, int] = dict( zip(unallocated_chain_ids, newly_allocated_chain_ids) - ) # type: Dict[Any, int] + ) chain_id_to_allocated_map.update((c, c) for c in existing_chains) return { @@ -1175,9 +1173,9 @@ class PersistEventsStore: Returns: list[(EventBase, EventContext)]: filtered list """ - new_events_and_contexts = ( - OrderedDict() - ) # type: OrderedDict[str, Tuple[EventBase, EventContext]] + new_events_and_contexts: OrderedDict[ + str, Tuple[EventBase, EventContext] + ] = OrderedDict() for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) if prev_event_context: @@ -1205,7 +1203,7 @@ class PersistEventsStore: we are persisting backfilled (bool): True if the events were backfilled """ - depth_updates = {} # type: Dict[str, int] + depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) @@ -1885,7 +1883,7 @@ class PersistEventsStore: ), ) - room_to_event_ids = {} # type: Dict[str, List[str]] + room_to_event_ids: Dict[str, List[str]] = {} for e, _ in events_and_contexts: room_to_event_ids.setdefault(e.room_id, []).append(e.event_id) @@ -2012,7 +2010,7 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ - events_by_room = {} # type: Dict[str, List[EventBase]] + events_by_room: Dict[str, List[EventBase]] = {} for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 29f33bac55..6fcb2b8353 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -960,9 +960,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): event_to_types = {row[0]: (row[1], row[2]) for row in rows} # Calculate the new last position we've processed up to. - new_last_depth = rows[-1][3] if rows else last_depth # type: int - new_last_stream = rows[-1][4] if rows else last_stream # type: int - new_last_room_id = rows[-1][5] if rows else "" # type: str + new_last_depth: int = rows[-1][3] if rows else last_depth + new_last_stream: int = rows[-1][4] if rows else last_stream + new_last_room_id: str = rows[-1][5] if rows else "" # Map from room_id to last depth/stream_ordering processed for the room, # excluding the last room (which we're likely still processing). We also @@ -989,7 +989,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): retcols=("event_id", "auth_id"), ) - event_to_auth_chain = {} # type: Dict[str, List[str]] + event_to_auth_chain: Dict[str, List[str]] = {} for row in auth_events: event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 403a5ddaba..3c86adab56 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1365,10 +1365,10 @@ class EventsWorkerStore(SQLBaseStore): # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows = await self.db_pool.runInteraction( + rows: List[Tuple] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, - ) # type: List[Tuple] + ) # if we've got fewer rows than the limit, we're good if len(rows) < target_row_count: @@ -1469,7 +1469,7 @@ class EventsWorkerStore(SQLBaseStore): """ mapping = {} - txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str] + txn_id_to_event: Dict[Tuple[str, int, str], str] = {} for event in events: token_id = getattr(event.internal_metadata, "token_id", None) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index eb4841830d..664c65dac5 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -115,7 +115,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): logger.info("[purge] looking for events to delete") should_delete_expr = "state_key IS NULL" - should_delete_params = () # type: Tuple[Any, ...] + should_delete_params: Tuple[Any, ...] = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index db52176337..a7fb8cd848 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -79,9 +79,9 @@ class PushRulesWorkerStore( super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) # type: Union[StreamIdGenerator, SlavedIdTracker] + self._push_rules_stream_id_gen: Union[ + StreamIdGenerator, SlavedIdTracker + ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e31c5864ac..6ad1a0cf7f 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1744,7 +1744,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) - values = [v for _, v in items] # type: List[Union[str, int]] + values: List[Union[str, int]] = [v for _, v in items] # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where # clause and values before we handle that. This seems to be only used in the "set password" handler. diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 7581c7d3ff..959f13de47 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1085,9 +1085,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and # then filtering the results. if from_token.topological is not None: - from_bound = ( - from_token.as_historical_tuple() - ) # type: Tuple[Optional[int], int] + from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple() elif direction == "b": from_bound = ( None, @@ -1099,7 +1097,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): from_token.stream, ) - to_bound = None # type: Optional[Tuple[Optional[int], int]] + to_bound: Optional[Tuple[Optional[int], int]] = None if to_token: if to_token.topological is not None: to_bound = to_token.as_historical_tuple() diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 1d62c6140f..f93ff0a545 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -42,7 +42,7 @@ class TagsWorkerStore(AccountDataWorkerStore): "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) - tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]] + tags_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_tags = tags_by_room.setdefault(row["room_id"], {}) room_tags[row["tag"]] = db_to_json(row["content"]) diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 22c05cdde7..38bfdf5dad 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -224,12 +224,12 @@ class UIAuthWorkerStore(SQLBaseStore): self, txn: LoggingTransaction, session_id: str, key: str, value: Any ): # Get the current value. - result = self.db_pool.simple_select_one_txn( + result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), - ) # type: Dict[str, Any] # type: ignore + ) # Update it and add it back to the database. serverdict = db_to_json(result["serverdict"]) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 051095fea9..a39877f0d5 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -307,7 +307,7 @@ class EventsPersistenceStorage: matched the transcation ID; the existing event is returned in such a case. """ - partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] + partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {} for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) @@ -384,7 +384,7 @@ class EventsPersistenceStorage: A dictionary of event ID to event ID we didn't persist as we already had another event persisted with the same TXN ID. """ - replaced_events = {} # type: Dict[str, str] + replaced_events: Dict[str, str] = {} if not events_and_contexts: return replaced_events @@ -440,16 +440,14 @@ class EventsPersistenceStorage: # Set of remote users which were in rooms the server has left. We # should check if we still share any rooms and if not we mark their # device lists as stale. - potentially_left_users = set() # type: Set[str] + potentially_left_users: Set[str] = set() if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): # Work out the new "current state" for each room. # We do this by working out what the new extremities are and then # calculating the state from that. - events_by_room = ( - {} - ) # type: Dict[str, List[Tuple[EventBase, EventContext]]] + events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {} for event, context in chunk: events_by_room.setdefault(event.room_id, []).append( (event, context) @@ -622,9 +620,9 @@ class EventsPersistenceStorage: ) # Remove any events which are prev_events of any existing events. - existing_prevs = await self.persist_events_store._get_events_which_are_prevs( - result - ) # type: Collection[str] + existing_prevs: Collection[ + str + ] = await self.persist_events_store._get_events_which_are_prevs(result) result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 683e5e3b90..82a7686df0 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -256,7 +256,7 @@ def _setup_new_database( for database in databases ) - directory_entries = [] # type: List[_DirectoryListing] + directory_entries: List[_DirectoryListing] = [] for directory in directories: directory_entries.extend( _DirectoryListing(file_name, os.path.join(directory, file_name)) @@ -424,10 +424,10 @@ def _upgrade_existing_database( directories.append(os.path.join(schema_path, database, "delta", str(v))) # Used to check if we have any duplicate file names - file_name_counter = Counter() # type: CounterType[str] + file_name_counter: CounterType[str] = Counter() # Now find which directories have anything of interest. - directory_entries = [] # type: List[_DirectoryListing] + directory_entries: List[_DirectoryListing] = [] for directory in directories: logger.debug("Looking for schema deltas in %s", directory) try: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c9dce726cb..f8fbba9d38 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -91,7 +91,7 @@ class StateFilter: Returns: The new state filter. """ - type_dict = {} # type: Dict[str, Optional[Set[str]]] + type_dict: Dict[str, Optional[Set[str]]] = {} for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -194,7 +194,7 @@ class StateFilter: """ where_clause = "" - where_args = [] # type: List[str] + where_args: List[str] = [] if self.is_full(): return where_clause, where_args diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f1e62f9e85..c768fdea56 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -112,7 +112,7 @@ class StreamIdGenerator: # insertion ordering will ensure its in the correct ordering. # # The key and values are the same, but we never look at the values. - self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int] + self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def get_next(self): """ @@ -236,15 +236,15 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = {} # type: Dict[str, int] + self._current_positions: Dict[str, int] = {} # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). - self._unfinished_ids = set() # type: Set[int] + self._unfinished_ids: Set[int] = set() # Set of local IDs that we've processed that are larger than the current # position, due to there being smaller unpersisted IDs. - self._finished_ids = set() # type: Set[int] + self._finished_ids: Set[int] = set() # We track the max position where we know everything before has been # persisted. This is done by a) looking at the min across all instances @@ -265,7 +265,7 @@ class MultiWriterIdGenerator: self._persisted_upto_position = ( min(self._current_positions.values()) if self._current_positions else 1 ) - self._known_persisted_positions = [] # type: List[int] + self._known_persisted_positions: List[int] = [] self._sequence_gen = PostgresSequenceGenerator(sequence_name) @@ -465,7 +465,7 @@ class MultiWriterIdGenerator: self._unfinished_ids.discard(next_id) self._finished_ids.add(next_id) - new_cur = None # type: Optional[int] + new_cur: Optional[int] = None if self._unfinished_ids: # If there are unfinished IDs then the new position will be the diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 30b6b8e0ca..bb33e04fb1 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -208,10 +208,10 @@ class LocalSequenceGenerator(SequenceGenerator): get_next_id_txn; should return the curreent maximum id """ # the callback. this is cleared after it is called, so that it can be GCed. - self._callback = get_first_callback # type: Optional[GetFirstCallbackType] + self._callback: Optional[GetFirstCallbackType] = get_first_callback # The current max value, or None if we haven't looked in the DB yet. - self._current_max_id = None # type: Optional[int] + self._current_max_id: Optional[int] = None self._lock = threading.Lock() def get_next_id_txn(self, txn: Cursor) -> int: @@ -274,7 +274,7 @@ def build_sequence_generator( `check_consistency` details. """ if isinstance(database_engine, PostgresEngine): - seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator + seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name) else: seq = LocalSequenceGenerator(get_first_callback) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 061102c3c8..014db1355b 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -257,7 +257,7 @@ class Linearizer: max_count: The maximum number of concurrent accesses """ if name is None: - self.name = id(self) # type: Union[str, int] + self.name: Union[str, int] = id(self) else: self.name = name @@ -269,7 +269,7 @@ class Linearizer: self.max_count = max_count # key_to_defer is a map from the key to a _LinearizerEntry. - self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] + self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {} def is_queued(self, key: Hashable) -> bool: """Checks whether there is a process queued up waiting""" @@ -409,10 +409,10 @@ class ReadWriteLock: def __init__(self): # Latest readers queued - self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]] + self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} # Latest writer queued - self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] + self.key_to_current_writer: Dict[str, defer.Deferred] = {} async def read(self, key: str) -> ContextManager: new_defer = defer.Deferred() diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 8fd5bfb69b..274cea7eb7 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -93,11 +93,11 @@ class BatchingQueue(Generic[V, R]): self._clock = clock # The set of keys currently being processed. - self._processing_keys = set() # type: Set[Hashable] + self._processing_keys: Set[Hashable] = set() # The currently pending batch of values by key, with a Deferred to call # with the result of the corresponding `_process_batch_callback` call. - self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]] + self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {} # The function to call with batches of values. self._process_batch_callback = process_batch_callback @@ -108,9 +108,7 @@ class BatchingQueue(Generic[V, R]): number_of_keys.labels(self._name).set_function(lambda: len(self._next_values)) - self._number_in_flight_metric = number_in_flight.labels( - self._name - ) # type: Gauge + self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name) async def add_to_queue(self, value: V, key: Hashable = ()) -> R: """Adds the value to the queue with the given key, returning the result diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index ca36f07c20..9012034b7a 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -29,8 +29,8 @@ logger = logging.getLogger(__name__) TRACK_MEMORY_USAGE = False -caches_by_name = {} # type: Dict[str, Sized] -collectors_by_name = {} # type: Dict[str, CacheMetric] +caches_by_name: Dict[str, Sized] = {} +collectors_by_name: Dict[str, "CacheMetric"] = {} cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index a301c9e89b..891bee0b33 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -63,9 +63,9 @@ class CachedCall(Generic[TV]): f: The underlying function. Only one call to this function will be alive at once (per instance of CachedCall) """ - self._callable = f # type: Optional[Callable[[], Awaitable[TV]]] - self._deferred = None # type: Optional[Deferred] - self._result = None # type: Union[None, Failure, TV] + self._callable: Optional[Callable[[], Awaitable[TV]]] = f + self._deferred: Optional[Deferred] = None + self._result: Union[None, Failure, TV] = None async def get(self) -> TV: """Kick off the call if necessary, and return the result""" diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 1044139119..8c6fafc677 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -80,25 +80,25 @@ class DeferredCache(Generic[KT, VT]): cache_type = TreeCache if tree else dict # _pending_deferred_cache maps from the key value to a `CacheEntry` object. - self._pending_deferred_cache = ( - cache_type() - ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]] + self._pending_deferred_cache: Union[ + TreeCache, "MutableMapping[KT, CacheEntry]" + ] = cache_type() def metrics_cb(): cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) # cache is used for completed results and maps to the result itself, rather than # a Deferred. - self.cache = LruCache( + self.cache: LruCache[KT, VT] = LruCache( max_size=max_entries, cache_name=name, cache_type=cache_type, size_callback=(lambda d: len(d) or 1) if iterable else None, metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, - ) # type: LruCache[KT, VT] + ) - self.thread = None # type: Optional[threading.Thread] + self.thread: Optional[threading.Thread] = None @property def max_entries(self): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index d77e8edeea..1e8e6b1d01 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -46,17 +46,17 @@ F = TypeVar("F", bound=Callable[..., Any]) class _CachedFunction(Generic[F]): - invalidate = None # type: Any - invalidate_all = None # type: Any - prefill = None # type: Any - cache = None # type: Any - num_args = None # type: Any + invalidate: Any = None + invalidate_all: Any = None + prefill: Any = None + cache: Any = None + num_args: Any = None - __name__ = None # type: str + __name__: str # Note: This function signature is actually fiddled with by the synapse mypy # plugin to a) make it a bound method, and b) remove any `cache_context` arg. - __call__ = None # type: F + __call__: F class _CacheDescriptorBase: @@ -115,8 +115,8 @@ class _CacheDescriptorBase: class _LruCachedFunction(Generic[F]): - cache = None # type: LruCache[CacheKey, Any] - __call__ = None # type: F + cache: LruCache[CacheKey, Any] + __call__: F def lru_cache( @@ -180,10 +180,10 @@ class LruCacheDescriptor(_CacheDescriptorBase): self.max_entries = max_entries def __get__(self, obj, owner): - cache = LruCache( + cache: LruCache[CacheKey, Any] = LruCache( cache_name=self.orig.__name__, max_size=self.max_entries, - ) # type: LruCache[CacheKey, Any] + ) get_cache_key = self.cache_key_builder sentinel = LruCacheDescriptor._Sentinel.sentinel @@ -271,12 +271,12 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): self.iterable = iterable def __get__(self, obj, owner): - cache = DeferredCache( + cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, tree=self.tree, iterable=self.iterable, - ) # type: DeferredCache[CacheKey, Any] + ) get_cache_key = self.cache_key_builder @@ -359,7 +359,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def __get__(self, obj, objtype=None): cached_method = getattr(obj, self.cached_method_name) - cache = cached_method.cache # type: DeferredCache[CacheKey, Any] + cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) @@ -472,15 +472,15 @@ class _CacheContext: Cache = Union[DeferredCache, LruCache] - _cache_context_objects = ( - WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext] + _cache_context_objects: """WeakValueDictionary[ + Tuple["_CacheContext.Cache", CacheKey], "_CacheContext" + ]""" = WeakValueDictionary() def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: self._cache = cache self._cache_key = cache_key - def invalidate(self): # type: () -> None + def invalidate(self) -> None: """Invalidates the cache entry referred to by the context.""" self._cache.invalidate(self._cache_key) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 56d94d96ce..3f852edd7f 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -62,13 +62,13 @@ class DictionaryCache(Generic[KT, DKT]): """ def __init__(self, name: str, max_entries: int = 1000): - self.cache = LruCache( + self.cache: LruCache[KT, DictionaryEntry] = LruCache( max_size=max_entries, cache_name=name, size_callback=len - ) # type: LruCache[KT, DictionaryEntry] + ) self.name = name self.sequence = 0 - self.thread = None # type: Optional[threading.Thread] + self.thread: Optional[threading.Thread] = None def check_thread(self) -> None: expected_thread = self.thread diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index ac47a31cd7..bde16b8577 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -27,7 +27,7 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -SENTINEL = object() # type: Any +SENTINEL: Any = object() T = TypeVar("T") @@ -71,7 +71,7 @@ class ExpiringCache(Generic[KT, VT]): self._expiry_ms = expiry_ms self._reset_expiry_on_get = reset_expiry_on_get - self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry] + self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict() self.iterable = iterable diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4b9d0433ff..efeba0cb96 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -226,7 +226,7 @@ class _Node: # footprint down. Storing `None` is free as its a singleton, while empty # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive # thing and used sets). - self.callbacks = None # type: Optional[List[Callable[[], None]]] + self.callbacks: Optional[List[Callable[[], None]]] = None self.add_callbacks(callbacks) @@ -362,15 +362,15 @@ class LruCache(Generic[KT, VT]): # register_cache might call our "set_cache_factor" callback; there's nothing to # do yet when we get resized. - self._on_resize = None # type: Optional[Callable[[],None]] + self._on_resize: Optional[Callable[[], None]] = None if cache_name is not None: - metrics = register_cache( + metrics: Optional[CacheMetric] = register_cache( "lru_cache", cache_name, self, collect_callback=metrics_collection_callback, - ) # type: Optional[CacheMetric] + ) else: metrics = None diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 34c662c4db..ed7204336f 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -66,7 +66,7 @@ class ResponseCache(Generic[KV]): # This is poorly-named: it includes both complete and incomplete results. # We keep complete results rather than switching to absolute values because # that makes it easier to cache Failure results. - self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred] + self.pending_result_cache: Dict[KV, ObservableDeferred] = {} self.clock = clock self.timeout_sec = timeout_ms / 1000.0 diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index e81e468899..3a41a8baa6 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -45,10 +45,10 @@ class StreamChangeCache: ): self._original_max_size = max_size self._max_size = math.floor(max_size) - self._entity_to_key = {} # type: Dict[EntityType, int] + self._entity_to_key: Dict[EntityType, int] = {} # map from stream id to the a set of entities which changed at that stream id. - self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]] + self._cache: SortedDict[int, Set[EntityType]] = SortedDict() # the earliest stream_pos for which we can reliably answer # get_all_entities_changed. In other words, one less than the earliest @@ -155,7 +155,7 @@ class StreamChangeCache: if stream_pos < self._earliest_known_stream_pos: return None - changed_entities = [] # type: List[EntityType] + changed_entities: List[EntityType] = [] for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): changed_entities.extend(self._cache[k]) diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index c276107d56..46afe3f934 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -23,7 +23,7 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -SENTINEL = object() # type: Any +SENTINEL: Any = object() T = TypeVar("T") KT = TypeVar("KT") @@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]): def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): # map from key to _CacheEntry - self._data = {} # type: Dict[KT, _CacheEntry] + self._data: Dict[KT, _CacheEntry] = {} # the _CacheEntries, sorted by expiry time - self._expiry_list = SortedList() # type: SortedList[_CacheEntry] + self._expiry_list: SortedList[_CacheEntry] = SortedList() self._timer = timer diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 886afa9d19..8ac3eab2f5 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -68,7 +68,7 @@ def sorted_topologically( # This is implemented by Kahn's algorithm. degree_map = {node: 0 for node in nodes} - reverse_graph = {} # type: Dict[T, Set[T]] + reverse_graph: Dict[T, Set[T]] = {} for node, edges in graph.items(): if node not in degree_map: diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index f6ebfd7e7d..d1f76e3dc5 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -39,7 +39,7 @@ def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: caveat in the macaroon, or if the caveat was not found in the macaroon. """ prefix = key + " = " - result = None # type: Optional[str] + result: Optional[str] = None for caveat in macaroon.caveats: if not caveat.caveat_id.startswith(prefix): continue diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 45353d41c5..1b82dca81b 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -124,7 +124,7 @@ class Measure: assert isinstance(curr_context, LoggingContext) parent_context = curr_context self._logging_context = LoggingContext(str(curr_context), parent_context) - self.start = None # type: Optional[int] + self.start: Optional[int] = None def __enter__(self) -> "Measure": if self.start is not None: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index eed0291cae..99f01e325c 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -41,7 +41,7 @@ def do_patch(): @functools.wraps(f) def wrapped(*args, **kwargs): start_context = current_context() - changes = [] # type: List[str] + changes: List[str] = [] orig = orig_inline_callbacks(_check_yield_points(f, changes)) try: @@ -131,7 +131,7 @@ def _check_yield_points(f: Callable, changes: List[str]): gen = f(*args, **kwargs) last_yield_line_no = gen.gi_frame.f_lineno - result = None # type: Any + result: Any = None while True: expected_context = current_context() -- cgit 1.5.1 From 95e47b2e782b5e7afa5fd2afd1d0ea7745eaac36 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 19 Jul 2021 16:28:05 +0200 Subject: [pyupgrade] `synapse/` (#10348) This PR is tantamount to running ``` pyupgrade --py36-plus --keep-percent-format `find synapse/ -type f -name "*.py"` ``` Part of #9744 --- changelog.d/10348.misc | 1 + synapse/app/generic_worker.py | 6 ++-- synapse/app/homeserver.py | 6 ++-- synapse/config/appservice.py | 2 +- synapse/config/tls.py | 6 ++-- synapse/handlers/cas.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 4 +-- synapse/handlers/oidc.py | 38 ++++++++++++++------------ synapse/handlers/register.py | 15 ++++------ synapse/handlers/saml.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/http/proxyagent.py | 2 +- synapse/http/site.py | 2 +- synapse/logging/opentracing.py | 2 +- synapse/metrics/_exposition.py | 26 ++++++++---------- synapse/metrics/background_process_metrics.py | 3 +- synapse/rest/client/v1/login.py | 25 ++++++----------- synapse/rest/media/v1/__init__.py | 4 +-- synapse/storage/database.py | 2 +- synapse/storage/databases/main/deviceinbox.py | 4 +-- synapse/storage/databases/main/group_server.py | 6 +++- synapse/storage/databases/main/roommember.py | 2 +- synapse/storage/prepare_database.py | 2 +- synapse/types.py | 4 +-- synapse/util/caches/lrucache.py | 3 +- synapse/util/caches/treecache.py | 3 +- synapse/util/daemonize.py | 8 +++--- synapse/visibility.py | 4 +-- 29 files changed, 86 insertions(+), 102 deletions(-) create mode 100644 changelog.d/10348.misc (limited to 'synapse/http') diff --git a/changelog.d/10348.misc b/changelog.d/10348.misc new file mode 100644 index 0000000000..b2275a1350 --- /dev/null +++ b/changelog.d/10348.misc @@ -0,0 +1 @@ +Run `pyupgrade` on the codebase. \ No newline at end of file diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index b43d858f59..c3d4992518 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -395,10 +395,8 @@ class GenericWorkerServer(HomeServer): elif listener.type == "metrics": if not self.config.enable_metrics: logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) + "Metrics listener configured, but " + "enable_metrics is not True!" ) else: _base.listen_metrics(listener.bind_addresses, listener.port) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7af56ac136..920b34d97b 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -305,10 +305,8 @@ class SynapseHomeServer(HomeServer): elif listener.type == "metrics": if not self.config.enable_metrics: logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) + "Metrics listener configured, but " + "enable_metrics is not True!" ) else: _base.listen_metrics(listener.bind_addresses, listener.port) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index a39d457c56..1ebea88db2 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -64,7 +64,7 @@ def load_appservices(hostname, config_files): for config_file in config_files: try: - with open(config_file, "r") as f: + with open(config_file) as f: appservice = _load_appservice(hostname, yaml.safe_load(f), config_file) if appservice.id in seen_ids: raise ConfigError( diff --git a/synapse/config/tls.py b/synapse/config/tls.py index fed05ac7be..5679f05e42 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -66,10 +66,8 @@ class TlsConfig(Config): if self.federation_client_minimum_tls_version == "1.3": if getattr(SSL, "OP_NO_TLSv1_3", None) is None: raise ConfigError( - ( - "federation_client_minimum_tls_version cannot be 1.3, " - "your OpenSSL does not support it" - ) + "federation_client_minimum_tls_version cannot be 1.3, " + "your OpenSSL does not support it" ) # Whitelist of domains to not verify certificates for diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index b681d208bc..0325f86e20 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -40,7 +40,7 @@ class CasError(Exception): def __str__(self): if self.error_description: - return "{}: {}".format(self.error, self.error_description) + return f"{self.error}: {self.error_description}" return self.error diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5c4463583e..cf389be3e4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -735,7 +735,7 @@ class FederationHandler(BaseHandler): # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - (await self.store.get_events(missing_desired_events, allow_rejected=True)) + await self.store.get_events(missing_desired_events, allow_rejected=True) ) # check for events which were in the wrong room. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 33d16fbf9c..0961dec5ab 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler): ) url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") + url_bytes = b"/_matrix/identity/api/v1/3pid/unbind" content = { "mxid": mxid, @@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler): return data["mxid"] except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - except IOError as e: + except OSError as e: logger.warning("Error from v1 identity server lookup: %s" % (e,)) return None diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index a330c48fa7..eca8f16040 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -72,26 +72,26 @@ _SESSION_COOKIES = [ (b"oidc_session_no_samesite", b"HttpOnly"), ] + #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and #: OpenID.Core sec 3.1.3.3. -Token = TypedDict( - "Token", - { - "access_token": str, - "token_type": str, - "id_token": Optional[str], - "refresh_token": Optional[str], - "expires_in": int, - "scope": Optional[str], - }, -) +class Token(TypedDict): + access_token: str + token_type: str + id_token: Optional[str] + refresh_token: Optional[str] + expires_in: int + scope: Optional[str] + #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but #: there is no real point of doing this in our case. JWK = Dict[str, str] + #: A JWK Set, as per RFC7517 sec 5. -JWKS = TypedDict("JWKS", {"keys": List[JWK]}) +class JWKS(TypedDict): + keys: List[JWK] class OidcHandler: @@ -255,7 +255,7 @@ class OidcError(Exception): def __str__(self): if self.error_description: - return "{}: {}".format(self.error, self.error_description) + return f"{self.error}: {self.error_description}" return self.error @@ -639,7 +639,7 @@ class OidcProvider: ) logger.warning(description) # Body was still valid JSON. Might be useful to log it for debugging. - logger.warning("Code exchange response: {resp!r}".format(resp=resp)) + logger.warning("Code exchange response: %r", resp) raise OidcError("server_error", description) return resp @@ -1217,10 +1217,12 @@ class OidcSessionData: ui_auth_session_id = attr.ib(type=str) -UserAttributeDict = TypedDict( - "UserAttributeDict", - {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, -) +class UserAttributeDict(TypedDict): + localpart: Optional[str] + display_name: Optional[str] + emails: List[str] + + C = TypeVar("C") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 056fe5e89f..8cf614136e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -55,15 +55,12 @@ login_counter = Counter( ["guest", "auth_provider"], ) -LoginDict = TypedDict( - "LoginDict", - { - "device_id": str, - "access_token": str, - "valid_until_ms": Optional[int], - "refresh_token": Optional[str], - }, -) + +class LoginDict(TypedDict): + device_id: str + access_token: str + valid_until_ms: Optional[int] + refresh_token: Optional[str] class RegistrationHandler(BaseHandler): diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 72f54c9403..e6e71e9729 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -372,7 +372,7 @@ class SamlHandler(BaseHandler): DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) + "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),) ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 722c4ae670..150a4f291e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1601,7 +1601,7 @@ class SyncHandler: logger.debug( "Membership changes in %s: [%s]", room_id, - ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)), + ", ".join("%s (%s)" % (e.event_id, e.membership) for e in events), ) non_joins = [e for e in events if e.membership != Membership.JOIN] diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7a6a1717de..f7193e60bd 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -172,7 +172,7 @@ class ProxyAgent(_AgentBase): """ uri = uri.strip() if not _VALID_URI.match(uri): - raise ValueError("Invalid URI {!r}".format(uri)) + raise ValueError(f"Invalid URI {uri!r}") parsed_uri = URI.fromBytes(uri) pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) diff --git a/synapse/http/site.py b/synapse/http/site.py index 3b0a38124e..190084e8aa 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -384,7 +384,7 @@ class SynapseRequest(Request): # authenticated (e.g. and admin is puppetting a user) then we log both. requester, authenticated_entity = self.get_authenticated_entity() if authenticated_entity: - requester = "{}.{}".format(authenticated_entity, requester) + requester = f"{authenticated_entity}.{requester}" self.site.access_logger.log( log_level, diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 185844f188..ecd51f1b4a 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -374,7 +374,7 @@ def init_tracer(hs: "HomeServer"): config = JaegerConfig( config=hs.config.jaeger_config, - service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + service_name=f"{hs.config.server_name} {hs.get_instance_name()}", scope_manager=LogContextScopeManager(hs.config), metrics_factory=PrometheusMetricsFactory(), ) diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 7e49d0d02c..bb9bcb5592 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -34,7 +34,7 @@ from twisted.web.resource import Resource from synapse.util import caches -CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" INF = float("inf") @@ -55,8 +55,8 @@ def floatToGoString(d): # Go switches to exponents sooner than Python. # We only need to care about positive values for le/quantile. if d > 0 and dot > 6: - mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.") - return "{0}e+0{1}".format(mantissa, dot - 1) + mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.") + return f"{mantissa}e+0{dot - 1}" return s @@ -65,7 +65,7 @@ def sample_line(line, name): labelstr = "{{{0}}}".format( ",".join( [ - '{0}="{1}"'.format( + '{}="{}"'.format( k, v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), ) @@ -78,10 +78,8 @@ def sample_line(line, name): timestamp = "" if line.timestamp is not None: # Convert to milliseconds. - timestamp = " {0:d}".format(int(float(line.timestamp) * 1000)) - return "{0}{1} {2}{3}\n".format( - name, labelstr, floatToGoString(line.value), timestamp - ) + timestamp = f" {int(float(line.timestamp) * 1000):d}" + return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) def generate_latest(registry, emit_help=False): @@ -118,12 +116,12 @@ def generate_latest(registry, emit_help=False): # Output in the old format for compatibility. if emit_help: output.append( - "# HELP {0} {1}\n".format( + "# HELP {} {}\n".format( mname, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0} {1}\n".format(mname, mtype)) + output.append(f"# TYPE {mname} {mtype}\n") om_samples: Dict[str, List[str]] = {} for s in metric.samples: @@ -143,13 +141,13 @@ def generate_latest(registry, emit_help=False): for suffix, lines in sorted(om_samples.items()): if emit_help: output.append( - "# HELP {0}{1} {2}\n".format( + "# HELP {}{} {}\n".format( metric.name, suffix, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix)) + output.append(f"# TYPE {metric.name}{suffix} gauge\n") output.extend(lines) # Get rid of the weird colon things while we're at it @@ -163,12 +161,12 @@ def generate_latest(registry, emit_help=False): # Also output in the new format, if it's different. if emit_help: output.append( - "# HELP {0} {1}\n".format( + "# HELP {} {}\n".format( mnewname, metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), ) ) - output.append("# TYPE {0} {1}\n".format(mnewname, mtype)) + output.append(f"# TYPE {mnewname} {mtype}\n") for s in metric.samples: # Get rid of the OpenMetrics specific samples (we should already have diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 4455fa71a8..3a14260752 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -137,8 +137,7 @@ class _Collector: _background_process_db_txn_duration, _background_process_db_sched_duration, ): - for r in m.collect(): - yield r + yield from m.collect() REGISTRY.register(_Collector()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 99d02cb355..11567bf32c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,19 +44,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -LoginResponse = TypedDict( - "LoginResponse", - { - "user_id": str, - "access_token": str, - "home_server": str, - "expires_in_ms": Optional[int], - "refresh_token": Optional[str], - "device_id": str, - "well_known": Optional[Dict[str, Any]], - }, - total=False, -) +class LoginResponse(TypedDict, total=False): + user_id: str + access_token: str + home_server: str + expires_in_ms: Optional[int] + refresh_token: Optional[str] + device_id: str + well_known: Optional[Dict[str, Any]] class LoginRestServlet(RestServlet): @@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend( - ({"type": t} for t in self.auth_handler.get_supported_login_types()) - ) + flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) diff --git a/synapse/rest/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py index d20186bbd0..3dd16d4bb5 100644 --- a/synapse/rest/media/v1/__init__.py +++ b/synapse/rest/media/v1/__init__.py @@ -17,7 +17,7 @@ import PIL.Image # check for JPEG support. try: PIL.Image._getdecoder("rgb", "jpeg", None) -except IOError as e: +except OSError as e: if str(e).startswith("decoder jpeg not available"): raise Exception( "FATAL: jpeg codec not supported. Install pillow correctly! " @@ -32,7 +32,7 @@ except Exception: # check for PNG support. try: PIL.Image._getdecoder("rgb", "zip", None) -except IOError as e: +except OSError as e: if str(e).startswith("decoder zip not available"): raise Exception( "FATAL: zip codec not supported. Install pillow correctly! " diff --git a/synapse/storage/database.py b/synapse/storage/database.py index f80d822c12..ccf9ac51ef 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -907,7 +907,7 @@ class DatabasePool: # The sort is to ensure that we don't rely on dictionary iteration # order. keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i) ) for k in keys: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 50e7ddd735..c55508867d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -203,9 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): "delete_messages_for_device", delete_messages_for_device_txn ) - log_kv( - {"message": "deleted {} messages for device".format(count), "count": count} - ) + log_kv({"message": f"deleted {count} messages for device", "count": count}) # Update the cache, ensuring that we only ever increase the value last_deleted_stream_id = self._last_device_delete_cache.get( diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 66ad363bfb..e70d3649ff 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -27,8 +27,11 @@ from synapse.util import json_encoder _DEFAULT_CATEGORY_ID = "" _DEFAULT_ROLE_ID = "" + # A room in a group. -_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool}) +class _RoomInGroup(TypedDict): + room_id: str + is_public: bool class GroupServerWorkerStore(SQLBaseStore): @@ -92,6 +95,7 @@ class GroupServerWorkerStore(SQLBaseStore): "is_public": False # Whether this is a public room or not } """ + # TODO: Pagination def _get_rooms_in_group_txn(txn): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4d82c4c26d..68f1b40ea6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -649,7 +649,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): event_to_memberships = await self._get_joined_profiles_from_event_ids( missing_member_event_ids ) - users_in_room.update((row for row in event_to_memberships.values() if row)) + users_in_room.update(row for row in event_to_memberships.values() if row) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 82a7686df0..61392b9639 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -639,7 +639,7 @@ def get_statements(f: Iterable[str]) -> Generator[str, None, None]: def executescript(txn: Cursor, schema_path: str) -> None: - with open(schema_path, "r") as f: + with open(schema_path) as f: execute_statements_from_stream(txn, f) diff --git a/synapse/types.py b/synapse/types.py index fad23c8700..429bb013d2 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -577,10 +577,10 @@ class RoomStreamToken: entries = [] for name, pos in self.instance_map.items(): instance_id = await store.get_id_for_instance(name) - entries.append("{}.{}".format(instance_id, pos)) + entries.append(f"{instance_id}.{pos}") encoded_map = "~".join(entries) - return "m{}~{}".format(self.stream, encoded_map) + return f"m{self.stream}~{encoded_map}" else: return "s%d" % (self.stream,) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index efeba0cb96..5c65d187b6 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -90,8 +90,7 @@ def enumerate_leaves(node, depth): yield node else: for n in node.values(): - for m in enumerate_leaves(n, depth - 1): - yield m + yield from enumerate_leaves(n, depth - 1) P = TypeVar("P") diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index a6df81ebff..4138931e7b 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d): """ if isinstance(d, TreeCacheNode): for value_d in d.values(): - for value in iterate_tree_cache_entry(value_d): - yield value + yield from iterate_tree_cache_entry(value_d) else: yield d diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index 31b24dd188..d8532411c2 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -31,13 +31,13 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # If pidfile already exists, we should read pid from there; to overwrite it, if # locking will fail, because locking attempt somehow purges the file contents. if os.path.isfile(pid_file): - with open(pid_file, "r") as pid_fh: + with open(pid_file) as pid_fh: old_pid = pid_fh.read() # Create a lockfile so that only one instance of this daemon is running at any time. try: lock_fh = open(pid_file, "w") - except IOError: + except OSError: print("Unable to create the pidfile.") sys.exit(1) @@ -45,7 +45,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # Try to get an exclusive lock on the file. This will fail if another process # has the file locked. fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB) - except IOError: + except OSError: print("Unable to lock on the pidfile.") # We need to overwrite the pidfile if we got here. # @@ -113,7 +113,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - try: lock_fh.write("%s" % (os.getpid())) lock_fh.flush() - except IOError: + except OSError: logger.error("Unable to write pid to the pidfile.") print("Unable to write pid to the pidfile.") sys.exit(1) diff --git a/synapse/visibility.py b/synapse/visibility.py index 1dc6b90275..17532059e9 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -96,7 +96,7 @@ async def filter_events_for_client( if isinstance(ignored_users_dict, dict): ignore_list = frozenset(ignored_users_dict.keys()) - erased_senders = await storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased(e.sender for e in events) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -353,7 +353,7 @@ async def filter_events_for_server( ) if not check_history_visibility_only: - erased_senders = await storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased(e.sender for e in events) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. -- cgit 1.5.1