diff options
author | Sean Quah <8349537+squahtx@users.noreply.github.com> | 2021-10-22 18:15:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-22 18:15:41 +0100 |
commit | 2b82ec425fccb0ef626242779f7ccd4d77a0685c (patch) | |
tree | d541487cf7936d98807c5b1ef5128bb1bb5c783c | |
parent | Fix synapse.config module "read" command (#11145) (diff) | |
download | synapse-2b82ec425fccb0ef626242779f7ccd4d77a0685c.tar.xz |
Add type hints for most `HomeServer` parameters (#11095)
58 files changed, 342 insertions, 143 deletions
diff --git a/changelog.d/11095.misc b/changelog.d/11095.misc new file mode 100644 index 0000000000..786e90b595 --- /dev/null +++ b/changelog.d/11095.misc @@ -0,0 +1 @@ +Add type hints to most `HomeServer` parameters. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index bb4d53d778..2ca2e051e4 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -294,7 +294,7 @@ def listen_ssl( return r -def refresh_certificate(hs): +def refresh_certificate(hs: "HomeServer"): """ Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. @@ -419,11 +419,11 @@ async def start(hs: "HomeServer"): atexit.register(gc.freeze) -def setup_sentry(hs): +def setup_sentry(hs: "HomeServer"): """Enable sentry integration, if enabled in configuration Args: - hs (synapse.server.HomeServer) + hs """ if not hs.config.metrics.sentry_enabled: @@ -449,7 +449,7 @@ def setup_sentry(hs): scope.set_tag("worker_name", name) -def setup_sdnotify(hs): +def setup_sdnotify(hs: "HomeServer"): """Adds process state hooks to tell systemd what we are up to.""" # Tell systemd our state, if we're using it. This will silently fail if diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index b156b93bf3..2fc848596d 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer): DATASTORE_CLASS = AdminCmdSlavedStore -async def export_data_command(hs, args): +async def export_data_command(hs: HomeServer, args): """Export data for a user. Args: - hs (HomeServer) + hs args (argparse.Namespace) """ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 7489f31d9a..51eadf122d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") - def __init__(self, hs): + def __init__(self, hs: HomeServer): """ Args: - hs (synapse.server.HomeServer): server + hs: server """ super().__init__() self.auth = hs.get_auth() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 422f03cc04..93e2299266 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]: e = e.__cause__ -def run(hs): +def run(hs: HomeServer): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index fcd01e833c..126450e17a 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,11 +15,15 @@ import logging import math import resource import sys +from typing import TYPE_CHECKING from prometheus_client import Gauge from synapse.metrics.background_process_metrics import wrap_as_background_process +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger("synapse.app.homeserver") # Contains the list of processes we will be monitoring @@ -41,7 +45,7 @@ registered_reserved_users_mau_gauge = Gauge( @wrap_as_background_process("phone_stats_home") -async def phone_stats_home(hs, stats, stats_process=_stats_process): +async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process): logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) uptime = int(now - hs.start_time) @@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): logger.warning("Error reporting stats: %s", e) -def start_phone_stats_home(hs): +def start_phone_stats_home(hs: "HomeServer"): """ Start the background tasks which report phone home stats. """ diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 935f24263c..d08f6bbd7f 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -27,6 +27,7 @@ from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: from synapse.appservice import ApplicationService + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient): pushing. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 0a08231e5a..5252e61a99 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -18,6 +18,7 @@ import os import sys import threading from string import Template +from typing import TYPE_CHECKING import yaml from zope.interface import implementer @@ -38,6 +39,9 @@ from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError +if TYPE_CHECKING: + from synapse.server import HomeServer + DEFAULT_LOG_CONFIG = Template( """\ # Log configuration for Synapse. @@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path): def setup_logging( - hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner + hs: "HomeServer", + config, + use_worker_options=False, + logBeginner: LogBeginner = globalLogBeginner, ) -> None: """ Set up the logging subsystem. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 0cd424e12a..f56344a3b9 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -14,6 +14,7 @@ # limitations under the License. import logging from collections import namedtuple +from typing import TYPE_CHECKING from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -25,11 +26,15 @@ from synapse.events.utils import prune_event, validate_canonicaljson from synapse.http.servlet import assert_params_in_dict from synapse.types import JsonDict, get_domain_from_id +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class FederationBase: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.server_name = hs.hostname diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d8c0b86f23..0d66034f44 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -467,7 +467,7 @@ class FederationServer(FederationBase): async def on_room_state_request( self, origin: str, room_id: str, event_id: Optional[str] - ) -> Tuple[int, Dict[str, Any]]: + ) -> Tuple[int, JsonDict]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -481,7 +481,7 @@ class FederationServer(FederationBase): # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. with (await self._server_linearizer.queue((origin, room_id))): - resp = dict( + resp: JsonDict = dict( await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, @@ -1061,11 +1061,12 @@ class FederationServer(FederationBase): origin, event = next - lock = await self.store.try_acquire_lock( + new_lock = await self.store.try_acquire_lock( _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id ) - if not lock: + if not new_lock: return + lock = new_lock def __str__(self) -> str: return "<ReplicationLayer(%s)>" % self.server_name diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 4f59224686..203d723d41 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -21,6 +21,7 @@ import typing import urllib.parse from io import BytesIO, StringIO from typing import ( + TYPE_CHECKING, Callable, Dict, Generic, @@ -73,6 +74,9 @@ from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter( @@ -319,7 +323,7 @@ class MatrixFederationHttpClient: requests. """ - def __init__(self, hs, tls_client_options_factory): + def __init__(self, hs: "HomeServer", tls_client_options_factory): self.hs = hs self.signing_key = hs.signing_key self.server_name = hs.hostname @@ -711,7 +715,7 @@ class MatrixFederationHttpClient: Returns: A list of headers to be added as "Authorization:" headers """ - request = { + request: JsonDict = { "method": method.decode("ascii"), "uri": url_bytes.decode("ascii"), "origin": self.server_name, diff --git a/synapse/http/server.py b/synapse/http/server.py index 897ba5e453..1af0d9a31d 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,6 +22,7 @@ import urllib from http import HTTPStatus from inspect import isawaitable from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -61,6 +62,9 @@ from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.iterutils import chunk_seq +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) HTML_ERROR_TEMPLATE = """<!DOCTYPE html> @@ -343,6 +347,11 @@ class DirectServeJsonResource(_AsyncResource): return_json_error(f, request) +_PathEntry = collections.namedtuple( + "_PathEntry", ["pattern", "callback", "servlet_classname"] +) + + class JsonResource(DirectServeJsonResource): """This implements the HttpServer interface and provides JSON support for Resources. @@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource): isLeaf = True - _PathEntry = collections.namedtuple( - "_PathEntry", ["pattern", "callback", "servlet_classname"] - ) - - def __init__(self, hs, canonical_json=True, extract_context=False): + def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() - self.path_regexs = {} + self.path_regexs: Dict[bytes, List[_PathEntry]] = {} self.hs = hs def register_paths(self, method, path_patterns, callback, servlet_classname): @@ -391,7 +396,7 @@ class JsonResource(DirectServeJsonResource): for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( - self._PathEntry(path_pattern, callback, servlet_classname) + _PathEntry(path_pattern, callback, servlet_classname) ) def _get_handler_for_request( diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index ba8114ac9e..1457d9d59b 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.http.server import JsonResource from synapse.replication.http import ( account_data, @@ -26,16 +28,19 @@ from synapse.replication.http import ( streams, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + REPLICATION_PREFIX = "/_synapse/replication" class ReplicationRestResource(JsonResource): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # We enable extracting jaeger contexts here as these are internal APIs. super().__init__(hs, canonical_json=False, extract_context=True) self.register_servlets(hs) - def register_servlets(self, hs): + def register_servlets(self, hs: "HomeServer"): send_event.register_servlets(hs, self) federation.register_servlets(hs, self) presence.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index e047ec74d8..585332b244 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -17,7 +17,7 @@ import logging import re import urllib from inspect import signature -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -156,7 +156,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): pass @classmethod - def make_client(cls, hs): + def make_client(cls, hs: "HomeServer"): """Create a client that makes requests. Returns a callable that accepts the same parameters as @@ -208,7 +208,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): url_args.append(txn_id) if cls.METHOD == "POST": - request_func = client.post_json_get_json + request_func: Callable[ + ..., Awaitable[Any] + ] = client.post_json_get_json elif cls.METHOD == "PUT": request_func = client.put_json elif cls.METHOD == "GET": diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 70e951af63..5f0f225aa9 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +41,7 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "account_data_type") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -78,7 +82,7 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "room_id", "account_data_type") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -119,7 +123,7 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "room_id", "tag") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -162,7 +166,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): ) CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -183,7 +187,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationUserAccountDataRestServlet(hs).register(http_server) ReplicationRoomAccountDataRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 5a5818ef61..42dffb39cb 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,9 +13,13 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -51,7 +55,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater @@ -68,5 +72,5 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return 200, user_devices -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationUserDevicesResyncRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index a0b3145f4e..5ed535c90d 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -21,6 +22,9 @@ from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): NAME = "fed_send_events" PATH_ARGS = () - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -151,7 +155,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): NAME = "fed_send_edu" PATH_ARGS = ("edu_type",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -194,7 +198,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): # This is a query, so let's not bother caching CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -238,7 +242,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): NAME = "fed_cleanup_room" PATH_ARGS = ("room_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -273,7 +277,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): NAME = "store_room_on_outlier_membership" PATH_ARGS = ("room_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -289,7 +293,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 550bd5c95f..0db419ea57 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,7 +34,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): NAME = "device_check_registered" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.registration_handler = hs.get_registration_handler() @@ -82,5 +86,5 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): return 200, res -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): RegisterDeviceReplicationServlet(hs).register(http_server) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 34206c5060..7371c240b2 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): NAME = "remote_join" PATH_ARGS = ("room_id", "user_id") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation_handler = hs.get_federation_handler() @@ -320,7 +320,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id", "user_id", "change") CACHE = False # No point caching as should return instantly. - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.registeration_handler = hs.get_registration_handler() @@ -360,7 +360,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index bb00247953..63143085d5 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -117,6 +117,6 @@ class ReplicationPresenceSetState(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationBumpPresenceActiveTime(hs).register(http_server) ReplicationPresenceSetState(hs).register(http_server) diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py index 139427cb1f..6c8db3061e 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py @@ -67,5 +67,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRemovePusherRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index d6dd7242eb..7adfbb666f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -26,7 +30,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): NAME = "register_user" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -100,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): NAME = "post_register" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -130,6 +134,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRegisterServlet(hs).register(http_server) ReplicationPostRegisterActionsServlet(hs).register(http_server) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index fae5ffa451..9f6851d059 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -22,6 +23,9 @@ from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -57,7 +61,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): NAME = "send_event" PATH_ARGS = ("event_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() @@ -135,5 +139,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationSendEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index 9afa147d00..3223bc2432 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -13,11 +13,15 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -46,7 +50,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): PATH_ARGS = ("stream_name",) METHOD = "GET" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._instance_name = hs.get_instance_name() @@ -74,5 +78,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index e460dd85cd..7ecb446e7c 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -13,18 +13,21 @@ # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen: Optional[ diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 436d39c320..61cd7e5228 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.lrucache import LruCache from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.client_ip_last_seen: LruCache[tuple, int] = LruCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 26bdead565..0a58296089 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream @@ -20,9 +22,12 @@ from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d4d3f8c448..63ed50caa5 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.storage.database import DatabasePool from synapse.storage.databases.main.event_federation import EventFederationWorkerStore @@ -30,6 +31,9 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -54,7 +58,7 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 37875bc973..90284c202d 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.storage.database import DatabasePool from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index e9bdc38470..497e16c69e 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import GroupServerStream @@ -19,9 +21,12 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index b402f82810..aaf91e5e02 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -21,6 +21,8 @@ from synapse.logging.context import make_deferred_yieldable from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.server import HomeServer set_counter = Counter( @@ -59,7 +61,12 @@ class ExternalCache: """ def __init__(self, hs: "HomeServer"): - self._redis_connection = hs.get_outbound_redis_connection() + if hs.config.redis.redis_enabled: + self._redis_connection: Optional[ + "RedisProtocol" + ] = hs.get_outbound_redis_connection() + else: + self._redis_connection = None def _get_redis_key(self, cache_name: str, key: str) -> str: return "cache_v1:%s:%s" % (cache_name, key) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6aa9318027..06fd06fdf3 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -294,7 +294,7 @@ class ReplicationCommandHandler: # This shouldn't be possible raise Exception("Unrecognised command %s in stream queue", cmd.NAME) - def start_replication(self, hs): + def start_replication(self, hs: "HomeServer"): """Helper method to start a replication connection to the remote server using TCP. """ @@ -321,6 +321,8 @@ class ReplicationCommandHandler: hs.config.redis.redis_host, # type: ignore[arg-type] hs.config.redis.redis_port, self._factory, + timeout=30, + bindAddress=None, ) else: client_name = hs.get_instance_name() @@ -331,6 +333,8 @@ class ReplicationCommandHandler: host, # type: ignore[arg-type] port, self._factory, + timeout=30, + bindAddress=None, ) def get_streams(self) -> Dict[str, Stream]: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 80f9b23bfd..55326877fd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -16,6 +16,7 @@ import logging import random +from typing import TYPE_CHECKING from prometheus_client import Counter @@ -27,6 +28,9 @@ from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.streams import EventsStream from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) @@ -37,7 +41,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections.""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.command_handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server.server_name @@ -65,7 +69,7 @@ class ReplicationStreamer: data is available it will propagate to all connected clients. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9b905aba9d..c8b188ae4e 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -241,7 +241,7 @@ class BackfillStream(Stream): NAME = "backfill" ROW_TYPE = BackfillStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -363,7 +363,7 @@ class ReceiptsStream(Stream): NAME = "receipts" ROW_TYPE = ReceiptsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -380,7 +380,7 @@ class PushRulesStream(Stream): NAME = "push_rules" ROW_TYPE = PushRulesStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( @@ -405,7 +405,7 @@ class PushersStream(Stream): NAME = "pushers" ROW_TYPE = PushersStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( @@ -438,7 +438,7 @@ class CachesStream(Stream): NAME = "caches" ROW_TYPE = CachesStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -459,7 +459,7 @@ class DeviceListsStream(Stream): NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -476,7 +476,7 @@ class ToDeviceStream(Stream): NAME = "to_device" ROW_TYPE = ToDeviceStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -495,7 +495,7 @@ class TagAccountDataStream(Stream): NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -582,7 +582,7 @@ class GroupServerStream(Stream): NAME = "groups" ROW_TYPE = GroupsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -599,7 +599,7 @@ class UserSignatureStream(Stream): NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index a6fa03c90f..80fbf32f17 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -110,7 +110,7 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): """ Args: - hs (synapse.server.HomeServer): server + hs: server """ self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/server.py b/synapse/server.py index a64c846d1c..0fbf36ba99 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -800,9 +800,14 @@ class HomeServer(metaclass=abc.ABCMeta): return ExternalCache(self) @cache_in_self - def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: - if not self.config.redis.redis_enabled: - return None + def get_outbound_redis_connection(self) -> "RedisProtocol": + """ + The Redis connection used for replication. + + Raises: + AssertionError: if Redis is not enabled in the homeserver config. + """ + assert self.config.redis.redis_enabled # We only want to import redis module if we're using it, as we have # `txredisapi` as an optional dependency. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index f5a8f90a0f..fa4e89d35c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -19,6 +19,7 @@ from collections import defaultdict from sys import intern from time import monotonic as monotonic_time from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +if TYPE_CHECKING: + from synapse.server import HomeServer + # python 3 does not have a maximum int value MAX_TXN_ID = 2 ** 63 - 1 @@ -392,7 +396,7 @@ class DatabasePool: def __init__( self, - hs, + hs: "HomeServer", database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine, ): diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 20b755056b..cfe887b7f7 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -13,33 +13,49 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar +from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.state import StateGroupDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -class Databases: +DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True) + + +class Databases(Generic[DataStoreT]): """The various databases. These are low level interfaces to physical databases. Attributes: - main (DataStore) + databases + main + state + persist_events """ - def __init__(self, main_store_class, hs): + databases: List[DatabasePool] + main: DataStoreT + state: StateGroupDataStore + persist_events: Optional[PersistEventsStore] + + def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # Note we pass in the main store class here as workers use a different main # store. self.databases = [] - main = None - state = None - persist_events = None + main: Optional[DataStoreT] = None + state: Optional[StateGroupDataStore] = None + persist_events: Optional[PersistEventsStore] = None for database_config in hs.config.database.databases: db_name = database_config.name diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 5c21402dea..259cae5b37 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import DatabasePool @@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -126,7 +129,7 @@ class DataStore( LockStore, SessionStore, ): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 70ca3e09f7..f8bec266ac 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -28,6 +28,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore): `get_max_account_data_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index c57ae5ef15..36e8422fc6 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -15,7 +15,7 @@ import itertools import logging -from typing import Any, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes from synapse.replication.tcp.streams import BackfillStream, CachesStream @@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 3154906d45..8143168107 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -26,11 +26,14 @@ from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class DeviceInboxWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 6464520386..a01bf2c5b7 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,17 @@ # limitations under the License. import abc import logging -from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( @@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ba9f71a230..ef5d1ef01e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,7 +14,7 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter, Gauge @@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.server import HomeServer + oldest_pdu_in_federation_staging = Gauge( "synapse_federation_server_oldest_inbound_pdu_in_staging", "The age in seconds since we received the oldest pdu in the federation staging area", @@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 97b3e92d3f..d957e770dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr @@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn @@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 1afc59fafb..fc49112063 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr @@ -26,6 +26,9 @@ from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.types import Cursor from synapse.types import JsonDict +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -76,7 +79,7 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 2fa945d171..717487be28 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool +if TYPE_CHECKING: + from synapse.server import HomeServer + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( "media_repository_drop_index_wo_method" ) @@ -43,7 +46,7 @@ class MediaSortOrder(Enum): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index dac3d14da8..d901933ae4 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -14,7 +14,7 @@ import calendar import logging import time -from typing import Dict +from typing import TYPE_CHECKING, Dict from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Collect metrics on the number of forward extremities that exist. @@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index a14ac03d4b..b5284e4f67 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Number of msec of granularity to store the monthly_active_user timestamp @@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._mau_stats_only = hs.config.server.mau_stats_only diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index fc720f5947..fa782023d4 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -14,7 +14,7 @@ # limitations under the License. import abc import logging -from typing import Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules @@ -33,6 +33,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -75,7 +78,7 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 01a4281301..c99f8aebdb 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from twisted.internet import defer @@ -29,11 +29,14 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 835d7889cb..f879bbe7c7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -17,7 +17,7 @@ import collections import logging from abc import abstractmethod from enum import Enum -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError @@ -32,6 +32,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.stringutils import MXC_REGEX +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -69,7 +72,7 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ddb162a4fc..4b288bb2e7 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.metrics import Measure if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.state import _StateCacheEntry logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Used by `_get_joined_hosts` to ensure only one thing mutates the cache @@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile @@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index c85383c975..7fe233767f 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -15,7 +15,7 @@ import logging import re from collections import namedtuple -from typing import Collection, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set from synapse.api.errors import SynapseError from synapse.events import EventBase @@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) SearchEntry = namedtuple( @@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if not hs.config.server.enable_search: @@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a8e8dd4577..fa2c3b1feb 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -15,7 +15,7 @@ import collections.abc import logging from collections import namedtuple -from typing import Iterable, Optional, Set +from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -30,6 +30,9 @@ from synapse.types import StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -53,7 +56,7 @@ class _GetStateGroupDelta( class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index e20033bb28..5d7b59d861 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from enum import Enum from itertools import chain -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing_extensions import Counter @@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # these fields track absolutes (e.g. total number of rooms on the server) @@ -93,7 +96,7 @@ class UserSortOrder(Enum): class StatsStore(StateDeltasStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 860146cd1b..d7dc1f73ac 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,7 @@ import logging from collections import namedtuple -from typing import Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + db_binary_type = memoryview logger = logging.getLogger(__name__) @@ -57,7 +60,7 @@ class DestinationRetryTimings: class TransactionWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 0e8270746d..402f134d89 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -18,6 +18,7 @@ import itertools import logging from collections import deque from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -56,6 +57,9 @@ from synapse.types import ( from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # The number of times we are recalculating the current state @@ -272,7 +276,7 @@ class EventsPersistenceStorage: current state and forward extremity changes. """ - def __init__(self, hs, stores: Databases): + def __init__(self, hs: "HomeServer", stores: Databases): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same # store for now. |