diff --git a/synapse/__init__.py b/synapse/__init__.py
index aa9a3269c0..5ecce24eee 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.38.0rc1"
+__version__ = "1.38.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 307f5f9a94..8916e6fa2f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -63,9 +63,9 @@ class Auth:
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- self.token_cache = LruCache(
+ self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache"
- ) # type: LruCache[str, Tuple[str, bool]]
+ )
self._auth_blocking = AuthBlocking(self.hs)
@@ -240,6 +240,37 @@ class Auth:
except KeyError:
raise MissingClientTokenError()
+ async def validate_appservice_can_control_user_id(
+ self, app_service: ApplicationService, user_id: str
+ ):
+ """Validates that the app service is allowed to control
+ the given user.
+
+ Args:
+ app_service: The app service that controls the user
+ user_id: The author MXID that the app service is controlling
+
+ Raises:
+ AuthError: If the application service is not allowed to control the user
+ (user namespace regex does not match, wrong homeserver, etc)
+ or if the user has not been registered yet.
+ """
+
+ # It's ok if the app service is trying to use the sender from their registration
+ if app_service.sender == user_id:
+ pass
+ # Check to make sure the app service is allowed to control the user
+ elif not app_service.is_interested_in_user(user_id):
+ raise AuthError(
+ 403,
+ "Application service cannot masquerade as this user (%s)." % user_id,
+ )
+ # Check to make sure the user is already registered on the homeserver
+ elif not (await self.store.get_user_by_id(user_id)):
+ raise AuthError(
+ 403, "Application service has not registered this user (%s)" % user_id
+ )
+
async def _get_appservice_user_id(
self, request: Request
) -> Tuple[Optional[str], Optional[ApplicationService]]:
@@ -261,13 +292,11 @@ class Auth:
return app_service.sender, app_service
user_id = request.args[b"user_id"][0].decode("utf8")
+ await self.validate_appservice_can_control_user_id(app_service, user_id)
+
if app_service.sender == user_id:
return app_service.sender, app_service
- if not app_service.is_interested_in_user(user_id):
- raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (await self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
return user_id, app_service
async def get_user_by_access_token(
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 4cb8bbaf70..054ab14ab6 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -118,7 +118,7 @@ class RedirectException(CodeMessageException):
super().__init__(code=http_code, msg=msg)
self.location = location
- self.cookies = [] # type: List[bytes]
+ self.cookies: List[bytes] = []
class SynapseError(CodeMessageException):
@@ -160,7 +160,7 @@ class ProxiedRequestError(SynapseError):
):
super().__init__(code, msg, errcode)
if additional_fields is None:
- self._additional_fields = {} # type: Dict
+ self._additional_fields: Dict = {}
else:
self._additional_fields = dict(additional_fields)
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index ce49a0ad58..ad1ff6a9df 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -289,7 +289,7 @@ class Filter:
room_id = None
ev_type = "m.presence"
contains_url = False
- labels = [] # type: List[str]
+ labels: List[str] = []
else:
sender = event.get("sender", None)
if not sender:
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index b9a10283f4..3e3d09bbd2 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -46,9 +46,7 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per request
- self.actions = (
- OrderedDict()
- ) # type: OrderedDict[Hashable, Tuple[float, int, float]]
+ self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict()
async def can_do_action(
self,
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index f6c1c97b40..a20abc5a65 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -195,7 +195,7 @@ class RoomVersions:
)
-KNOWN_ROOM_VERSIONS = {
+KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
v.identifier: v
for v in (
RoomVersions.V1,
@@ -209,4 +209,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V7,
)
# Note that we do not include MSC2043 here unless it is enabled in the config.
-} # type: Dict[str, RoomVersion]
+}
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5b041fcaad..b43d858f59 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -270,7 +270,7 @@ class GenericWorkerServer(HomeServer):
site_tag = port
# We always include a health resource.
- resources = {"/health": HealthResource()} # type: Dict[str, IResource]
+ resources: Dict[str, IResource] = {"/health": HealthResource()}
for res in listener_config.http_options.resources:
for name in res.names:
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 61152b2c46..935f24263c 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -88,9 +88,9 @@ class ApplicationServiceApi(SimpleHttpClient):
super().__init__(hs)
self.clock = hs.get_clock()
- self.protocol_meta_cache = ResponseCache(
+ self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
- ) # type: ResponseCache[Tuple[str, str]]
+ )
async def query_user(self, service, user_id):
if service.url is None:
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 746fc3cc02..a39d457c56 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -57,8 +57,8 @@ def load_appservices(hostname, config_files):
return []
# Dicts of value -> filename
- seen_as_tokens = {} # type: Dict[str, str]
- seen_ids = {} # type: Dict[str, str]
+ seen_as_tokens: Dict[str, str] = {}
+ seen_ids: Dict[str, str] = {}
appservices = []
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 7789b40323..8d5f38b5d9 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -25,7 +25,7 @@ from ._base import Config, ConfigError
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
# Map from canonicalised cache name to cache.
-_CACHES = {} # type: Dict[str, Callable[[float], None]]
+_CACHES: Dict[str, Callable[[float], None]] = {}
# a lock on the contents of _CACHES
_CACHES_LOCK = threading.Lock()
@@ -157,7 +157,7 @@ class CacheConfig(Config):
self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
)
- self.cache_factors = {} # type: Dict[str, float]
+ self.cache_factors: Dict[str, float] = {}
cache_config = config.get("caches") or {}
self.global_factor = cache_config.get(
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 5564d7d097..bcecbfec03 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -134,9 +134,9 @@ class EmailConfig(Config):
# trusted_third_party_id_servers does not contain a scheme whereas
# account_threepid_delegate_email is expected to. Presume https
- self.account_threepid_delegate_email = (
+ self.account_threepid_delegate_email: Optional[str] = (
"https://" + first_trusted_identity_server
- ) # type: Optional[str]
+ )
self.using_identity_server_from_trusted_list = True
else:
raise ConfigError(
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 7fb1f7021f..e25ccba9ac 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -25,10 +25,10 @@ class ExperimentalConfig(Config):
experimental = config.get("experimental_features") or {}
# MSC2858 (multiple SSO identity providers)
- self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
+ self.msc2858_enabled: bool = experimental.get("msc2858_enabled", False)
# MSC3026 (busy presence state)
- self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
+ self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
# MSC2716 (backfill existing history)
- self.msc2716_enabled = experimental.get("msc2716_enabled", False) # type: bool
+ self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index cdd7a1ef05..7d64993e22 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -22,7 +22,7 @@ class FederationConfig(Config):
def read_config(self, config, **kwargs):
# FIXME: federation_domain_whitelist needs sytests
- self.federation_domain_whitelist = None # type: Optional[dict]
+ self.federation_domain_whitelist: Optional[dict] = None
federation_domain_whitelist = config.get("federation_domain_whitelist", None)
if federation_domain_whitelist is not None:
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 942e2672a9..ba89d11cf0 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -460,7 +460,7 @@ def _parse_oidc_config_dict(
) from e
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
- client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
+ client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None
if client_secret_jwt_key_config is not None:
keyfile = client_secret_jwt_key_config.get("key_file")
if keyfile:
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index fd90b79772..0f5b2b3977 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
- self.password_providers = [] # type: List[Any]
+ self.password_providers: List[Any] = []
providers = []
# We want to be backwards compatible with the old `ldap_config`
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index a7a82742ac..0dfb3a227a 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -62,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
- requirements = {} # type: Dict[str, List]
+ requirements: Dict[str, List] = {}
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
@@ -141,7 +141,7 @@ class ContentRepositoryConfig(Config):
#
# We don't create the storage providers here as not all workers need
# them to be started.
- self.media_storage_providers = [] # type: List[tuple]
+ self.media_storage_providers: List[tuple] = []
for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 6bff715230..b9e0c0b300 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -505,7 +505,7 @@ class ServerConfig(Config):
" greater than 'allowed_lifetime_max'"
)
- self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]]
+ self.retention_purge_jobs: List[Dict[str, Optional[int]]] = []
for purge_job_config in retention_config.get("purge_jobs", []):
interval_config = purge_job_config.get("interval")
@@ -688,23 +688,21 @@ class ServerConfig(Config):
# not included in the sample configuration file on purpose as it's a temporary
# hack, so that some users can trial the new defaults without impacting every
# user on the homeserver.
- users_new_default_push_rules = (
+ users_new_default_push_rules: list = (
config.get("users_new_default_push_rules") or []
- ) # type: list
+ )
if not isinstance(users_new_default_push_rules, list):
raise ConfigError("'users_new_default_push_rules' must be a list")
# Turn the list into a set to improve lookup speed.
- self.users_new_default_push_rules = set(
- users_new_default_push_rules
- ) # type: set
+ self.users_new_default_push_rules: set = set(users_new_default_push_rules)
# Whitelist of domain names that given next_link parameters must have
- next_link_domain_whitelist = config.get(
+ next_link_domain_whitelist: Optional[List[str]] = config.get(
"next_link_domain_whitelist"
- ) # type: Optional[List[str]]
+ )
- self.next_link_domain_whitelist = None # type: Optional[Set[str]]
+ self.next_link_domain_whitelist: Optional[Set[str]] = None
if next_link_domain_whitelist is not None:
if not isinstance(next_link_domain_whitelist, list):
raise ConfigError("'next_link_domain_whitelist' must be a list")
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index cb7716c837..a233a9ce03 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -34,7 +34,7 @@ class SpamCheckerConfig(Config):
section = "spamchecker"
def read_config(self, config, **kwargs):
- self.spam_checkers = [] # type: List[Tuple[Any, Dict]]
+ self.spam_checkers: List[Tuple[Any, Dict]] = []
spam_checkers = config.get("spam_checker") or []
if isinstance(spam_checkers, dict):
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index e4346e02aa..d0f04cf8e6 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -39,7 +39,7 @@ class SSOConfig(Config):
section = "sso"
def read_config(self, config, **kwargs):
- sso_config = config.get("sso") or {} # type: Dict[str, Any]
+ sso_config: Dict[str, Any] = config.get("sso") or {}
# The sso-specific template_dir
self.sso_template_dir = sso_config.get("template_dir")
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 78f61fe9da..6f253e00c0 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -38,13 +38,9 @@ class StatsConfig(Config):
def read_config(self, config, **kwargs):
self.stats_enabled = True
- self.stats_bucket_size = 86400 * 1000
stats_config = config.get("stats", None)
if stats_config:
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
- self.stats_bucket_size = self.parse_duration(
- stats_config.get("bucket_size", "1d")
- )
if not self.stats_enabled:
logger.warning(ROOM_STATS_DISABLED_WARN)
@@ -59,9 +55,4 @@ class StatsConfig(Config):
# correctly.
#
#enabled: false
-
- # The size of each timeslice in the room_stats_historical and
- # user_stats_historical tables, as a time period. Defaults to "1d".
- #
- #bucket_size: 1h
"""
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 9a16a8fbae..fed05ac7be 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -80,7 +80,7 @@ class TlsConfig(Config):
fed_whitelist_entries = []
# Support globs (*) in whitelist values
- self.federation_certificate_verification_whitelist = [] # type: List[Pattern]
+ self.federation_certificate_verification_whitelist: List[Pattern] = []
for entry in fed_whitelist_entries:
try:
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))
@@ -132,8 +132,8 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
- self.tls_certificate = None # type: Optional[crypto.X509]
- self.tls_private_key = None # type: Optional[crypto.PKey]
+ self.tls_certificate: Optional[crypto.X509] = None
+ self.tls_private_key: Optional[crypto.PKey] = None
def is_disk_cert_valid(self, allow_self_signed=True):
"""
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index e5a4685ed4..9e9b1c1c86 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -170,11 +170,13 @@ class Keyring:
)
self._key_fetchers = key_fetchers
- self._server_queue = BatchingQueue(
+ self._server_queue: BatchingQueue[
+ _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
+ ] = BatchingQueue(
"keyring_server",
clock=hs.get_clock(),
process_batch_callback=self._inner_fetch_key_requests,
- ) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]]
+ )
async def verify_json_for_server(
self,
@@ -330,7 +332,7 @@ class Keyring:
# First we need to deduplicate requests for the same key. We do this by
# taking the *maximum* requested `minimum_valid_until_ts` for each pair
# of server name/key ID.
- server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]]
+ server_to_key_to_ts: Dict[str, Dict[str, int]] = {}
for request in requests:
by_server = server_to_key_to_ts.setdefault(request.server_name, {})
for key_id in request.key_ids:
@@ -355,7 +357,7 @@ class Keyring:
# We now convert the returned list of results into a map from server
# name to key ID to FetchKeyResult, to return.
- to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]]
+ to_return: Dict[str, Dict[str, FetchKeyResult]] = {}
for (request, results) in zip(deduped_requests, results_per_request):
to_return_by_server = to_return.setdefault(request.server_name, {})
for key_id, key_result in results.items():
@@ -455,7 +457,7 @@ class StoreKeyFetcher(KeyFetcher):
)
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
- keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
+ keys: Dict[str, Dict[str, FetchKeyResult]] = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
return keys
@@ -603,7 +605,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError)
)
- union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
+ union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {}
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
@@ -656,8 +658,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
- keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
- added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
+ keys: Dict[str, Dict[str, FetchKeyResult]] = {}
+ added_keys: List[Tuple[str, str, FetchKeyResult]] = []
time_now_ms = self.clock.time_msec()
@@ -805,7 +807,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
Raises:
KeyLookupError if there was a problem making the lookup
"""
- keys = {} # type: Dict[str, FetchKeyResult]
+ keys: Dict[str, FetchKeyResult] = {}
for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 89bcf81515..137dff2513 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -48,6 +48,9 @@ def check(
room_version_obj: the version of the room
event: the event being checked.
auth_events: the existing room state.
+ do_sig_check: True if it should be verified that the sending server
+ signed the event.
+ do_size_check: True if the size of the event fields should be verified.
Raises:
AuthError if the checks fail
@@ -528,7 +531,7 @@ def _check_power_levels(
user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels:
- levels_to_check = [
+ levels_to_check: List[Tuple[str, Optional[str]]] = [
("users_default", None),
("events_default", None),
("state_default", None),
@@ -536,7 +539,7 @@ def _check_power_levels(
("redact", None),
("kick", None),
("invite", None),
- ] # type: List[Tuple[str, Optional[str]]]
+ ]
old_list = current_state.content.get("users", {})
for user in set(list(old_list) + list(user_list)):
@@ -566,12 +569,12 @@ def _check_power_levels(
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
- old_level = int(old_loc[level_to_check]) # type: Optional[int]
+ old_level: Optional[int] = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
- new_level = int(new_loc[level_to_check]) # type: Optional[int]
+ new_level: Optional[int] = int(new_loc[level_to_check])
else:
new_level = None
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 6286ad999a..65dc7a4ed0 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -105,28 +105,28 @@ class _EventInternalMetadata:
self._dict = dict(internal_metadata_dict)
# the stream ordering of this event. None, until it has been persisted.
- self.stream_ordering = None # type: Optional[int]
+ self.stream_ordering: Optional[int] = None
# whether this event is an outlier (ie, whether we have the state at that point
# in the DAG)
self.outlier = False
- out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
- send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str
- recheck_redaction = DictProperty("recheck_redaction") # type: bool
- soft_failed = DictProperty("soft_failed") # type: bool
- proactively_send = DictProperty("proactively_send") # type: bool
- redacted = DictProperty("redacted") # type: bool
- txn_id = DictProperty("txn_id") # type: str
- token_id = DictProperty("token_id") # type: int
- historical = DictProperty("historical") # type: bool
+ out_of_band_membership: bool = DictProperty("out_of_band_membership")
+ send_on_behalf_of: str = DictProperty("send_on_behalf_of")
+ recheck_redaction: bool = DictProperty("recheck_redaction")
+ soft_failed: bool = DictProperty("soft_failed")
+ proactively_send: bool = DictProperty("proactively_send")
+ redacted: bool = DictProperty("redacted")
+ txn_id: str = DictProperty("txn_id")
+ token_id: int = DictProperty("token_id")
+ historical: bool = DictProperty("historical")
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
# be here
- before = DictProperty("before") # type: RoomStreamToken
- after = DictProperty("after") # type: RoomStreamToken
- order = DictProperty("order") # type: Tuple[int, int]
+ before: RoomStreamToken = DictProperty("before")
+ after: RoomStreamToken = DictProperty("after")
+ order: Tuple[int, int] = DictProperty("order")
def get_dict(self) -> JsonDict:
return dict(self._dict)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 26e3950859..87e2bb123b 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -132,12 +132,12 @@ class EventBuilder:
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions.
- auth_events = await self._store.add_event_hashes(
- auth_event_ids
- ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
- prev_events = await self._store.add_event_hashes(
- prev_event_ids
- ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+ auth_events: Union[
+ List[str], List[Tuple[str, Dict[str, str]]]
+ ] = await self._store.add_event_hashes(auth_event_ids)
+ prev_events: Union[
+ List[str], List[Tuple[str, Dict[str, str]]]
+ ] = await self._store.add_event_hashes(prev_event_ids)
else:
auth_events = auth_event_ids
prev_events = prev_event_ids
@@ -156,7 +156,7 @@ class EventBuilder:
# the db)
depth = min(depth, MAX_DEPTH)
- event_dict = {
+ event_dict: Dict[str, Any] = {
"auth_events": auth_events,
"prev_events": prev_events,
"type": self.type,
@@ -166,7 +166,7 @@ class EventBuilder:
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
- } # type: Dict[str, Any]
+ }
if self.is_state():
event_dict["state_key"] = self._state_key
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index efec16c226..57f1d53fa8 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -76,7 +76,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
"""Wrapper that loads spam checkers configured using the old configuration, and
registers the spam checker hooks they implement.
"""
- spam_checkers = [] # type: List[Any]
+ spam_checkers: List[Any] = []
api = hs.get_module_api()
for module, config in hs.config.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
@@ -239,7 +239,7 @@ class SpamChecker:
will be used as the error message returned to the user.
"""
for callback in self._check_event_for_spam_callbacks:
- res = await callback(event) # type: Union[bool, str]
+ res: Union[bool, str] = await callback(event)
if res:
return res
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ed09c6af1f..c767d30627 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -86,7 +86,7 @@ class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
+ self.pdu_destination_tried: Dict[str, Dict[str, int]] = {}
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@@ -94,13 +94,13 @@ class FederationClient(FederationBase):
self.hostname = hs.hostname
self.signing_key = hs.signing_key
- self._get_pdu_cache = ExpiringCache(
+ self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache(
cache_name="get_pdu_cache",
clock=self._clock,
max_len=1000,
expiry_ms=120 * 1000,
reset_expiry_on_get=False,
- ) # type: ExpiringCache[str, EventBase]
+ )
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
@@ -293,10 +293,10 @@ class FederationClient(FederationBase):
transaction_data,
)
- pdu_list = [
+ pdu_list: List[EventBase] = [
event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"]
- ] # type: List[EventBase]
+ ]
if pdu_list and pdu_list[0]:
pdu = pdu_list[0]
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ac0f2ccfb3..d91f0ff32f 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -122,12 +122,12 @@ class FederationServer(FederationBase):
# origins that we are currently processing a transaction from.
# a dict from origin to txn id.
- self._active_transactions = {} # type: Dict[str, str]
+ self._active_transactions: Dict[str, str] = {}
# We cache results for transaction with the same ID
- self._transaction_resp_cache = ResponseCache(
+ self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "fed_txn_handler", timeout_ms=30000
- ) # type: ResponseCache[Tuple[str, str]]
+ )
self.transaction_actions = TransactionActions(self.store)
@@ -135,12 +135,12 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
- self._state_resp_cache = ResponseCache(
- hs.get_clock(), "state_resp", timeout_ms=30000
- ) # type: ResponseCache[Tuple[str, Optional[str]]]
- self._state_ids_resp_cache = ResponseCache(
+ self._state_resp_cache: ResponseCache[
+ Tuple[str, Optional[str]]
+ ] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000)
+ self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000
- ) # type: ResponseCache[Tuple[str, str]]
+ )
self._federation_metrics_domains = (
hs.config.federation.federation_metrics_domains
@@ -337,7 +337,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
- pdus_by_room = {} # type: Dict[str, List[EventBase]]
+ pdus_by_room: Dict[str, List[EventBase]] = {}
newest_pdu_ts = 0
@@ -516,9 +516,9 @@ class FederationServer(FederationBase):
self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
if event_id:
- pdus = await self.handler.get_state_for_pdu(
+ pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
room_id, event_id
- ) # type: Iterable[EventBase]
+ )
else:
pdus = (await self.state.get_current_state(room_id)).values()
@@ -791,7 +791,7 @@ class FederationServer(FederationBase):
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self.store.claim_e2e_one_time_keys(query)
- json_result = {} # type: Dict[str, Dict[str, dict]]
+ json_result: Dict[str, Dict[str, dict]] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
@@ -1119,17 +1119,13 @@ class FederationHandlerRegistry:
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
- self.edu_handlers = (
- {}
- ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
- self.query_handlers = (
- {}
- ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
+ self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
# EDU received.
- self._edu_type_to_instance = {} # type: Dict[str, List[str]]
+ self._edu_type_to_instance: Dict[str, List[str]] = {}
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 65d76ea974..1fbf325fdc 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -71,34 +71,32 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# We may have multiple federation sender instances, so we need to track
# their positions separately.
self._sender_instances = hs.config.worker.federation_shard_config.instances
- self._sender_positions = {} # type: Dict[str, int]
+ self._sender_positions: Dict[str, int] = {}
# Pending presence map user_id -> UserPresenceState
- self.presence_map = {} # type: Dict[str, UserPresenceState]
+ self.presence_map: Dict[str, UserPresenceState] = {}
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
- self.presence_destinations = (
- SortedDict()
- ) # type: SortedDict[int, Tuple[str, Iterable[str]]]
+ self.presence_destinations: SortedDict[
+ int, Tuple[str, Iterable[str]]
+ ] = SortedDict()
# (destination, key) -> EDU
- self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
+ self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {}
# stream position -> (destination, key)
- self.keyed_edu_changed = (
- SortedDict()
- ) # type: SortedDict[int, Tuple[str, tuple]]
+ self.keyed_edu_changed: SortedDict[int, Tuple[str, tuple]] = SortedDict()
- self.edus = SortedDict() # type: SortedDict[int, Edu]
+ self.edus: SortedDict[int, Edu] = SortedDict()
# stream ID for the next entry into keyed_edu_changed/edus.
self.pos = 1
# map from stream ID to the time that stream entry was generated, so that we
# can clear out entries after a while
- self.pos_time = SortedDict() # type: SortedDict[int, int]
+ self.pos_time: SortedDict[int, int] = SortedDict()
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -291,7 +289,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
- rows = [] # type: List[Tuple[int, BaseFederationRow]]
+ rows: List[Tuple[int, BaseFederationRow]] = []
# Fetch presence to send to destinations
i = self.presence_destinations.bisect_right(from_token)
@@ -445,11 +443,11 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-_rowtypes = (
+_rowtypes: Tuple[Type[BaseFederationRow], ...] = (
PresenceDestinationsRow,
KeyedEduRow,
EduRow,
-) # type: Tuple[Type[BaseFederationRow], ...]
+)
TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index deb40f4610..d980e0d986 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,9 +14,12 @@
import abc
import logging
+from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
+import attr
from prometheus_client import Counter
+from typing_extensions import Literal
from twisted.internet import defer
@@ -33,8 +36,12 @@ from synapse.metrics import (
event_processing_loop_room_count,
events_processed_counter,
)
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
+from synapse.util import Clock
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -137,6 +144,84 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
raise NotImplementedError()
+@attr.s
+class _PresenceQueue:
+ """A queue of destinations that need to be woken up due to new presence
+ updates.
+
+ Staggers waking up of per destination queues to ensure that we don't attempt
+ to start TLS connections with many hosts all at once, leading to pinned CPU.
+ """
+
+ # The maximum duration in seconds between queuing up a destination and it
+ # being woken up.
+ _MAX_TIME_IN_QUEUE = 30.0
+
+ # The maximum duration in seconds between waking up consecutive destination
+ # queues.
+ _MAX_DELAY = 0.1
+
+ sender: "FederationSender" = attr.ib()
+ clock: Clock = attr.ib()
+ queue: "OrderedDict[str, Literal[None]]" = attr.ib(factory=OrderedDict)
+ processing: bool = attr.ib(default=False)
+
+ def add_to_queue(self, destination: str) -> None:
+ """Add a destination to the queue to be woken up."""
+
+ self.queue[destination] = None
+
+ if not self.processing:
+ self._handle()
+
+ @wrap_as_background_process("_PresenceQueue.handle")
+ async def _handle(self) -> None:
+ """Background process to drain the queue."""
+
+ if not self.queue:
+ return
+
+ assert not self.processing
+ self.processing = True
+
+ try:
+ # We start with a delay that should drain the queue quickly enough that
+ # we process all destinations in the queue in _MAX_TIME_IN_QUEUE
+ # seconds.
+ #
+ # We also add an upper bound to the delay, to gracefully handle the
+ # case where the queue only has a few entries in it.
+ current_sleep_seconds = min(
+ self._MAX_DELAY, self._MAX_TIME_IN_QUEUE / len(self.queue)
+ )
+
+ while self.queue:
+ destination, _ = self.queue.popitem(last=False)
+
+ queue = self.sender._get_per_destination_queue(destination)
+
+ if not queue._new_data_to_send:
+ # The per destination queue has already been woken up.
+ continue
+
+ queue.attempt_new_transaction()
+
+ await self.clock.sleep(current_sleep_seconds)
+
+ if not self.queue:
+ break
+
+ # More destinations may have been added to the queue, so we may
+ # need to reduce the delay to ensure everything gets processed
+ # within _MAX_TIME_IN_QUEUE seconds.
+ current_sleep_seconds = min(
+ current_sleep_seconds, self._MAX_TIME_IN_QUEUE / len(self.queue)
+ )
+
+ finally:
+ self.processing = False
+
+
class FederationSender(AbstractFederationSender):
def __init__(self, hs: "HomeServer"):
self.hs = hs
@@ -148,14 +233,14 @@ class FederationSender(AbstractFederationSender):
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
- self._presence_router = None # type: Optional[PresenceRouter]
+ self._presence_router: Optional["PresenceRouter"] = None
self._transaction_manager = TransactionManager(hs)
self._instance_name = hs.get_instance_name()
self._federation_shard_config = hs.config.worker.federation_shard_config
# map from destination to PerDestinationQueue
- self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue]
+ self._per_destination_queues: Dict[str, PerDestinationQueue] = {}
LaterGauge(
"synapse_federation_transaction_queue_pending_destinations",
@@ -192,9 +277,7 @@ class FederationSender(AbstractFederationSender):
# awaiting a call to flush_read_receipts_for_room. The presence of an entry
# here for a given room means that we are rate-limiting RR flushes to that room,
# and that there is a pending call to _flush_rrs_for_room in the system.
- self._queues_awaiting_rr_flush_by_room = (
- {}
- ) # type: Dict[str, Set[PerDestinationQueue]]
+ self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {}
self._rr_txn_interval_per_room_ms = (
1000.0 / hs.config.federation_rr_transactions_per_room_per_second
@@ -210,6 +293,8 @@ class FederationSender(AbstractFederationSender):
self._external_cache = hs.get_external_cache()
+ self._presence_queue = _PresenceQueue(self, self.clock)
+
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -265,7 +350,7 @@ class FederationSender(AbstractFederationSender):
if not event.internal_metadata.should_proactively_send():
return
- destinations = None # type: Optional[Set[str]]
+ destinations: Optional[Set[str]] = None
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
@@ -331,7 +416,7 @@ class FederationSender(AbstractFederationSender):
for event in events:
await handle_event(event)
- events_by_room = {} # type: Dict[str, List[EventBase]]
+ events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
@@ -519,7 +604,12 @@ class FederationSender(AbstractFederationSender):
self._instance_name, destination
):
continue
- self._get_per_destination_queue(destination).send_presence(states)
+
+ self._get_per_destination_queue(destination).send_presence(
+ states, start_loop=False
+ )
+
+ self._presence_queue.add_to_queue(destination)
def build_and_send_edu(
self,
@@ -628,7 +718,7 @@ class FederationSender(AbstractFederationSender):
In order to reduce load spikes, adds a delay between each destination.
"""
- last_processed = None # type: Optional[str]
+ last_processed: Optional[str] = None
while True:
destinations_to_wake = (
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 3a2efd56ee..c11d1f6d31 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -105,34 +105,34 @@ class PerDestinationQueue:
# catch-up at startup.
# New events will only be sent once this is finished, at which point
# _catching_up is flipped to False.
- self._catching_up = True # type: bool
+ self._catching_up: bool = True
# The stream_ordering of the most recent PDU that was discarded due to
# being in catch-up mode.
- self._catchup_last_skipped = 0 # type: int
+ self._catchup_last_skipped: int = 0
# Cache of the last successfully-transmitted stream ordering for this
# destination (we are the only updater so this is safe)
- self._last_successful_stream_ordering = None # type: Optional[int]
+ self._last_successful_stream_ordering: Optional[int] = None
# a queue of pending PDUs
- self._pending_pdus = [] # type: List[EventBase]
+ self._pending_pdus: List[EventBase] = []
# XXX this is never actually used: see
# https://github.com/matrix-org/synapse/issues/7549
- self._pending_edus = [] # type: List[Edu]
+ self._pending_edus: List[Edu] = []
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
- self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu]
+ self._pending_edus_keyed: Dict[Tuple[str, Hashable], Edu] = {}
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
- self._pending_presence = {} # type: Dict[str, UserPresenceState]
+ self._pending_presence: Dict[str, UserPresenceState] = {}
# room_id -> receipt_type -> user_id -> receipt_dict
- self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]]
+ self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {}
self._rrs_pending_flush = False
# stream_id of last successfully sent to-device message.
@@ -171,14 +171,24 @@ class PerDestinationQueue:
self.attempt_new_transaction()
- def send_presence(self, states: Iterable[UserPresenceState]) -> None:
- """Add presence updates to the queue. Start the transmission loop if necessary.
+ def send_presence(
+ self, states: Iterable[UserPresenceState], start_loop: bool = True
+ ) -> None:
+ """Add presence updates to the queue.
+
+ Args:
+ states: Presence updates to send
+ start_loop: Whether to start the transmission loop if not already
+ running.
Args:
states: presence to send
"""
self._pending_presence.update({state.user_id: state for state in states})
- self.attempt_new_transaction()
+ self._new_data_to_send = True
+
+ if start_loop:
+ self.attempt_new_transaction()
def queue_read_receipt(self, receipt: ReadReceipt) -> None:
"""Add a RR to the list to be sent. Doesn't start the transmission loop yet
@@ -243,7 +253,7 @@ class PerDestinationQueue:
)
async def _transaction_transmission_loop(self) -> None:
- pending_pdus = [] # type: List[EventBase]
+ pending_pdus: List[EventBase] = []
try:
self.transmission_loop_running = True
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index c9e7c57461..98b1bf77fd 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -395,9 +395,9 @@ class TransportLayerClient:
# this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms")
- data = {
+ data: Dict[str, Any] = {
"include_all_networks": "true" if include_all_networks else "false"
- } # type: Dict[str, Any]
+ }
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
@@ -423,9 +423,9 @@ class TransportLayerClient:
else:
path = _create_v1_path("/publicRooms")
- args = {
+ args: Dict[str, Any] = {
"include_all_networks": "true" if include_all_networks else "false"
- } # type: Dict[str, Any]
+ }
if third_party_instance_id:
args["third_party_instance_id"] = (third_party_instance_id,)
if limit:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index d37d9565fc..2974d4d0cc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1013,7 +1013,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access:
raise FederationDeniedError(origin)
- limit = int(content.get("limit", 100)) # type: Optional[int]
+ limit: Optional[int] = int(content.get("limit", 100))
since_token = content.get("since", None)
search_filter = content.get("filter", None)
@@ -1095,7 +1095,9 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1110,7 +1112,9 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1131,7 +1135,9 @@ class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1152,7 +1158,9 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1174,7 +1182,9 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
group_id: str,
room_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1192,7 +1202,9 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
group_id: str,
room_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1220,7 +1232,9 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
room_id: str,
config_key: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1243,7 +1257,9 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1264,7 +1280,9 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1288,7 +1306,9 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet):
group_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1354,7 +1374,9 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
group_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1487,7 +1509,9 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
category_id: str,
room_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1523,7 +1547,9 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
category_id: str,
room_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1549,7 +1575,9 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1571,7 +1599,9 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
group_id: str,
category_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1589,7 +1619,9 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
group_id: str,
category_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1618,7 +1650,9 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
group_id: str,
category_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1644,7 +1678,9 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1666,7 +1702,9 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
group_id: str,
role_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1682,7 +1720,9 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
group_id: str,
role_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1713,7 +1753,9 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
group_id: str,
role_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1750,7 +1792,9 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
role_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1784,7 +1828,9 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
role_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1825,7 +1871,9 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
query: Dict[bytes, List[bytes]],
group_id: str,
) -> Tuple[int, JsonDict]:
- requester_user_id = parse_string_from_args(query, "requester_user_id")
+ requester_user_id = parse_string_from_args(
+ query, "requester_user_id", required=True
+ )
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1943,7 +1991,7 @@ class RoomComplexityServlet(BaseFederationServlet):
return 200, complexity
-FEDERATION_SERVLET_CLASSES = (
+FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationSendServlet,
FederationEventServlet,
FederationStateV1Servlet,
@@ -1971,15 +2019,13 @@ FEDERATION_SERVLET_CLASSES = (
FederationSpaceSummaryServlet,
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
-) # type: Tuple[Type[BaseFederationServlet], ...]
+)
-OPENID_SERVLET_CLASSES = (
- OpenIdUserInfo,
-) # type: Tuple[Type[BaseFederationServlet], ...]
+OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,)
-ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...]
+ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,)
-GROUP_SERVER_SERVLET_CLASSES = (
+GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationGroupsProfileServlet,
FederationGroupsSummaryServlet,
FederationGroupsRoomsServlet,
@@ -1998,19 +2044,19 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsAddRoomsServlet,
FederationGroupsAddRoomsConfigServlet,
FederationGroupsSettingJoinPolicyServlet,
-) # type: Tuple[Type[BaseFederationServlet], ...]
+)
-GROUP_LOCAL_SERVLET_CLASSES = (
+GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationGroupsLocalInviteServlet,
FederationGroupsRemoveLocalUserServlet,
FederationGroupsBulkPublicisedServlet,
-) # type: Tuple[Type[BaseFederationServlet], ...]
+)
-GROUP_ATTESTATION_SERVLET_CLASSES = (
+GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationGroupsRenewAttestaionServlet,
-) # type: Tuple[Type[BaseFederationServlet], ...]
+)
DEFAULT_SERVLET_GROUPS = (
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index a06d060ebf..3dc55ab861 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -707,9 +707,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
See accept_invite, join_group.
"""
if not self.hs.is_mine_id(user_id):
- local_attestation = self.attestations.create_attestation(
- group_id, user_id
- ) # type: Optional[JsonDict]
+ local_attestation: Optional[
+ JsonDict
+ ] = self.attestations.create_attestation(group_id, user_id)
remote_attestation = content["attestation"]
@@ -868,9 +868,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
remote_attestation, user_id=requester_user_id, group_id=group_id
)
- local_attestation = self.attestations.create_attestation(
- group_id, requester_user_id
- ) # type: Optional[JsonDict]
+ local_attestation: Optional[
+ JsonDict
+ ] = self.attestations.create_attestation(group_id, requester_user_id)
else:
local_attestation = None
remote_attestation = None
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 4064a2b859..06d7012bac 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -22,6 +22,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
NotFoundError,
+ RequestSendFailed,
ShadowBanError,
StoreError,
SynapseError,
@@ -252,12 +253,14 @@ class DirectoryHandler(BaseHandler):
retry_on_dns_fail=False,
ignore_backoff=True,
)
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to fetch alias")
except CodeMessageException as e:
logging.warning("Error retrieving alias")
if e.code == 404:
fed_result = None
else:
- raise
+ raise SynapseError(502, "Failed to fetch alias")
if fed_result and "room_id" in fed_result and "servers" in fed_result:
room_id = fed_result["room_id"]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 991ec9919a..0209aee186 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1414,12 +1414,15 @@ class FederationHandler(BaseHandler):
Invites must be signed by the invitee's server before distribution.
"""
- pdu = await self.federation_client.send_invite(
- destination=target_host,
- room_id=event.room_id,
- event_id=event.event_id,
- pdu=event,
- )
+ try:
+ pdu = await self.federation_client.send_invite(
+ destination=target_host,
+ room_id=event.room_id,
+ event_id=event.event_id,
+ pdu=event,
+ )
+ except RequestSendFailed:
+ raise SynapseError(502, f"Can't connect to server {target_host}")
return pdu
@@ -3031,9 +3034,13 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context)
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
- await self.federation_client.forward_third_party_invite(
- destinations, room_id, event_dict
- )
+
+ try:
+ await self.federation_client.forward_third_party_invite(
+ destinations, room_id, event_dict
+ )
+ except (RequestSendFailed, HttpResponseException):
+ raise SynapseError(502, "Failed to forward third party invite")
async def on_exchange_third_party_invite_request(
self, event_dict: JsonDict
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 3f783947bd..5ecac0732c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -518,6 +518,9 @@ class EventCreationHandler:
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
opposed to being inline with the current DAG.
+ historical: Indicates whether the message is being inserted
+ back in time around some existing events. This is used to skip
+ a few checks and mark the event as backfilled.
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
@@ -772,6 +775,7 @@ class EventCreationHandler:
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
outlier: bool = False,
+ historical: bool = False,
depth: Optional[int] = None,
) -> Tuple[EventBase, int]:
"""
@@ -799,6 +803,9 @@ class EventCreationHandler:
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
opposed to being inline with the current DAG.
+ historical: Indicates whether the message is being inserted
+ back in time around some existing events. This is used to skip
+ a few checks and mark the event as backfilled.
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
@@ -847,6 +854,7 @@ class EventCreationHandler:
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
outlier=outlier,
+ historical=historical,
depth=depth,
)
@@ -1594,11 +1602,13 @@ class EventCreationHandler:
for k, v in original_event.internal_metadata.get_dict().items():
setattr(builder.internal_metadata, k, v)
- # the event type hasn't changed, so there's no point in re-calculating the
- # auth events.
+ # modules can send new state events, so we re-calculate the auth events just in
+ # case.
+ prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
+
event = await builder.build(
- prev_event_ids=original_event.prev_event_ids(),
- auth_event_ids=original_event.auth_event_ids(),
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=None,
)
# we rebuild the event context, to be on the safe side. If nothing else,
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index f782d9db32..0059ad0f56 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -30,6 +30,8 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
+ self.event_auth_handler = hs.get_event_auth_handler()
+
self.hs = hs
# We only need to poke the federation sender explicitly if its on the
@@ -59,6 +61,19 @@ class ReceiptsHandler(BaseHandler):
"""Called when we receive an EDU of type m.receipt from a remote HS."""
receipts = []
for room_id, room_values in content.items():
+ # If we're not in the room just ditch the event entirely. This is
+ # probably an old server that has come back and thinks we're still in
+ # the room (or we've been rejoined to the room by a state reset).
+ is_in_room = await self.event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
+ if not is_in_room:
+ logger.info(
+ "Ignoring receipt from %s as we're not in the room",
+ origin,
+ )
+ continue
+
for receipt_type, users in room_values.items():
for user_id, user_values in users.items():
if get_domain_from_id(user_id) != origin:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5019e6c1bb..1c2af01abb 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,7 +20,12 @@ import msgpack
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.api.errors import Codes, HttpResponseException
+from synapse.api.errors import (
+ Codes,
+ HttpResponseException,
+ RequestSendFailed,
+ SynapseError,
+)
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
@@ -418,14 +423,17 @@ class RoomListHandler(BaseHandler):
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
- return await repl_layer.get_public_rooms(
- server_name,
- limit=limit,
- since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
+ try:
+ return await repl_layer.get_public_rooms(
+ server_name,
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
+ except (RequestSendFailed, HttpResponseException):
+ raise SynapseError(502, "Failed to fetch room list")
key = (
server_name,
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index b585057ec3..366e6211e5 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
HistoryVisibility,
+ JoinRules,
Membership,
RoomTypes,
)
@@ -150,14 +151,21 @@ class SpaceSummaryHandler:
# The room should only be included in the summary if:
# a. the user is in the room;
# b. the room is world readable; or
- # c. the user is in a space that has been granted access to
- # the room.
+ # c. the user could join the room, e.g. the join rules
+ # are set to public or the user is in a space that
+ # has been granted access to the room.
#
# Note that we know the user is not in the root room (which is
# why the remote call was made in the first place), but the user
# could be in one of the children rooms and we just didn't know
# about the link.
- include_room = room.get("world_readable") is True
+
+ # The API doesn't return the room version so assume that a
+ # join rule of knock is valid.
+ include_room = (
+ room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
+ or room.get("world_readable") is True
+ )
# Check if the user is a member of any of the allowed spaces
# from the response.
@@ -420,9 +428,8 @@ class SpaceSummaryHandler:
It should be included if:
- * The requester is joined or invited to the room.
- * The requester can join without an invite (per MSC3083).
- * The origin server has any user that is joined or invited to the room.
+ * The requester is joined or can join the room (per MSC3173).
+ * The origin server has any user that is joined or can join the room.
* The history visibility is set to world readable.
Args:
@@ -441,13 +448,39 @@ class SpaceSummaryHandler:
# If there's no state for the room, it isn't known.
if not state_ids:
+ # The user might have a pending invite for the room.
+ if requester and await self._store.get_invite_for_local_user_in_room(
+ requester, room_id
+ ):
+ return True
+
logger.info("room %s is unknown, omitting from summary", room_id)
return False
room_version = await self._store.get_room_version(room_id)
- # if we have an authenticated requesting user, first check if they are able to view
- # stripped state in the room.
+ # Include the room if it has join rules of public or knock.
+ join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""))
+ if join_rules_event_id:
+ join_rules_event = await self._store.get_event(join_rules_event_id)
+ join_rule = join_rules_event.content.get("join_rule")
+ if join_rule == JoinRules.PUBLIC or (
+ room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+ ):
+ return True
+
+ # Include the room if it is peekable.
+ hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist_vis_event_id:
+ hist_vis_ev = await self._store.get_event(hist_vis_event_id)
+ hist_vis = hist_vis_ev.content.get("history_visibility")
+ if hist_vis == HistoryVisibility.WORLD_READABLE:
+ return True
+
+ # Otherwise we need to check information specific to the user or server.
+
+ # If we have an authenticated requesting user, check if they are a member
+ # of the room (or can join the room).
if requester:
member_event_id = state_ids.get((EventTypes.Member, requester), None)
@@ -470,9 +503,11 @@ class SpaceSummaryHandler:
return True
# If this is a request over federation, check if the host is in the room or
- # is in one of the spaces specified via the join rules.
+ # has a user who could join the room.
elif origin:
- if await self._event_auth_handler.check_host_in_room(room_id, origin):
+ if await self._event_auth_handler.check_host_in_room(
+ room_id, origin
+ ) or await self._store.is_host_invited(room_id, origin):
return True
# Alternately, if the host has a user in any of the spaces specified
@@ -490,18 +525,10 @@ class SpaceSummaryHandler:
):
return True
- # otherwise, check if the room is peekable
- hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None)
- if hist_vis_event_id:
- hist_vis_ev = await self._store.get_event(hist_vis_event_id)
- hist_vis = hist_vis_ev.content.get("history_visibility")
- if hist_vis == HistoryVisibility.WORLD_READABLE:
- return True
-
logger.info(
- "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary",
+ "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary",
room_id,
- requester,
+ requester or origin,
)
return False
@@ -535,6 +562,7 @@ class SpaceSummaryHandler:
"canonical_alias": stats["canonical_alias"],
"num_joined_members": stats["joined_members"],
"avatar_url": stats["avatar"],
+ "join_rules": stats["join_rules"],
"world_readable": (
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
),
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 4e45d1da57..814d08efcb 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -45,7 +45,6 @@ class StatsHandler:
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.stats_bucket_size = hs.config.stats_bucket_size
self.stats_enabled = hs.config.stats_enabled
@@ -106,20 +105,6 @@ class StatsHandler:
room_deltas = {}
user_deltas = {}
- # Then count deltas for total_events and total_event_bytes.
- (
- room_count,
- user_count,
- ) = await self.store.get_changes_room_total_events_and_bytes(
- self.pos, max_pos
- )
-
- for room_id, fields in room_count.items():
- room_deltas.setdefault(room_id, Counter()).update(fields)
-
- for user_id, fields in user_count.items():
- user_deltas.setdefault(user_id, Counter()).update(fields)
-
logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas)
@@ -181,12 +166,10 @@ class StatsHandler:
event_content = {} # type: JsonDict
- sender = None
if event_id is not None:
event = await self.store.get_event(event_id, allow_none=True)
if event:
event_content = event.content or {}
- sender = event.sender
# All the values in this dict are deltas (RELATIVE changes)
room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter())
@@ -244,12 +227,6 @@ class StatsHandler:
room_stats_delta["joined_members"] += 1
elif membership == Membership.INVITE:
room_stats_delta["invited_members"] += 1
-
- if sender and self.is_mine_id(sender):
- user_to_stats_deltas.setdefault(sender, Counter())[
- "invites_sent"
- ] += 1
-
elif membership == Membership.LEAVE:
room_stats_delta["left_members"] += 1
elif membership == Membership.BAN:
@@ -279,10 +256,6 @@ class StatsHandler:
room_state["is_federatable"] = (
event_content.get("m.federate", True) is True
)
- if sender and self.is_mine_id(sender):
- user_to_stats_deltas.setdefault(sender, Counter())[
- "rooms_created"
- ] += 1
elif typ == EventTypes.JoinRules:
room_state["join_rules"] = event_content.get("join_rule")
elif typ == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e22393adc4..c0a8364755 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -208,6 +208,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
+ self.event_auth_handler = hs.get_event_auth_handler()
self.hs = hs
@@ -326,6 +327,19 @@ class TypingWriterHandler(FollowerTypingHandler):
room_id = content["room_id"]
user_id = content["user_id"]
+ # If we're not in the room just ditch the event entirely. This is
+ # probably an old server that has come back and thinks we're still in
+ # the room (or we've been rejoined to the room by a state reset).
+ is_in_room = await self.event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
+ if not is_in_room:
+ logger.info(
+ "Ignoring typing update from %s as we're not in the room",
+ origin,
+ )
+ return
+
member = RoomMember(user_id=user_id, room_id=room_id)
# Check that the string is a valid user id
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index ed4671b7de..578fc48ef4 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -69,7 +69,7 @@ def _get_requested_host(request: IRequest) -> bytes:
return hostname
# no Host header, use the address/port that the request arrived on
- host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]
+ host: Union[address.IPv4Address, address.IPv6Address] = request.getHost()
hostname = host.host.encode("ascii")
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 1ca6624fd5..2ac76b15c2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -160,7 +160,7 @@ class _IPBlacklistingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
- addresses = [] # type: List[IAddress]
+ addresses: List[IAddress] = []
def _callback() -> None:
has_bad_ip = False
@@ -333,9 +333,9 @@ class SimpleHttpClient:
if self._ip_blacklist:
# If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding.
- self.reactor = BlacklistingReactorWrapper(
+ self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
- ) # type: ISynapseReactor
+ )
else:
self.reactor = hs.get_reactor()
@@ -349,14 +349,14 @@ class SimpleHttpClient:
pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
pool.cachedConnectionTimeout = 2 * 60
- self.agent = ProxyAgent(
+ self.agent: IAgent = ProxyAgent(
self.reactor,
hs.get_reactor(),
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
- ) # type: IAgent
+ )
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
@@ -411,7 +411,7 @@ class SimpleHttpClient:
cooperator=self._cooperator,
)
- request_deferred = treq.request(
+ request_deferred: defer.Deferred = treq.request(
method,
uri,
agent=self.agent,
@@ -421,7 +421,7 @@ class SimpleHttpClient:
# response bodies.
unbuffered=True,
**self._extra_treq_args,
- ) # type: defer.Deferred
+ )
# we use our own timeout mechanism rather than treq's as a workaround
# for https://twistedmatrix.com/trac/ticket/9534.
@@ -772,7 +772,7 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
- transport = None # type: Optional[ITCPTransport]
+ transport: Optional[ITCPTransport] = None
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
@@ -798,7 +798,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
- transport = None # type: Optional[ITCPTransport]
+ transport: Optional[ITCPTransport] = None
def __init__(
self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int]
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 20d39a4ea6..43f2140429 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -70,10 +70,8 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__)
-_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
-_had_valid_well_known_cache = TTLCache(
- "had-valid-well-known"
-) # type: TTLCache[bytes, bool]
+_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
+_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
@attr.s(slots=True, frozen=True)
@@ -130,9 +128,10 @@ class WellKnownResolver:
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"):
- result, cache_period = await self._fetch_well_known(
- server_name
- ) # type: Optional[bytes], float
+ result: Optional[bytes]
+ cache_period: float
+
+ result, cache_period = await self._fetch_well_known(server_name)
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b8849c0150..2efa15bf04 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -43,6 +43,7 @@ from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
+from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
@@ -105,7 +106,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
the parsed data.
"""
- CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore
+ CONTENT_TYPE: str = abc.abstractproperty() # type: ignore
"""The expected content type of the response, e.g. `application/json`. If
the content type doesn't match we fail the request.
"""
@@ -262,6 +263,15 @@ async def _handle_response(
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
+ except ResponseFailed as e:
+ logger.warning(
+ "{%s} [%s] Failed to read response - %s %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ request.uri.decode("ascii"),
+ )
+ raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response %s %s: %s",
@@ -317,11 +327,11 @@ class MatrixFederationHttpClient:
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
- self.reactor = BlacklistingReactorWrapper(
+ self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
hs.get_reactor(),
hs.config.federation_ip_range_whitelist,
hs.config.federation_ip_range_blacklist,
- ) # type: ISynapseReactor
+ )
user_agent = hs.version_string
if hs.config.user_agent_suffix:
@@ -494,7 +504,7 @@ class MatrixFederationHttpClient:
)
# Inject the span into the headers
- headers_dict = {} # type: Dict[bytes, List[bytes]]
+ headers_dict: Dict[bytes, List[bytes]] = {}
opentracing.inject_header_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -523,9 +533,9 @@ class MatrixFederationHttpClient:
destination_bytes, method_bytes, url_to_sign_bytes, json
)
data = encode_canonical_json(json)
- producer = QuieterFileBodyProducer(
+ producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
- ) # type: Optional[IBodyProducer]
+ )
else:
producer = None
auth_headers = self.build_auth_headers(
@@ -1137,6 +1147,24 @@ class MatrixFederationHttpClient:
msg,
)
raise SynapseError(502, msg, Codes.TOO_LARGE)
+ except defer.TimeoutError as e:
+ logger.warning(
+ "{%s} [%s] Timed out reading response - %s %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ request.uri.decode("ascii"),
+ )
+ raise RequestSendFailed(e, can_retry=True) from e
+ except ResponseFailed as e:
+ logger.warning(
+ "{%s} [%s] Failed to read response - %s %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ request.uri.decode("ascii"),
+ )
+ raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 7dfae8b786..7a6a1717de 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -117,7 +117,8 @@ class ProxyAgent(_AgentBase):
https_proxy = proxies["https"].encode() if "https" in proxies else None
no_proxy = proxies["no"] if "no" in proxies else None
- # Parse credentials from https proxy connection string if present
+ # Parse credentials from http and https proxy connection string if present
+ self.http_proxy_creds, http_proxy = parse_username_password(http_proxy)
self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
self.http_proxy_endpoint = _http_proxy_endpoint(
@@ -189,6 +190,15 @@ class ProxyAgent(_AgentBase):
and self.http_proxy_endpoint
and not should_skip_proxy
):
+ # Determine whether we need to set Proxy-Authorization headers
+ if self.http_proxy_creds:
+ # Set a Proxy-Authorization header
+ if headers is None:
+ headers = Headers()
+ headers.addRawHeader(
+ b"Proxy-Authorization",
+ self.http_proxy_creds.as_proxy_authorization_value(),
+ )
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index efbc6d5b25..b79fa722e9 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -81,7 +81,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
if f.check(SynapseError):
# mypy doesn't understand that f.check asserts the type.
- exc = f.value # type: SynapseError # type: ignore
+ exc: SynapseError = f.value # type: ignore
error_code = exc.code
error_dict = exc.error_dict()
@@ -132,7 +132,7 @@ def return_html_error(
"""
if f.check(CodeMessageException):
# mypy doesn't understand that f.check asserts the type.
- cme = f.value # type: CodeMessageException # type: ignore
+ cme: CodeMessageException = f.value # type: ignore
code = cme.code
msg = cme.msg
@@ -404,7 +404,7 @@ class JsonResource(DirectServeJsonResource):
key word arguments to pass to the callback
"""
# At this point the path must be bytes.
- request_path_bytes = request.path # type: bytes # type: ignore
+ request_path_bytes: bytes = request.path # type: ignore
request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests.
request_method = request.method
@@ -557,7 +557,7 @@ class _ByteProducer:
request: Request,
iterator: Iterator[bytes],
):
- self._request = request # type: Optional[Request]
+ self._request: Optional[Request] = request
self._iterator = iterator
self._paused = False
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 6ba2ce1e53..04560fb589 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -205,7 +205,7 @@ def parse_string(
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
- args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
return parse_string_from_args(
args,
name,
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 40754b7bea..3b0a38124e 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -64,16 +64,16 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, max_request_body_size=1024, **kw):
Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size
- self.site = channel.site # type: SynapseSite
+ self.site: SynapseSite = channel.site
self._channel = channel # this is used by the tests
self.start_time = 0.0
# The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object.
- self._requester = None # type: Optional[Union[Requester, str]]
+ self._requester: Optional[Union[Requester, str]] = None
# we can't yet create the logcontext, as we don't know the method.
- self.logcontext = None # type: Optional[LoggingContext]
+ self.logcontext: Optional[LoggingContext] = None
global _next_request_seq
self.request_seq = _next_request_seq
@@ -152,7 +152,7 @@ class SynapseRequest(Request):
Returns:
The redacted URI as a string.
"""
- uri = self.uri # type: Union[bytes, str]
+ uri: Union[bytes, str] = self.uri
if isinstance(uri, bytes):
uri = uri.decode("ascii", errors="replace")
return redact_uri(uri)
@@ -167,7 +167,7 @@ class SynapseRequest(Request):
Returns:
The request method as a string.
"""
- method = self.method # type: Union[bytes, str]
+ method: Union[bytes, str] = self.method
if isinstance(method, bytes):
return self.method.decode("ascii")
return method
@@ -434,8 +434,8 @@ class XForwardedForRequest(SynapseRequest):
"""
# the client IP and ssl flag, as extracted from the headers.
- _forwarded_for = None # type: Optional[_XForwardedForAddress]
- _forwarded_https = False # type: bool
+ _forwarded_for: "Optional[_XForwardedForAddress]" = None
+ _forwarded_https: bool = False
def requestReceived(self, command, path, version):
# this method is called by the Channel once the full request has been
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index c515690b38..8202d0494d 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -110,9 +110,9 @@ class RemoteHandler(logging.Handler):
self.port = port
self.maximum_buffer = maximum_buffer
- self._buffer = deque() # type: Deque[logging.LogRecord]
- self._connection_waiter = None # type: Optional[Deferred]
- self._producer = None # type: Optional[LogProducer]
+ self._buffer: Deque[logging.LogRecord] = deque()
+ self._connection_waiter: Optional[Deferred] = None
+ self._producer: Optional[LogProducer] = None
# Connect without DNS lookups if it's a direct IP.
if _reactor is None:
@@ -123,9 +123,9 @@ class RemoteHandler(logging.Handler):
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
- endpoint = TCP4ClientEndpoint(
+ endpoint: IStreamClientEndpoint = TCP4ClientEndpoint(
_reactor, self.host, self.port
- ) # type: IStreamClientEndpoint
+ )
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else:
@@ -165,7 +165,7 @@ class RemoteHandler(logging.Handler):
def writer(result: Protocol) -> None:
# Force recognising transport as a Connection and not the more
# generic ITransport.
- transport = result.transport # type: Connection # type: ignore
+ transport: Connection = result.transport # type: ignore
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
@@ -188,7 +188,7 @@ class RemoteHandler(logging.Handler):
self._producer.resumeProducing()
self._connection_waiter = None
- deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
+ deferred: Deferred = self._service.whenConnected(failAfterFailures=1)
deferred.addCallbacks(writer, fail)
self._connection_waiter = deferred
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index c7a971a9d6..b9933a1528 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -63,7 +63,7 @@ def parse_drain_configs(
DrainType.CONSOLE_JSON,
DrainType.FILE_JSON,
):
- formatter = "json" # type: Optional[str]
+ formatter: Optional[str] = "json"
elif logging_type in (
DrainType.CONSOLE_JSON_TERSE,
DrainType.NETWORK_JSON_TERSE,
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 7fc11a9ac2..18ac507802 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -113,13 +113,13 @@ class ContextResourceUsage:
self.reset()
else:
# FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
- self.ru_utime = copy_from.ru_utime # type: float
- self.ru_stime = copy_from.ru_stime # type: float
- self.db_txn_count = copy_from.db_txn_count # type: int
+ self.ru_utime: float = copy_from.ru_utime
+ self.ru_stime: float = copy_from.ru_stime
+ self.db_txn_count: int = copy_from.db_txn_count
- self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
- self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
- self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int
+ self.db_txn_duration_sec: float = copy_from.db_txn_duration_sec
+ self.db_sched_duration_sec: float = copy_from.db_sched_duration_sec
+ self.evt_db_fetch_count: int = copy_from.evt_db_fetch_count
def copy(self) -> "ContextResourceUsage":
return ContextResourceUsage(copy_from=self)
@@ -289,12 +289,12 @@ class LoggingContext:
# The thread resource usage when the logcontext became active. None
# if the context is not currently active.
- self.usage_start = None # type: Optional[resource._RUsage]
+ self.usage_start: Optional[resource._RUsage] = None
self.main_thread = get_thread_id()
self.request = None
self.tag = ""
- self.scope = None # type: Optional[_LogContextScope]
+ self.scope: Optional["_LogContextScope"] = None
# keep track of whether we have hit the __exit__ block for this context
# (suggesting that the the thing that created the context thinks it should
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 140ed711e3..185844f188 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -251,7 +251,7 @@ try:
except Exception:
logger.exception("Failed to report span")
- RustReporter = _WrappedRustReporter # type: Optional[Type[_WrappedRustReporter]]
+ RustReporter: Optional[Type[_WrappedRustReporter]] = _WrappedRustReporter
except ImportError:
RustReporter = None
@@ -286,7 +286,7 @@ class SynapseBaggage:
# Block everything by default
# A regex which matches the server_names to expose traces for.
# None means 'block everything'.
-_homeserver_whitelist = None # type: Optional[Pattern[str]]
+_homeserver_whitelist: Optional[Pattern[str]] = None
# Util methods
@@ -662,7 +662,7 @@ def inject_header_dict(
span = opentracing.tracer.active_span
- carrier = {} # type: Dict[str, str]
+ carrier: Dict[str, str] = {}
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -704,7 +704,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination):
return {}
- carrier = {} # type: Dict[str, str]
+ carrier: Dict[str, str] = {}
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
@@ -718,7 +718,7 @@ def active_span_context_as_string():
Returns:
The active span context encoded as a string.
"""
- carrier = {} # type: Dict[str, str]
+ carrier: Dict[str, str] = {}
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index fef2846669..f237b8a236 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
running_on_pypy = platform.python_implementation() == "PyPy"
-all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]]
+all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
@@ -130,7 +130,7 @@ class InFlightGauge:
)
# Counts number of in flight blocks for a given set of label values
- self._registrations = {} # type: Dict
+ self._registrations: Dict = {}
# Protects access to _registrations
self._lock = threading.Lock()
@@ -248,7 +248,7 @@ class GaugeBucketCollector:
# We initially set this to None. We won't report metrics until
# this has been initialised after a successful data update
- self._metric = None # type: Optional[GaugeHistogramMetricFamily]
+ self._metric: Optional[GaugeHistogramMetricFamily] = None
registry.register(self)
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index 8002be56e0..7e49d0d02c 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -125,7 +125,7 @@ def generate_latest(registry, emit_help=False):
)
output.append("# TYPE {0} {1}\n".format(mname, mtype))
- om_samples = {} # type: Dict[str, List[str]]
+ om_samples: Dict[str, List[str]] = {}
for s in metric.samples:
for suffix in ["_created", "_gsum", "_gcount"]:
if s.name == metric.name + suffix:
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index de96ca0821..4455fa71a8 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -93,7 +93,7 @@ _background_process_db_sched_duration = Counter(
# map from description to a counter, so that we can name our logcontexts
# incrementally. (It actually duplicates _background_process_start_count, but
# it's much simpler to do so than to try to combine them.)
-_background_process_counts = {} # type: Dict[str, int]
+_background_process_counts: Dict[str, int] = {}
# Set of all running background processes that became active active since the
# last time metrics were scraped (i.e. background processes that performed some
@@ -103,7 +103,7 @@ _background_process_counts = {} # type: Dict[str, int]
# background processes stacking up behind a lock or linearizer, where we then
# only need to iterate over and update metrics for the process that have
# actually been active and can ignore the idle ones.
-_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess]
+_background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set()
# A lock that covers the above set and dict
_bg_metrics_lock = threading.Lock()
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 721c45abac..308f045700 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -54,7 +54,7 @@ class ModuleApi:
self._state = hs.get_state_handler()
# We expose these as properties below in order to attach a helpful docstring.
- self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
+ self._http_client: SimpleHttpClient = hs.get_simple_http_client()
self._public_room_list_manager = PublicRoomListManager(hs)
self._spam_checker = hs.get_spam_checker()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 3c3cc47631..c5fbebc17d 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -203,21 +203,21 @@ class Notifier:
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs: "synapse.server.HomeServer"):
- self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream]
- self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]]
+ self.user_to_user_stream: Dict[str, _NotifierUserStream] = {}
+ self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
self.hs = hs
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
- self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry]
+ self.pending_new_room_events: List[_PendingRoomEventEntry] = []
# Called when there are new things to stream over replication
- self.replication_callbacks = [] # type: List[Callable[[], None]]
+ self.replication_callbacks: List[Callable[[], None]] = []
# Called when remote servers have come back online after having been
# down.
- self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]]
+ self.remote_server_up_callbacks: List[Callable[[str], None]] = []
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
@@ -237,7 +237,7 @@ class Notifier:
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
def count_listeners():
- all_user_streams = set() # type: Set[_NotifierUserStream]
+ all_user_streams: Set[_NotifierUserStream] = set()
for streams in list(self.room_to_user_streams.values()):
all_user_streams |= streams
@@ -329,8 +329,8 @@ class Notifier:
pending = self.pending_new_room_events
self.pending_new_room_events = []
- users = set() # type: Set[UserID]
- rooms = set() # type: Set[str]
+ users: Set[UserID] = set()
+ rooms: Set[str] = set()
for entry in pending:
if entry.event_pos.persisted_after(max_room_stream_token):
@@ -580,7 +580,7 @@ class Notifier:
if after_token == before_token:
return EventStreamResult([], (from_token, from_token))
- events = [] # type: List[EventBase]
+ events: List[EventBase] = []
end_token = from_token
for name, source in self.event_sources.sources.items():
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 669ea462e2..c337e530d3 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -194,7 +194,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
- actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
+ actions_by_user: Dict[str, List[Union[dict, str]]] = {}
room_members = await self.store.get_joined_users_from_context(event, context)
@@ -207,7 +207,7 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels
)
- condition_cache = {} # type: Dict[str, bool]
+ condition_cache: Dict[str, bool] = {}
# If the event is not a state event check if any users ignore the sender.
if not event.is_state():
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 2ee0ccd58a..1fc9716a34 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -26,10 +26,10 @@ def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, l
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
- rules = {
+ rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
"global": {},
"device": {},
- } # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
+ }
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 99a18874d1..e08e125cb8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -66,8 +66,8 @@ class EmailPusher(Pusher):
self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey
- self.timed_call = None # type: Optional[IDelayedCall]
- self.throttle_params = {} # type: Dict[str, ThrottleParams]
+ self.timed_call: Optional[IDelayedCall] = None
+ self.throttle_params: Dict[str, ThrottleParams] = {}
self._inited = False
self._is_processing = False
@@ -168,7 +168,7 @@ class EmailPusher(Pusher):
)
)
- soonest_due_at = None # type: Optional[int]
+ soonest_due_at: Optional[int] = None
if not unprocessed:
await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 708ddbd78d..250a4861b0 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -71,7 +71,7 @@ class HttpPusher(Pusher):
self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusher_config.failing_since
- self.timed_call = None # type: Optional[IDelayedCall]
+ self.timed_call: Optional[IDelayedCall] = None
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
self._pusherpool = hs.get_pusherpool()
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 5f9ea5003a..7be5fe1e9b 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -110,7 +110,7 @@ class Mailer:
self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self.app_name = app_name
- self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig
+ self.email_subjects: EmailSubjectConfig = hs.config.email_subjects
logger.info("Created Mailer for app_name %s" % app_name)
@@ -230,7 +230,7 @@ class Mailer:
[pa["event_id"] for pa in push_actions]
)
- notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
+ notifs_by_room: Dict[str, List[Dict[str, Any]]] = {}
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@@ -356,13 +356,13 @@ class Mailer:
room_name = await calculate_room_name(self.store, room_state_ids, user_id)
- room_vars = {
+ room_vars: Dict[str, Any] = {
"title": room_name,
"hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [],
"invite": is_invite,
"link": self._make_room_link(room_id),
- } # type: Dict[str, Any]
+ }
if not is_invite:
for n in notifs:
@@ -460,9 +460,9 @@ class Mailer:
type_state_key = ("m.room.member", event.sender)
sender_state_event_id = room_state_ids.get(type_state_key)
if sender_state_event_id:
- sender_state_event = await self.store.get_event(
+ sender_state_event: Optional[EventBase] = await self.store.get_event(
sender_state_event_id
- ) # type: Optional[EventBase]
+ )
else:
# Attempt to check the historical state for the room.
historical_state = await self.state_store.get_state_for_event(
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 412941393f..0510c1cbd5 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -199,7 +199,7 @@ def name_from_member_event(member_event: EventBase) -> str:
def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
- ret = {} # type: Dict[str, Dict[str, str]]
+ ret: Dict[str, Dict[str, str]] = {}
for k, v in state.items():
ret.setdefault(k[0], {})[k[1]] = v
return ret
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 98b90a4f51..7a8dc63976 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -195,9 +195,9 @@ class PushRuleEvaluatorForEvent:
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache = LruCache(
+regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
50000, "regex_push_cache"
-) # type: LruCache[Tuple[str, bool, bool], Pattern]
+)
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index c51938b8cf..021275437c 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -31,13 +31,13 @@ class PusherFactory:
self.hs = hs
self.config = hs.config
- self.pusher_types = {
+ self.pusher_types: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] = {
"http": HttpPusher
- } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
+ }
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
- self.mailers = {} # type: Dict[str, Mailer]
+ self.mailers: Dict[str, Mailer] = {}
self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 579fcdf472..2519ad76db 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -87,7 +87,7 @@ class PusherPool:
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
# map from user id to app_id:pushkey to pusher
- self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
+ self.pushers: Dict[str, Dict[str, Pusher]] = {}
def start(self) -> None:
"""Starts the pushers off in a background process."""
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 271c17c226..cdcbdd772b 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -115,7 +115,7 @@ CONDITIONAL_REQUIREMENTS = {
"cache_memory": ["pympler"],
}
-ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
+ALL_OPTIONAL_REQUIREMENTS: Set[str] = set()
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
# Exclude systemd as it's a system-based requirement.
@@ -193,7 +193,7 @@ def check_requirements(for_feature=None):
if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be
# installed.
- OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
+ OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), [])
for dependency in OPTS:
try:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index f13a7c23b4..25589b0042 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -85,17 +85,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
is received.
"""
- NAME = abc.abstractproperty() # type: str # type: ignore
- PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
+ NAME: str = abc.abstractproperty() # type: ignore
+ PATH_ARGS: Tuple[str, ...] = abc.abstractproperty() # type: ignore
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
def __init__(self, hs: "HomeServer"):
if self.CACHE:
- self.response_cache = ResponseCache(
+ self.response_cache: ResponseCache[str] = ResponseCache(
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
- ) # type: ResponseCache[str]
+ )
# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
@@ -232,7 +232,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
- headers = {} # type: Dict[bytes, List[bytes]]
+ headers: Dict[bytes, List[bytes]] = {}
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index faa99387a7..e460dd85cd 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -27,7 +27,9 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen = MultiWriterIdGenerator(
+ self._cache_id_gen: Optional[
+ MultiWriterIdGenerator
+ ] = MultiWriterIdGenerator(
db_conn,
database,
stream_name="caches",
@@ -41,7 +43,7 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
],
sequence_name="cache_invalidation_stream_seq",
writers=[],
- ) # type: Optional[MultiWriterIdGenerator]
+ )
else:
self._cache_id_gen = None
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 13ed87adc4..436d39c320 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -23,9 +23,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self.client_ip_last_seen = LruCache(
+ self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
cache_name="client_ip_last_seen", max_size=50000
- ) # type: LruCache[tuple, int]
+ )
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 62d7809175..9d4859798b 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -121,13 +121,13 @@ class ReplicationDataHandler:
self._pusher_pool = hs.get_pusherpool()
self._presence_handler = hs.get_presence_handler()
- self.send_handler = None # type: Optional[FederationSenderHandler]
+ self.send_handler: Optional[FederationSenderHandler] = None
if hs.should_send_federation():
self.send_handler = FederationSenderHandler(hs)
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
- self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
+ self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {}
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -173,7 +173,7 @@ class ReplicationDataHandler:
if entities:
self.notifier.on_new_event("to_device_key", token, users=entities)
elif stream_name == DeviceListsStream.NAME:
- all_room_ids = set() # type: Set[str]
+ all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
@@ -201,7 +201,7 @@ class ReplicationDataHandler:
if row.data.rejected:
continue
- extra_users = () # type: Tuple[UserID, ...]
+ extra_users: Tuple[UserID, ...] = ()
if row.data.type == EventTypes.Member and row.data.state_key:
extra_users = (UserID.from_string(row.data.state_key),)
@@ -348,7 +348,7 @@ class FederationSenderHandler:
# Stores the latest position in the federation stream we've gotten up
# to. This is always set before we use it.
- self.federation_position = None # type: Optional[int]
+ self.federation_position: Optional[int] = None
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 505d450e19..1311b013da 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -34,7 +34,7 @@ class Command(metaclass=abc.ABCMeta):
A full command line on the wire is constructed from `NAME + " " + to_line()`
"""
- NAME = None # type: str
+ NAME: str
@classmethod
@abc.abstractmethod
@@ -380,7 +380,7 @@ class RemoteServerUpCommand(_SimpleCommand):
NAME = "REMOTE_SERVER_UP"
-_COMMANDS = (
+_COMMANDS: Tuple[Type[Command], ...] = (
ServerCommand,
RdataCommand,
PositionCommand,
@@ -393,7 +393,7 @@ _COMMANDS = (
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
-) # type: Tuple[Type[Command], ...]
+)
# Map of command name to command type.
COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 2ad7a200bb..eae4515363 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -105,12 +105,12 @@ class ReplicationCommandHandler:
hs.get_instance_name() in hs.config.worker.writers.presence
)
- self._streams = {
+ self._streams: Dict[str, Stream] = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
- } # type: Dict[str, Stream]
+ }
# List of streams that this instance is the source of
- self._streams_to_replicate = [] # type: List[Stream]
+ self._streams_to_replicate: List[Stream] = []
for stream in self._streams.values():
if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
@@ -180,14 +180,14 @@ class ReplicationCommandHandler:
# Map of stream name to batched updates. See RdataCommand for info on
# how batching works.
- self._pending_batches = {} # type: Dict[str, List[Any]]
+ self._pending_batches: Dict[str, List[Any]] = {}
# The factory used to create connections.
- self._factory = None # type: Optional[ReconnectingClientFactory]
+ self._factory: Optional[ReconnectingClientFactory] = None
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
- self._connections = [] # type: List[IReplicationConnection]
+ self._connections: List[IReplicationConnection] = []
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@@ -200,7 +200,7 @@ class ReplicationCommandHandler:
# them in order in a separate background process.
# the streams which are currently being processed by _unsafe_process_queue
- self._processing_streams = set() # type: Set[str]
+ self._processing_streams: Set[str] = set()
# for each stream, a queue of commands that are awaiting processing, and the
# connection that they arrived on.
@@ -210,7 +210,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
- self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
+ self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {}
LaterGauge(
"synapse_replication_tcp_command_queue",
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 6e3705364f..8c80153ab6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -102,7 +102,7 @@ tcp_outbound_commands_counter = Counter(
# A list of all connected protocols. This allows us to send metrics about the
# connections.
-connected_connections = [] # type: List[BaseReplicationStreamProtocol]
+connected_connections: "List[BaseReplicationStreamProtocol]" = []
logger = logging.getLogger(__name__)
@@ -146,15 +146,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The transport is going to be an ITCPTransport, but that doesn't have the
# (un)registerProducer methods, those are only on the implementation.
- transport = None # type: Connection
+ transport: Connection
delimiter = b"\n"
# Valid commands we expect to receive
- VALID_INBOUND_COMMANDS = [] # type: Collection[str]
+ VALID_INBOUND_COMMANDS: Collection[str] = []
# Valid commands we can send
- VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
+ VALID_OUTBOUND_COMMANDS: Collection[str] = []
max_line_buffer = 10000
@@ -165,7 +165,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
# When we requested the connection be closed
- self.time_we_closed = None # type: Optional[int]
+ self.time_we_closed: Optional[int] = None
self.received_ping = False # Have we received a ping from the other side
@@ -175,10 +175,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection
- self.pending_commands = [] # type: List[Command]
+ self.pending_commands: List[Command] = []
# The LoopingCall for sending pings.
- self._send_ping_loop = None # type: Optional[task.LoopingCall]
+ self._send_ping_loop: Optional[task.LoopingCall] = None
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 6a2c2655e4..8c0df627c8 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -57,7 +57,7 @@ class ConstantProperty(Generic[T, V]):
it.
"""
- constant = attr.ib() # type: V
+ constant: V = attr.ib()
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
return self.constant
@@ -91,9 +91,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
commands.
"""
- synapse_handler = None # type: ReplicationCommandHandler
- synapse_stream_name = None # type: str
- synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ synapse_handler: "ReplicationCommandHandler"
+ synapse_stream_name: str
+ synapse_outbound_redis_connection: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b03824925a..3716c41bea 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -85,9 +85,9 @@ class Stream:
time it was called.
"""
- NAME = None # type: str # The name of the stream
+ NAME: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
- ROW_TYPE = None # type: Any
+ ROW_TYPE: Any = None
@classmethod
def parse_row(cls, row: StreamRow):
@@ -283,9 +283,7 @@ class PresenceStream(Stream):
assert isinstance(presence_handler, PresenceHandler)
- update_function = (
- presence_handler.get_all_presence_updates
- ) # type: UpdateFunction
+ update_function: UpdateFunction = presence_handler.get_all_presence_updates
else:
# Query presence writer process
update_function = make_http_update_function(hs, self.NAME)
@@ -334,9 +332,9 @@ class TypingStream(Stream):
if writer_instance == hs.get_instance_name():
# On the writer, query the typing handler
typing_writer_handler = hs.get_typing_writer_handler()
- update_function = (
- typing_writer_handler.get_all_typing_updates
- ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+ update_function: Callable[
+ [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
+ ] = typing_writer_handler.get_all_typing_updates
current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index e7e87bac92..a030e9299e 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -65,7 +65,7 @@ class BaseEventsStreamRow:
"""
# Unique string that ids the type. Must be overridden in sub classes.
- TypeId = None # type: str
+ TypeId: str
@classmethod
def from_data(cls, data):
@@ -103,10 +103,10 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id = attr.ib() # str, optional
-_EventRows = (
+_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
EventsStreamEventRow,
EventsStreamCurrentStateRow,
-) # type: Tuple[Type[BaseEventsStreamRow], ...]
+)
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
@@ -157,9 +157,9 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
- event_rows = await self._store.get_all_new_forward_event_rows(
+ event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
- ) # type: List[Tuple]
+ )
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so
# that we know it is safe to just take upper_limit = event_rows[-1][0].
@@ -172,7 +172,7 @@ class EventsStream(Stream):
if len(event_rows) == target_row_count:
limited = True
- upper_limit = event_rows[-1][0] # type: int
+ upper_limit: int = event_rows[-1][0]
else:
limited = False
upper_limit = current_token
@@ -191,30 +191,30 @@ class EventsStream(Stream):
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
- ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+ ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
- ) # type: List[Tuple]
+ )
# we now need to turn the raw database rows returned into tuples suitable
# for the replication protocol (basically, we add an identifier to
# distinguish the row type). At the same time, we can limit the event_rows
# to the max stream_id from state_rows.
- event_updates = (
+ event_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in event_rows
if stream_id <= upper_limit
- ) # type: Iterable[Tuple[int, Tuple]]
+ )
- state_updates = (
+ state_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows
- ) # type: Iterable[Tuple[int, Tuple]]
+ )
- ex_outliers_updates = (
+ ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in ex_outliers_rows
- ) # type: Iterable[Tuple[int, Tuple]]
+ )
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 096a85d363..c445af9bd9 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -51,9 +51,9 @@ class FederationStream(Stream):
current_token = current_token_without_instance(
federation_sender.get_current_token
)
- update_function = (
- federation_sender.get_replication_rows
- ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+ update_function: Callable[
+ [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
+ ] = federation_sender.get_replication_rows
elif hs.should_send_federation():
# federation sender: Query master process
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f0cddd2d2c..3c51a742bf 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -462,6 +462,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
@@ -500,7 +501,13 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
admin_user_id = None
for admin_user in reversed(admin_users):
- if room_state.get((EventTypes.Member, admin_user)):
+ (
+ current_membership_type,
+ _,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ admin_user, room_id
+ )
+ if current_membership_type == "join":
admin_user_id = admin_user
break
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 92ebe838fd..ebf4e32230 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -29,6 +29,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.filtering import Filter
+from synapse.appservice import ApplicationService
from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
RestServlet,
@@ -47,11 +48,13 @@ from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import (
JsonDict,
+ Requester,
RoomAlias,
RoomID,
StreamToken,
ThirdPartyInstanceID,
UserID,
+ create_requester,
)
from synapse.util import json_decoder
from synapse.util.stringutils import parse_and_validate_server_name, random_string
@@ -309,7 +312,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+ async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
(
most_recent_prev_event_id,
most_recent_prev_event_depth,
@@ -349,6 +352,54 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
return depth
+ def _create_insertion_event_dict(
+ self, sender: str, room_id: str, origin_server_ts: int
+ ):
+ """Creates an event dict for an "insertion" event with the proper fields
+ and a random chunk ID.
+
+ Args:
+ sender: The event author MXID
+ room_id: The room ID that the event belongs to
+ origin_server_ts: Timestamp when the event was sent
+
+ Returns:
+ Tuple of event ID and stream ordering position
+ """
+
+ next_chunk_id = random_string(8)
+ insertion_event = {
+ "type": EventTypes.MSC2716_INSERTION,
+ "sender": sender,
+ "room_id": room_id,
+ "content": {
+ EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
+ EventContentFields.MSC2716_HISTORICAL: True,
+ },
+ "origin_server_ts": origin_server_ts,
+ }
+
+ return insertion_event
+
+ async def _create_requester_for_user_id_from_app_service(
+ self, user_id: str, app_service: ApplicationService
+ ) -> Requester:
+ """Creates a new requester for the given user_id
+ and validates that the app service is allowed to control
+ the given user.
+
+ Args:
+ user_id: The author MXID that the app service is controlling
+ app_service: The app service that controls the user
+
+ Returns:
+ Requester object
+ """
+
+ await self.auth.validate_appservice_can_control_user_id(app_service, user_id)
+
+ return create_requester(user_id, app_service=app_service)
+
async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
@@ -414,7 +465,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
if event_dict["type"] == EventTypes.Member:
membership = event_dict["content"].get("membership", None)
event_id, _ = await self.room_member_handler.update_membership(
- requester,
+ await self._create_requester_for_user_id_from_app_service(
+ state_event["sender"], requester.app_service
+ ),
target=UserID.from_string(event_dict["state_key"]),
room_id=room_id,
action=membership,
@@ -434,7 +487,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
+ await self._create_requester_for_user_id_from_app_service(
+ state_event["sender"], requester.app_service
+ ),
event_dict,
outlier=True,
prev_event_ids=[fake_prev_event_id],
@@ -449,37 +504,73 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
events_to_create = body["events"]
- # If provided, connect the chunk to the last insertion point
- # The chunk ID passed in comes from the chunk_id in the
- # "insertion" event from the previous chunk.
+ prev_event_ids = prev_events_from_query
+ inherited_depth = await self._inherit_depth_from_prev_ids(
+ prev_events_from_query
+ )
+
+ # Figure out which chunk to connect to. If they passed in
+ # chunk_id_from_query let's use it. The chunk ID passed in comes
+ # from the chunk_id in the "insertion" event from the previous chunk.
+ last_event_in_chunk = events_to_create[-1]
+ chunk_id_to_connect_to = chunk_id_from_query
+ base_insertion_event = None
if chunk_id_from_query:
- last_event_in_chunk = events_to_create[-1]
- last_event_in_chunk["content"][
- EventContentFields.MSC2716_CHUNK_ID
- ] = chunk_id_from_query
+ # TODO: Verify the chunk_id_from_query corresponds to an insertion event
+ pass
+ # Otherwise, create an insertion event to act as a starting point.
+ #
+ # We don't always have an insertion event to start hanging more history
+ # off of (ideally there would be one in the main DAG, but that's not the
+ # case if we're wanting to add history to e.g. existing rooms without
+ # an insertion event), in which case we just create a new insertion event
+ # that can then get pointed to by a "marker" event later.
+ else:
+ base_insertion_event_dict = self._create_insertion_event_dict(
+ sender=requester.user.to_string(),
+ room_id=room_id,
+ origin_server_ts=last_event_in_chunk["origin_server_ts"],
+ )
+ base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
+
+ (
+ base_insertion_event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ await self._create_requester_for_user_id_from_app_service(
+ base_insertion_event_dict["sender"],
+ requester.app_service,
+ ),
+ base_insertion_event_dict,
+ prev_event_ids=base_insertion_event_dict.get("prev_events"),
+ auth_event_ids=auth_event_ids,
+ historical=True,
+ depth=inherited_depth,
+ )
+
+ chunk_id_to_connect_to = base_insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_CHUNK_ID
+ ]
- # Add an "insertion" event to the start of each chunk (next to the oldest
+ # Connect this current chunk to the insertion event from the previous chunk
+ last_event_in_chunk["content"][
+ EventContentFields.MSC2716_CHUNK_ID
+ ] = chunk_id_to_connect_to
+
+ # Add an "insertion" event to the start of each chunk (next to the oldest-in-time
# event in the chunk) so the next chunk can be connected to this one.
- next_chunk_id = random_string(64)
- insertion_event = {
- "type": EventTypes.MSC2716_INSERTION,
- "sender": requester.user.to_string(),
- "content": {
- EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
- EventContentFields.MSC2716_HISTORICAL: True,
- },
+ insertion_event = self._create_insertion_event_dict(
+ sender=requester.user.to_string(),
+ room_id=room_id,
# Since the insertion event is put at the start of the chunk,
- # where the oldest event is, copy the origin_server_ts from
+ # where the oldest-in-time event is, copy the origin_server_ts from
# the first event we're inserting
- "origin_server_ts": events_to_create[0]["origin_server_ts"],
- }
+ origin_server_ts=events_to_create[0]["origin_server_ts"],
+ )
# Prepend the insertion event to the start of the chunk
events_to_create = [insertion_event] + events_to_create
- inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query)
-
event_ids = []
- prev_event_ids = prev_events_from_query
events_to_persist = []
for ev in events_to_create:
assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
@@ -498,7 +589,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
}
event, context = await self.event_creation_handler.create_event(
- requester,
+ await self._create_requester_for_user_id_from_app_service(
+ ev["sender"], requester.app_service
+ ),
event_dict,
prev_event_ids=event_dict.get("prev_events"),
auth_event_ids=auth_event_ids,
@@ -528,15 +621,23 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
# where topological_ordering is just depth.
for (event, context) in reversed(events_to_persist):
ev = await self.event_creation_handler.handle_new_client_event(
- requester=requester,
+ await self._create_requester_for_user_id_from_app_service(
+ event["sender"], requester.app_service
+ ),
event=event,
context=context,
)
+ # Add the base_insertion_event to the bottom of the list we return
+ if base_insertion_event is not None:
+ event_ids.append(base_insertion_event.event_id)
+
return 200, {
"state_events": auth_event_ids,
"events": event_ids,
- "next_chunk_id": next_chunk_id,
+ "next_chunk_id": insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_CHUNK_ID
+ ],
}
def on_GET(self, request, room_id):
diff --git a/synapse/server.py b/synapse/server.py
index 2c27d2a7e8..095dba9ad0 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -247,15 +247,15 @@ class HomeServer(metaclass=abc.ABCMeta):
# the key we use to sign events and requests
self.signing_key = config.key.signing_key[0]
self.config = config
- self._listening_services = [] # type: List[twisted.internet.tcp.Port]
- self.start_time = None # type: Optional[int]
+ self._listening_services: List[twisted.internet.tcp.Port] = []
+ self.start_time: Optional[int] = None
self._instance_id = random_string(5)
self._instance_name = config.worker.instance_name
self.version_string = version_string
- self.datastores = None # type: Optional[Databases]
+ self.datastores: Optional[Databases] = None
self._module_web_resources: Dict[str, IResource] = {}
self._module_web_resources_consumed = False
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index e65f6f88fe..4e0f814035 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -34,7 +34,7 @@ class ConsentServerNotices:
self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastore()
- self._users_in_progress = set() # type: Set[str]
+ self._users_in_progress: Set[str] = set()
self._current_consent_version = hs.config.user_consent_version
self._server_notice_content = hs.config.user_consent_server_notice_content
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index e4b0bc5c72..073b0d754f 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -205,7 +205,7 @@ class ResourceLimitsServerNotices:
# The user has yet to join the server notices room
pass
- referenced_events = [] # type: List[str]
+ referenced_events: List[str] = []
if pinned_state_event is not None:
referenced_events = list(pinned_state_event.content.get("pinned", []))
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index c875b15b32..cdf0973d05 100644
--- a/synapse/server_notices/server_notices_sender.py
+++ b/synapse/server_notices/server_notices_sender.py
@@ -32,10 +32,12 @@ class ServerNoticesSender(WorkerServerNoticesSender):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self._server_notices = (
+ self._server_notices: Iterable[
+ Union[ConsentServerNotices, ResourceLimitsServerNotices]
+ ] = (
ConsentServerNotices(hs),
ResourceLimitsServerNotices(hs),
- ) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]]
+ )
async def on_user_syncing(self, user_id: str) -> None:
"""Called when the user performs a sync operation.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index a1770f620e..6223daf522 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -309,9 +309,9 @@ class StateHandler:
if old_state:
# if we're given the state before the event, then we use that
- state_ids_before_event = {
+ state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state
- } # type: StateMap[str]
+ }
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
@@ -513,23 +513,25 @@ class StateResolutionHandler:
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
# dict of set of event_ids -> _StateCacheEntry.
- self._state_cache = ExpiringCache(
+ self._state_cache: ExpiringCache[
+ FrozenSet[int], _StateCacheEntry
+ ] = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=100000,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
- ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
+ )
#
# stuff for tracking time spent on state-res by room
#
# tracks the amount of work done on state res per room
- self._state_res_metrics = defaultdict(
+ self._state_res_metrics: DefaultDict[str, _StateResMetrics] = defaultdict(
_StateResMetrics
- ) # type: DefaultDict[str, _StateResMetrics]
+ )
self.clock.looping_call(self._report_metrics, 120 * 1000)
@@ -700,9 +702,9 @@ class StateResolutionHandler:
items = self._state_res_metrics.items()
# log the N biggest rooms
- biggest = heapq.nlargest(
+ biggest: List[Tuple[str, _StateResMetrics]] = heapq.nlargest(
n_to_log, items, key=lambda i: extract_key(i[1])
- ) # type: List[Tuple[str, _StateResMetrics]]
+ )
metrics_logger.debug(
"%i biggest rooms for state-res by %s: %s",
len(biggest),
@@ -754,7 +756,7 @@ def _make_state_cache_entry(
# failing that, look for the closest match.
prev_group = None
- delta_ids = None # type: Optional[StateMap[str]]
+ delta_ids: Optional[StateMap[str]] = None
for old_group, old_state in state_groups_ids.items():
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 318e998813..267193cedf 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -159,7 +159,7 @@ def _seperate(
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
- conflicted_state = {} # type: MutableStateMap[Set[str]]
+ conflicted_state: MutableStateMap[Set[str]] = {}
for state_set in state_set_iterator:
for key, value in state_set.items():
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 008644cd98..e66e6571c8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -276,7 +276,7 @@ async def _get_auth_chain_difference(
# event IDs if they appear in the `event_map`. This is the intersection of
# the event's auth chain with the events in the `event_map` *plus* their
# auth event IDs.
- events_to_auth_chain = {} # type: Dict[str, Set[str]]
+ events_to_auth_chain: Dict[str, Set[str]] = {}
for event in event_map.values():
chain = {event.event_id}
events_to_auth_chain[event.event_id] = chain
@@ -301,17 +301,17 @@ async def _get_auth_chain_difference(
# ((type, state_key)->event_id) mappings; and (b) we have stripped out
# unpersisted events and replaced them with the persisted events in
# their auth chain.
- state_sets_ids = [] # type: List[Set[str]]
+ state_sets_ids: List[Set[str]] = []
# For each state set, the unpersisted event IDs reachable (by their auth
# chain) from the events in that set.
- unpersisted_set_ids = [] # type: List[Set[str]]
+ unpersisted_set_ids: List[Set[str]] = []
for state_set in state_sets:
- set_ids = set() # type: Set[str]
+ set_ids: Set[str] = set()
state_sets_ids.append(set_ids)
- unpersisted_ids = set() # type: Set[str]
+ unpersisted_ids: Set[str] = set()
unpersisted_set_ids.append(unpersisted_ids)
for event_id in state_set.values():
@@ -334,7 +334,7 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
- difference_from_event_map = union - intersection # type: Collection[str]
+ difference_from_event_map: Collection[str] = union - intersection
else:
difference_from_event_map = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
@@ -458,7 +458,7 @@ async def _reverse_topological_power_sort(
The sorted list
"""
- graph = {} # type: Dict[str, Set[str]]
+ graph: Dict[str, Set[str]] = {}
for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
@@ -657,7 +657,7 @@ async def _get_mainline_depth_for_event(
"""
room_id = event.room_id
- tmp_event = event # type: Optional[EventBase]
+ tmp_event: Optional[EventBase] = event
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
@@ -767,7 +767,7 @@ def lexicographical_topological_sort(
# outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph
- reverse_graph = {} # type: Dict[str, Set[str]]
+ reverse_graph: Dict[str, Set[str]] = {}
# Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index c1f4d99e19..7f975a8f16 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -93,14 +93,12 @@ class BackgroundUpdater:
self.db_pool = database
# if a background update is currently running, its name.
- self._current_background_update = None # type: Optional[str]
-
- self._background_update_performance = (
- {}
- ) # type: Dict[str, BackgroundUpdatePerformance]
- self._background_update_handlers = (
- {}
- ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
+ self._current_background_update: Optional[str] = None
+
+ self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
+ self._background_update_handlers: Dict[
+ str, Callable[[JsonDict, int], Awaitable[int]]
+ ] = {}
self._all_done = False
def start_doing_background_updates(self) -> None:
@@ -412,7 +410,7 @@ class BackgroundUpdater:
c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine):
- runner = create_index_psql # type: Optional[Callable[[Connection], None]]
+ runner: Optional[Callable[[Connection], None]] = create_index_psql
elif psql_only:
runner = None
else:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 33c42cf95a..f80d822c12 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -670,8 +670,8 @@ class DatabasePool:
Returns:
The result of func
"""
- after_callbacks = [] # type: List[_CallbackListEntry]
- exception_callbacks = [] # type: List[_CallbackListEntry]
+ after_callbacks: List[_CallbackListEntry] = []
+ exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -1090,7 +1090,7 @@ class DatabasePool:
return False
# We didn't find any existing rows, so insert a new one
- allvalues = {} # type: Dict[str, Any]
+ allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
@@ -1121,7 +1121,7 @@ class DatabasePool:
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
"""
- allvalues = {} # type: Dict[str, Any]
+ allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
allvalues.update(insertion_values or {})
@@ -1257,7 +1257,7 @@ class DatabasePool:
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
- allnames = [] # type: List[str]
+ allnames: List[str] = []
allnames.extend(key_names)
allnames.extend(value_names)
@@ -1566,7 +1566,7 @@ class DatabasePool:
"""
keyvalues = keyvalues or {}
- results = [] # type: List[Dict[str, Any]]
+ results: List[Dict[str, Any]] = []
if not iterable:
return results
@@ -1978,7 +1978,7 @@ class DatabasePool:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
- arg_list = [] # type: List[Any]
+ arg_list: List[Any] = []
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 9f182c2a89..e2d1b758bd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -48,9 +48,7 @@ def _make_exclusive_regex(
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
- exclusive_user_pattern = re.compile(
- exclusive_user_regex
- ) # type: Optional[Pattern]
+ exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex)
else:
# We handle this case specially otherwise the constructed regex
# will always match
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 0e3dd4e9ca..78ae68ec68 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -247,7 +247,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
txn.execute(sql, query_params)
- result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+ result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index c4474df975..d39368c20e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -62,9 +62,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
# Cache of event ID to list of auth event IDs and their depths.
- self._event_auth_cache = LruCache(
+ self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
500000, "_event_auth_cache", size_callback=len
- ) # type: LruCache[str, List[Tuple[str, int]]]
+ )
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
@@ -137,10 +137,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
initial_events = set(event_ids)
# All the events that we've found that are reachable from the events.
- seen_events = set() # type: Set[str]
+ seen_events: Set[str] = set()
# A map from chain ID to max sequence number of the given events.
- event_chains = {} # type: Dict[int, int]
+ event_chains: Dict[int, int] = {}
sql = """
SELECT event_id, chain_id, sequence_number
@@ -182,7 +182,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
# A map from chain ID to max sequence number *reachable* from any event ID.
- chains = {} # type: Dict[int, int]
+ chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000):
@@ -353,14 +353,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
initial_events = set(state_sets[0]).union(*state_sets[1:])
# Map from event_id -> (chain ID, seq no)
- chain_info = {} # type: Dict[str, Tuple[int, int]]
+ chain_info: Dict[str, Tuple[int, int]] = {}
# Map from chain ID -> seq no -> event Id
- chain_to_event = {} # type: Dict[int, Dict[int, str]]
+ chain_to_event: Dict[int, Dict[int, str]] = {}
# All the chains that we've found that are reachable from the state
# sets.
- seen_chains = set() # type: Set[int]
+ seen_chains: Set[int] = set()
sql = """
SELECT event_id, chain_id, sequence_number
@@ -392,9 +392,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
- set_to_chain = [] # type: List[Dict[int, int]]
+ set_to_chain: List[Dict[int, int]] = []
for state_set in state_sets:
- chains = {} # type: Dict[int, int]
+ chains: Dict[int, int] = {}
set_to_chain.append(chains)
for event_id in state_set:
@@ -446,7 +446,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
- chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
+ chain_to_gap: Dict[int, Tuple[int, int]] = {}
for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
@@ -555,7 +555,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
}
# The sorted list of events whose auth chains we should walk.
- search = [] # type: List[Tuple[int, str]]
+ search: List[Tuple[int, str]] = []
# We need to get the depth of the initial events for sorting purposes.
sql = """
@@ -578,7 +578,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
search.sort()
# Map from event to its auth events
- event_to_auth_events = {} # type: Dict[str, Set[str]]
+ event_to_auth_events: Dict[str, Set[str]] = {}
base_sql = """
SELECT a.event_id, auth_id, depth
@@ -1230,7 +1230,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging"
)
- (age,) = txn.fetchone()
+ (received_ts,) = txn.fetchone()
+
+ age = self._clock.time_msec() - received_ts
return count, age
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d1237c65cc..55caa6bbe7 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -759,7 +759,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
- summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
+ summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 897fa06639..ec8579b9ad 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -109,10 +109,8 @@ class PersistEventsStore:
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen = (
- self.store._backfill_id_gen
- ) # type: MultiWriterIdGenerator
- self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
+ self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
+ self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
# This should only exist on instances that are configured to write
assert (
@@ -221,7 +219,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
- results = [] # type: List[str]
+ results: List[str] = []
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
@@ -508,7 +506,7 @@ class PersistEventsStore:
"""
# Map from event ID to chain ID/sequence number.
- chain_map = {} # type: Dict[str, Tuple[int, int]]
+ chain_map: Dict[str, Tuple[int, int]] = {}
# Set of event IDs to calculate chain ID/seq numbers for.
events_to_calc_chain_id_for = set(event_to_room_id)
@@ -817,8 +815,8 @@ class PersistEventsStore:
# new chain if the sequence number has already been allocated.
#
- existing_chains = set() # type: Set[int]
- tree = [] # type: List[Tuple[str, Optional[str]]]
+ existing_chains: Set[int] = set()
+ tree: List[Tuple[str, Optional[str]]] = []
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before
@@ -848,7 +846,7 @@ class PersistEventsStore:
)
txn.execute(sql % (clause,), args)
- chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
+ chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}
# Allocate the new events chain ID/sequence numbers.
#
@@ -858,8 +856,8 @@ class PersistEventsStore:
# number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs.
- unallocated_chain_ids = set() # type: Set[object]
- new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
+ unallocated_chain_ids: Set[object] = set()
+ new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated
@@ -870,7 +868,7 @@ class PersistEventsStore:
if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id]
- new_chain_tuple = None # type: Optional[Tuple[Any, int]]
+ new_chain_tuple: Optional[Tuple[Any, int]] = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
@@ -897,9 +895,9 @@ class PersistEventsStore:
)
# Map from potentially temporary chain ID to real chain ID
- chain_id_to_allocated_map = dict(
+ chain_id_to_allocated_map: Dict[Any, int] = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids)
- ) # type: Dict[Any, int]
+ )
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
return {
@@ -1175,9 +1173,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
- new_events_and_contexts = (
- OrderedDict()
- ) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
+ new_events_and_contexts: OrderedDict[
+ str, Tuple[EventBase, EventContext]
+ ] = OrderedDict()
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
@@ -1205,7 +1203,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
- depth_updates = {} # type: Dict[str, int]
+ depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1580,11 +1578,11 @@ class PersistEventsStore:
# invalidate the cache for the redacted event
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
- self.db_pool.simple_insert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="redactions",
+ keyvalues={"event_id": event.event_id},
values={
- "event_id": event.event_id,
"redacts": event.redacts,
"received_ts": self._clock.time_msec(),
},
@@ -1885,7 +1883,7 @@ class PersistEventsStore:
),
)
- room_to_event_ids = {} # type: Dict[str, List[str]]
+ room_to_event_ids: Dict[str, List[str]] = {}
for e, _ in events_and_contexts:
room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
@@ -2012,7 +2010,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- events_by_room = {} # type: Dict[str, List[EventBase]]
+ events_by_room: Dict[str, List[EventBase]] = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 29f33bac55..6fcb2b8353 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -960,9 +960,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
# Calculate the new last position we've processed up to.
- new_last_depth = rows[-1][3] if rows else last_depth # type: int
- new_last_stream = rows[-1][4] if rows else last_stream # type: int
- new_last_room_id = rows[-1][5] if rows else "" # type: str
+ new_last_depth: int = rows[-1][3] if rows else last_depth
+ new_last_stream: int = rows[-1][4] if rows else last_stream
+ new_last_room_id: str = rows[-1][5] if rows else ""
# Map from room_id to last depth/stream_ordering processed for the room,
# excluding the last room (which we're likely still processing). We also
@@ -989,7 +989,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
retcols=("event_id", "auth_id"),
)
- event_to_auth_chain = {} # type: Dict[str, List[str]]
+ event_to_auth_chain: Dict[str, List[str]] = {}
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 403a5ddaba..3c86adab56 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1365,10 +1365,10 @@ class EventsWorkerStore(SQLBaseStore):
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows = await self.db_pool.runInteraction(
+ rows: List[Tuple] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
- ) # type: List[Tuple]
+ )
# if we've got fewer rows than the limit, we're good
if len(rows) < target_row_count:
@@ -1469,7 +1469,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
mapping = {}
- txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
+ txn_id_to_event: Dict[Tuple[str, int, str], str] = {}
for event in events:
token_id = getattr(event.internal_metadata, "token_id", None)
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index c3f551d377..e3a544d9b2 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -320,7 +320,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
Returns millisecond unixtime for start of UTC day.
"""
- now = time.gmtime()
+ now = time.gmtime(self._clock.time())
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
@@ -352,7 +352,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) udv
ON u.user_id = udv.user_id AND u.device_id=udv.device_id
INNER JOIN users ON users.name=u.user_id
- WHERE last_seen > ? AND last_seen <= ?
+ WHERE ? <= last_seen AND last_seen < ?
AND udv.timestamp IS NULL AND users.is_guest=0
AND users.appservice_id IS NULL
GROUP BY u.user_id, u.device_id
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 7fb7780d0f..664c65dac5 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -115,7 +115,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
logger.info("[purge] looking for events to delete")
should_delete_expr = "state_key IS NULL"
- should_delete_params = () # type: Tuple[Any, ...]
+ should_delete_params: Tuple[Any, ...] = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
@@ -215,6 +215,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_relations",
"event_search",
"rejections",
+ "redactions",
):
logger.info("[purge] removing events from %s", table)
@@ -392,7 +393,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"room_memberships",
"room_stats_state",
"room_stats_current",
- "room_stats_historical",
"room_stats_earliest_token",
"rooms",
"stream_ordering_to_exterm",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index db52176337..a7fb8cd848 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -79,9 +79,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen = StreamIdGenerator(
- db_conn, "push_rules_stream", "stream_id"
- ) # type: Union[StreamIdGenerator, SlavedIdTracker]
+ self._push_rules_stream_id_gen: Union[
+ StreamIdGenerator, SlavedIdTracker
+ ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e31c5864ac..6ad1a0cf7f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1744,7 +1744,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
- values = [v for _, v in items] # type: List[Union[str, int]]
+ values: List[Union[str, int]] = [v for _, v in items]
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
# clause and values before we handle that. This seems to be only used in the "set password" handler.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9f0d64a325..6ddafe5434 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -25,6 +25,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.types import Cursor
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -1022,10 +1023,22 @@ class RoomWorkerStore(SQLBaseStore):
)
-class RoomBackgroundUpdateStore(SQLBaseStore):
+class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
+ POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
+ REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
+
+
+_REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
+ "DROP TRIGGER populate_min_depth2_trigger ON room_depth",
+ "DROP FUNCTION populate_min_depth2()",
+ "ALTER TABLE room_depth DROP COLUMN min_depth",
+ "ALTER TABLE room_depth RENAME COLUMN min_depth2 TO min_depth",
+)
+
+class RoomBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -1037,15 +1050,25 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
)
self.db_pool.updates.register_background_update_handler(
- self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
+ _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
self._remove_tombstoned_rooms_from_directory,
)
self.db_pool.updates.register_background_update_handler(
- self.ADD_ROOMS_ROOM_VERSION_COLUMN,
+ _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN,
self._background_add_rooms_room_version_column,
)
+ # BG updates to change the type of room_depth.min_depth
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
+ self._background_populate_room_depth_min_depth2,
+ )
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH,
+ self._background_replace_room_depth_min_depth,
+ )
+
async def _background_insert_retention(self, progress, batch_size):
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
@@ -1164,7 +1187,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
new_last_room_id = room_id
self.db_pool.updates._background_update_progress_txn(
- txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id}
+ txn,
+ _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN,
+ {"room_id": new_last_room_id},
)
return False
@@ -1176,7 +1201,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
if end:
await self.db_pool.updates._end_background_update(
- self.ADD_ROOMS_ROOM_VERSION_COLUMN
+ _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN
)
return batch_size
@@ -1215,7 +1240,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
if not rooms:
await self.db_pool.updates._end_background_update(
- self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
+ _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
)
return 0
@@ -1224,7 +1249,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
await self.set_room_is_public(room_id, False)
await self.db_pool.updates._background_update_progress(
- self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
+ _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
)
return len(rooms)
@@ -1268,6 +1293,71 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return max_ordering is None
+ async def _background_populate_room_depth_min_depth2(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Populate room_depth.min_depth2
+
+ This is to deal with the fact that min_depth was initially created as a
+ 32-bit integer field.
+ """
+
+ def process(txn: Cursor) -> int:
+ last_room = progress.get("last_room", "")
+ txn.execute(
+ """
+ UPDATE room_depth SET min_depth2=min_depth
+ WHERE room_id IN (
+ SELECT room_id FROM room_depth WHERE room_id > ?
+ ORDER BY room_id LIMIT ?
+ )
+ RETURNING room_id;
+ """,
+ (last_room, batch_size),
+ )
+ row_count = txn.rowcount
+ if row_count == 0:
+ return 0
+ last_room = max(row[0] for row in txn)
+ logger.info("populated room_depth up to %s", last_room)
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
+ {"last_room": last_room},
+ )
+ return row_count
+
+ result = await self.db_pool.runInteraction(
+ "_background_populate_min_depth2", process
+ )
+
+ if result != 0:
+ return result
+
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2
+ )
+ return 0
+
+ async def _background_replace_room_depth_min_depth(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Drop the old 'min_depth' column and rename 'min_depth2' into its place."""
+
+ def process(txn: Cursor) -> None:
+ for sql in _REPLACE_ROOM_DEPTH_SQL_COMMANDS:
+ logger.info("completing room_depth migration: %s", sql)
+ txn.execute(sql)
+
+ await self.db_pool.runInteraction("_background_replace_room_depth", process)
+
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH,
+ )
+
+ return 0
+
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2796354a1f..4d82c4c26d 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -703,13 +703,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
+ return await self._check_host_room_membership(room_id, host, Membership.JOIN)
+
+ @cached(max_entries=10000)
+ async def is_host_invited(self, room_id: str, host: str) -> bool:
+ return await self._check_host_room_membership(room_id, host, Membership.INVITE)
+
+ async def _check_host_room_membership(
+ self, room_id: str, host: str, membership: str
+ ) -> bool:
if "%" in host or "_" in host:
raise Exception("Invalid host name")
sql = """
SELECT state_key FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (event_id)
- WHERE m.membership = 'join'
+ WHERE m.membership = ?
AND type = 'm.room.member'
AND c.room_id = ?
AND state_key LIKE ?
@@ -722,7 +731,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
- "is_host_joined", None, sql, room_id, like_clause
+ "is_host_joined", None, sql, membership, room_id, like_clause
)
if not rows:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 82a1833509..59d67c255b 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -26,7 +26,6 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state_deltas import StateDeltasStore
-from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -49,14 +48,6 @@ ABSOLUTE_STATS_FIELDS = {
"user": ("joined_rooms",),
}
-# these fields are per-timeslice and so should be reset to 0 upon a new slice
-# You can draw these stats on a histogram.
-# Example: number of events sent locally during a time slice
-PER_SLICE_FIELDS = {
- "room": ("total_events", "total_event_bytes"),
- "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"),
-}
-
TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
# these are the tables (& ID columns) which contain our actual subjects
@@ -106,7 +97,6 @@ class StatsStore(StateDeltasStore):
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats_enabled
- self.stats_bucket_size = hs.config.stats_bucket_size
self.stats_delta_processing_lock = DeferredLock()
@@ -122,22 +112,6 @@ class StatsStore(StateDeltasStore):
self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
- def quantise_stats_time(self, ts):
- """
- Quantises a timestamp to be a multiple of the bucket size.
-
- Args:
- ts (int): the timestamp to quantise, in milliseconds since the Unix
- Epoch
-
- Returns:
- int: a timestamp which
- - is divisible by the bucket size;
- - is no later than `ts`; and
- - is the largest such timestamp.
- """
- return (ts // self.stats_bucket_size) * self.stats_bucket_size
-
async def _populate_stats_process_users(self, progress, batch_size):
"""
This is a background update which regenerates statistics for users.
@@ -288,56 +262,6 @@ class StatsStore(StateDeltasStore):
desc="update_room_state",
)
- async def get_statistics_for_subject(
- self, stats_type: str, stats_id: str, start: str, size: int = 100
- ) -> List[dict]:
- """
- Get statistics for a given subject.
-
- Args:
- stats_type: The type of subject
- stats_id: The ID of the subject (e.g. room_id or user_id)
- start: Pagination start. Number of entries, not timestamp.
- size: How many entries to return.
-
- Returns:
- A list of dicts, where the dict has the keys of
- ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
- """
- return await self.db_pool.runInteraction(
- "get_statistics_for_subject",
- self._get_statistics_for_subject_txn,
- stats_type,
- stats_id,
- start,
- size,
- )
-
- def _get_statistics_for_subject_txn(
- self, txn, stats_type, stats_id, start, size=100
- ):
- """
- Transaction-bound version of L{get_statistics_for_subject}.
- """
-
- table, id_col = TYPE_TO_TABLE[stats_type]
- selected_columns = list(
- ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
- )
-
- slice_list = self.db_pool.simple_select_list_paginate_txn(
- txn,
- table + "_historical",
- "end_ts",
- start,
- size,
- retcols=selected_columns + ["bucket_size", "end_ts"],
- keyvalues={id_col: stats_id},
- order_direction="DESC",
- )
-
- return slice_list
-
@cached()
async def get_earliest_token_for_stats(
self, stats_type: str, id: str
@@ -451,14 +375,10 @@ class StatsStore(StateDeltasStore):
table, id_col = TYPE_TO_TABLE[stats_type]
- quantised_ts = self.quantise_stats_time(int(ts))
- end_ts = quantised_ts + self.stats_bucket_size
-
# Lets be paranoid and check that all the given field names are known
abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type]
- slice_field_names = PER_SLICE_FIELDS[stats_type]
for field in chain(fields.keys(), absolute_field_overrides.keys()):
- if field not in abs_field_names and field not in slice_field_names:
+ if field not in abs_field_names:
# guard against potential SQL injection dodginess
raise ValueError(
"%s is not a recognised field"
@@ -491,20 +411,6 @@ class StatsStore(StateDeltasStore):
additive_relatives=deltas_of_absolute_fields,
)
- per_slice_additive_relatives = {
- key: fields.get(key, 0) for key in slice_field_names
- }
- self._upsert_copy_from_table_with_additive_relatives_txn(
- txn=txn,
- into_table=table + "_historical",
- keyvalues={id_col: stats_id},
- extra_dst_insvalues={"bucket_size": self.stats_bucket_size},
- extra_dst_keyvalues={"end_ts": end_ts},
- additive_relatives=per_slice_additive_relatives,
- src_table=table + "_current",
- copy_columns=abs_field_names,
- )
-
def _upsert_with_additive_relatives_txn(
self, txn, table, keyvalues, absolutes, additive_relatives
):
@@ -528,7 +434,7 @@ class StatsStore(StateDeltasStore):
]
relative_updates = [
- "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s"
+ "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)"
% {"table": table, "field": field}
for field in additive_relatives.keys()
]
@@ -568,205 +474,13 @@ class StatsStore(StateDeltasStore):
self.db_pool.simple_insert_txn(txn, table, merged_dict)
else:
for (key, val) in additive_relatives.items():
- current_row[key] += val
+ if current_row[key] is None:
+ current_row[key] = val
+ else:
+ current_row[key] += val
current_row.update(absolutes)
self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
- def _upsert_copy_from_table_with_additive_relatives_txn(
- self,
- txn,
- into_table,
- keyvalues,
- extra_dst_keyvalues,
- extra_dst_insvalues,
- additive_relatives,
- src_table,
- copy_columns,
- ):
- """Updates the historic stats table with latest updates.
-
- This involves copying "absolute" fields from the `_current` table, and
- adding relative fields to any existing values.
-
- Args:
- txn: Transaction
- into_table (str): The destination table to UPSERT the row into
- keyvalues (dict[str, any]): Row-identifying key values
- extra_dst_keyvalues (dict[str, any]): Additional keyvalues
- for `into_table`.
- extra_dst_insvalues (dict[str, any]): Additional values to insert
- on new row creation for `into_table`.
- additive_relatives (dict[str, any]): Fields that will be added onto
- if existing row present. (Must be disjoint from copy_columns.)
- src_table (str): The source table to copy from
- copy_columns (iterable[str]): The list of columns to copy
- """
- if self.database_engine.can_native_upsert:
- ins_columns = chain(
- keyvalues,
- copy_columns,
- additive_relatives,
- extra_dst_keyvalues,
- extra_dst_insvalues,
- )
- sel_exprs = chain(
- keyvalues,
- copy_columns,
- (
- "?"
- for _ in chain(
- additive_relatives, extra_dst_keyvalues, extra_dst_insvalues
- )
- ),
- )
- keyvalues_where = ("%s = ?" % f for f in keyvalues)
-
- sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns)
- sets_ar = (
- "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f)
- for f in additive_relatives
- )
-
- sql = """
- INSERT INTO %(into_table)s (%(ins_columns)s)
- SELECT %(sel_exprs)s
- FROM %(src_table)s
- WHERE %(keyvalues_where)s
- ON CONFLICT (%(keyvalues)s)
- DO UPDATE SET %(sets)s
- """ % {
- "into_table": into_table,
- "ins_columns": ", ".join(ins_columns),
- "sel_exprs": ", ".join(sel_exprs),
- "keyvalues_where": " AND ".join(keyvalues_where),
- "src_table": src_table,
- "keyvalues": ", ".join(
- chain(keyvalues.keys(), extra_dst_keyvalues.keys())
- ),
- "sets": ", ".join(chain(sets_cc, sets_ar)),
- }
-
- qargs = list(
- chain(
- additive_relatives.values(),
- extra_dst_keyvalues.values(),
- extra_dst_insvalues.values(),
- keyvalues.values(),
- )
- )
- txn.execute(sql, qargs)
- else:
- self.database_engine.lock_table(txn, into_table)
- src_row = self.db_pool.simple_select_one_txn(
- txn, src_table, keyvalues, copy_columns
- )
- all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
- dest_current_row = self.db_pool.simple_select_one_txn(
- txn,
- into_table,
- keyvalues=all_dest_keyvalues,
- retcols=list(chain(additive_relatives.keys(), copy_columns)),
- allow_none=True,
- )
-
- if dest_current_row is None:
- merged_dict = {
- **keyvalues,
- **extra_dst_keyvalues,
- **extra_dst_insvalues,
- **src_row,
- **additive_relatives,
- }
- self.db_pool.simple_insert_txn(txn, into_table, merged_dict)
- else:
- for (key, val) in additive_relatives.items():
- src_row[key] = dest_current_row[key] + val
- self.db_pool.simple_update_txn(
- txn, into_table, all_dest_keyvalues, src_row
- )
-
- async def get_changes_room_total_events_and_bytes(
- self, min_pos: int, max_pos: int
- ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
- """Fetches the counts of events in the given range of stream IDs.
-
- Args:
- min_pos
- max_pos
-
- Returns:
- Mapping of room ID to field changes.
- """
-
- return await self.db_pool.runInteraction(
- "stats_incremental_total_events_and_bytes",
- self.get_changes_room_total_events_and_bytes_txn,
- min_pos,
- max_pos,
- )
-
- def get_changes_room_total_events_and_bytes_txn(
- self, txn, low_pos: int, high_pos: int
- ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
- """Gets the total_events and total_event_bytes counts for rooms and
- senders, in a range of stream_orderings (including backfilled events).
-
- Args:
- txn
- low_pos: Low stream ordering
- high_pos: High stream ordering
-
- Returns:
- The room and user deltas for total_events/total_event_bytes in the
- format of `stats_id` -> fields
- """
-
- if low_pos >= high_pos:
- # nothing to do here.
- return {}, {}
-
- if isinstance(self.database_engine, PostgresEngine):
- new_bytes_expression = "OCTET_LENGTH(json)"
- else:
- new_bytes_expression = "LENGTH(CAST(json AS BLOB))"
-
- sql = """
- SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes
- FROM events INNER JOIN event_json USING (event_id)
- WHERE (? < stream_ordering AND stream_ordering <= ?)
- OR (? <= stream_ordering AND stream_ordering <= ?)
- GROUP BY events.room_id
- """ % (
- new_bytes_expression,
- )
-
- txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos))
-
- room_deltas = {
- room_id: {"total_events": new_events, "total_event_bytes": new_bytes}
- for room_id, new_events, new_bytes in txn
- }
-
- sql = """
- SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes
- FROM events INNER JOIN event_json USING (event_id)
- WHERE (? < stream_ordering AND stream_ordering <= ?)
- OR (? <= stream_ordering AND stream_ordering <= ?)
- GROUP BY events.sender
- """ % (
- new_bytes_expression,
- )
-
- txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos))
-
- user_deltas = {
- user_id: {"total_events": new_events, "total_event_bytes": new_bytes}
- for user_id, new_events, new_bytes in txn
- if self.hs.is_mine_id(user_id)
- }
-
- return room_deltas, user_deltas
-
async def _calculate_and_set_initial_state_for_room(
self, room_id: str
) -> Tuple[dict, dict, int]:
@@ -893,6 +607,7 @@ class StatsStore(StateDeltasStore):
"invited_members": membership_counts.get(Membership.INVITE, 0),
"left_members": membership_counts.get(Membership.LEAVE, 0),
"banned_members": membership_counts.get(Membership.BAN, 0),
+ "knocked_members": membership_counts.get(Membership.KNOCK, 0),
"local_users_in_room": len(local_users_in_room),
},
)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7581c7d3ff..959f13de47 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1085,9 +1085,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
# stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
# then filtering the results.
if from_token.topological is not None:
- from_bound = (
- from_token.as_historical_tuple()
- ) # type: Tuple[Optional[int], int]
+ from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
elif direction == "b":
from_bound = (
None,
@@ -1099,7 +1097,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
from_token.stream,
)
- to_bound = None # type: Optional[Tuple[Optional[int], int]]
+ to_bound: Optional[Tuple[Optional[int], int]] = None
if to_token:
if to_token.topological is not None:
to_bound = to_token.as_historical_tuple()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 1d62c6140f..f93ff0a545 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -42,7 +42,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
- tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
+ tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 22c05cdde7..38bfdf5dad 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -224,12 +224,12 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
):
# Get the current value.
- result = self.db_pool.simple_select_one_txn(
+ result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
- ) # type: Dict[str, Any] # type: ignore
+ )
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 051095fea9..a39877f0d5 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -307,7 +307,7 @@ class EventsPersistenceStorage:
matched the transcation ID; the existing event is returned in such
a case.
"""
- partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
+ partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
@@ -384,7 +384,7 @@ class EventsPersistenceStorage:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
"""
- replaced_events = {} # type: Dict[str, str]
+ replaced_events: Dict[str, str] = {}
if not events_and_contexts:
return replaced_events
@@ -440,16 +440,14 @@ class EventsPersistenceStorage:
# Set of remote users which were in rooms the server has left. We
# should check if we still share any rooms and if not we mark their
# device lists as stale.
- potentially_left_users = set() # type: Set[str]
+ potentially_left_users: Set[str] = set()
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
- events_by_room = (
- {}
- ) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
+ events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
@@ -622,9 +620,9 @@ class EventsPersistenceStorage:
)
# Remove any events which are prev_events of any existing events.
- existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
- result
- ) # type: Collection[str]
+ existing_prevs: Collection[
+ str
+ ] = await self.persist_events_store._get_events_which_are_prevs(result)
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 683e5e3b90..82a7686df0 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -256,7 +256,7 @@ def _setup_new_database(
for database in databases
)
- directory_entries = [] # type: List[_DirectoryListing]
+ directory_entries: List[_DirectoryListing] = []
for directory in directories:
directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -424,10 +424,10 @@ def _upgrade_existing_database(
directories.append(os.path.join(schema_path, database, "delta", str(v)))
# Used to check if we have any duplicate file names
- file_name_counter = Counter() # type: CounterType[str]
+ file_name_counter: CounterType[str] = Counter()
# Now find which directories have anything of interest.
- directory_entries = [] # type: List[_DirectoryListing]
+ directory_entries: List[_DirectoryListing] = []
for directory in directories:
logger.debug("Looking for schema deltas in %s", directory)
try:
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 0a53b73ccc..36340a652a 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 60
+SCHEMA_VERSION = 61
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -21,6 +21,10 @@ older versions of Synapse).
See `README.md <synapse/storage/schema/README.md>`_ for more information on how this
works.
+
+Changes in SCHEMA_VERSION = 61:
+ - The `user_stats_historical` and `room_stats_historical` tables are not written and
+ are not read (previously, they were written but not read).
"""
diff --git a/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres b/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres
new file mode 100644
index 0000000000..c8aec78e60
--- /dev/null
+++ b/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres
@@ -0,0 +1,23 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- we use bigint elsewhere in the database for appservice txn ids (notably,
+-- application_services_state.last_txn), and generally we use bigints everywhere else
+-- we have monotonic counters, so let's bring this one in line.
+--
+-- assuming there aren't thousands of rows for decommisioned/non-functional ASes, this
+-- table should be pretty small, so safe to do a synchronous ALTER TABLE.
+
+ALTER TABLE application_services_txns ALTER COLUMN txn_id SET DATA TYPE BIGINT;
diff --git a/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql b/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql
new file mode 100644
index 0000000000..35ca7a40c0
--- /dev/null
+++ b/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this index is redundant; there is another UNIQUE index on this table.
+DROP INDEX IF EXISTS room_depth_room;
+
diff --git a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py
new file mode 100644
index 0000000000..f8d7db9f2e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py
@@ -0,0 +1,70 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration handles the process of changing the type of `room_depth.min_depth` to
+a BIGINT.
+"""
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+ if not isinstance(database_engine, PostgresEngine):
+ # this only applies to postgres - sqlite does not distinguish between big and
+ # little ints.
+ return
+
+ # First add a new column to contain the bigger min_depth
+ cur.execute("ALTER TABLE room_depth ADD COLUMN min_depth2 BIGINT")
+
+ # Create a trigger which will keep it populated.
+ cur.execute(
+ """
+ CREATE OR REPLACE FUNCTION populate_min_depth2() RETURNS trigger AS $BODY$
+ BEGIN
+ new.min_depth2 := new.min_depth;
+ RETURN NEW;
+ END;
+ $BODY$ LANGUAGE plpgsql
+ """
+ )
+
+ cur.execute(
+ """
+ CREATE TRIGGER populate_min_depth2_trigger BEFORE INSERT OR UPDATE ON room_depth
+ FOR EACH ROW
+ EXECUTE PROCEDURE populate_min_depth2()
+ """
+ )
+
+ # Start a bg process to populate it for old rooms
+ cur.execute(
+ """
+ INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6103, 'populate_room_depth_min_depth2', '{}')
+ """
+ )
+
+ # and another to switch them over once it completes.
+ cur.execute(
+ """
+ INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (6103, 'replace_room_depth_min_depth', '{}', 'populate_room_depth2')
+ """
+ )
+
+
+def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres b/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres
new file mode 100644
index 0000000000..35a153da7b
--- /dev/null
+++ b/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- By default the postgres statistics collector massively underestimates the
+-- number of distinct state groups are in the `state_groups_state`, which can
+-- cause postgres to use table scans for queries for multiple state groups.
+--
+-- To work around this we can manually tell postgres the number of distinct state
+-- groups there are by setting `n_distinct` (a negative value here is the number
+-- of distinct values divided by the number of rows, so -0.02 means on average
+-- there are 50 rows per distinct value). We don't need a particularly
+-- accurate number here, as a) we just want it to always use index scans and b)
+-- our estimate is going to be better than the one made by the statistics
+-- collector.
+
+ALTER TABLE state_groups_state ALTER COLUMN state_group SET (n_distinct = -0.02);
+
+-- Ideally we'd do an `ANALYZE state_groups_state (state_group)` here so that
+-- the above gets picked up immediately, but that can take a bit of time so we
+-- rely on the autovacuum eventually getting run and doing that in the
+-- background for us.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c9dce726cb..f8fbba9d38 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -91,7 +91,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- type_dict = {} # type: Dict[str, Optional[Set[str]]]
+ type_dict: Dict[str, Optional[Set[str]]] = {}
for typ, s in types:
if typ in type_dict:
if type_dict[typ] is None:
@@ -194,7 +194,7 @@ class StateFilter:
"""
where_clause = ""
- where_args = [] # type: List[str]
+ where_args: List[str] = []
if self.is_full():
return where_clause, where_args
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f1e62f9e85..c768fdea56 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -112,7 +112,7 @@ class StreamIdGenerator:
# insertion ordering will ensure its in the correct ordering.
#
# The key and values are the same, but we never look at the values.
- self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
+ self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def get_next(self):
"""
@@ -236,15 +236,15 @@ class MultiWriterIdGenerator:
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
- self._current_positions = {} # type: Dict[str, int]
+ self._current_positions: Dict[str, int] = {}
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
- self._unfinished_ids = set() # type: Set[int]
+ self._unfinished_ids: Set[int] = set()
# Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs.
- self._finished_ids = set() # type: Set[int]
+ self._finished_ids: Set[int] = set()
# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
@@ -265,7 +265,7 @@ class MultiWriterIdGenerator:
self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 1
)
- self._known_persisted_positions = [] # type: List[int]
+ self._known_persisted_positions: List[int] = []
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
@@ -465,7 +465,7 @@ class MultiWriterIdGenerator:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
- new_cur = None # type: Optional[int]
+ new_cur: Optional[int] = None
if self._unfinished_ids:
# If there are unfinished IDs then the new position will be the
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 30b6b8e0ca..bb33e04fb1 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -208,10 +208,10 @@ class LocalSequenceGenerator(SequenceGenerator):
get_next_id_txn; should return the curreent maximum id
"""
# the callback. this is cleared after it is called, so that it can be GCed.
- self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
+ self._callback: Optional[GetFirstCallbackType] = get_first_callback
# The current max value, or None if we haven't looked in the DB yet.
- self._current_max_id = None # type: Optional[int]
+ self._current_max_id: Optional[int] = None
self._lock = threading.Lock()
def get_next_id_txn(self, txn: Cursor) -> int:
@@ -274,7 +274,7 @@ def build_sequence_generator(
`check_consistency` details.
"""
if isinstance(database_engine, PostgresEngine):
- seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
+ seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name)
else:
seq = LocalSequenceGenerator(get_first_callback)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 20fceaa935..99b0aac2fb 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -32,9 +32,9 @@ class EventSources:
}
def __init__(self, hs):
- self.sources = {
+ self.sources: Dict[str, Any] = {
name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
- } # type: Dict[str, Any]
+ }
self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken:
diff --git a/synapse/types.py b/synapse/types.py
index 8d2fa00f71..fad23c8700 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -182,14 +182,14 @@ def create_requester(
)
-def get_domain_from_id(string):
+def get_domain_from_id(string: str) -> str:
idx = string.find(":")
if idx == -1:
raise SynapseError(400, "Invalid ID: %r" % (string,))
return string[idx + 1 :]
-def get_localpart_from_id(string):
+def get_localpart_from_id(string: str) -> str:
idx = string.find(":")
if idx == -1:
raise SynapseError(400, "Invalid ID: %r" % (string,))
@@ -210,7 +210,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
'domain' : The domain part of the name
"""
- SIGIL = abc.abstractproperty() # type: str # type: ignore
+ SIGIL: str = abc.abstractproperty() # type: ignore
localpart = attr.ib(type=str)
domain = attr.ib(type=str)
@@ -304,7 +304,7 @@ class GroupID(DomainSpecificString):
@classmethod
def from_string(cls: Type[DS], s: str) -> DS:
- group_id = super().from_string(s) # type: DS # type: ignore
+ group_id: DS = super().from_string(s) # type: ignore
if not group_id.localpart:
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
@@ -600,7 +600,7 @@ class StreamToken:
groups_key = attr.ib(type=int)
_SEPARATOR = "_"
- START = None # type: StreamToken
+ START: "StreamToken"
@classmethod
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 061102c3c8..014db1355b 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -257,7 +257,7 @@ class Linearizer:
max_count: The maximum number of concurrent accesses
"""
if name is None:
- self.name = id(self) # type: Union[str, int]
+ self.name: Union[str, int] = id(self)
else:
self.name = name
@@ -269,7 +269,7 @@ class Linearizer:
self.max_count = max_count
# key_to_defer is a map from the key to a _LinearizerEntry.
- self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry]
+ self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {}
def is_queued(self, key: Hashable) -> bool:
"""Checks whether there is a process queued up waiting"""
@@ -409,10 +409,10 @@ class ReadWriteLock:
def __init__(self):
# Latest readers queued
- self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
+ self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
# Latest writer queued
- self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
+ self.key_to_current_writer: Dict[str, defer.Deferred] = {}
async def read(self, key: str) -> ContextManager:
new_defer = defer.Deferred()
diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py
index 8fd5bfb69b..274cea7eb7 100644
--- a/synapse/util/batching_queue.py
+++ b/synapse/util/batching_queue.py
@@ -93,11 +93,11 @@ class BatchingQueue(Generic[V, R]):
self._clock = clock
# The set of keys currently being processed.
- self._processing_keys = set() # type: Set[Hashable]
+ self._processing_keys: Set[Hashable] = set()
# The currently pending batch of values by key, with a Deferred to call
# with the result of the corresponding `_process_batch_callback` call.
- self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
+ self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {}
# The function to call with batches of values.
self._process_batch_callback = process_batch_callback
@@ -108,9 +108,7 @@ class BatchingQueue(Generic[V, R]):
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
- self._number_in_flight_metric = number_in_flight.labels(
- self._name
- ) # type: Gauge
+ self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name)
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
"""Adds the value to the queue with the given key, returning the result
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index ca36f07c20..9012034b7a 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
TRACK_MEMORY_USAGE = False
-caches_by_name = {} # type: Dict[str, Sized]
-collectors_by_name = {} # type: Dict[str, CacheMetric]
+caches_by_name: Dict[str, Sized] = {}
+collectors_by_name: Dict[str, "CacheMetric"] = {}
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index a301c9e89b..891bee0b33 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -63,9 +63,9 @@ class CachedCall(Generic[TV]):
f: The underlying function. Only one call to this function will be alive
at once (per instance of CachedCall)
"""
- self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
- self._deferred = None # type: Optional[Deferred]
- self._result = None # type: Union[None, Failure, TV]
+ self._callable: Optional[Callable[[], Awaitable[TV]]] = f
+ self._deferred: Optional[Deferred] = None
+ self._result: Union[None, Failure, TV] = None
async def get(self) -> TV:
"""Kick off the call if necessary, and return the result"""
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1044139119..8c6fafc677 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -80,25 +80,25 @@ class DeferredCache(Generic[KT, VT]):
cache_type = TreeCache if tree else dict
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
- self._pending_deferred_cache = (
- cache_type()
- ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
+ self._pending_deferred_cache: Union[
+ TreeCache, "MutableMapping[KT, CacheEntry]"
+ ] = cache_type()
def metrics_cb():
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
- self.cache = LruCache(
+ self.cache: LruCache[KT, VT] = LruCache(
max_size=max_entries,
cache_name=name,
cache_type=cache_type,
size_callback=(lambda d: len(d) or 1) if iterable else None,
metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
- ) # type: LruCache[KT, VT]
+ )
- self.thread = None # type: Optional[threading.Thread]
+ self.thread: Optional[threading.Thread] = None
@property
def max_entries(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index d77e8edeea..1e8e6b1d01 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -46,17 +46,17 @@ F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Generic[F]):
- invalidate = None # type: Any
- invalidate_all = None # type: Any
- prefill = None # type: Any
- cache = None # type: Any
- num_args = None # type: Any
+ invalidate: Any = None
+ invalidate_all: Any = None
+ prefill: Any = None
+ cache: Any = None
+ num_args: Any = None
- __name__ = None # type: str
+ __name__: str
# Note: This function signature is actually fiddled with by the synapse mypy
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
- __call__ = None # type: F
+ __call__: F
class _CacheDescriptorBase:
@@ -115,8 +115,8 @@ class _CacheDescriptorBase:
class _LruCachedFunction(Generic[F]):
- cache = None # type: LruCache[CacheKey, Any]
- __call__ = None # type: F
+ cache: LruCache[CacheKey, Any]
+ __call__: F
def lru_cache(
@@ -180,10 +180,10 @@ class LruCacheDescriptor(_CacheDescriptorBase):
self.max_entries = max_entries
def __get__(self, obj, owner):
- cache = LruCache(
+ cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__,
max_size=self.max_entries,
- ) # type: LruCache[CacheKey, Any]
+ )
get_cache_key = self.cache_key_builder
sentinel = LruCacheDescriptor._Sentinel.sentinel
@@ -271,12 +271,12 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
def __get__(self, obj, owner):
- cache = DeferredCache(
+ cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
- ) # type: DeferredCache[CacheKey, Any]
+ )
get_cache_key = self.cache_key_builder
@@ -359,7 +359,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name)
- cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
+ cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig)
@@ -472,15 +472,15 @@ class _CacheContext:
Cache = Union[DeferredCache, LruCache]
- _cache_context_objects = (
- WeakValueDictionary()
- ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
+ _cache_context_objects: """WeakValueDictionary[
+ Tuple["_CacheContext.Cache", CacheKey], "_CacheContext"
+ ]""" = WeakValueDictionary()
def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
self._cache = cache
self._cache_key = cache_key
- def invalidate(self): # type: () -> None
+ def invalidate(self) -> None:
"""Invalidates the cache entry referred to by the context."""
self._cache.invalidate(self._cache_key)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 56d94d96ce..3f852edd7f 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -62,13 +62,13 @@ class DictionaryCache(Generic[KT, DKT]):
"""
def __init__(self, name: str, max_entries: int = 1000):
- self.cache = LruCache(
+ self.cache: LruCache[KT, DictionaryEntry] = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
- ) # type: LruCache[KT, DictionaryEntry]
+ )
self.name = name
self.sequence = 0
- self.thread = None # type: Optional[threading.Thread]
+ self.thread: Optional[threading.Thread] = None
def check_thread(self) -> None:
expected_thread = self.thread
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index ac47a31cd7..bde16b8577 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -27,7 +27,7 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-SENTINEL = object() # type: Any
+SENTINEL: Any = object()
T = TypeVar("T")
@@ -71,7 +71,7 @@ class ExpiringCache(Generic[KT, VT]):
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get
- self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]
+ self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
self.iterable = iterable
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4b9d0433ff..efeba0cb96 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -226,7 +226,7 @@ class _Node:
# footprint down. Storing `None` is free as its a singleton, while empty
# lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
# thing and used sets).
- self.callbacks = None # type: Optional[List[Callable[[], None]]]
+ self.callbacks: Optional[List[Callable[[], None]]] = None
self.add_callbacks(callbacks)
@@ -362,15 +362,15 @@ class LruCache(Generic[KT, VT]):
# register_cache might call our "set_cache_factor" callback; there's nothing to
# do yet when we get resized.
- self._on_resize = None # type: Optional[Callable[[],None]]
+ self._on_resize: Optional[Callable[[], None]] = None
if cache_name is not None:
- metrics = register_cache(
+ metrics: Optional[CacheMetric] = register_cache(
"lru_cache",
cache_name,
self,
collect_callback=metrics_collection_callback,
- ) # type: Optional[CacheMetric]
+ )
else:
metrics = None
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 34c662c4db..ed7204336f 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -66,7 +66,7 @@ class ResponseCache(Generic[KV]):
# This is poorly-named: it includes both complete and incomplete results.
# We keep complete results rather than switching to absolute values because
# that makes it easier to cache Failure results.
- self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred]
+ self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index e81e468899..3a41a8baa6 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -45,10 +45,10 @@ class StreamChangeCache:
):
self._original_max_size = max_size
self._max_size = math.floor(max_size)
- self._entity_to_key = {} # type: Dict[EntityType, int]
+ self._entity_to_key: Dict[EntityType, int] = {}
# map from stream id to the a set of entities which changed at that stream id.
- self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]]
+ self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
# the earliest stream_pos for which we can reliably answer
# get_all_entities_changed. In other words, one less than the earliest
@@ -155,7 +155,7 @@ class StreamChangeCache:
if stream_pos < self._earliest_known_stream_pos:
return None
- changed_entities = [] # type: List[EntityType]
+ changed_entities: List[EntityType] = []
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index c276107d56..46afe3f934 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -23,7 +23,7 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-SENTINEL = object() # type: Any
+SENTINEL: Any = object()
T = TypeVar("T")
KT = TypeVar("KT")
@@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry
- self._data = {} # type: Dict[KT, _CacheEntry]
+ self._data: Dict[KT, _CacheEntry] = {}
# the _CacheEntries, sorted by expiry time
- self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
+ self._expiry_list: SortedList[_CacheEntry] = SortedList()
self._timer = timer
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 886afa9d19..8ac3eab2f5 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -68,7 +68,7 @@ def sorted_topologically(
# This is implemented by Kahn's algorithm.
degree_map = {node: 0 for node in nodes}
- reverse_graph = {} # type: Dict[T, Set[T]]
+ reverse_graph: Dict[T, Set[T]] = {}
for node, edges in graph.items():
if node not in degree_map:
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index f6ebfd7e7d..d1f76e3dc5 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -39,7 +39,7 @@ def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
caveat in the macaroon, or if the caveat was not found in the macaroon.
"""
prefix = key + " = "
- result = None # type: Optional[str]
+ result: Optional[str] = None
for caveat in macaroon.caveats:
if not caveat.caveat_id.startswith(prefix):
continue
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 45353d41c5..1b82dca81b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -124,7 +124,7 @@ class Measure:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
self._logging_context = LoggingContext(str(curr_context), parent_context)
- self.start = None # type: Optional[int]
+ self.start: Optional[int] = None
def __enter__(self) -> "Measure":
if self.start is not None:
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index eed0291cae..99f01e325c 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -41,7 +41,7 @@ def do_patch():
@functools.wraps(f)
def wrapped(*args, **kwargs):
start_context = current_context()
- changes = [] # type: List[str]
+ changes: List[str] = []
orig = orig_inline_callbacks(_check_yield_points(f, changes))
try:
@@ -131,7 +131,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
gen = f(*args, **kwargs)
last_yield_line_no = gen.gi_frame.f_lineno
- result = None # type: Any
+ result: Any = None
while True:
expected_context = current_context()
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 490fb26e81..1dc6b90275 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -90,7 +90,7 @@ async def filter_events_for_client(
AccountDataTypes.IGNORED_USER_LIST, user_id
)
- ignore_list = frozenset() # type: FrozenSet[str]
+ ignore_list: FrozenSet[str] = frozenset()
if ignore_dict_content:
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
if isinstance(ignored_users_dict, dict):
|