From cdd308845ba22fef22a39ed5bf904b438e48b491 Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Wed, 13 Oct 2021 12:21:52 +0100 Subject: Port the Password Auth Providers module interface to the new generic interface (#10548) Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> Co-authored-by: Brendan Abolivier --- synapse/config/password_auth_providers.py | 53 ++++++++++++++----------------- 1 file changed, 23 insertions(+), 30 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 83994df798..f980102b45 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config): section = "authproviders" def read_config(self, config, **kwargs): + """Parses the old password auth providers config. The config format looks like this: + + password_providers: + # Example config for an LDAP auth provider + - module: "ldap_auth_provider.LdapAuthProvider" + config: + enabled: true + uri: "ldap://ldap.example.com:389" + start_tls: true + base: "ou=users,dc=example,dc=com" + attributes: + uid: "cn" + mail: "email" + name: "givenName" + #bind_dn: + #bind_password: + #filter: "(objectClass=posixAccount)" + + We expect admins to use modules for this feature (which is why it doesn't appear + in the sample config file), but we want to keep support for it around for a bit + for backwards compatibility. + """ + self.password_providers: List[Tuple[Type, Any]] = [] providers = [] @@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config): ) self.password_providers.append((provider_class, provider_config)) - - def generate_config_section(self, **kwargs): - return """\ - # Password providers allow homeserver administrators to integrate - # their Synapse installation with existing authentication methods - # ex. LDAP, external tokens, etc. - # - # For more information and known implementations, please see - # https://matrix-org.github.io/synapse/latest/password_auth_providers.html - # - # Note: instances wishing to use SAML or CAS authentication should - # instead use the `saml2_config` or `cas_config` options, - # respectively. - # - password_providers: - # # Example config for an LDAP auth provider - # - module: "ldap_auth_provider.LdapAuthProvider" - # config: - # enabled: true - # uri: "ldap://ldap.example.com:389" - # start_tls: true - # base: "ou=users,dc=example,dc=com" - # attributes: - # uid: "cn" - # mail: "email" - # name: "givenName" - # #bind_dn: - # #bind_password: - # #filter: "(objectClass=posixAccount)" - """ -- cgit 1.5.1 From 55731333488bfd53ece117938dde1cef710eef68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Oct 2021 10:30:48 -0400 Subject: Move experimental & retention config out of the server module. (#11070) --- changelog.d/11070.misc | 1 + docs/sample_config.yaml | 83 ++++++------ synapse/config/_base.pyi | 2 + synapse/config/experimental.py | 3 + synapse/config/homeserver.py | 2 + synapse/config/retention.py | 226 +++++++++++++++++++++++++++++++++ synapse/config/server.py | 201 ----------------------------- synapse/events/utils.py | 6 +- synapse/handlers/pagination.py | 13 +- synapse/storage/databases/main/room.py | 8 +- 10 files changed, 290 insertions(+), 255 deletions(-) create mode 100644 changelog.d/11070.misc create mode 100644 synapse/config/retention.py (limited to 'synapse/config') diff --git a/changelog.d/11070.misc b/changelog.d/11070.misc new file mode 100644 index 0000000000..52b23f9671 --- /dev/null +++ b/changelog.d/11070.misc @@ -0,0 +1 @@ +Create a separate module for the retention configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7bfaed483b..b90ed62d61 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -472,6 +472,48 @@ limit_remote_rooms: # #user_ips_max_age: 14d +# Inhibits the /requestToken endpoints from returning an error that might leak +# information about whether an e-mail address is in use or not on this +# homeserver. +# Note that for some endpoints the error situation is the e-mail already being +# used, and for others the error is entering the e-mail being unused. +# If this option is enabled, instead of returning an error, these endpoints will +# act as if no error happened and return a fake session ID ('sid') to clients. +# +#request_token_inhibit_3pid_errors: true + +# A list of domains that the domain portion of 'next_link' parameters +# must match. +# +# This parameter is optionally provided by clients while requesting +# validation of an email or phone number, and maps to a link that +# users will be automatically redirected to after validation +# succeeds. Clients can make use this parameter to aid the validation +# process. +# +# The whitelist is applied whether the homeserver or an +# identity server is handling validation. +# +# The default value is no whitelist functionality; all domains are +# allowed. Setting this value to an empty list will instead disallow +# all domains. +# +#next_link_domain_whitelist: ["matrix.org"] + +# Templates to use when generating email or HTML page contents. +# +templates: + # Directory in which Synapse will try to find template files to use to generate + # email or HTML page contents. + # If not set, or a file is not found within the template directory, a default + # template from within the Synapse package will be used. + # + # See https://matrix-org.github.io/synapse/latest/templates.html for more + # information about using custom templates. + # + #custom_template_directory: /path/to/custom/templates/ + + # Message retention policy at the server level. # # Room admins and mods can define a retention period for their rooms using the @@ -541,47 +583,6 @@ retention: # - shortest_max_lifetime: 3d # interval: 1d -# Inhibits the /requestToken endpoints from returning an error that might leak -# information about whether an e-mail address is in use or not on this -# homeserver. -# Note that for some endpoints the error situation is the e-mail already being -# used, and for others the error is entering the e-mail being unused. -# If this option is enabled, instead of returning an error, these endpoints will -# act as if no error happened and return a fake session ID ('sid') to clients. -# -#request_token_inhibit_3pid_errors: true - -# A list of domains that the domain portion of 'next_link' parameters -# must match. -# -# This parameter is optionally provided by clients while requesting -# validation of an email or phone number, and maps to a link that -# users will be automatically redirected to after validation -# succeeds. Clients can make use this parameter to aid the validation -# process. -# -# The whitelist is applied whether the homeserver or an -# identity server is handling validation. -# -# The default value is no whitelist functionality; all domains are -# allowed. Setting this value to an empty list will instead disallow -# all domains. -# -#next_link_domain_whitelist: ["matrix.org"] - -# Templates to use when generating email or HTML page contents. -# -templates: - # Directory in which Synapse will try to find template files to use to generate - # email or HTML page contents. - # If not set, or a file is not found within the template directory, a default - # template from within the Synapse package will be used. - # - # See https://matrix-org.github.io/synapse/latest/templates.html for more - # information about using custom templates. - # - #custom_template_directory: /path/to/custom/templates/ - ## TLS ## diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 06fbd1166b..c1d9069798 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -26,6 +26,7 @@ from synapse.config import ( redis, registration, repository, + retention, room_directory, saml2, server, @@ -91,6 +92,7 @@ class RootConfig: modules: modules.ModulesConfig caches: cache.CacheConfig federation: federation.FederationConfig + retention: retention.RetentionConfig config_classes: List = ... def __init__(self) -> None: ... diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 7b0381c06a..b013a3918c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -24,6 +24,9 @@ class ExperimentalConfig(Config): def read_config(self, config: JsonDict, **kwargs): experimental = config.get("experimental_features") or {} + # Whether to enable experimental MSC1849 (aka relations) support + self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True) + # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 442f1b9ac0..001605c265 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -38,6 +38,7 @@ from .ratelimiting import RatelimitConfig from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig +from .retention import RetentionConfig from .room import RoomConfig from .room_directory import RoomDirectoryConfig from .saml2 import SAML2Config @@ -59,6 +60,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ModulesConfig, ServerConfig, + RetentionConfig, TlsConfig, FederationConfig, CacheConfig, diff --git a/synapse/config/retention.py b/synapse/config/retention.py new file mode 100644 index 0000000000..aed9bf458f --- /dev/null +++ b/synapse/config/retention.py @@ -0,0 +1,226 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Optional + +import attr + +from synapse.config._base import Config, ConfigError + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RetentionPurgeJob: + """Object describing the configuration of the manhole""" + + interval: int + shortest_max_lifetime: Optional[int] + longest_max_lifetime: Optional[int] + + +class RetentionConfig(Config): + section = "retention" + + def read_config(self, config, **kwargs): + retention_config = config.get("retention") + if retention_config is None: + retention_config = {} + + self.retention_enabled = retention_config.get("enabled", False) + + retention_default_policy = retention_config.get("default_policy") + + if retention_default_policy is not None: + self.retention_default_min_lifetime = retention_default_policy.get( + "min_lifetime" + ) + if self.retention_default_min_lifetime is not None: + self.retention_default_min_lifetime = self.parse_duration( + self.retention_default_min_lifetime + ) + + self.retention_default_max_lifetime = retention_default_policy.get( + "max_lifetime" + ) + if self.retention_default_max_lifetime is not None: + self.retention_default_max_lifetime = self.parse_duration( + self.retention_default_max_lifetime + ) + + if ( + self.retention_default_min_lifetime is not None + and self.retention_default_max_lifetime is not None + and ( + self.retention_default_min_lifetime + > self.retention_default_max_lifetime + ) + ): + raise ConfigError( + "The default retention policy's 'min_lifetime' can not be greater" + " than its 'max_lifetime'" + ) + else: + self.retention_default_min_lifetime = None + self.retention_default_max_lifetime = None + + if self.retention_enabled: + logger.info( + "Message retention policies support enabled with the following default" + " policy: min_lifetime = %s ; max_lifetime = %s", + self.retention_default_min_lifetime, + self.retention_default_max_lifetime, + ) + + self.retention_allowed_lifetime_min = retention_config.get( + "allowed_lifetime_min" + ) + if self.retention_allowed_lifetime_min is not None: + self.retention_allowed_lifetime_min = self.parse_duration( + self.retention_allowed_lifetime_min + ) + + self.retention_allowed_lifetime_max = retention_config.get( + "allowed_lifetime_max" + ) + if self.retention_allowed_lifetime_max is not None: + self.retention_allowed_lifetime_max = self.parse_duration( + self.retention_allowed_lifetime_max + ) + + if ( + self.retention_allowed_lifetime_min is not None + and self.retention_allowed_lifetime_max is not None + and self.retention_allowed_lifetime_min + > self.retention_allowed_lifetime_max + ): + raise ConfigError( + "Invalid retention policy limits: 'allowed_lifetime_min' can not be" + " greater than 'allowed_lifetime_max'" + ) + + self.retention_purge_jobs: List[RetentionPurgeJob] = [] + for purge_job_config in retention_config.get("purge_jobs", []): + interval_config = purge_job_config.get("interval") + + if interval_config is None: + raise ConfigError( + "A retention policy's purge jobs configuration must have the" + " 'interval' key set." + ) + + interval = self.parse_duration(interval_config) + + shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") + + if shortest_max_lifetime is not None: + shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) + + longest_max_lifetime = purge_job_config.get("longest_max_lifetime") + + if longest_max_lifetime is not None: + longest_max_lifetime = self.parse_duration(longest_max_lifetime) + + if ( + shortest_max_lifetime is not None + and longest_max_lifetime is not None + and shortest_max_lifetime > longest_max_lifetime + ): + raise ConfigError( + "A retention policy's purge jobs configuration's" + " 'shortest_max_lifetime' value can not be greater than its" + " 'longest_max_lifetime' value." + ) + + self.retention_purge_jobs.append( + RetentionPurgeJob(interval, shortest_max_lifetime, longest_max_lifetime) + ) + + if not self.retention_purge_jobs: + self.retention_purge_jobs = [ + RetentionPurgeJob(self.parse_duration("1d"), None, None) + ] + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, and the state of a room contains a + # 'm.room.retention' event in its state which contains a 'min_lifetime' or a + # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy + # to these limits when running purge jobs. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a more frequent basis than for the rest of the rooms + # (e.g. every 12h), but not want that purge to be performed by a job that's + # iterating over every room it knows, which could be heavy on the server. + # + # If any purge job is configured, it is strongly recommended to have at least + # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' + # set, or one job without 'shortest_max_lifetime' and one job without + # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if + # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a + # room's policy to these values is done after the policies are retrieved from + # Synapse's database (which is done using the range specified in a purge job's + # configuration). + # + #purge_jobs: + # - longest_max_lifetime: 3d + # interval: 12h + # - shortest_max_lifetime: 3d + # interval: 1d + """ diff --git a/synapse/config/server.py b/synapse/config/server.py index 818b806357..ed094bdc44 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -225,15 +225,6 @@ class ManholeConfig: pub_key: Optional[Key] -@attr.s(slots=True, frozen=True, auto_attribs=True) -class RetentionConfig: - """Object describing the configuration of the manhole""" - - interval: int - shortest_max_lifetime: Optional[int] - longest_max_lifetime: Optional[int] - - @attr.s(frozen=True) class LimitRemoteRoomsConfig: enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False) @@ -376,11 +367,6 @@ class ServerConfig(Config): # (other than those sent by local server admins) self.block_non_admin_invites = config.get("block_non_admin_invites", False) - # Whether to enable experimental MSC1849 (aka relations) support - self.experimental_msc1849_support_enabled = config.get( - "experimental_msc1849_support_enabled", True - ) - # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 @@ -466,124 +452,6 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) - retention_config = config.get("retention") - if retention_config is None: - retention_config = {} - - self.retention_enabled = retention_config.get("enabled", False) - - retention_default_policy = retention_config.get("default_policy") - - if retention_default_policy is not None: - self.retention_default_min_lifetime = retention_default_policy.get( - "min_lifetime" - ) - if self.retention_default_min_lifetime is not None: - self.retention_default_min_lifetime = self.parse_duration( - self.retention_default_min_lifetime - ) - - self.retention_default_max_lifetime = retention_default_policy.get( - "max_lifetime" - ) - if self.retention_default_max_lifetime is not None: - self.retention_default_max_lifetime = self.parse_duration( - self.retention_default_max_lifetime - ) - - if ( - self.retention_default_min_lifetime is not None - and self.retention_default_max_lifetime is not None - and ( - self.retention_default_min_lifetime - > self.retention_default_max_lifetime - ) - ): - raise ConfigError( - "The default retention policy's 'min_lifetime' can not be greater" - " than its 'max_lifetime'" - ) - else: - self.retention_default_min_lifetime = None - self.retention_default_max_lifetime = None - - if self.retention_enabled: - logger.info( - "Message retention policies support enabled with the following default" - " policy: min_lifetime = %s ; max_lifetime = %s", - self.retention_default_min_lifetime, - self.retention_default_max_lifetime, - ) - - self.retention_allowed_lifetime_min = retention_config.get( - "allowed_lifetime_min" - ) - if self.retention_allowed_lifetime_min is not None: - self.retention_allowed_lifetime_min = self.parse_duration( - self.retention_allowed_lifetime_min - ) - - self.retention_allowed_lifetime_max = retention_config.get( - "allowed_lifetime_max" - ) - if self.retention_allowed_lifetime_max is not None: - self.retention_allowed_lifetime_max = self.parse_duration( - self.retention_allowed_lifetime_max - ) - - if ( - self.retention_allowed_lifetime_min is not None - and self.retention_allowed_lifetime_max is not None - and self.retention_allowed_lifetime_min - > self.retention_allowed_lifetime_max - ): - raise ConfigError( - "Invalid retention policy limits: 'allowed_lifetime_min' can not be" - " greater than 'allowed_lifetime_max'" - ) - - self.retention_purge_jobs: List[RetentionConfig] = [] - for purge_job_config in retention_config.get("purge_jobs", []): - interval_config = purge_job_config.get("interval") - - if interval_config is None: - raise ConfigError( - "A retention policy's purge jobs configuration must have the" - " 'interval' key set." - ) - - interval = self.parse_duration(interval_config) - - shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") - - if shortest_max_lifetime is not None: - shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) - - longest_max_lifetime = purge_job_config.get("longest_max_lifetime") - - if longest_max_lifetime is not None: - longest_max_lifetime = self.parse_duration(longest_max_lifetime) - - if ( - shortest_max_lifetime is not None - and longest_max_lifetime is not None - and shortest_max_lifetime > longest_max_lifetime - ): - raise ConfigError( - "A retention policy's purge jobs configuration's" - " 'shortest_max_lifetime' value can not be greater than its" - " 'longest_max_lifetime' value." - ) - - self.retention_purge_jobs.append( - RetentionConfig(interval, shortest_max_lifetime, longest_max_lifetime) - ) - - if not self.retention_purge_jobs: - self.retention_purge_jobs = [ - RetentionConfig(self.parse_duration("1d"), None, None) - ] - self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])] # no_tls is not really supported any more, but let's grandfather it in @@ -1255,75 +1123,6 @@ class ServerConfig(Config): # #user_ips_max_age: 14d - # Message retention policy at the server level. - # - # Room admins and mods can define a retention period for their rooms using the - # 'm.room.retention' state event, and server admins can cap this period by setting - # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. - # - # If this feature is enabled, Synapse will regularly look for and purge events - # which are older than the room's maximum retention period. Synapse will also - # filter events received over federation so that events that should have been - # purged are ignored and not stored again. - # - retention: - # The message retention policies feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # Default retention policy. If set, Synapse will apply it to rooms that lack the - # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't - # matter much because Synapse doesn't take it into account yet. - # - #default_policy: - # min_lifetime: 1d - # max_lifetime: 1y - - # Retention policy limits. If set, and the state of a room contains a - # 'm.room.retention' event in its state which contains a 'min_lifetime' or a - # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy - # to these limits when running purge jobs. - # - #allowed_lifetime_min: 1d - #allowed_lifetime_max: 1y - - # Server admins can define the settings of the background jobs purging the - # events which lifetime has expired under the 'purge_jobs' section. - # - # If no configuration is provided, a single job will be set up to delete expired - # events in every room daily. - # - # Each job's configuration defines which range of message lifetimes the job - # takes care of. For example, if 'shortest_max_lifetime' is '2d' and - # 'longest_max_lifetime' is '3d', the job will handle purging expired events in - # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and - # lower than or equal to 3 days. Both the minimum and the maximum value of a - # range are optional, e.g. a job with no 'shortest_max_lifetime' and a - # 'longest_max_lifetime' of '3d' will handle every room with a retention policy - # which 'max_lifetime' is lower than or equal to three days. - # - # The rationale for this per-job configuration is that some rooms might have a - # retention policy with a low 'max_lifetime', where history needs to be purged - # of outdated messages on a more frequent basis than for the rest of the rooms - # (e.g. every 12h), but not want that purge to be performed by a job that's - # iterating over every room it knows, which could be heavy on the server. - # - # If any purge job is configured, it is strongly recommended to have at least - # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' - # set, or one job without 'shortest_max_lifetime' and one job without - # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if - # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a - # room's policy to these values is done after the policies are retrieved from - # Synapse's database (which is done using the range specified in a purge job's - # configuration). - # - #purge_jobs: - # - longest_max_lifetime: 3d - # interval: 12h - # - shortest_max_lifetime: 3d - # interval: 1d - # Inhibits the /requestToken endpoints from returning an error that might leak # information about whether an e-mail address is in use or not on this # homeserver. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 23bd24d963..3f3eba86a8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -385,9 +385,7 @@ class EventClientSerializer: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.experimental_msc1849_support_enabled = ( - hs.config.server.experimental_msc1849_support_enabled - ) + self._msc1849_enabled = hs.config.experimental.msc1849_enabled async def serialize_event( self, @@ -418,7 +416,7 @@ class EventClientSerializer: # we need to bundle in with the event. # Do not bundle relations if the event has been redacted if not event.internal_metadata.is_redacted() and ( - self.experimental_msc1849_support_enabled and bundle_aggregations + self._msc1849_enabled and bundle_aggregations ): annotations = await self.store.get_aggregation_groups_for_event(event_id) references = await self.store.get_relations_for_event( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 176e4dfdd4..60ff896386 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -86,19 +86,22 @@ class PaginationHandler: self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = ( - hs.config.server.retention_default_max_lifetime + hs.config.retention.retention_default_max_lifetime ) self._retention_allowed_lifetime_min = ( - hs.config.server.retention_allowed_lifetime_min + hs.config.retention.retention_allowed_lifetime_min ) self._retention_allowed_lifetime_max = ( - hs.config.server.retention_allowed_lifetime_max + hs.config.retention.retention_allowed_lifetime_max ) - if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled: + if ( + hs.config.worker.run_background_tasks + and hs.config.retention.retention_enabled + ): # Run the purge jobs described in the configuration file. - for job in hs.config.server.retention_purge_jobs: + for job in hs.config.retention.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) self.clock.looping_call( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index d69eaf80ce..835d7889cb 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore): # policy. if not ret: return { - "min_lifetime": self.config.server.retention_default_min_lifetime, - "max_lifetime": self.config.server.retention_default_max_lifetime, + "min_lifetime": self.config.retention.retention_default_min_lifetime, + "max_lifetime": self.config.retention.retention_default_max_lifetime, } row = ret[0] @@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore): # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.server.retention_default_min_lifetime + row["min_lifetime"] = self.config.retention.retention_default_min_lifetime if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.server.retention_default_max_lifetime + row["max_lifetime"] = self.config.retention.retention_default_max_lifetime return row -- cgit 1.5.1 From ba00e20234eadae66f105f8bda64e39beed9a92d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 21 Oct 2021 14:39:16 -0400 Subject: Add a thread relation type per MSC3440. (#11088) Adds experimental support for MSC3440's `io.element.thread` relation type (and the aggregation for it). --- changelog.d/11088.feature | 1 + synapse/api/constants.py | 1 + synapse/config/experimental.py | 2 + synapse/events/utils.py | 17 +++++++++ synapse/rest/client/relations.py | 3 +- synapse/storage/databases/main/events.py | 4 ++ synapse/storage/databases/main/relations.py | 59 ++++++++++++++++++++++++++++- tests/rest/client/test_relations.py | 40 ++++++++++++++++--- 8 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 changelog.d/11088.feature (limited to 'synapse/config') diff --git a/changelog.d/11088.feature b/changelog.d/11088.feature new file mode 100644 index 0000000000..76b0d28084 --- /dev/null +++ b/changelog.d/11088.feature @@ -0,0 +1 @@ +Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a31f037748..a33ac34161 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -176,6 +176,7 @@ class RelationTypes: ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" + THREAD = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b013a3918c..8b098ad48d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -26,6 +26,8 @@ class ExperimentalConfig(Config): # Whether to enable experimental MSC1849 (aka relations) support self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True) + # MSC3440 (thread relation) + self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False) # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 3f3eba86a8..6fa631aa1d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -386,6 +386,7 @@ class EventClientSerializer: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self._msc1849_enabled = hs.config.experimental.msc1849_enabled + self._msc3440_enabled = hs.config.experimental.msc3440_enabled async def serialize_event( self, @@ -462,6 +463,22 @@ class EventClientSerializer: "sender": edit.sender, } + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.store.get_thread_summary(event_id) + if latest_thread_event: + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": await self.serialize_event( + latest_thread_event, time_now, bundle_aggregations=False + ), + "count": thread_count, + } + return serialized_event async def serialize_events( diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d695c18be2..58f6699073 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -128,9 +128,10 @@ class RelationSendServlet(RestServlet): content["m.relates_to"] = { "event_id": parent_id, - "key": aggregation_key, "rel_type": relation_type, } + if aggregation_key is not None: + content["m.relates_to"]["key"] = aggregation_key event_dict = { "type": event_type, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 37439f8562..8d9086ecf0 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1710,6 +1710,7 @@ class PersistEventsStore: RelationTypes.ANNOTATION, RelationTypes.REFERENCE, RelationTypes.REPLACE, + RelationTypes.THREAD, ): # Unknown relation type return @@ -1740,6 +1741,9 @@ class PersistEventsStore: if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + if rel_type == RelationTypes.THREAD: + txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. Part of MSC2716. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2bbf6d6a95..40760fbd1b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Optional +from typing import Optional, Tuple import attr @@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore): return await self.get_event(edit_id, allow_none=True) + @cached() + async def get_thread_summary( + self, event_id: str + ) -> Tuple[int, Optional[EventBase]]: + """Get the number of threaded replies, the senders of those replies, and + the latest reply (if any) for the given event. + + Args: + event_id: The original event ID + + Returns: + The number of items in the thread and the most recent response, if any. + """ + + def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]: + # Fetch the count of threaded events and the latest event ID. + # TODO Should this only allow m.room.message events. + sql = """ + SELECT event_id + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + + txn.execute(sql, (event_id, RelationTypes.THREAD)) + row = txn.fetchone() + if row is None: + return 0, None + + latest_event_id = row[0] + + sql = """ + SELECT COALESCE(COUNT(event_id), 0) + FROM event_relations + WHERE + relates_to_id = ? + AND relation_type = ? + """ + txn.execute(sql, (event_id, RelationTypes.THREAD)) + count = txn.fetchone()[0] + + return count, latest_event_id + + count, latest_event_id = await self.db_pool.runInteraction( + "get_thread_summary", _get_thread_summary_txn + ) + + latest_event = None + if latest_event_id: + latest_event = await self.get_event(latest_event_id, allow_none=True) + + return count, latest_event + async def has_user_annotated_event( self, parent_id: str, event_type: str, aggregation_key: str, sender: str ) -> bool: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 3c7d49f0b4..78c2fb86b9 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -101,10 +101,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] @@ -141,8 +141,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ expected_event_ids = [] - for _ in range(10): - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + for idx in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) + ) self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) @@ -386,8 +388,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(400, channel.code, channel.json_body) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_aggregation_get_event(self): - """Test that annotations and references get correctly bundled when + """Test that annotations, references, and threads get correctly bundled when getting the parent event. """ @@ -410,6 +413,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + thread_2 = channel.json_body["event_id"] + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), @@ -429,6 +439,25 @@ class RelationsTestCase(unittest.HomeserverTestCase): RelationTypes.REFERENCE: { "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] }, + RelationTypes.THREAD: { + "count": 2, + "latest_event": { + "age": 100, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "origin_server_ts": 1600, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "unsigned": {"age": 100}, + "user_id": self.user_id, + }, + }, }, ) @@ -559,7 +588,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): { "m.relates_to": { "event_id": self.parent_id, - "key": None, "rel_type": "m.reference", } }, -- cgit 1.5.1 From b9ce53e8785d6f0dba6a3efcd708e4a185c32465 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Fri, 22 Oct 2021 13:00:52 +0300 Subject: Fix synapse.config module "read" command (#11145) `synapse.config.__main__` has the possibility to read a config item. This can be used to conveniently also validate the config is valid before trying to start Synapse. The "read" command broke in https://github.com/matrix-org/synapse/pull/10916 as it now requires passing in "server.server_name" for example. Also made the read command optional so one can just call this with just the confirm file reference and get a "Config parses OK" if things are ok. Signed-off-by: Jason Robinson Co-authored-by: Brendan Abolivier --- changelog.d/11145.bugfix | 1 + synapse/config/__main__.py | 46 ++++++++++++++++++++-------- tests/config/test___main__.py | 31 +++++++++++++++++++ tests/config/test_load.py | 70 ++++++++++--------------------------------- tests/config/utils.py | 58 +++++++++++++++++++++++++++++++++++ 5 files changed, 138 insertions(+), 68 deletions(-) create mode 100644 changelog.d/11145.bugfix create mode 100644 tests/config/test___main__.py create mode 100644 tests/config/utils.py (limited to 'synapse/config') diff --git a/changelog.d/11145.bugfix b/changelog.d/11145.bugfix new file mode 100644 index 0000000000..f369feac42 --- /dev/null +++ b/changelog.d/11145.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.45.0 breaking the configuration file parsing script. diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index b5b6735a8f..c555f5f914 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,25 +12,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys + from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig -if __name__ == "__main__": - import sys - from synapse.config.homeserver import HomeServerConfig +def main(args): + action = args[1] if len(args) > 1 and args[1] == "read" else None + # If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]` + # will be the key to read. + # We'll want to rework this code if we want to support more actions than just `read`. + load_config_args = args[3:] if action else args[1:] - action = sys.argv[1] + try: + config = HomeServerConfig.load_config("", load_config_args) + except ConfigError as e: + sys.stderr.write("\n" + str(e) + "\n") + sys.exit(1) + + print("Config parses OK!") if action == "read": - key = sys.argv[2] + key = args[2] + key_parts = key.split(".") + + value = config try: - config = HomeServerConfig.load_config("", sys.argv[3:]) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") + while len(key_parts): + value = getattr(value, key_parts[0]) + key_parts.pop(0) + + print(f"\n{key}: {value}") + except AttributeError: + print( + f"\nNo '{key}' key could be found in the provided configuration file." + ) sys.exit(1) - print(getattr(config, key)) - sys.exit(0) - else: - sys.stderr.write("Unknown command %r\n" % (action,)) - sys.exit(1) + +if __name__ == "__main__": + main(sys.argv) diff --git a/tests/config/test___main__.py b/tests/config/test___main__.py new file mode 100644 index 0000000000..b1c73d3612 --- /dev/null +++ b/tests/config/test___main__.py @@ -0,0 +1,31 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config.__main__ import main + +from tests.config.utils import ConfigFileTestCase + + +class ConfigMainFileTestCase(ConfigFileTestCase): + def test_executes_without_an_action(self): + self.generate_config() + main(["", "-c", self.config_file]) + + def test_read__error_if_key_not_found(self): + self.generate_config() + with self.assertRaises(SystemExit): + main(["", "read", "foo.bar.hello", "-c", self.config_file]) + + def test_read__passes_if_key_found(self): + self.generate_config() + main(["", "read", "server.server_name", "-c", self.config_file]) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 59635de205..765258c47a 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,43 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os.path -import shutil -import tempfile -from contextlib import redirect_stdout -from io import StringIO - import yaml from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig -from tests import unittest - - -class ConfigLoadingTestCase(unittest.TestCase): - def setUp(self): - self.dir = tempfile.mkdtemp() - self.file = os.path.join(self.dir, "homeserver.yaml") +from tests.config.utils import ConfigFileTestCase - def tearDown(self): - shutil.rmtree(self.dir) +class ConfigLoadingFileTestCase(ConfigFileTestCase): def test_load_fails_if_server_name_missing(self): self.generate_config_and_remove_lines_containing("server_name") with self.assertRaises(ConfigError): - HomeServerConfig.load_config("", ["-c", self.file]) + HomeServerConfig.load_config("", ["-c", self.config_file]) with self.assertRaises(ConfigError): - HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) def test_generates_and_loads_macaroon_secret_key(self): self.generate_config() - with open(self.file) as f: + with open(self.config_file) as f: raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) - config = HomeServerConfig.load_config("", ["-c", self.file]) + config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertTrue( hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", @@ -58,7 +46,7 @@ class ConfigLoadingTestCase(unittest.TestCase): "was: %r" % (config.key.macaroon_secret_key,) ) - config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) self.assertTrue( hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", @@ -71,9 +59,9 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_load_succeeds_if_macaroon_secret_key_missing(self): self.generate_config_and_remove_lines_containing("macaroon") - config1 = HomeServerConfig.load_config("", ["-c", self.file]) - config2 = HomeServerConfig.load_config("", ["-c", self.file]) - config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + config1 = HomeServerConfig.load_config("", ["-c", self.config_file]) + config2 = HomeServerConfig.load_config("", ["-c", self.config_file]) + config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) self.assertEqual( config1.key.macaroon_secret_key, config2.key.macaroon_secret_key ) @@ -87,15 +75,15 @@ class ConfigLoadingTestCase(unittest.TestCase): ["enable_registration: true", "disable_registration: true"] ) # Check that disable_registration clobbers enable_registration. - config = HomeServerConfig.load_config("", ["-c", self.file]) + config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.registration.enable_registration) - config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( - "", ["-c", self.file, "--enable-registration"] + "", ["-c", self.config_file, "--enable-registration"] ) self.assertTrue(config.registration.enable_registration) @@ -104,33 +92,5 @@ class ConfigLoadingTestCase(unittest.TestCase): self.add_lines_to_config(["enable_metrics: true"]) # The default Metrics Flags are off by default. - config = HomeServerConfig.load_config("", ["-c", self.file]) + config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.metrics.metrics_flags.known_servers) - - def generate_config(self): - with redirect_stdout(StringIO()): - HomeServerConfig.load_or_generate_config( - "", - [ - "--generate-config", - "-c", - self.file, - "--report-stats=yes", - "-H", - "lemurs.win", - ], - ) - - def generate_config_and_remove_lines_containing(self, needle): - self.generate_config() - - with open(self.file) as f: - contents = f.readlines() - contents = [line for line in contents if needle not in line] - with open(self.file, "w") as f: - f.write("".join(contents)) - - def add_lines_to_config(self, lines): - with open(self.file, "a") as f: - for line in lines: - f.write(line + "\n") diff --git a/tests/config/utils.py b/tests/config/utils.py new file mode 100644 index 0000000000..94c18a052b --- /dev/null +++ b/tests/config/utils.py @@ -0,0 +1,58 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import tempfile +import unittest +from contextlib import redirect_stdout +from io import StringIO + +from synapse.config.homeserver import HomeServerConfig + + +class ConfigFileTestCase(unittest.TestCase): + def setUp(self): + self.dir = tempfile.mkdtemp() + self.config_file = os.path.join(self.dir, "homeserver.yaml") + + def tearDown(self): + shutil.rmtree(self.dir) + + def generate_config(self): + with redirect_stdout(StringIO()): + HomeServerConfig.load_or_generate_config( + "", + [ + "--generate-config", + "-c", + self.config_file, + "--report-stats=yes", + "-H", + "lemurs.win", + ], + ) + + def generate_config_and_remove_lines_containing(self, needle): + self.generate_config() + + with open(self.config_file) as f: + contents = f.readlines() + contents = [line for line in contents if needle not in line] + with open(self.config_file, "w") as f: + f.write("".join(contents)) + + def add_lines_to_config(self, lines): + with open(self.config_file, "a") as f: + for line in lines: + f.write(line + "\n") -- cgit 1.5.1 From 2b82ec425fccb0ef626242779f7ccd4d77a0685c Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 22 Oct 2021 18:15:41 +0100 Subject: Add type hints for most `HomeServer` parameters (#11095) --- changelog.d/11095.misc | 1 + synapse/app/_base.py | 8 +++---- synapse/app/admin_cmd.py | 4 ++-- synapse/app/generic_worker.py | 4 ++-- synapse/app/homeserver.py | 2 +- synapse/app/phone_stats_home.py | 8 +++++-- synapse/appservice/api.py | 3 ++- synapse/config/logger.py | 9 ++++++- synapse/federation/federation_base.py | 7 +++++- synapse/federation/federation_server.py | 9 +++---- synapse/http/matrixfederationclient.py | 8 +++++-- synapse/http/server.py | 19 +++++++++------ synapse/replication/http/__init__.py | 9 +++++-- synapse/replication/http/_base.py | 8 ++++--- synapse/replication/http/account_data.py | 14 +++++++---- synapse/replication/http/devices.py | 8 +++++-- synapse/replication/http/federation.py | 16 ++++++++----- synapse/replication/http/login.py | 8 +++++-- synapse/replication/http/membership.py | 6 ++--- synapse/replication/http/presence.py | 2 +- synapse/replication/http/push.py | 2 +- synapse/replication/http/register.py | 10 +++++--- synapse/replication/http/send_event.py | 8 +++++-- synapse/replication/http/streams.py | 8 +++++-- synapse/replication/slave/storage/_base.py | 7 ++++-- synapse/replication/slave/storage/client_ips.py | 7 +++++- synapse/replication/slave/storage/devices.py | 7 +++++- synapse/replication/slave/storage/events.py | 6 ++++- synapse/replication/slave/storage/filtering.py | 7 +++++- synapse/replication/slave/storage/groups.py | 7 +++++- synapse/replication/tcp/external_cache.py | 9 ++++++- synapse/replication/tcp/handler.py | 6 ++++- synapse/replication/tcp/resource.py | 8 +++++-- synapse/replication/tcp/streams/_base.py | 20 ++++++++-------- synapse/rest/admin/devices.py | 2 +- synapse/server.py | 11 ++++++--- synapse/storage/database.py | 6 ++++- synapse/storage/databases/__init__.py | 28 +++++++++++++++++----- synapse/storage/databases/main/__init__.py | 7 ++++-- synapse/storage/databases/main/account_data.py | 7 ++++-- synapse/storage/databases/main/cache.py | 7 ++++-- synapse/storage/databases/main/deviceinbox.py | 9 ++++--- synapse/storage/databases/main/devices.py | 21 ++++++++++++---- synapse/storage/databases/main/event_federation.py | 9 ++++--- .../storage/databases/main/event_push_actions.py | 9 ++++--- .../storage/databases/main/events_bg_updates.py | 7 ++++-- synapse/storage/databases/main/media_repository.py | 9 ++++--- synapse/storage/databases/main/metrics.py | 7 ++++-- .../storage/databases/main/monthly_active_users.py | 9 ++++--- synapse/storage/databases/main/push_rule.py | 7 ++++-- synapse/storage/databases/main/receipts.py | 7 ++++-- synapse/storage/databases/main/room.py | 11 +++++---- synapse/storage/databases/main/roommember.py | 7 +++--- synapse/storage/databases/main/search.py | 9 ++++--- synapse/storage/databases/main/state.py | 11 +++++---- synapse/storage/databases/main/stats.py | 7 ++++-- synapse/storage/databases/main/transactions.py | 7 ++++-- synapse/storage/persist_events.py | 6 ++++- 58 files changed, 342 insertions(+), 143 deletions(-) create mode 100644 changelog.d/11095.misc (limited to 'synapse/config') diff --git a/changelog.d/11095.misc b/changelog.d/11095.misc new file mode 100644 index 0000000000..786e90b595 --- /dev/null +++ b/changelog.d/11095.misc @@ -0,0 +1 @@ +Add type hints to most `HomeServer` parameters. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index bb4d53d778..2ca2e051e4 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -294,7 +294,7 @@ def listen_ssl( return r -def refresh_certificate(hs): +def refresh_certificate(hs: "HomeServer"): """ Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. @@ -419,11 +419,11 @@ async def start(hs: "HomeServer"): atexit.register(gc.freeze) -def setup_sentry(hs): +def setup_sentry(hs: "HomeServer"): """Enable sentry integration, if enabled in configuration Args: - hs (synapse.server.HomeServer) + hs """ if not hs.config.metrics.sentry_enabled: @@ -449,7 +449,7 @@ def setup_sentry(hs): scope.set_tag("worker_name", name) -def setup_sdnotify(hs): +def setup_sdnotify(hs: "HomeServer"): """Adds process state hooks to tell systemd what we are up to.""" # Tell systemd our state, if we're using it. This will silently fail if diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index b156b93bf3..2fc848596d 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer): DATASTORE_CLASS = AdminCmdSlavedStore -async def export_data_command(hs, args): +async def export_data_command(hs: HomeServer, args): """Export data for a user. Args: - hs (HomeServer) + hs args (argparse.Namespace) """ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 7489f31d9a..51eadf122d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - def __init__(self, hs): + def __init__(self, hs: HomeServer): """ Args: - hs (synapse.server.HomeServer): server + hs: server """ super().__init__() self.auth = hs.get_auth() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 422f03cc04..93e2299266 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]: e = e.__cause__ -def run(hs): +def run(hs: HomeServer): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index fcd01e833c..126450e17a 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,11 +15,15 @@ import logging import math import resource import sys +from typing import TYPE_CHECKING from prometheus_client import Gauge from synapse.metrics.background_process_metrics import wrap_as_background_process +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger("synapse.app.homeserver") # Contains the list of processes we will be monitoring @@ -41,7 +45,7 @@ registered_reserved_users_mau_gauge = Gauge( @wrap_as_background_process("phone_stats_home") -async def phone_stats_home(hs, stats, stats_process=_stats_process): +async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process): logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) uptime = int(now - hs.start_time) @@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): logger.warning("Error reporting stats: %s", e) -def start_phone_stats_home(hs): +def start_phone_stats_home(hs: "HomeServer"): """ Start the background tasks which report phone home stats. """ diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 935f24263c..d08f6bbd7f 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -27,6 +27,7 @@ from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: from synapse.appservice import ApplicationService + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient): pushing. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 0a08231e5a..5252e61a99 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -18,6 +18,7 @@ import os import sys import threading from string import Template +from typing import TYPE_CHECKING import yaml from zope.interface import implementer @@ -38,6 +39,9 @@ from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError +if TYPE_CHECKING: + from synapse.server import HomeServer + DEFAULT_LOG_CONFIG = Template( """\ # Log configuration for Synapse. @@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path): def setup_logging( - hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner + hs: "HomeServer", + config, + use_worker_options=False, + logBeginner: LogBeginner = globalLogBeginner, ) -> None: """ Set up the logging subsystem. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 0cd424e12a..f56344a3b9 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -14,6 +14,7 @@ # limitations under the License. import logging from collections import namedtuple +from typing import TYPE_CHECKING from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -25,11 +26,15 @@ from synapse.events.utils import prune_event, validate_canonicaljson from synapse.http.servlet import assert_params_in_dict from synapse.types import JsonDict, get_domain_from_id +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class FederationBase: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.server_name = hs.hostname diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d8c0b86f23..0d66034f44 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -467,7 +467,7 @@ class FederationServer(FederationBase): async def on_room_state_request( self, origin: str, room_id: str, event_id: Optional[str] - ) -> Tuple[int, Dict[str, Any]]: + ) -> Tuple[int, JsonDict]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -481,7 +481,7 @@ class FederationServer(FederationBase): # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. with (await self._server_linearizer.queue((origin, room_id))): - resp = dict( + resp: JsonDict = dict( await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, @@ -1061,11 +1061,12 @@ class FederationServer(FederationBase): origin, event = next - lock = await self.store.try_acquire_lock( + new_lock = await self.store.try_acquire_lock( _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id ) - if not lock: + if not new_lock: return + lock = new_lock def __str__(self) -> str: return "" % self.server_name diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 4f59224686..203d723d41 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -21,6 +21,7 @@ import typing import urllib.parse from io import BytesIO, StringIO from typing import ( + TYPE_CHECKING, Callable, Dict, Generic, @@ -73,6 +74,9 @@ from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter( @@ -319,7 +323,7 @@ class MatrixFederationHttpClient: requests. """ - def __init__(self, hs, tls_client_options_factory): + def __init__(self, hs: "HomeServer", tls_client_options_factory): self.hs = hs self.signing_key = hs.signing_key self.server_name = hs.hostname @@ -711,7 +715,7 @@ class MatrixFederationHttpClient: Returns: A list of headers to be added as "Authorization:" headers """ - request = { + request: JsonDict = { "method": method.decode("ascii"), "uri": url_bytes.decode("ascii"), "origin": self.server_name, diff --git a/synapse/http/server.py b/synapse/http/server.py index 897ba5e453..1af0d9a31d 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,6 +22,7 @@ import urllib from http import HTTPStatus from inspect import isawaitable from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -61,6 +62,9 @@ from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.iterutils import chunk_seq +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) HTML_ERROR_TEMPLATE = """ @@ -343,6 +347,11 @@ class DirectServeJsonResource(_AsyncResource): return_json_error(f, request) +_PathEntry = collections.namedtuple( + "_PathEntry", ["pattern", "callback", "servlet_classname"] +) + + class JsonResource(DirectServeJsonResource): """This implements the HttpServer interface and provides JSON support for Resources. @@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource): isLeaf = True - _PathEntry = collections.namedtuple( - "_PathEntry", ["pattern", "callback", "servlet_classname"] - ) - - def __init__(self, hs, canonical_json=True, extract_context=False): + def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() - self.path_regexs = {} + self.path_regexs: Dict[bytes, List[_PathEntry]] = {} self.hs = hs def register_paths(self, method, path_patterns, callback, servlet_classname): @@ -391,7 +396,7 @@ class JsonResource(DirectServeJsonResource): for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( - self._PathEntry(path_pattern, callback, servlet_classname) + _PathEntry(path_pattern, callback, servlet_classname) ) def _get_handler_for_request( diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index ba8114ac9e..1457d9d59b 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.http.server import JsonResource from synapse.replication.http import ( account_data, @@ -26,16 +28,19 @@ from synapse.replication.http import ( streams, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + REPLICATION_PREFIX = "/_synapse/replication" class ReplicationRestResource(JsonResource): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # We enable extracting jaeger contexts here as these are internal APIs. super().__init__(hs, canonical_json=False, extract_context=True) self.register_servlets(hs) - def register_servlets(self, hs): + def register_servlets(self, hs: "HomeServer"): send_event.register_servlets(hs, self) federation.register_servlets(hs, self) presence.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index e047ec74d8..585332b244 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -17,7 +17,7 @@ import logging import re import urllib from inspect import signature -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -156,7 +156,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): pass @classmethod - def make_client(cls, hs): + def make_client(cls, hs: "HomeServer"): """Create a client that makes requests. Returns a callable that accepts the same parameters as @@ -208,7 +208,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): url_args.append(txn_id) if cls.METHOD == "POST": - request_func = client.post_json_get_json + request_func: Callable[ + ..., Awaitable[Any] + ] = client.post_json_get_json elif cls.METHOD == "PUT": request_func = client.put_json elif cls.METHOD == "GET": diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 70e951af63..5f0f225aa9 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +41,7 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "account_data_type") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -78,7 +82,7 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "room_id", "account_data_type") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -119,7 +123,7 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "room_id", "tag") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -162,7 +166,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): ) CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -183,7 +187,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationUserAccountDataRestServlet(hs).register(http_server) ReplicationRoomAccountDataRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 5a5818ef61..42dffb39cb 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,9 +13,13 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -51,7 +55,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater @@ -68,5 +72,5 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return 200, user_devices -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationUserDevicesResyncRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index a0b3145f4e..5ed535c90d 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -21,6 +22,9 @@ from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): NAME = "fed_send_events" PATH_ARGS = () - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -151,7 +155,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): NAME = "fed_send_edu" PATH_ARGS = ("edu_type",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -194,7 +198,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): # This is a query, so let's not bother caching CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -238,7 +242,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): NAME = "fed_cleanup_room" PATH_ARGS = ("room_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -273,7 +277,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): NAME = "store_room_on_outlier_membership" PATH_ARGS = ("room_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -289,7 +293,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 550bd5c95f..0db419ea57 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,7 +34,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): NAME = "device_check_registered" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.registration_handler = hs.get_registration_handler() @@ -82,5 +86,5 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): return 200, res -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): RegisterDeviceReplicationServlet(hs).register(http_server) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 34206c5060..7371c240b2 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): NAME = "remote_join" PATH_ARGS = ("room_id", "user_id") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation_handler = hs.get_federation_handler() @@ -320,7 +320,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id", "user_id", "change") CACHE = False # No point caching as should return instantly. - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.registeration_handler = hs.get_registration_handler() @@ -360,7 +360,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index bb00247953..63143085d5 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -117,6 +117,6 @@ class ReplicationPresenceSetState(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationBumpPresenceActiveTime(hs).register(http_server) ReplicationPresenceSetState(hs).register(http_server) diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py index 139427cb1f..6c8db3061e 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py @@ -67,5 +67,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRemovePusherRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index d6dd7242eb..7adfbb666f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -26,7 +30,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): NAME = "register_user" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -100,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): NAME = "post_register" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -130,6 +134,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRegisterServlet(hs).register(http_server) ReplicationPostRegisterActionsServlet(hs).register(http_server) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index fae5ffa451..9f6851d059 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -22,6 +23,9 @@ from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -57,7 +61,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): NAME = "send_event" PATH_ARGS = ("event_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() @@ -135,5 +139,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationSendEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index 9afa147d00..3223bc2432 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -13,11 +13,15 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -46,7 +50,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): PATH_ARGS = ("stream_name",) METHOD = "GET" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._instance_name = hs.get_instance_name() @@ -74,5 +78,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index e460dd85cd..7ecb446e7c 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -13,18 +13,21 @@ # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen: Optional[ diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 436d39c320..61cd7e5228 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.lrucache import LruCache from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.client_ip_last_seen: LruCache[tuple, int] = LruCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 26bdead565..0a58296089 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream @@ -20,9 +22,12 @@ from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d4d3f8c448..63ed50caa5 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.storage.database import DatabasePool from synapse.storage.databases.main.event_federation import EventFederationWorkerStore @@ -30,6 +31,9 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -54,7 +58,7 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 37875bc973..90284c202d 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.storage.database import DatabasePool from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index e9bdc38470..497e16c69e 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import GroupServerStream @@ -19,9 +21,12 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index b402f82810..aaf91e5e02 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -21,6 +21,8 @@ from synapse.logging.context import make_deferred_yieldable from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.server import HomeServer set_counter = Counter( @@ -59,7 +61,12 @@ class ExternalCache: """ def __init__(self, hs: "HomeServer"): - self._redis_connection = hs.get_outbound_redis_connection() + if hs.config.redis.redis_enabled: + self._redis_connection: Optional[ + "RedisProtocol" + ] = hs.get_outbound_redis_connection() + else: + self._redis_connection = None def _get_redis_key(self, cache_name: str, key: str) -> str: return "cache_v1:%s:%s" % (cache_name, key) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6aa9318027..06fd06fdf3 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -294,7 +294,7 @@ class ReplicationCommandHandler: # This shouldn't be possible raise Exception("Unrecognised command %s in stream queue", cmd.NAME) - def start_replication(self, hs): + def start_replication(self, hs: "HomeServer"): """Helper method to start a replication connection to the remote server using TCP. """ @@ -321,6 +321,8 @@ class ReplicationCommandHandler: hs.config.redis.redis_host, # type: ignore[arg-type] hs.config.redis.redis_port, self._factory, + timeout=30, + bindAddress=None, ) else: client_name = hs.get_instance_name() @@ -331,6 +333,8 @@ class ReplicationCommandHandler: host, # type: ignore[arg-type] port, self._factory, + timeout=30, + bindAddress=None, ) def get_streams(self) -> Dict[str, Stream]: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 80f9b23bfd..55326877fd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -16,6 +16,7 @@ import logging import random +from typing import TYPE_CHECKING from prometheus_client import Counter @@ -27,6 +28,9 @@ from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.streams import EventsStream from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) @@ -37,7 +41,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections.""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.command_handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server.server_name @@ -65,7 +69,7 @@ class ReplicationStreamer: data is available it will propagate to all connected clients. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9b905aba9d..c8b188ae4e 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -241,7 +241,7 @@ class BackfillStream(Stream): NAME = "backfill" ROW_TYPE = BackfillStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -363,7 +363,7 @@ class ReceiptsStream(Stream): NAME = "receipts" ROW_TYPE = ReceiptsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -380,7 +380,7 @@ class PushRulesStream(Stream): NAME = "push_rules" ROW_TYPE = PushRulesStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( @@ -405,7 +405,7 @@ class PushersStream(Stream): NAME = "pushers" ROW_TYPE = PushersStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( @@ -438,7 +438,7 @@ class CachesStream(Stream): NAME = "caches" ROW_TYPE = CachesStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -459,7 +459,7 @@ class DeviceListsStream(Stream): NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -476,7 +476,7 @@ class ToDeviceStream(Stream): NAME = "to_device" ROW_TYPE = ToDeviceStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -495,7 +495,7 @@ class TagAccountDataStream(Stream): NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -582,7 +582,7 @@ class GroupServerStream(Stream): NAME = "groups" ROW_TYPE = GroupsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -599,7 +599,7 @@ class UserSignatureStream(Stream): NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index a6fa03c90f..80fbf32f17 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -110,7 +110,7 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): """ Args: - hs (synapse.server.HomeServer): server + hs: server """ self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/server.py b/synapse/server.py index a64c846d1c..0fbf36ba99 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -800,9 +800,14 @@ class HomeServer(metaclass=abc.ABCMeta): return ExternalCache(self) @cache_in_self - def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: - if not self.config.redis.redis_enabled: - return None + def get_outbound_redis_connection(self) -> "RedisProtocol": + """ + The Redis connection used for replication. + + Raises: + AssertionError: if Redis is not enabled in the homeserver config. + """ + assert self.config.redis.redis_enabled # We only want to import redis module if we're using it, as we have # `txredisapi` as an optional dependency. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index f5a8f90a0f..fa4e89d35c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -19,6 +19,7 @@ from collections import defaultdict from sys import intern from time import monotonic as monotonic_time from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +if TYPE_CHECKING: + from synapse.server import HomeServer + # python 3 does not have a maximum int value MAX_TXN_ID = 2 ** 63 - 1 @@ -392,7 +396,7 @@ class DatabasePool: def __init__( self, - hs, + hs: "HomeServer", database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine, ): diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 20b755056b..cfe887b7f7 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -13,33 +13,49 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar +from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.state import StateGroupDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -class Databases: +DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True) + + +class Databases(Generic[DataStoreT]): """The various databases. These are low level interfaces to physical databases. Attributes: - main (DataStore) + databases + main + state + persist_events """ - def __init__(self, main_store_class, hs): + databases: List[DatabasePool] + main: DataStoreT + state: StateGroupDataStore + persist_events: Optional[PersistEventsStore] + + def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # Note we pass in the main store class here as workers use a different main # store. self.databases = [] - main = None - state = None - persist_events = None + main: Optional[DataStoreT] = None + state: Optional[StateGroupDataStore] = None + persist_events: Optional[PersistEventsStore] = None for database_config in hs.config.database.databases: db_name = database_config.name diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 5c21402dea..259cae5b37 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import DatabasePool @@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -126,7 +129,7 @@ class DataStore( LockStore, SessionStore, ): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 70ca3e09f7..f8bec266ac 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -28,6 +28,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore): `get_max_account_data_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index c57ae5ef15..36e8422fc6 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -15,7 +15,7 @@ import itertools import logging -from typing import Any, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes from synapse.replication.tcp.streams import BackfillStream, CachesStream @@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 3154906d45..8143168107 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -26,11 +26,14 @@ from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class DeviceInboxWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 6464520386..a01bf2c5b7 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,17 @@ # limitations under the License. import abc import logging -from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( @@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ba9f71a230..ef5d1ef01e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,7 +14,7 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter, Gauge @@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.server import HomeServer + oldest_pdu_in_federation_staging = Gauge( "synapse_federation_server_oldest_inbound_pdu_in_staging", "The age in seconds since we received the oldest pdu in the federation staging area", @@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 97b3e92d3f..d957e770dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr @@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn @@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 1afc59fafb..fc49112063 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr @@ -26,6 +26,9 @@ from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.types import Cursor from synapse.types import JsonDict +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -76,7 +79,7 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 2fa945d171..717487be28 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool +if TYPE_CHECKING: + from synapse.server import HomeServer + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( "media_repository_drop_index_wo_method" ) @@ -43,7 +46,7 @@ class MediaSortOrder(Enum): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index dac3d14da8..d901933ae4 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -14,7 +14,7 @@ import calendar import logging import time -from typing import Dict +from typing import TYPE_CHECKING, Dict from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Collect metrics on the number of forward extremities that exist. @@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index a14ac03d4b..b5284e4f67 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Number of msec of granularity to store the monthly_active_user timestamp @@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._mau_stats_only = hs.config.server.mau_stats_only diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index fc720f5947..fa782023d4 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -14,7 +14,7 @@ # limitations under the License. import abc import logging -from typing import Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules @@ -33,6 +33,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -75,7 +78,7 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 01a4281301..c99f8aebdb 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from twisted.internet import defer @@ -29,11 +29,14 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 835d7889cb..f879bbe7c7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -17,7 +17,7 @@ import collections import logging from abc import abstractmethod from enum import Enum -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError @@ -32,6 +32,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.stringutils import MXC_REGEX +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -69,7 +72,7 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ddb162a4fc..4b288bb2e7 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.metrics import Measure if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.state import _StateCacheEntry logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Used by `_get_joined_hosts` to ensure only one thing mutates the cache @@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile @@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index c85383c975..7fe233767f 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -15,7 +15,7 @@ import logging import re from collections import namedtuple -from typing import Collection, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set from synapse.api.errors import SynapseError from synapse.events import EventBase @@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) SearchEntry = namedtuple( @@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if not hs.config.server.enable_search: @@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a8e8dd4577..fa2c3b1feb 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -15,7 +15,7 @@ import collections.abc import logging from collections import namedtuple -from typing import Iterable, Optional, Set +from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -30,6 +30,9 @@ from synapse.types import StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -53,7 +56,7 @@ class _GetStateGroupDelta( class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index e20033bb28..5d7b59d861 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from enum import Enum from itertools import chain -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing_extensions import Counter @@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # these fields track absolutes (e.g. total number of rooms on the server) @@ -93,7 +96,7 @@ class UserSortOrder(Enum): class StatsStore(StateDeltasStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 860146cd1b..d7dc1f73ac 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,7 @@ import logging from collections import namedtuple -from typing import Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + db_binary_type = memoryview logger = logging.getLogger(__name__) @@ -57,7 +60,7 @@ class DestinationRetryTimings: class TransactionWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 0e8270746d..402f134d89 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -18,6 +18,7 @@ import itertools import logging from collections import deque from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -56,6 +57,9 @@ from synapse.types import ( from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # The number of times we are recalculating the current state @@ -272,7 +276,7 @@ class EventsPersistenceStorage: current state and forward extremity changes. """ - def __init__(self, hs, stores: Databases): + def __init__(self, hs: "HomeServer", stores: Databases): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same # store for now. -- cgit 1.5.1