From e07a8caf58fad2c56518560cfc31d90a761bd5a9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 17 Jun 2020 14:13:41 +0100 Subject: Add support for using rust-python-jaeger-reporter (#7697) --- mypy.ini | 3 +++ 1 file changed, 3 insertions(+) (limited to 'mypy.ini') diff --git a/mypy.ini b/mypy.ini index 3533797d68..a61009b197 100644 --- a/mypy.ini +++ b/mypy.ini @@ -78,3 +78,6 @@ ignore_missing_imports = True [mypy-authlib.*] ignore_missing_imports = True + +[mypy-rust_python_jaeger_reporter.*] +ignore_missing_imports = True -- cgit 1.4.1 From 5dd73d029eff32668b3ca69b7fb8529fc7c58745 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2020 15:05:50 +0100 Subject: Add type hints to handlers.message and events.builder (#8067) --- changelog.d/8067.misc | 1 + mypy.ini | 3 ++ synapse/events/builder.py | 58 ++++++++++++++++++++----------------- synapse/handlers/message.py | 22 ++++++++------ synapse/handlers/room_member.py | 12 +++++--- tests/rest/client/test_retention.py | 4 ++- tox.ini | 2 ++ 7 files changed, 61 insertions(+), 41 deletions(-) create mode 100644 changelog.d/8067.misc (limited to 'mypy.ini') diff --git a/changelog.d/8067.misc b/changelog.d/8067.misc new file mode 100644 index 0000000000..f4404b7506 --- /dev/null +++ b/changelog.d/8067.misc @@ -0,0 +1 @@ +Add type hints to `synapse.handlers.message` and `synapse.events.builder`. diff --git a/mypy.ini b/mypy.ini index a61009b197..c69cb5dc40 100644 --- a/mypy.ini +++ b/mypy.ini @@ -81,3 +81,6 @@ ignore_missing_imports = True [mypy-rust_python_jaeger_reporter.*] ignore_missing_imports = True + +[mypy-nacl.*] +ignore_missing_imports = True diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 4e179d49b3..9ed24380dd 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -17,6 +17,7 @@ from typing import Optional import attr from nacl.signing import SigningKey +from synapse.api.auth import Auth from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -27,6 +28,8 @@ from synapse.api.room_versions import ( ) from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.state import StateHandler +from synapse.storage.databases.main import DataStore from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string @@ -42,45 +45,46 @@ class EventBuilder(object): Attributes: room_version: Version of the target room - room_id (str) - type (str) - sender (str) - content (dict) - unsigned (dict) - internal_metadata (_EventInternalMetadata) - - _state (StateHandler) - _auth (synapse.api.Auth) - _store (DataStore) - _clock (Clock) - _hostname (str): The hostname of the server creating the event + room_id + type + sender + content + unsigned + internal_metadata + + _state + _auth + _store + _clock + _hostname: The hostname of the server creating the event _signing_key: The signing key to use to sign the event as the server """ - _state = attr.ib() - _auth = attr.ib() - _store = attr.ib() - _clock = attr.ib() - _hostname = attr.ib() - _signing_key = attr.ib() + _state = attr.ib(type=StateHandler) + _auth = attr.ib(type=Auth) + _store = attr.ib(type=DataStore) + _clock = attr.ib(type=Clock) + _hostname = attr.ib(type=str) + _signing_key = attr.ib(type=SigningKey) room_version = attr.ib(type=RoomVersion) - room_id = attr.ib() - type = attr.ib() - sender = attr.ib() + room_id = attr.ib(type=str) + type = attr.ib(type=str) + sender = attr.ib(type=str) - content = attr.ib(default=attr.Factory(dict)) - unsigned = attr.ib(default=attr.Factory(dict)) + content = attr.ib(default=attr.Factory(dict), type=JsonDict) + unsigned = attr.ib(default=attr.Factory(dict), type=JsonDict) # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. - _state_key = attr.ib(default=None) - _redacts = attr.ib(default=None) - _origin_server_ts = attr.ib(default=None) + _state_key = attr.ib(default=None, type=Optional[str]) + _redacts = attr.ib(default=None, type=Optional[str]) + _origin_server_ts = attr.ib(default=None, type=Optional[int]) internal_metadata = attr.ib( - default=attr.Factory(lambda: _EventInternalMetadata({})) + default=attr.Factory(lambda: _EventInternalMetadata({})), + type=_EventInternalMetadata, ) @property diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8ddded8389..2643438e84 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from canonicaljson import encode_canonical_json, json @@ -93,11 +93,11 @@ class MessageHandler(object): async def get_room_data( self, - user_id: str = None, - room_id: str = None, - event_type: Optional[str] = None, - state_key: str = "", - is_guest: bool = False, + user_id: str, + room_id: str, + event_type: str, + state_key: str, + is_guest: bool, ) -> dict: """ Get data from a room. @@ -407,7 +407,7 @@ class EventCreationHandler(object): # # map from room id to time-of-last-attempt. # - self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int] + self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] # we need to construct a ConsentURIBuilder here, as it checks that the necessary # config options, but *only* if we have a configuration for which we are @@ -707,7 +707,7 @@ class EventCreationHandler(object): async def create_and_send_nonmember_event( self, requester: Requester, - event_dict: EventBase, + event_dict: dict, ratelimit: bool = True, txn_id: Optional[str] = None, ) -> Tuple[EventBase, int]: @@ -971,7 +971,7 @@ class EventCreationHandler(object): # Validate a newly added alias or newly added alt_aliases. original_alias = None - original_alt_aliases = set() + original_alt_aliases = [] # type: List[str] original_event_id = event.unsigned.get("replaces_state") if original_event_id: @@ -1019,6 +1019,10 @@ class EventCreationHandler(object): current_state_ids = await context.get_current_state_ids() + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + state_to_include_ids = [ e_id for k, e_id in current_state_ids.items() diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8e409f24e8..31705cdbdb 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -16,7 +16,7 @@ import abc import logging from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 @@ -37,6 +37,10 @@ from synapse.util.distributor import user_joined_room, user_left_room from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) @@ -48,7 +52,7 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -207,7 +211,7 @@ class RoomMemberHandler(object): return duplicate.event_id, stream_id stream_id = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit + requester, event, context, extra_users=[target], ratelimit=ratelimit, ) prev_state_ids = await context.get_prev_state_ids() @@ -1000,7 +1004,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): check_complexity = self.hs.config.limit_remote_rooms.enabled if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: - check_complexity = not await self.hs.auth.is_server_admin(user) + check_complexity = not await self.auth.is_server_admin(user) if check_complexity: # Fetch the room complexity diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index e54ffea150..0b191d13c6 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -144,7 +144,9 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Get the create event to, later, check that we can still access it. message_handler = self.hs.get_message_handler() create_event = self.get_success( - message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + message_handler.get_room_data( + self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False + ) ) # Send a first event to the room. This is the event we'll want to be purged at the diff --git a/tox.ini b/tox.ini index 45e129580f..e5413eb110 100644 --- a/tox.ini +++ b/tox.ini @@ -179,6 +179,7 @@ commands = mypy \ synapse/appservice \ synapse/config \ synapse/event_auth.py \ + synapse/events/builder.py \ synapse/events/spamcheck.py \ synapse/federation \ synapse/handlers/auth.py \ @@ -186,6 +187,7 @@ commands = mypy \ synapse/handlers/directory.py \ synapse/handlers/federation.py \ synapse/handlers/identity.py \ + synapse/handlers/message.py \ synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ synapse/handlers/room_member.py \ -- cgit 1.4.1 From 98125bba7a63f34bf623fdef3902f2e4ab7c1231 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Aug 2020 09:59:37 -0400 Subject: Allow running mypy directly. (#8175) --- changelog.d/8175.misc | 1 + mypy.ini | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ tox.ini | 52 +-------------------------------------------------- 3 files changed, 51 insertions(+), 51 deletions(-) create mode 100644 changelog.d/8175.misc (limited to 'mypy.ini') diff --git a/changelog.d/8175.misc b/changelog.d/8175.misc new file mode 100644 index 0000000000..28af294dcf --- /dev/null +++ b/changelog.d/8175.misc @@ -0,0 +1 @@ +Standardize the mypy configuration. diff --git a/mypy.ini b/mypy.ini index c69cb5dc40..4213e31b03 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,6 +6,55 @@ check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs +files = + synapse/api, + synapse/appservice, + synapse/config, + synapse/event_auth.py, + synapse/events/builder.py, + synapse/events/spamcheck.py, + synapse/federation, + synapse/handlers/auth.py, + synapse/handlers/cas_handler.py, + synapse/handlers/directory.py, + synapse/handlers/federation.py, + synapse/handlers/identity.py, + synapse/handlers/message.py, + synapse/handlers/oidc_handler.py, + synapse/handlers/presence.py, + synapse/handlers/room.py, + synapse/handlers/room_member.py, + synapse/handlers/room_member_worker.py, + synapse/handlers/saml_handler.py, + synapse/handlers/sync.py, + synapse/handlers/ui_auth, + synapse/http/server.py, + synapse/http/site.py, + synapse/logging/, + synapse/metrics, + synapse/module_api, + synapse/notifier.py, + synapse/push/pusherpool.py, + synapse/push/push_rule_evaluator.py, + synapse/replication, + synapse/rest, + synapse/server.py, + synapse/server_notices, + synapse/spam_checker_api, + synapse/state, + synapse/storage/databases/main/ui_auth.py, + synapse/storage/database.py, + synapse/storage/engines, + synapse/storage/state.py, + synapse/storage/util, + synapse/streams, + synapse/types.py, + synapse/util/caches/stream_change_cache.py, + synapse/util/metrics.py, + tests/replication, + tests/test_utils, + tests/rest/client/v2_alpha/test_auth.py, + tests/util/test_stream_change_cache.py [mypy-pymacaroons.*] ignore_missing_imports = True diff --git a/tox.ini b/tox.ini index edeb757f7b..df473bd234 100644 --- a/tox.ini +++ b/tox.ini @@ -171,58 +171,8 @@ deps = {[base]deps} mypy==0.782 mypy-zope -env = - MYPYPATH = stubs/ extras = all -commands = mypy \ - synapse/api \ - synapse/appservice \ - synapse/config \ - synapse/event_auth.py \ - synapse/events/builder.py \ - synapse/events/spamcheck.py \ - synapse/federation \ - synapse/handlers/auth.py \ - synapse/handlers/cas_handler.py \ - synapse/handlers/directory.py \ - synapse/handlers/federation.py \ - synapse/handlers/identity.py \ - synapse/handlers/message.py \ - synapse/handlers/oidc_handler.py \ - synapse/handlers/presence.py \ - synapse/handlers/room.py \ - synapse/handlers/room_member.py \ - synapse/handlers/room_member_worker.py \ - synapse/handlers/saml_handler.py \ - synapse/handlers/sync.py \ - synapse/handlers/ui_auth \ - synapse/http/server.py \ - synapse/http/site.py \ - synapse/logging/ \ - synapse/metrics \ - synapse/module_api \ - synapse/notifier.py \ - synapse/push/pusherpool.py \ - synapse/push/push_rule_evaluator.py \ - synapse/replication \ - synapse/rest \ - synapse/server.py \ - synapse/server_notices \ - synapse/spam_checker_api \ - synapse/state \ - synapse/storage/databases/main/ui_auth.py \ - synapse/storage/database.py \ - synapse/storage/engines \ - synapse/storage/state.py \ - synapse/storage/util \ - synapse/streams \ - synapse/types.py \ - synapse/util/caches/stream_change_cache.py \ - synapse/util/metrics.py \ - tests/replication \ - tests/test_utils \ - tests/rest/client/v2_alpha/test_auth.py \ - tests/util/test_stream_change_cache.py +commands = mypy # To find all folders that pass mypy you run: # -- cgit 1.4.1 From 5bf8e5f55b49f9e46a7fe7d7872e6b16d38bffd3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 09:15:22 -0400 Subject: Convert the well known resolver to async (#8214) --- changelog.d/8214.misc | 1 + mypy.ini | 1 + synapse/http/federation/matrix_federation_agent.py | 4 +- synapse/http/federation/well_known_resolver.py | 57 ++++++++++++---------- .../federation/test_matrix_federation_agent.py | 24 ++++++--- 5 files changed, 53 insertions(+), 34 deletions(-) create mode 100644 changelog.d/8214.misc (limited to 'mypy.ini') diff --git a/changelog.d/8214.misc b/changelog.d/8214.misc new file mode 100644 index 0000000000..e26764dea1 --- /dev/null +++ b/changelog.d/8214.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/mypy.ini b/mypy.ini index 4213e31b03..21c6f523a0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,6 +28,7 @@ files = synapse/handlers/saml_handler.py, synapse/handlers/sync.py, synapse/handlers/ui_auth, + synapse/http/federation/well_known_resolver.py, synapse/http/server.py, synapse/http/site.py, synapse/logging/, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 369bf9c2fc..782d39d4ca 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -134,8 +134,8 @@ class MatrixFederationAgent(object): and not _is_ip_literal(parsed_uri.hostname) and not parsed_uri.port ): - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.hostname + well_known_result = yield defer.ensureDeferred( + self._well_known_resolver.get_well_known(parsed_uri.hostname) ) delegated_server = well_known_result.delegated_server diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index f794315deb..cdb6bec56e 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -16,6 +16,7 @@ import logging import random import time +from typing import Callable, Dict, Optional, Tuple import attr @@ -23,6 +24,7 @@ from twisted.internet import defer from twisted.web.client import RedirectAgent, readBody from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock, json_decoder @@ -99,15 +101,14 @@ class WellKnownResolver(object): self._well_known_agent = RedirectAgent(agent) self.user_agent = user_agent - @defer.inlineCallbacks - def get_well_known(self, server_name): + async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult: """Attempt to fetch and parse a .well-known file for the given server Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Returns: - Deferred[WellKnownLookupResult]: The result of the lookup + The result of the lookup """ try: prev_result, expiry, ttl = self._well_known_cache.get_with_expiry( @@ -124,7 +125,9 @@ class WellKnownResolver(object): # requests for the same server in parallel? try: with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._fetch_well_known(server_name) + result, cache_period = await self._fetch_well_known( + server_name + ) # type: Tuple[Optional[bytes], float] except _FetchWellKnownFailure as e: if prev_result and e.temporary: @@ -153,18 +156,17 @@ class WellKnownResolver(object): return WellKnownLookupResult(delegated_server=result) - @defer.inlineCallbacks - def _fetch_well_known(self, server_name): + async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]: """Actually fetch and parse a .well-known, without checking the cache Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Raises: _FetchWellKnownFailure if we fail to lookup a result Returns: - Deferred[Tuple[bytes,int]]: The lookup result and cache period. + The lookup result and cache period. """ had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) @@ -172,7 +174,7 @@ class WellKnownResolver(object): # We do this in two steps to differentiate between possibly transient # errors (e.g. can't connect to host, 503 response) and more permenant # errors (such as getting a 404 response). - response, body = yield self._make_well_known_request( + response, body = await self._make_well_known_request( server_name, retry=had_valid_well_known ) @@ -215,20 +217,20 @@ class WellKnownResolver(object): return result, cache_period - @defer.inlineCallbacks - def _make_well_known_request(self, server_name, retry): + async def _make_well_known_request( + self, server_name: bytes, retry: bool + ) -> Tuple[IResponse, bytes]: """Make the well known request. This will retry the request if requested and it fails (with unable to connect or receives a 5xx error). Args: - server_name (bytes) - retry (bool): Whether to retry the request if it fails. + server_name: name of the server, from the requested url + retry: Whether to retry the request if it fails. Returns: - Deferred[tuple[IResponse, bytes]] Returns the response object and - body. Response may be a non-200 response. + Returns the response object and body. Response may be a non-200 response. """ uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") @@ -243,12 +245,12 @@ class WellKnownResolver(object): logger.info("Fetching %s", uri_str) try: - response = yield make_deferred_yieldable( + response = await make_deferred_yieldable( self._well_known_agent.request( b"GET", uri, headers=Headers(headers) ) ) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 500 <= response.code < 600: raise Exception("Non-200 response %s" % (response.code,)) @@ -265,21 +267,24 @@ class WellKnownResolver(object): logger.info("Error fetching %s: %s. Retrying", uri_str, e) # Sleep briefly in the hopes that they come back up - yield self._clock.sleep(0.5) + await self._clock.sleep(0.5) -def _cache_period_from_headers(headers, time_now=time.time): +def _cache_period_from_headers( + headers: Headers, time_now: Callable[[], float] = time.time +) -> Optional[float]: cache_controls = _parse_cache_control(headers) if b"no-store" in cache_controls: return 0 if b"max-age" in cache_controls: - try: - max_age = int(cache_controls[b"max-age"]) - return max_age - except ValueError: - pass + max_age = cache_controls[b"max-age"] + if max_age: + try: + return int(max_age) + except ValueError: + pass expires = headers.getRawHeaders(b"expires") if expires is not None: @@ -295,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time): return None -def _parse_cache_control(headers): +def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: cache_controls = {} for hdr in headers.getRawHeaders(b"cache-control", []): for directive in hdr.split(b","): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 69945a8f98..eb78ab412a 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase): well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") @@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # another lookup. self.reactor.pump((900.0,)) - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # The resolver may retry a few times, so fonx all requests that come along attempts = 0 @@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((10000.0,)) # Repated the request, this time it should fail if the lookup fails. - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) -- cgit 1.4.1 From 112266eafd457204a34a76fa51d7074d0809a1db Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Sep 2020 17:52:38 +0100 Subject: Add StreamStore to mypy (#8232) --- changelog.d/8232.misc | 1 + mypy.ini | 1 + synapse/events/__init__.py | 4 +-- synapse/storage/database.py | 34 +++++++++++++++++++++++ synapse/storage/databases/main/stream.py | 46 +++++++++++++++++++------------- 5 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 changelog.d/8232.misc (limited to 'mypy.ini') diff --git a/changelog.d/8232.misc b/changelog.d/8232.misc new file mode 100644 index 0000000000..3a7a352c4f --- /dev/null +++ b/changelog.d/8232.misc @@ -0,0 +1 @@ +Add type hints to `StreamStore`. diff --git a/mypy.ini b/mypy.ini index 21c6f523a0..ae3290d5bb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -43,6 +43,7 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, synapse/storage/engines, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 67db763dbf..62ea44fa49 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -18,7 +18,7 @@ import abc import os from distutils.util import strtobool -from typing import Dict, Optional, Type +from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 @@ -120,7 +120,7 @@ class _EventInternalMetadata(object): # be here before = DictProperty("before") # type: str after = DictProperty("after") # type: str - order = DictProperty("order") # type: int + order = DictProperty("order") # type: Tuple[int, int] def get_dict(self) -> JsonDict: return dict(self._dict) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7ab370efef..af8796ad92 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -604,6 +604,18 @@ class DatabasePool(object): results = [dict(zip(col_headers, row)) for row in cursor] return results + @overload + async def execute( + self, desc: str, decoder: Literal[None], query: str, *args: Any + ) -> List[Tuple[Any, ...]]: + ... + + @overload + async def execute( + self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any + ) -> R: + ... + async def execute( self, desc: str, @@ -1088,6 +1100,28 @@ class DatabasePool(object): desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) + @overload + async def simple_select_one_onecol( + self, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one_onecol", + ) -> Any: + ... + + @overload + async def simple_select_one_onecol( + self, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one_onecol", + ) -> Optional[Any]: + ... + async def simple_select_one_onecol( self, table: str, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 83c1ddf95a..be6df8a6d1 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -39,7 +39,7 @@ what sort order was used: import abc import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from twisted.internet import defer @@ -54,9 +54,12 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import RoomStreamToken +from synapse.types import Collection, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -206,7 +209,7 @@ def _make_generic_sql_bound( ) -def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]: +def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super(StreamWorkerStore, self).__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: raise NotImplementedError() @abc.abstractmethod - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() async def get_room_events_stream_for_rooms( self, - room_ids: Iterable[str], + room_ids: Collection[str], from_key: str, to_key: str, limit: int = 0, @@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return results - def get_rooms_that_changed(self, room_ids, from_key): + def get_rooms_that_changed( + self, room_ids: Collection[str], from_key: str + ) -> Set[str]: """Given a list of rooms and a token, return rooms where there may have been changes. Args: - room_ids (list) - from_key (str): The room_key portion of a StreamToken + room_ids + from_key: The room_key portion of a StreamToken """ - from_key = RoomStreamToken.parse_stream_token(from_key).stream + from_id = RoomStreamToken.parse_stream_token(from_key).stream return { room_id for room_id in room_ids - if self._events_stream_cache.has_entity_changed(room_id, from_key) + if self._events_stream_cache.has_entity_changed(room_id, from_id) } async def get_room_events_stream_for_room( @@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key - async def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user( + self, user_id: str, from_key: str, to_key: str + ) -> List[EventBase]: from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return row[0][0] if row else 0 - def _get_max_topological_txn(self, txn, room_id): + def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int: txn.execute( "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id,), @@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_events_around_txn( self, - txn, + txn: LoggingTransaction, room_id: str, event_id: str, before_limit: int, @@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=["stream_ordering", "topological_ordering"], ) + # This cannot happen as `allow_none=False`. + assert results is not None + # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( @@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="update_federation_out_pos", ) - def _reset_federation_positions_txn(self, txn) -> None: + def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ @@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): GROUP BY type """ txn.execute(sql) - min_positions = dict(txn) # Map from type -> min position + min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position # Ensure we do actually have some values here assert set(min_positions) == {"federation", "events"} @@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def _paginate_room_events_txn( self, - txn, + txn: LoggingTransaction, room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, -- cgit 1.4.1 From 208e1d3eb345dca12e25696e30cee7e788b65ae2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 3 Sep 2020 15:38:32 +0100 Subject: Fix typing for `@cached` wrapped functions (#8240) This requires adding a mypy plugin to fiddle with the type signatures a bit. --- changelog.d/8240.misc | 1 + mypy.ini | 3 +- scripts-dev/mypy_synapse_plugin.py | 85 ++++++++++++++++++++++++++++++++++++++ synapse/handlers/federation.py | 10 ++--- synapse/util/caches/descriptors.py | 42 ++++++++++++------- 5 files changed, 121 insertions(+), 20 deletions(-) create mode 100644 changelog.d/8240.misc create mode 100644 scripts-dev/mypy_synapse_plugin.py (limited to 'mypy.ini') diff --git a/changelog.d/8240.misc b/changelog.d/8240.misc new file mode 100644 index 0000000000..acfbd89e24 --- /dev/null +++ b/changelog.d/8240.misc @@ -0,0 +1 @@ +Fix type hints for functions decorated with `@cached`. diff --git a/mypy.ini b/mypy.ini index ae3290d5bb..8a351eabfe 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,6 @@ [mypy] namespace_packages = True -plugins = mypy_zope:plugin +plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py follow_imports = silent check_untyped_defs = True show_error_codes = True @@ -51,6 +51,7 @@ files = synapse/storage/util, synapse/streams, synapse/types.py, + synapse/util/caches/descriptors.py, synapse/util/caches/stream_change_cache.py, synapse/util/metrics.py, tests/replication, diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py new file mode 100644 index 0000000000..a5b88731f1 --- /dev/null +++ b/scripts-dev/mypy_synapse_plugin.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This is a mypy plugin for Synpase to deal with some of the funky typing that +can crop up, e.g the cache descriptors. +""" + +from typing import Callable, Optional + +from mypy.plugin import MethodSigContext, Plugin +from mypy.typeops import bind_self +from mypy.types import CallableType + + +class SynapsePlugin(Plugin): + def get_method_signature_hook( + self, fullname: str + ) -> Optional[Callable[[MethodSigContext], CallableType]]: + if fullname.startswith( + "synapse.util.caches.descriptors._CachedFunction.__call__" + ): + return cached_function_method_signature + return None + + +def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: + """Fixes the `_CachedFunction.__call__` signature to be correct. + + It already has *almost* the correct signature, except: + + 1. the `self` argument needs to be marked as "bound"; and + 2. any `cache_context` argument should be removed. + """ + + # First we mark this as a bound function signature. + signature = bind_self(ctx.default_signature) + + # Secondly, we remove any "cache_context" args. + # + # Note: We should be only doing this if `cache_context=True` is set, but if + # it isn't then the code will raise an exception when its called anyway, so + # its not the end of the world. + context_arg_index = None + for idx, name in enumerate(signature.arg_names): + if name == "cache_context": + context_arg_index = idx + break + + if context_arg_index: + arg_types = list(signature.arg_types) + arg_types.pop(context_arg_index) + + arg_names = list(signature.arg_names) + arg_names.pop(context_arg_index) + + arg_kinds = list(signature.arg_kinds) + arg_kinds.pop(context_arg_index) + + signature = signature.copy_modified( + arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, + ) + + return signature + + +def plugin(version: str): + # This is the entry point of the plugin, and let's us deal with the fact + # that the mypy plugin interface is *not* stable by looking at the version + # string. + # + # However, since we pin the version of mypy Synapse uses in CI, we don't + # really care. + return SynapsePlugin diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bd8efbb768..310c7f7138 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -440,11 +440,11 @@ class FederationHandler(BaseHandler): if not prevs - seen: return - latest = await self.store.get_latest_event_ids_in_room(room_id) + latest_list = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us - latest = set(latest) + latest = set(latest_list) latest |= seen logger.info( @@ -781,7 +781,7 @@ class FederationHandler(BaseHandler): # keys across all devices. current_keys = [ key - for device in cached_devices + for device in cached_devices.values() for key in device.get("keys", {}).get("keys", {}).values() ] @@ -2119,8 +2119,8 @@ class FederationHandler(BaseHandler): if backfilled or event.internal_metadata.is_outlier(): return - extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids) + extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) if extrem_ids == prev_event_ids: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 49d9fddcf0..825810eb16 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,11 +18,10 @@ import functools import inspect import logging import threading -from typing import Any, Tuple, Union, cast +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary from prometheus_client import Gauge -from typing_extensions import Protocol from twisted.internet import defer @@ -38,8 +37,10 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] +F = TypeVar("F", bound=Callable[..., Any]) -class _CachedFunction(Protocol): + +class _CachedFunction(Generic[F]): invalidate = None # type: Any invalidate_all = None # type: Any invalidate_many = None # type: Any @@ -47,8 +48,11 @@ class _CachedFunction(Protocol): cache = None # type: Any num_args = None # type: Any - def __name__(self): - ... + __name__ = None # type: str + + # Note: This function signature is actually fiddled with by the synapse mypy + # plugin to a) make it a bound method, and b) remove any `cache_context` arg. + __call__ = None # type: F cache_pending_metric = Gauge( @@ -123,7 +127,7 @@ class Cache(object): self.name = name self.keylen = keylen - self.thread = None + self.thread = None # type: Optional[threading.Thread] self.metrics = register_cache( "cache", name, @@ -662,9 +666,13 @@ class _CacheContext: def cached( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( + max_entries: int = 1000, + num_args: Optional[int] = None, + tree: bool = False, + cache_context: bool = False, + iterable: bool = False, +) -> Callable[[F], _CachedFunction[F]]: + func = lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -673,8 +681,12 @@ def cached( iterable=iterable, ) + return cast(Callable[[F], _CachedFunction[F]], func) -def cachedList(cached_method_name, list_name, num_args=None): + +def cachedList( + cached_method_name: str, list_name: str, num_args: Optional[int] = None +) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None): cache. Args: - cached_method_name (str): The name of the single-item lookup method. + cached_method_name: The name of the single-item lookup method. This is only used to find the cache to use. - list_name (str): The name of the argument that is the list to use to + list_name: The name of the argument that is the list to use to do batch lookups in the cache. - num_args (int): Number of arguments to use as the key in the cache + num_args: Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. Example: @@ -702,9 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None): def batch_do_something(self, first_arg, second_args): ... """ - return lambda orig: CacheListDescriptor( + func = lambda orig: CacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, ) + + return cast(Callable[[F], _CachedFunction[F]], func) -- cgit 1.4.1 From be16ee59a87723c2da164f56dc2274ae3ac3e438 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 3 Sep 2020 22:02:29 +0100 Subject: Add type hints to more handlers (#8244) --- changelog.d/8244.misc | 1 + mypy.ini | 3 ++ synapse/handlers/events.py | 49 ++++++++++++------------ synapse/handlers/initial_sync.py | 80 ++++++++++++++++++++++++---------------- synapse/handlers/pagination.py | 56 ++++++++++++++++------------ 5 files changed, 110 insertions(+), 79 deletions(-) create mode 100644 changelog.d/8244.misc (limited to 'mypy.ini') diff --git a/changelog.d/8244.misc b/changelog.d/8244.misc new file mode 100644 index 0000000000..e650072223 --- /dev/null +++ b/changelog.d/8244.misc @@ -0,0 +1 @@ +Add type hints to pagination, initial sync and events handlers. diff --git a/mypy.ini b/mypy.ini index 8a351eabfe..7764f17856 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,10 +17,13 @@ files = synapse/handlers/auth.py, synapse/handlers/cas_handler.py, synapse/handlers/directory.py, + synapse/handlers/events.py, synapse/handlers/federation.py, synapse/handlers/identity.py, + synapse/handlers/initial_sync.py, synapse/handlers/message.py, synapse/handlers/oidc_handler.py, + synapse/handlers/pagination.py, synapse/handlers/presence.py, synapse/handlers/room.py, synapse/handlers/room_member.py, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1924636c4d..b05e32f457 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -15,29 +15,30 @@ import logging import random +from typing import TYPE_CHECKING, Iterable, List, Optional from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.utils import log_function -from synapse.types import UserID +from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(EventStreamHandler, self).__init__(hs) - # Count of active streams per user - self._streams_per_user = {} - # Grace timers per user to delay the "stopped" signal - self._stop_timer_per_user = {} - self.distributor = hs.get_distributor() self.distributor.declare("started_user_eventstream") self.distributor.declare("stopped_user_eventstream") @@ -52,14 +53,14 @@ class EventStreamHandler(BaseHandler): @log_function async def get_stream( self, - auth_user_id, - pagin_config, - timeout=0, - as_client_event=True, - affect_presence=True, - room_id=None, - is_guest=False, - ): + auth_user_id: str, + pagin_config: PaginationConfig, + timeout: int = 0, + as_client_event: bool = True, + affect_presence: bool = True, + room_id: Optional[str] = None, + is_guest: bool = False, + ) -> JsonDict: """Fetches the events stream for a given user. """ @@ -98,7 +99,7 @@ class EventStreamHandler(BaseHandler): # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. - to_add = [] + to_add = [] # type: List[JsonDict] for event in events: if not isinstance(event, EventBase): continue @@ -110,7 +111,7 @@ class EventStreamHandler(BaseHandler): # Send down presence for everyone in the room. users = await self.state.get_current_users_in_room( event.room_id - ) + ) # type: Iterable[str] else: users = [event.state_key] @@ -144,20 +145,22 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(EventHandler, self).__init__(hs) self.storage = hs.get_storage() - async def get_event(self, user, room_id, event_id): + async def get_event( + self, user: UserID, room_id: Optional[str], event_id: str + ) -> Optional[EventBase]: """Retrieve a single specified event. Args: - user (synapse.types.UserID): The user requesting the event - room_id (str|None): The expected room id. We'll return None if the + user: The user requesting the event + room_id: The expected room id. We'll return None if the event's room does not match. - event_id (str): The event ID to obtain. + event_id: The event ID to obtain. Returns: - dict: An event, or None if there is no event matching this ID. + An event, or None if there is no event matching this ID. Raises: SynapseError if there was a problem retrieving this event, or AuthError if the user does not have the rights to inspect this diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index ae6bd1d352..d5ddc583ad 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from twisted.internet import defer @@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig -from synapse.types import StreamToken, UserID +from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class InitialSyncHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(InitialSyncHandler, self).__init__(hs) self.hs = hs self.state = hs.get_state_handler() @@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler): def snapshot_all_rooms( self, - user_id=None, - pagin_config=None, - as_client_event=True, - include_archived=False, - ): + user_id: str, + pagin_config: PaginationConfig, + as_client_event: bool = True, + include_archived: bool = False, + ) -> JsonDict: """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is joined, depending on the pagination config. Args: - user_id (str): The ID of the user making the request. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config used to determine how many messages *PER ROOM* to return. - as_client_event (bool): True to get events in client-server format. - include_archived (bool): True to get rooms that the user has left + user_id: The ID of the user making the request. + pagin_config: The pagination config used to determine how many + messages *PER ROOM* to return. + as_client_event: True to get events in client-server format. + include_archived: True to get rooms that the user has left Returns: - A list of dicts with "room_id" and "membership" keys for all rooms - the user is currently invited or joined in on. Rooms where the user - is joined on, may return a "messages" key with messages, depending - on the specified PaginationConfig. + A JsonDict with the same format as the response to `/intialSync` + API """ key = ( user_id, @@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler): async def _snapshot_all_rooms( self, - user_id=None, - pagin_config=None, - as_client_event=True, - include_archived=False, - ): + user_id: str, + pagin_config: PaginationConfig, + as_client_event: bool = True, + include_archived: bool = False, + ) -> JsonDict: memberships = [Membership.INVITE, Membership.JOIN] if include_archived: @@ -134,7 +138,7 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - async def handle_room(event): + async def handle_room(event: RoomsForUser): d = { "room_id": event.room_id, "membership": event.membership, @@ -251,17 +255,18 @@ class InitialSyncHandler(BaseHandler): return ret - async def room_initial_sync(self, requester, room_id, pagin_config=None): + async def room_initial_sync( + self, requester: Requester, room_id: str, pagin_config: PaginationConfig + ) -> JsonDict: """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. Args: - requester(Requester): The user to get a snapshot for. - room_id(str): The room to get a snapshot of. - pagin_config(synapse.streams.config.PaginationConfig): - The pagination config used to determine how many messages to - return. + requester: The user to get a snapshot for. + room_id: The room to get a snapshot of. + pagin_config: The pagination config used to determine how many + messages to return. Raises: AuthError if the user wasn't in the room. Returns: @@ -305,8 +310,14 @@ class InitialSyncHandler(BaseHandler): return result async def _room_initial_sync_parted( - self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking - ): + self, + user_id: str, + room_id: str, + pagin_config: PaginationConfig, + membership: Membership, + member_event_id: str, + is_peeking: bool, + ) -> JsonDict: room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -350,8 +361,13 @@ class InitialSyncHandler(BaseHandler): } async def _room_initial_sync_joined( - self, user_id, room_id, pagin_config, membership, is_peeking - ): + self, + user_id: str, + room_id: str, + pagin_config: PaginationConfig, + membership: Membership, + is_peeking: bool, + ) -> JsonDict: current_state = await self.state.get_current_state(room_id=room_id) # TODO: These concurrently diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5a1aa7d830..63d7edff87 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Set from twisted.python.failure import Failure @@ -30,6 +30,10 @@ from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) @@ -68,7 +72,7 @@ class PaginationHandler(object): paginating during a purge. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -78,9 +82,9 @@ class PaginationHandler(object): self._server_name = hs.hostname self.pagination_lock = ReadWriteLock() - self._purges_in_progress_by_room = set() + self._purges_in_progress_by_room = set() # type: Set[str] # map from purge id to PurgeStatus - self._purges_by_id = {} + self._purges_by_id = {} # type: Dict[str, PurgeStatus] self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime @@ -102,7 +106,9 @@ class PaginationHandler(object): job["longest_max_lifetime"], ) - async def purge_history_for_rooms_in_range(self, min_ms, max_ms): + async def purge_history_for_rooms_in_range( + self, min_ms: Optional[int], max_ms: Optional[int] + ): """Purge outdated events from rooms within the given retention range. If a default retention policy is defined in the server's configuration and its @@ -110,10 +116,10 @@ class PaginationHandler(object): retention policy. Args: - min_ms (int|None): Duration in milliseconds that define the lower limit of + min_ms: Duration in milliseconds that define the lower limit of the range to handle (exclusive). If None, it means that the range has no lower limit. - max_ms (int|None): Duration in milliseconds that define the upper limit of + max_ms: Duration in milliseconds that define the upper limit of the range to handle (inclusive). If None, it means that the range has no upper limit. """ @@ -220,18 +226,19 @@ class PaginationHandler(object): "_purge_history", self._purge_history, purge_id, room_id, token, True, ) - def start_purge_history(self, room_id, token, delete_local_events=False): + def start_purge_history( + self, room_id: str, token: str, delete_local_events: bool = False + ) -> str: """Start off a history purge on a room. Args: - room_id (str): The room to purge from - - token (str): topological token to delete events before - delete_local_events (bool): True to delete local events as well as + room_id: The room to purge from + token: topological token to delete events before + delete_local_events: True to delete local events as well as remote ones Returns: - str: unique ID for this purge transaction. + unique ID for this purge transaction. """ if room_id in self._purges_in_progress_by_room: raise SynapseError( @@ -284,14 +291,11 @@ class PaginationHandler(object): self.hs.get_reactor().callLater(24 * 3600, clear_purge) - def get_purge_status(self, purge_id): + def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]: """Get the current status of an active purge Args: - purge_id (str): purge_id returned by start_purge_history - - Returns: - PurgeStatus|None + purge_id: purge_id returned by start_purge_history """ return self._purges_by_id.get(purge_id) @@ -312,8 +316,8 @@ class PaginationHandler(object): async def get_messages( self, requester: Requester, - room_id: Optional[str] = None, - pagin_config: Optional[PaginationConfig] = None, + room_id: str, + pagin_config: PaginationConfig, as_client_event: bool = True, event_filter: Optional[Filter] = None, ) -> Dict[str, Any]: @@ -368,11 +372,15 @@ class PaginationHandler(object): # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. + + # This is only None if the room is world_readable, in which + # case "JOIN" would have been returned. + assert member_event_id + leave_token = await self.store.get_topological_token_for_event( member_event_id ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < max_topo: + if RoomStreamToken.parse(leave_token).topological < max_topo: source_config.from_key = str(leave_token) await self.hs.get_handlers().federation_handler.maybe_backfill( @@ -419,8 +427,8 @@ class PaginationHandler(object): ) if state_ids: - state = await self.store.get_events(list(state_ids.values())) - state = state.values() + state_dict = await self.store.get_events(list(state_ids.values())) + state = state_dict.values() time_now = self.clock.time_msec() -- cgit 1.4.1