From 3857de2194e3b2057c4af71e095eb6759508f25f Mon Sep 17 00:00:00 2001 From: lugino-emeritus Date: Tue, 28 Jul 2020 14:41:44 +0200 Subject: Option to allow server admins to join complex rooms (#7902) Fixes #7901. Signed-off-by: Niklas Tittjung --- synapse/handlers/room_member.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'synapse/handlers/room_member.py') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a1a8fa1d3b..5a40e8c144 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -952,7 +952,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - if self.hs.config.limit_remote_rooms.enabled: + check_complexity = self.hs.config.limit_remote_rooms.enabled + if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = not await self.hs.auth.is_server_admin(user) + + if check_complexity: # Fetch the room complexity too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts @@ -975,7 +979,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. - if self.hs.config.limit_remote_rooms.enabled: + if check_complexity: if too_complex is False: # We checked, and we're under the limit. return event_id, stream_id -- cgit 1.5.1 From 0a7fb2471685cf4bccfd76f51ddd614a3aeb4536 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 30 Jul 2020 16:58:57 +0100 Subject: Fix invite rejection when we have no forward-extremeties (#7980) Thanks to some slightly overzealous cleanup in the `delete_old_current_state_events`, it's possible to end up with no `event_forward_extremities` in a room where we have outstanding local invites. The user would then get a "no create event in auth events" when trying to reject the invite. We can hack around it by using the dangling invite as the prev event. --- changelog.d/7980.bugfix | 1 + synapse/handlers/room_member.py | 29 +++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7980.bugfix (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/7980.bugfix b/changelog.d/7980.bugfix new file mode 100644 index 0000000000..fa351b4b77 --- /dev/null +++ b/changelog.d/7980.bugfix @@ -0,0 +1 @@ +Fix "no create event in auth events" when trying to reject invitation after inviter leaves. Bug introduced in Synapse v1.10.0. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5a40e8c144..78586a0a1e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -469,26 +469,39 @@ class RoomMemberHandler(object): user_id=target.to_string(), room_id=room_id ) # type: Optional[RoomsForUser] if not invite: + logger.info( + "%s sent a leave request to %s, but that is not an active room " + "on this server, and there is no pending invite", + target, + room_id, + ) + raise SynapseError(404, "Not a known room") logger.info( "%s rejects invite to %s from %s", target, room_id, invite.sender ) - if self.hs.is_mine_id(invite.sender): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: + if not self.hs.is_mine_id(invite.sender): # send the rejection to the inviter's HS (with fallback to # local event) return await self.remote_reject_invite( invite.event_id, txn_id, requester, content, ) + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath, which will also send the + # rejection out to any other servers we believe are still in the room. + + # thanks to overzealous cleaning up of event_forward_extremities in + # `delete_old_current_state_events`, it's possible to end up with no + # forward extremities here. If that happens, let's just hang the + # rejection off the invite event. + # + # see: https://github.com/matrix-org/synapse/issues/7139 + if len(latest_event_ids) == 0: + latest_event_ids = [invite.event_id] + return await self._local_membership_update( requester=requester, target=target, -- cgit 1.5.1 From 18de00adb4471a55b504f4afb9f29facf0a51785 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 31 Jul 2020 14:34:42 +0100 Subject: Add ratelimiting on joins --- docs/sample_config.yaml | 12 ++++++++++++ synapse/config/ratelimiting.py | 21 +++++++++++++++++++++ synapse/handlers/room_member.py | 37 +++++++++++++++++++++++++++++++++++-- tests/utils.py | 4 ++++ 4 files changed, 72 insertions(+), 2 deletions(-) (limited to 'synapse/handlers/room_member.py') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b21e36bb6d..fef503479e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -731,6 +731,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful # to allow room admins to deal with abuse quickly. +# - two for ratelimiting number of rooms a user can join, "local" for when +# users are joining rooms the server is already in (this is cheap) vs +# "remote" for when users are trying to join rooms not on the server (which +# can be more expensive) # # The defaults are as shown below. # @@ -756,6 +760,14 @@ log_config: "CONFDIR/SERVERNAME.log.config" #rc_admin_redaction: # per_second: 1 # burst_count: 50 +# +#rc_joins: +# local: +# per_second: 0.1 +# burst_count: 3 +# remote: +# per_second: 0.01 +# burst_count: 3 # Ratelimiting settings for incoming federation diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 2dd94bae2b..b2c78ac40c 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -93,6 +93,15 @@ class RatelimitConfig(Config): if rc_admin_redaction: self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) + self.rc_joins_local = RateLimitConfig( + config.get("rc_joins", {}).get("local", {}), + defaults={"per_second": 0.1, "burst_count": 3}, + ) + self.rc_joins_remote = RateLimitConfig( + config.get("rc_joins", {}).get("remote", {}), + defaults={"per_second": 0.01, "burst_count": 3}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -118,6 +127,10 @@ class RatelimitConfig(Config): # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful # to allow room admins to deal with abuse quickly. + # - two for ratelimiting number of rooms a user can join, "local" for when + # users are joining rooms the server is already in (this is cheap) vs + # "remote" for when users are trying to join rooms not on the server (which + # can be more expensive) # # The defaults are as shown below. # @@ -143,6 +156,14 @@ class RatelimitConfig(Config): #rc_admin_redaction: # per_second: 1 # burst_count: 50 + # + #rc_joins: + # local: + # per_second: 0.1 + # burst_count: 3 + # remote: + # per_second: 0.01 + # burst_count: 3 # Ratelimiting settings for incoming federation diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a1a8fa1d3b..822ca9da6a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -22,7 +22,8 @@ from unpaddedbase64 import encode_base64 from synapse import types from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import EventFormatVersions from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase @@ -77,6 +78,17 @@ class RoomMemberHandler(object): if self._is_on_event_persistence_instance: self.persist_event_storage = hs.get_storage().persistence + self._join_rate_limiter_local = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, + burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, + ) + self._join_rate_limiter_remote = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, + burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -441,7 +453,28 @@ class RoomMemberHandler(object): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - if not is_host_in_room: + if is_host_in_room: + time_now_s = self.clock.time() + allowed, time_allowed = self._join_rate_limiter_local.can_do_action( + requester.user.to_string(), + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + + else: + time_now_s = self.clock.time() + allowed, time_allowed = self._join_rate_limiter_remote.can_do_action( + requester.user.to_string(), + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) diff --git a/tests/utils.py b/tests/utils.py index ac643679aa..a8e85436f9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -154,6 +154,10 @@ def default_config(name, parse=False): "account": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 10000, "burst_count": 10000}, }, + "rc_joins": { + "local": {"per_second": 10000, "burst_count": 10000}, + "remote": {"per_second": 10000, "burst_count": 10000}, + }, "saml2_enabled": False, "public_baseurl": None, "default_identity_server": None, -- cgit 1.5.1 From 5dd73d029eff32668b3ca69b7fb8529fc7c58745 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2020 15:05:50 +0100 Subject: Add type hints to handlers.message and events.builder (#8067) --- changelog.d/8067.misc | 1 + mypy.ini | 3 ++ synapse/events/builder.py | 58 ++++++++++++++++++++----------------- synapse/handlers/message.py | 22 ++++++++------ synapse/handlers/room_member.py | 12 +++++--- tests/rest/client/test_retention.py | 4 ++- tox.ini | 2 ++ 7 files changed, 61 insertions(+), 41 deletions(-) create mode 100644 changelog.d/8067.misc (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8067.misc b/changelog.d/8067.misc new file mode 100644 index 0000000000..f4404b7506 --- /dev/null +++ b/changelog.d/8067.misc @@ -0,0 +1 @@ +Add type hints to `synapse.handlers.message` and `synapse.events.builder`. diff --git a/mypy.ini b/mypy.ini index a61009b197..c69cb5dc40 100644 --- a/mypy.ini +++ b/mypy.ini @@ -81,3 +81,6 @@ ignore_missing_imports = True [mypy-rust_python_jaeger_reporter.*] ignore_missing_imports = True + +[mypy-nacl.*] +ignore_missing_imports = True diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 4e179d49b3..9ed24380dd 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -17,6 +17,7 @@ from typing import Optional import attr from nacl.signing import SigningKey +from synapse.api.auth import Auth from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -27,6 +28,8 @@ from synapse.api.room_versions import ( ) from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.state import StateHandler +from synapse.storage.databases.main import DataStore from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string @@ -42,45 +45,46 @@ class EventBuilder(object): Attributes: room_version: Version of the target room - room_id (str) - type (str) - sender (str) - content (dict) - unsigned (dict) - internal_metadata (_EventInternalMetadata) - - _state (StateHandler) - _auth (synapse.api.Auth) - _store (DataStore) - _clock (Clock) - _hostname (str): The hostname of the server creating the event + room_id + type + sender + content + unsigned + internal_metadata + + _state + _auth + _store + _clock + _hostname: The hostname of the server creating the event _signing_key: The signing key to use to sign the event as the server """ - _state = attr.ib() - _auth = attr.ib() - _store = attr.ib() - _clock = attr.ib() - _hostname = attr.ib() - _signing_key = attr.ib() + _state = attr.ib(type=StateHandler) + _auth = attr.ib(type=Auth) + _store = attr.ib(type=DataStore) + _clock = attr.ib(type=Clock) + _hostname = attr.ib(type=str) + _signing_key = attr.ib(type=SigningKey) room_version = attr.ib(type=RoomVersion) - room_id = attr.ib() - type = attr.ib() - sender = attr.ib() + room_id = attr.ib(type=str) + type = attr.ib(type=str) + sender = attr.ib(type=str) - content = attr.ib(default=attr.Factory(dict)) - unsigned = attr.ib(default=attr.Factory(dict)) + content = attr.ib(default=attr.Factory(dict), type=JsonDict) + unsigned = attr.ib(default=attr.Factory(dict), type=JsonDict) # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. - _state_key = attr.ib(default=None) - _redacts = attr.ib(default=None) - _origin_server_ts = attr.ib(default=None) + _state_key = attr.ib(default=None, type=Optional[str]) + _redacts = attr.ib(default=None, type=Optional[str]) + _origin_server_ts = attr.ib(default=None, type=Optional[int]) internal_metadata = attr.ib( - default=attr.Factory(lambda: _EventInternalMetadata({})) + default=attr.Factory(lambda: _EventInternalMetadata({})), + type=_EventInternalMetadata, ) @property diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8ddded8389..2643438e84 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from canonicaljson import encode_canonical_json, json @@ -93,11 +93,11 @@ class MessageHandler(object): async def get_room_data( self, - user_id: str = None, - room_id: str = None, - event_type: Optional[str] = None, - state_key: str = "", - is_guest: bool = False, + user_id: str, + room_id: str, + event_type: str, + state_key: str, + is_guest: bool, ) -> dict: """ Get data from a room. @@ -407,7 +407,7 @@ class EventCreationHandler(object): # # map from room id to time-of-last-attempt. # - self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int] + self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] # we need to construct a ConsentURIBuilder here, as it checks that the necessary # config options, but *only* if we have a configuration for which we are @@ -707,7 +707,7 @@ class EventCreationHandler(object): async def create_and_send_nonmember_event( self, requester: Requester, - event_dict: EventBase, + event_dict: dict, ratelimit: bool = True, txn_id: Optional[str] = None, ) -> Tuple[EventBase, int]: @@ -971,7 +971,7 @@ class EventCreationHandler(object): # Validate a newly added alias or newly added alt_aliases. original_alias = None - original_alt_aliases = set() + original_alt_aliases = [] # type: List[str] original_event_id = event.unsigned.get("replaces_state") if original_event_id: @@ -1019,6 +1019,10 @@ class EventCreationHandler(object): current_state_ids = await context.get_current_state_ids() + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + state_to_include_ids = [ e_id for k, e_id in current_state_ids.items() diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8e409f24e8..31705cdbdb 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -16,7 +16,7 @@ import abc import logging from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 @@ -37,6 +37,10 @@ from synapse.util.distributor import user_joined_room, user_left_room from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) @@ -48,7 +52,7 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -207,7 +211,7 @@ class RoomMemberHandler(object): return duplicate.event_id, stream_id stream_id = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit + requester, event, context, extra_users=[target], ratelimit=ratelimit, ) prev_state_ids = await context.get_prev_state_ids() @@ -1000,7 +1004,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): check_complexity = self.hs.config.limit_remote_rooms.enabled if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: - check_complexity = not await self.hs.auth.is_server_admin(user) + check_complexity = not await self.auth.is_server_admin(user) if check_complexity: # Fetch the room complexity diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index e54ffea150..0b191d13c6 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -144,7 +144,9 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Get the create event to, later, check that we can still access it. message_handler = self.hs.get_message_handler() create_event = self.get_success( - message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + message_handler.get_room_data( + self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False + ) ) # Send a first event to the room. This is the event we'll want to be purged at the diff --git a/tox.ini b/tox.ini index 45e129580f..e5413eb110 100644 --- a/tox.ini +++ b/tox.ini @@ -179,6 +179,7 @@ commands = mypy \ synapse/appservice \ synapse/config \ synapse/event_auth.py \ + synapse/events/builder.py \ synapse/events/spamcheck.py \ synapse/federation \ synapse/handlers/auth.py \ @@ -186,6 +187,7 @@ commands = mypy \ synapse/handlers/directory.py \ synapse/handlers/federation.py \ synapse/handlers/identity.py \ + synapse/handlers/message.py \ synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ synapse/handlers/room_member.py \ -- cgit 1.5.1 From f40645e60b9cab69c953094848be61c0989a91cb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 16:20:49 -0400 Subject: Convert events worker database to async/await. (#8071) --- changelog.d/8071.misc | 1 + synapse/event_auth.py | 2 +- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 6 +- synapse/handlers/room_member.py | 2 +- synapse/spam_checker_api/__init__.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 30 +++-- synapse/storage/databases/main/events_worker.py | 132 ++++++++++++--------- synapse/storage/databases/main/stream.py | 1 - .../test_resource_limits_server_notices.py | 6 +- tests/storage/test_appservice.py | 3 +- 12 files changed, 106 insertions(+), 97 deletions(-) create mode 100644 changelog.d/8071.misc (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8071.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c0981eee62..8c907ad596 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb7..5b270228e7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler): auth_types = auth_types_for_event(event) current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 532fc30681..b999d91d1a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -960,7 +960,7 @@ class EventCreationHandler(object): allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -1080,8 +1080,8 @@ class EventCreationHandler(object): auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb..aa1ccde211 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -716,7 +716,7 @@ class RoomMemberHandler(object): guest_access = await self.store.get_event(guest_access_id) - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96..4d9b13ac04 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -51,5 +51,5 @@ class SpamCheckerApi(object): state_ids = yield self._store.get_filtered_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1d3884667..dba8d91eef 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -641,7 +641,7 @@ class StateResolutionStore(object): allow_rejected (bool): If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. """ return self.store.get_events( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 431bd76693..4826be630c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. Args: @@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) + ) + return await self.get_events_as_list(event_ids) def get_auth_chain_ids( self, @@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8c63a0dc4d..e3a154a527 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,9 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -32,7 +33,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - @defer.inlineCallbacks - def get_event( + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... + + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... + + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield defer.ensureDeferred( - self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) + rows = await self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", ) return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4377bddb8c..497f607703 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], str]: - """Get new room events in stream ordering since `from_key`. Args: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 2858d13558..23db821fb7 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index a425e66f37..17fbde284a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) -- cgit 1.5.1 From e259d63f73fd7599520d0c4a6f5082e5cd383d25 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:07:42 -0400 Subject: Stop shadow-banned users from sending invites. (#8095) --- changelog.d/8095.feature | 1 + synapse/api/errors.py | 8 +++ synapse/handlers/room.py | 16 +++++- synapse/handlers/room_member.py | 62 ++++++++++++++++++++++- synapse/rest/admin/rooms.py | 3 ++ synapse/rest/client/v1/room.py | 67 +++++++++++++++---------- tests/rest/client/v1/test_rooms.py | 100 +++++++++++++++++++++++++++++++++++++ 7 files changed, 226 insertions(+), 31 deletions(-) create mode 100644 changelog.d/8095.feature (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8095.feature b/changelog.d/8095.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8095.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/api/errors.py b/synapse/api/errors.py index a3f314118a..4888c0ec4d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -604,3 +604,11 @@ class HttpResponseException(CodeMessageException): errmsg = j.pop("error", self.msg) return ProxiedRequestError(self.code, errmsg, errcode, j) + + +class ShadowBanError(Exception): + """ + Raised when a shadow-banned user attempts to perform an action. + + This should be caught and a proper "fake" success response sent to the user. + """ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 442cca28e6..0fc71475c3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,6 +20,7 @@ import itertools import logging import math +import random import string from collections import OrderedDict from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple @@ -626,6 +627,7 @@ class RoomCreationHandler(BaseHandler): if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) + invite_3pid_list = config.get("invite_3pid", []) invite_list = config.get("invite", []) for i in invite_list: try: @@ -634,6 +636,14 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) + if (invite_list or invite_3pid_list) and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + + # Allow the request to go through, but remove any associated invites. + invite_3pid_list = [] + invite_list = [] + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") @@ -648,8 +658,6 @@ class RoomCreationHandler(BaseHandler): % (user_id,), ) - invite_3pid_list = config.get("invite_3pid", []) - visibility = config.get("visibility", None) is_public = visibility == "public" @@ -744,6 +752,8 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct + # Note that update_membership with an action of "invite" can raise a + # ShadowBanError, but this was handled above by emptying invite_list. _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), @@ -758,6 +768,8 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] + # Note that do_3pid_invite can raise a ShadowBanError, but this was + # handled above by emptying invite_3pid_list. last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index aa1ccde211..3a6ee6378d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -15,6 +15,7 @@ import abc import logging +import random from http import HTTPStatus from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union @@ -22,7 +23,13 @@ from unpaddedbase64 import encode_base64 from synapse import types from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + LimitExceededError, + ShadowBanError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import EventFormatVersions from synapse.crypto.event_signing import compute_event_reference_hash @@ -285,6 +292,31 @@ class RoomMemberHandler(object): content: Optional[dict] = None, require_consent: bool = True, ) -> Tuple[str, int]: + """Update a user's membership in a room. + + Params: + requester: The user who is performing the update. + target: The user whose membership is being updated. + room_id: The room ID whose membership is being updated. + action: The membership change, see synapse.api.constants.Membership. + txn_id: The transaction ID, if given. + remote_room_hosts: Remote servers to send the update to. + third_party_signed: Information from a 3PID invite. + ratelimit: Whether to rate limit the request. + content: The content of the created event. + require_consent: Whether consent is required. + + Returns: + A tuple of the new event ID and stream ID. + + Raises: + ShadowBanError if a shadow-banned requester attempts to send an invite. + """ + if action == Membership.INVITE and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -773,6 +805,25 @@ class RoomMemberHandler(object): txn_id: Optional[str], id_access_token: Optional[str] = None, ) -> int: + """Invite a 3PID to a room. + + Args: + room_id: The room to invite the 3PID to. + inviter: The user sending the invite. + medium: The 3PID's medium. + address: The 3PID's address. + id_server: The identity server to use. + requester: The user making the request. + txn_id: The transaction ID this is part of, or None if this is not + part of a transaction. + id_access_token: The optional identity server access token. + + Returns: + The new stream ID. + + Raises: + ShadowBanError if the requester has been shadow-banned. + """ if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -780,6 +831,11 @@ class RoomMemberHandler(object): 403, "Invites have been disabled on this server", Codes.FORBIDDEN ) + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. await self.base_handler.ratelimit(requester) @@ -804,6 +860,8 @@ class RoomMemberHandler(object): ) if invitee: + # Note that update_membership with an action of "invite" can raise + # a ShadowBanError, but this was done above already. _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) @@ -1042,7 +1100,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return event_id, stream_id # The room is too large. Leave. - requester = types.create_requester(user, None, False, None) + requester = types.create_requester(user, None, False, False, None) await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 7c292ef3f9..09726d52d6 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -316,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet): join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + # update_membership with an action of "invite" can raise a + # ShadowBanError. This is not handled since it is assumed that + # an admin isn't going to call this API with a shadow-banned user. await self.room_member_handler.update_membership( requester=requester, target=fake_requester.user, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f216382636..a9dd3a6aec 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -27,6 +27,7 @@ from synapse.api.errors import ( Codes, HttpResponseException, InvalidClientCredentialsError, + ShadowBanError, SynapseError, ) from synapse.api.filtering import Filter @@ -45,6 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder +from synapse.util.stringutils import random_string MYPY = False if MYPY: @@ -200,14 +202,17 @@ class RoomStateEventRestServlet(TransactionRestServlet): event_dict["state_key"] = state_key if event_type == EventTypes.Member: - membership = content.get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - requester, - target=UserID.from_string(state_key), - room_id=room_id, - action=membership, - content=content, - ) + try: + membership = content.get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + requester, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, + ) + except ShadowBanError: + event_id = "$" + random_string(43) else: ( event, @@ -719,16 +724,20 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - await self.room_member_handler.do_3pid_invite( - room_id, - requester.user, - content["medium"], - content["address"], - content["id_server"], - requester, - txn_id, - content.get("id_access_token"), - ) + try: + await self.room_member_handler.do_3pid_invite( + room_id, + requester.user, + content["medium"], + content["address"], + content["id_server"], + requester, + txn_id, + content.get("id_access_token"), + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return 200, {} target = requester.user @@ -740,15 +749,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content: event_content = {"reason": content["reason"]} - await self.room_member_handler.update_membership( - requester=requester, - target=target, - room_id=room_id, - action=membership_action, - txn_id=txn_id, - third_party_signed=content.get("third_party_signed", None), - content=event_content, - ) + try: + await self.room_member_handler.update_membership( + requester=requester, + target=target, + room_id=room_id, + action=membership_action, + txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), + content=event_content, + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return_value = {} diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ef6b775ed2..e674eb90d7 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1974,3 +1974,103 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) + + +class ShadowBannedTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.banned_user_id = self.register_user("banned", "test") + self.banned_access_token = self.login("banned", "test") + + self.store = self.hs.get_datastore() + + self.get_success( + self.store.db_pool.simple_update( + table="users", + keyvalues={"name": self.banned_user_id}, + updatevalues={"shadow_banned": True}, + desc="shadow_ban", + ) + ) + + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def test_invite(self): + """Invites from shadow-banned users don't actually get sent.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + self.helper.invite( + room=room_id, + src=self.banned_user_id, + tok=self.banned_access_token, + targ=self.other_user_id, + ) + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + def test_invite_3pid(self): + """Ensure that a 3PID invite does not attempt to contact the identity server.""" + identity_handler = self.hs.get_handlers().identity_handler + identity_handler.lookup_3pid = Mock( + side_effect=AssertionError("This should not get called") + ) + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + request, channel = self.make_request( + "POST", + "/rooms/%s/invite" % (room_id,), + {"id_server": "test", "medium": "email", "address": "test@test.test"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # This should have raised an error earlier, but double check this wasn't called. + identity_handler.lookup_3pid.assert_not_called() + + def test_create_room(self): + """Invitations during a room creation should be discarded, but the room still gets created.""" + # The room creation is successful. + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + {"visibility": "public", "invite": [self.other_user_id]}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + room_id = channel.json_body["room_id"] + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + # Since a real room was created, the other user should be able to join it. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + # Both users should be in the room. + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) -- cgit 1.5.1 From cbbf9126cbd2ace90c1c0f615b87bcec30fdcbd8 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Fri, 21 Aug 2020 15:07:56 +0100 Subject: Do not apply ratelimiting on joins to appservices (#8139) Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited. Co-authored-by: Patrick Cloke Co-authored-by: Erik Johnston --- changelog.d/8139.bugfix | 1 + synapse/api/ratelimiting.py | 37 +++++++++++++++++++++ synapse/handlers/room_member.py | 14 ++++---- tests/api/test_ratelimiting.py | 73 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8139.bugfix (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8139.bugfix b/changelog.d/8139.bugfix new file mode 100644 index 0000000000..21f65d87b7 --- /dev/null +++ b/changelog.d/8139.bugfix @@ -0,0 +1 @@ +Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index ec6b3a69a2..e62ae50ac2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.types import Requester from synapse.util import Clock @@ -43,6 +44,42 @@ class Ratelimiter(object): # * The rate_hz of this particular entry. This can vary per request self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + def can_requester_do_action( + self, + requester: Requester, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: + """Can the requester perform the action? + + Args: + requester: The requester to key off when rate limiting. The user property + will be used. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Returns: + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero + """ + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + return self.can_do_action( + requester.user.to_string(), rate_hz, burst_count, update, _time_now_s + ) + def can_do_action( self, key: Any, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 3a6ee6378d..a03cb02792 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -491,9 +491,10 @@ class RoomMemberHandler(object): if is_host_in_room: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_local.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_local.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( @@ -502,9 +503,10 @@ class RoomMemberHandler(object): else: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_remote.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_remote.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index d580e729c5..1e1f30d790 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,4 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter +from synapse.appservice import ApplicationService +from synapse.types import create_requester from tests import unittest @@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) + def test_allowed_user_via_can_requester_do_action(self): + user_requester = create_requester("@user:example.com") + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=True, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=False, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + def test_allowed_via_ratelimit(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) -- cgit 1.5.1 From 2df82ae451e03d76fae5381961dd6229d5796400 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Fri, 21 Aug 2020 15:07:56 +0100 Subject: Do not apply ratelimiting on joins to appservices (#8139) Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited. Co-authored-by: Patrick Cloke Co-authored-by: Erik Johnston --- changelog.d/8139.bugfix | 1 + synapse/api/ratelimiting.py | 37 +++++++++++++++++++++ synapse/handlers/room_member.py | 14 ++++---- tests/api/test_ratelimiting.py | 73 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8139.bugfix (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8139.bugfix b/changelog.d/8139.bugfix new file mode 100644 index 0000000000..21f65d87b7 --- /dev/null +++ b/changelog.d/8139.bugfix @@ -0,0 +1 @@ +Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index ec6b3a69a2..e62ae50ac2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.types import Requester from synapse.util import Clock @@ -43,6 +44,42 @@ class Ratelimiter(object): # * The rate_hz of this particular entry. This can vary per request self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + def can_requester_do_action( + self, + requester: Requester, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: + """Can the requester perform the action? + + Args: + requester: The requester to key off when rate limiting. The user property + will be used. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Returns: + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero + """ + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + return self.can_do_action( + requester.user.to_string(), rate_hz, burst_count, update, _time_now_s + ) + def can_do_action( self, key: Any, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb..0cd59bce3b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -459,9 +459,10 @@ class RoomMemberHandler(object): if is_host_in_room: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_local.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_local.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( @@ -470,9 +471,10 @@ class RoomMemberHandler(object): else: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_remote.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_remote.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index d580e729c5..1e1f30d790 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,4 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter +from synapse.appservice import ApplicationService +from synapse.types import create_requester from tests import unittest @@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) + def test_allowed_user_via_can_requester_do_action(self): + user_requester = create_requester("@user:example.com") + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=True, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=False, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + def test_allowed_via_ratelimit(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) -- cgit 1.5.1 From 393a811a41d51d7967f6d455017176a20eacd85c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 24 Aug 2020 18:06:04 +0100 Subject: Fix join ratelimiter breaking profile updates and idempotency (#8153) --- changelog.d/8153.bugfix | 1 + synapse/handlers/room_member.py | 46 +++++++++++--------- tests/rest/client/v1/test_rooms.py | 87 +++++++++++++++++++++++++++++++++++++- tests/rest/client/v1/utils.py | 10 +++-- 4 files changed, 119 insertions(+), 25 deletions(-) create mode 100644 changelog.d/8153.bugfix (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8153.bugfix b/changelog.d/8153.bugfix new file mode 100644 index 0000000000..87a1f46ca1 --- /dev/null +++ b/changelog.d/8153.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0cd59bce3b..9fcabb22c7 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -210,24 +210,40 @@ class RoomMemberHandler(object): _, stream_id = await self.store.get_event_ordering(duplicate.event_id) return duplicate.event_id, stream_id - stream_id = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit, - ) - prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + newly_joined = False if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has actually joined the - # room. Don't bother if the user is just changing their profile - # info. newly_joined = True if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN + + # Only rate-limit if the user actually joined the room, otherwise we'll end + # up blocking profile updates. if newly_joined: - await self._user_joined_room(target, room_id) + time_now_s = self.clock.time() + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_local.can_requester_do_action(requester) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + + stream_id = await self.event_creation_handler.handle_new_client_event( + requester, event, context, extra_users=[target], ratelimit=ratelimit, + ) + + if event.membership == Membership.JOIN and newly_joined: + # Only fire user_joined_room if the user has actually joined the + # room. Don't bother if the user is just changing their profile + # info. + await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) @@ -457,19 +473,7 @@ class RoomMemberHandler(object): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - if is_host_in_room: - time_now_s = self.clock.time() - ( - allowed, - time_allowed, - ) = self._join_rate_limiter_local.can_requester_do_action(requester,) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) - ) - - else: + if not is_host_in_room: time_now_s = self.clock.time() ( allowed, diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ef6b775ed2..e74bddc1e5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias +from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string from tests import unittest @@ -675,6 +675,91 @@ class RoomMemberStateTestCase(RoomBase): self.assertEquals(json.loads(content), channel.json_body) +class RoomJoinRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}} + ) + def test_join_local_ratelimit(self): + """Tests that local joins are actually rate-limited.""" + for i in range(5): + self.helper.create_room_as(self.user_id) + + self.helper.create_room_as(self.user_id, expect_code=429) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}} + ) + def test_join_local_ratelimit_profile_change(self): + """Tests that sending a profile update into all of the user's joined rooms isn't + rate-limited by the rate-limiter on joins.""" + + # Create and join more rooms than the rate-limiting config allows in a second. + room_ids = [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + self.reactor.advance(1) + room_ids = room_ids + [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + + # Create a profile for the user, since it hasn't been done on registration. + store = self.hs.get_datastore() + store.create_profile(UserID.from_string(self.user_id).localpart) + + # Update the display name for the user. + path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id + request, channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + self.render(request) + self.assertEquals(channel.code, 200, channel.json_body) + + # Check that all the rooms have been sent a profile update into. + for room_id in room_ids: + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( + room_id, + self.user_id, + ) + + request, channel = self.make_request("GET", path) + self.render(request) + self.assertEquals(channel.code, 200) + + self.assertIn("displayname", channel.json_body) + self.assertEquals(channel.json_body["displayname"], "John Doe") + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}} + ) + def test_join_local_ratelimit_idempotent(self): + """Tests that the room join endpoints remain idempotent despite rate-limiting + on room joins.""" + room_id = self.helper.create_room_as(self.user_id) + + # Let's test both paths to be sure. + paths_to_test = [ + "/_matrix/client/r0/rooms/%s/join", + "/_matrix/client/r0/join/%s", + ] + + for path in paths_to_test: + # Make sure we send more requests than the rate-limiting config would allow + # if all of these requests ended up joining the user to a room. + for i in range(6): + request, channel = self.make_request("POST", path % room_id, {}) + self.render(request) + self.assertEquals(channel.code, 200) + + class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8933b560d2..e66c9a4c4c 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -39,7 +39,9 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator=None, is_public=True, tok=None): + def create_room_as( + self, room_creator=None, is_public=True, tok=None, expect_code=200, + ): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" @@ -54,9 +56,11 @@ class RestHelper(object): ) render(request, self.resource, self.hs.get_reactor()) - assert channel.result["code"] == b"200", channel.result + assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - return channel.json_body["room_id"] + + if expect_code == 200: + return channel.json_body["room_id"] def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): self.change_membership( -- cgit 1.5.1 From 5758dcf30c245efa1032385cd1af7853d39642a9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Aug 2020 14:25:27 -0400 Subject: Add type hints for state. (#8140) --- changelog.d/8140.misc | 1 + stubs/frozendict.pyi | 47 +++++++ synapse/federation/sender/__init__.py | 4 +- synapse/handlers/federation.py | 10 +- synapse/handlers/presence.py | 6 +- synapse/handlers/room_member.py | 20 +-- synapse/state/__init__.py | 192 +++++++++++++++---------- synapse/state/v1.py | 87 ++++++++---- synapse/state/v2.py | 255 ++++++++++++++++++++++------------ tox.ini | 1 + 10 files changed, 420 insertions(+), 203 deletions(-) create mode 100644 changelog.d/8140.misc create mode 100644 stubs/frozendict.pyi (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8140.misc b/changelog.d/8140.misc new file mode 100644 index 0000000000..78d8834328 --- /dev/null +++ b/changelog.d/8140.misc @@ -0,0 +1 @@ +Add type hints to `synapse.state`. diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi new file mode 100644 index 0000000000..3f3af59f26 --- /dev/null +++ b/stubs/frozendict.pyi @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stub for frozendict. + +from typing import ( + Any, + Hashable, + Iterable, + Iterator, + Mapping, + overload, + Tuple, + TypeVar, +) + +_KT = TypeVar("_KT", bound=Hashable) # Key type. +_VT = TypeVar("_VT") # Value type. + +class frozendict(Mapping[_KT, _VT]): + @overload + def __init__(self, **kwargs: _VT) -> None: ... + @overload + def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + def __getitem__(self, key: _KT) -> _VT: ... + def __contains__(self, key: Any) -> bool: ... + def copy(self, **add_or_replace: Any) -> frozendict: ... + def __iter__(self) -> Iterator[_KT]: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index e53b6ac456..4662008bfd 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -329,10 +329,10 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = await self.state.get_current_hosts_in_room(room_id) + domains_set = await self.state.get_current_hosts_in_room(room_id) domains = [ d - for d in domains + for d in domains_set if d != self.server_name and self._federation_shard_config.should_handle(self._instance_name, d) ] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5b270228e7..f8b234cee2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2134,10 +2134,10 @@ class FederationHandler(BaseHandler): ) state_sets = list(state_sets.values()) state_sets.append(state) - current_state_ids = await self.state_handler.resolve_events( + current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = {k: e.event_id for k, e in current_state_ids.items()} + current_state_ids = {k: e.event_id for k, e in current_states.items()} else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2149,9 +2149,11 @@ class FederationHandler(BaseHandler): # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) - current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] + current_state_ids_list = [ + e for k, e in current_state_ids.items() if k in auth_types + ] - auth_events_map = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids_list) current_auth_events = { (e.type, e.state_key): e for e in auth_events_map.values() } diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 24e1940ee5..1846068150 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure @@ -1318,7 +1318,7 @@ async def get_interested_parties( async def get_interested_remotes( store: DataStore, states: List[UserPresenceState], state_handler: StateHandler -) -> List[Tuple[List[str], List[UserPresenceState]]]: +) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1334,7 +1334,7 @@ async def get_interested_remotes( each tuple the list of UserPresenceState should be sent to each destination """ - hosts_and_states = [] + hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]] # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a03cb02792..52548087a9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,7 +17,7 @@ import abc import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 @@ -38,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID +from synapse.types import ( + Collection, + JsonDict, + Requester, + RoomAlias, + RoomID, + StateMap, + UserID, +) from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -738,9 +746,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join( - self, current_state_ids: Dict[Tuple[str, str], str] - ) -> bool: + async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool: """ Returns whether a guest can join a room based on its current state. """ @@ -969,9 +975,7 @@ class RoomMemberHandler(object): ) return stream_id - async def _is_host_in_room( - self, current_state_ids: Dict[Tuple[str, str], str] - ) -> bool: + async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dba8d91eef..a601303fa3 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,11 +16,22 @@ import logging from collections import namedtuple -from typing import Awaitable, Dict, Iterable, List, Optional, Set +from typing import ( + Awaitable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Union, + overload, +) import attr from frozendict import frozendict from prometheus_client import Histogram +from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions @@ -30,7 +41,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import StateMap +from synapse.types import Collection, StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -68,8 +79,14 @@ def _gen_state_id(): class _StateCacheEntry(object): __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] - def __init__(self, state, state_group, prev_group=None, delta_ids=None): - # dict[(str, str), str] map from (type, state_key) to event_id + def __init__( + self, + state: StateMap[str], + state_group: Optional[int], + prev_group: Optional[int] = None, + delta_ids: Optional[StateMap[str]] = None, + ): + # A map from (type, state_key) to event_id. self.state = frozendict(state) # the ID of a state group if one and only one is involved. @@ -107,24 +124,49 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() + @overload async def get_current_state( - self, room_id, event_type=None, state_key="", latest_event_ids=None - ): - """ Retrieves the current state for the room. This is done by + self, + room_id: str, + event_type: Literal[None] = None, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> StateMap[EventBase]: + ... + + @overload + async def get_current_state( + self, + room_id: str, + event_type: str, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> Optional[EventBase]: + ... + + async def get_current_state( + self, + room_id: str, + event_type: Optional[str] = None, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> Union[Optional[EventBase], StateMap[EventBase]]: + """Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. This is equivalent to getting the state of an event that were to send next before receiving any new events. - If `event_type` is specified, then the method returns only the one - event (or None) with that `event_type` and `state_key`. - Returns: - map from (type, state_key) to event + If `event_type` is specified, then the method returns only the one + event (or None) with that `event_type` and `state_key`. + + Otherwise, a map from (type, state_key) to event. """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) @@ -140,34 +182,30 @@ class StateHandler(object): state_map = await self.store.get_events( list(state.values()), get_prev_content=False ) - state = { + return { key: state_map[e_id] for key, e_id in state.items() if e_id in state_map } - return state - - async def get_current_state_ids(self, room_id, latest_event_ids=None): + async def get_current_state_ids( + self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None + ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room Args: - room_id (str): - - latest_event_ids (iterable[str]|None): if given, the forward - extremities to resolve. If None, we look them up from the - database (via a cache) + room_id: + latest_event_ids: if given, the forward extremities to resolve. If + None, we look them up from the database (via a cache). Returns: - Deferred[dict[(str, str), str)]]: the state dict, mapping from - (event_type, state_key) -> event_id + the state dict, mapping from (event_type, state_key) -> event_id """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = ret.state - - return state + return dict(ret.state) async def get_current_users_in_room( self, room_id: str, latest_event_ids: Optional[List[str]] = None @@ -183,32 +221,34 @@ class StateHandler(object): """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None + logger.debug("calling resolve_state_groups from get_current_users_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - joined_users = await self.store.get_joined_users_from_state(room_id, entry) - return joined_users + return await self.store.get_joined_users_from_state(room_id, entry) - async def get_current_hosts_in_room(self, room_id): + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: event_ids = await self.store.get_latest_event_ids_in_room(room_id) return await self.get_hosts_in_room_at_events(room_id, event_ids) - async def get_hosts_in_room_at_events(self, room_id, event_ids): + async def get_hosts_in_room_at_events( + self, room_id: str, event_ids: List[str] + ) -> Set[str]: """Get the hosts that were in a room at the given event ids Args: - room_id (str): - event_ids (list[str]): + room_id: + event_ids: Returns: - Deferred[list[str]]: the hosts in the room at the given events + The hosts in the room at the given events """ entry = await self.resolve_state_groups_for_events(room_id, event_ids) - joined_hosts = await self.store.get_joined_hosts(room_id, entry) - return joined_hosts + return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None - ): + ) -> EventContext: """Build an EventContext structure for the event. This works out what the current state should be for the event, and @@ -221,7 +261,7 @@ class StateHandler(object): when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. Returns: - synapse.events.snapshot.EventContext: + The event context. """ if event.internal_metadata.is_outlier(): @@ -275,7 +315,7 @@ class StateHandler(object): event.room_id, event.prev_event_ids() ) - state_ids_before_event = entry.state + state_ids_before_event = dict(entry.state) state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids @@ -346,19 +386,18 @@ class StateHandler(object): ) @measure_func() - async def resolve_state_groups_for_events(self, room_id, event_ids): + async def resolve_state_groups_for_events( + self, room_id: str, event_ids: Iterable[str] + ) -> _StateCacheEntry: """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Args: - room_id (str) - event_ids (list[str]) - explicit_room_version (str|None): If set uses the the given room - version to choose the resolution algorithm. If None, then - checks the database for room version. + room_id + event_ids Returns: - Deferred[_StateCacheEntry]: resolved state + The resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) @@ -394,7 +433,12 @@ class StateHandler(object): ) return result - async def resolve_events(self, room_version, state_sets, event): + async def resolve_events( + self, + room_version: str, + state_sets: Collection[Iterable[EventBase]], + event: EventBase, + ) -> StateMap[EventBase]: logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -414,9 +458,7 @@ class StateHandler(object): state_res_store=StateResolutionStore(self.store), ) - new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()} - - return new_state + return {key: state_map[ev_id] for key, ev_id in new_state.items()} class StateResolutionHandler(object): @@ -444,7 +486,12 @@ class StateResolutionHandler(object): @log_function async def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_res_store + self, + room_id: str, + room_version: str, + state_groups_ids: Dict[int, StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", ): """Resolves conflicts between a set of state groups @@ -452,13 +499,13 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging and sanity checks) - room_version (str): version of the room - state_groups_ids (dict[int, dict[(str, str), str]]): - map from state group id to the state in that state group + room_id: room we are resolving for (used for logging and sanity checks) + room_version: version of the room + state_groups_ids: + A map from state group id to the state in that state group (where 'state' is a map from state key to event id) - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -466,10 +513,10 @@ class StateResolutionHandler(object): If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store Returns: - _StateCacheEntry: resolved state + The resolved state """ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) @@ -530,21 +577,22 @@ class StateResolutionHandler(object): return cache -def _make_state_cache_entry(new_state, state_groups_ids): +def _make_state_cache_entry( + new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] +) -> _StateCacheEntry: """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed _StateCacheEntry. Args: - new_state (dict[(str, str), str]): resolved state map (mapping from - (type, state_key) to event_id) + new_state: resolved state map (mapping from (type, state_key) to event_id) - state_groups_ids (dict[int, dict[(str, str), str]]): - map from state group id to the state in that state group - (where 'state' is a map from state key to event id) + state_groups_ids: + map from state group id to the state in that state group (where + 'state' is a map from state key to event id) Returns: - _StateCacheEntry + The cache entry. """ # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id @@ -585,7 +633,7 @@ def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ) -> Awaitable[StateMap[str]]: @@ -633,15 +681,17 @@ class StateResolutionStore(object): store = attr.ib() - def get_events(self, event_ids, allow_rejected=False): + def get_events( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - allow_rejected (bool): If True return rejected events. + event_ids: The event_ids of the events to fetch + allow_rejected: If True return rejected events. Returns: - Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. + An awaitable which resolves to a dict from event_id to event. """ return self.store.get_events( @@ -651,7 +701,9 @@ class StateResolutionStore(object): allow_rejected=allow_rejected, ) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + def get_auth_chain_difference( + self, state_sets: List[Set[str]] + ) -> Awaitable[Set[str]]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -660,7 +712,7 @@ class StateResolutionStore(object): chain. Returns: - Deferred[Set[str]]: Set of event IDs. + An awaitable that resolves to a set of event IDs. """ return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index ab5e24841d..0eb7fdd9e5 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,7 +15,17 @@ import hashlib import logging -from typing import Awaitable, Callable, Dict, List, Optional +from typing import ( + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "") async def resolve_events_with_store( room_id: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[List[str]], Awaitable], -): + state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], +) -> StateMap[str]: """ Args: room_id: the room we are working in @@ -56,8 +66,7 @@ async def resolve_events_with_store( an Awaitable that resolves to a dict of event_id to event. Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + A map from (type, state_key) to event_id. """ if len(state_sets) == 1: return state_sets[0] @@ -75,8 +84,8 @@ async def resolve_events_with_store( "Asking for %d/%d conflicted events", len(needed_events), needed_event_count ) - # dict[str, FrozenEvent]: a map from state event id to event. Only includes - # the state events which are in conflict (and those in event_map) + # A map from state event id to event. Only includes the state events which + # are in conflict (and those in event_map). state_map = await state_map_factory(needed_events) if event_map is not None: state_map.update(event_map) @@ -91,8 +100,6 @@ async def resolve_events_with_store( # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. - # - # dict[(str, str), str]: a map from state key to event id auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) @@ -122,29 +129,30 @@ async def resolve_events_with_store( ) -def _seperate(state_sets): +def _seperate( + state_sets: Iterable[StateMap[str]], +) -> Tuple[StateMap[str], StateMap[Set[str]]]: """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. Args: - state_sets(iterable[dict[(str, str), str]]): + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. Returns: - (dict[(str, str), str], dict[(str, str), set[str]]): - A tuple of (unconflicted_state, conflicted_state), where: + A tuple of (unconflicted_state, conflicted_state), where: - unconflicted_state is a dict mapping (type, state_key)->event_id - for unconflicted state keys. + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. - conflicted_state is a dict mapping (type, state_key) to a set of - event ids for conflicted state keys. + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} + conflicted_state = {} # type: StateMap[Set[str]] for state_set in state_set_iterator: for key, value in state_set.items(): @@ -171,7 +179,21 @@ def _seperate(state_sets): return unconflicted_state, conflicted_state -def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): +def _create_auth_events_from_maps( + unconflicted_state: StateMap[str], + conflicted_state: StateMap[Set[str]], + state_map: Dict[str, EventBase], +) -> StateMap[str]: + """ + + Args: + unconflicted_state: The unconflicted state map. + conflicted_state: The conflicted state map. + state_map: + + Returns: + A map from state key to event id. + """ auth_events = {} for event_ids in conflicted_state.values(): for event_id in event_ids: @@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma keys = event_auth.auth_types_for_event(state_map[event_id]) for key in keys: if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id + auth_event_id = unconflicted_state.get(key, None) + if auth_event_id: + auth_events[key] = auth_event_id return auth_events def _resolve_with_state( - unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map + unconflicted_state_ids: StateMap[str], + conflicted_state_ids: StateMap[Set[str]], + auth_event_ids: StateMap[str], + state_map: Dict[str, EventBase], ): conflicted_state = {} for key, event_ids in conflicted_state_ids.items(): @@ -215,7 +240,9 @@ def _resolve_with_state( return new_state -def _resolve_state_events(conflicted_state, auth_events): +def _resolve_state_events( + conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase] +) -> StateMap[EventBase]: """ This is where we actually decide which of the conflicted state to use. @@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events): return resolved_state -def _resolve_auth_events(events, auth_events): +def _resolve_auth_events( + events: List[EventBase], auth_events: StateMap[EventBase] +) -> EventBase: reverse = list(reversed(_ordered_events(events))) auth_keys = { @@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events): return event -def _resolve_normal_events(events, auth_events): +def _resolve_normal_events( + events: List[EventBase], auth_events: StateMap[EventBase] +) -> EventBase: for event in _ordered_events(events): try: # The signatures have already been checked at this point @@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events): return event -def _ordered_events(events): +def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]: def key_func(e): # we have to use utf-8 rather than ascii here because it turns out we allow # people to send us events with non-ascii event IDs :/ diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 6634955cdc..0e9ffbd6e6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,7 +16,21 @@ import heapq import itertools import logging -from typing import Dict, List, Optional +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + overload, +) + +from typing_extensions import Literal import synapse.state from synapse import event_auth @@ -40,10 +54,10 @@ async def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", -): +) -> StateMap[str]: """Resolves the state using the v2 state resolution algorithm Args: @@ -63,8 +77,7 @@ async def resolve_events_with_store( state_res_store: Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + A map from (type, state_key) to event_id. """ logger.debug("Computing conflicted state") @@ -171,18 +184,23 @@ async def resolve_events_with_store( return resolved_state -async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> int: """Return the power level of the sender of the given event according to their auth events. Args: - room_id (str) - event_id (str) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + room_id + event_id + event_map + state_res_store Returns: - Deferred[int] + The power level. """ event = await _get_event(room_id, event_id, event_map, state_res_store) @@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st return int(level) -async def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference( + state_sets: Sequence[StateMap[str]], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> Set[str]: """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. Args: - state_sets (list) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + state_sets + event_map + state_res_store Returns: - Deferred[set[str]]: Set of event IDs + Set of event IDs """ difference = await state_res_store.get_auth_chain_difference( @@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store): return difference -def _seperate(state_sets): +def _seperate( + state_sets: Iterable[StateMap[str]], +) -> Tuple[StateMap[str], StateMap[Set[str]]]: """Return the unconflicted and conflicted state. This is different than in the original algorithm, as this defines a key to be conflicted if one of the state sets doesn't have that key. Args: - state_sets (list) + state_sets Returns: - tuple[dict, dict]: A tuple of unconflicted and conflicted state. The - conflicted state dict is a map from type/state_key to set of event IDs + A tuple of unconflicted and conflicted state. The conflicted state dict + is a map from type/state_key to set of event IDs """ unconflicted_state = {} conflicted_state = {} @@ -260,18 +284,20 @@ def _seperate(state_sets): event_ids.discard(None) conflicted_state[key] = event_ids - return unconflicted_state, conflicted_state + # mypy doesn't understand that discarding None above means that conflicted + # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]]. + return unconflicted_state, conflicted_state # type: ignore -def _is_power_event(event): +def _is_power_event(event: EventBase) -> bool: """Return whether or not the event is a "power event", as defined by the v2 state resolution algorithm Args: - event (FrozenEvent) + event Returns: - boolean + True if the event is a power event. """ if (event.type, event.state_key) in ( (EventTypes.PowerLevels, ""), @@ -288,19 +314,23 @@ def _is_power_event(event): async def _add_event_and_auth_chain_to_graph( - graph, room_id, event_id, event_map, state_res_store, auth_diff -): + graph: Dict[str, Set[str]], + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + auth_diff: Set[str], +) -> None: """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph Args: - graph (dict[str, set[str]]): A map from event ID to the events auth - event IDs - room_id (str): the room we are working in - event_id (str): Event to add to the graph - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - auth_diff (set[str]): Set of event IDs that are in the auth difference. + graph: A map from event ID to the events auth event IDs + room_id: the room we are working in + event_id: Event to add to the graph + event_map + state_res_store + auth_diff: Set of event IDs that are in the auth difference. """ state = [event_id] @@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph( async def _reverse_topological_power_sort( - clock, room_id, event_ids, event_map, state_res_store, auth_diff -): + clock: Clock, + room_id: str, + event_ids: Iterable[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + auth_diff: Set[str], +) -> List[str]: """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: - clock (Clock) - room_id (str): the room we are working in - event_ids (list[str]): The events to sort - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - auth_diff (set[str]): Set of event IDs that are in the auth difference. + clock + room_id: the room we are working in + event_ids: The events to sort + event_map + state_res_store + auth_diff: Set of event IDs that are in the auth difference. Returns: - Deferred[list[str]]: The sorted list + The sorted list """ - graph = {} + graph = {} # type: Dict[str, Set[str]] for idx, event_id in enumerate(event_ids, start=1): await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff @@ -372,22 +407,28 @@ async def _reverse_topological_power_sort( async def _iterative_auth_checks( - clock, room_id, room_version, event_ids, base_state, event_map, state_res_store -): + clock: Clock, + room_id: str, + room_version: str, + event_ids: List[str], + base_state: StateMap[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> StateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: - clock (Clock) - room_id (str) - room_version (str) - event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (StateMap[str]): The set of state to start with - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + clock + room_id + room_version + event_ids: Ordered list of events to apply auth checks to + base_state: The set of state to start with + event_map + state_res_store Returns: - Deferred[StateMap[str]]: Returns the final updated state + Returns the final updated state """ resolved_state = base_state.copy() room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -439,21 +480,26 @@ async def _iterative_auth_checks( async def _mainline_sort( - clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store -): + clock: Clock, + room_id: str, + event_ids: List[str], + resolved_power_event_id: Optional[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> List[str]: """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: - clock (Clock) - room_id (str): room we're working in - event_ids (list[str]): Events to sort - resolved_power_event_id (str): The final resolved power level event ID - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + clock + room_id: room we're working in + event_ids: Events to sort + resolved_power_event_id: The final resolved power level event ID + event_map + state_res_store Returns: - Deferred[list[str]]: The sorted list + The sorted list """ if not event_ids: # It's possible for there to be no event IDs here to sort, so we can @@ -505,59 +551,90 @@ async def _mainline_sort( async def _get_mainline_depth_for_event( - event, mainline_map, event_map, state_res_store -): + event: EventBase, + mainline_map: Dict[str, int], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> int: """Get the mainline depths for the given event based on the mainline map Args: - event (FrozenEvent) - mainline_map (dict[str, int]): Map from event_id to mainline depth for - events in the mainline. - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + event + mainline_map: Map from event_id to mainline depth for events in the mainline. + event_map + state_res_store Returns: - Deferred[int] + The mainline depth """ room_id = event.room_id + tmp_event = event # type: Optional[EventBase] # We do an iterative search, replacing `event with the power level in its # auth events (if any) - while event: + while tmp_event: depth = mainline_map.get(event.event_id) if depth is not None: return depth - auth_events = event.auth_event_ids() - event = None + auth_events = tmp_event.auth_event_ids() + tmp_event = None for aid in auth_events: aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): - event = aev + tmp_event = aev break # Didn't find a power level auth event, so we just return 0 return 0 -async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +@overload +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: Literal[False] = False, +) -> EventBase: + ... + + +@overload +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: Literal[True], +) -> Optional[EventBase]: + ... + + +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: bool = False, +) -> Optional[EventBase]: """Helper function to look up event in event_map, falling back to looking it up in the store Args: - room_id (str) - event_id (str) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - allow_none (bool): if the event is not found, return None rather than raising + room_id + event_id + event_map + state_res_store + allow_none: if the event is not found, return None rather than raising an exception Returns: - Deferred[Optional[FrozenEvent]] + The event, or none if the event does not exist (and allow_none is True). """ if event_id not in event_map: events = await state_res_store.get_events([event_id], allow_rejected=True) @@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F return event -def lexicographical_topological_sort(graph, key): +def lexicographical_topological_sort( + graph: Dict[str, Set[str]], key: Callable[[str], Any] +) -> Generator[str, None, None]: """Performs a lexicographic reverse topological sort on the graph. This returns a reverse topological sort (i.e. if node A references B then B @@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key): NOTE: `graph` is modified during the sort. Args: - graph (dict[str, set[str]]): A representation of the graph where each - node is a key in the dict and its value are the nodes edges. - key (func): A function that takes a node and returns a value that is - comparable and used to order nodes + graph: A representation of the graph where each node is a key in the + dict and its value are the nodes edges. + key: A function that takes a node and returns a value that is comparable + and used to order nodes Yields: - str: The next node in the topological sort + The next node in the topological sort """ # Note, this is basically Kahn's algorithm except we look at nodes with no # outgoing edges, c.f. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm outdegree_map = graph - reverse_graph = {} + reverse_graph = {} # type: Dict[str, Set[str]] # Lists of nodes with zero out degree. Is actually a tuple of # `(key(node), node)` so that sorting does the right thing diff --git a/tox.ini b/tox.ini index ea804108b5..edeb757f7b 100644 --- a/tox.ini +++ b/tox.ini @@ -209,6 +209,7 @@ commands = mypy \ synapse/server.py \ synapse/server_notices \ synapse/spam_checker_api \ + synapse/state \ synapse/storage/databases/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ -- cgit 1.5.1 From 2e6c90ff847b2d79ae372933e431d3a7ebaf5381 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Aug 2020 08:49:01 -0400 Subject: Do not propagate profile changes of shadow-banned users into rooms. (#8157) --- changelog.d/8157.feature | 1 + synapse/handlers/profile.py | 17 +- synapse/handlers/room_member.py | 2 +- tests/rest/client/test_shadow_banned.py | 272 ++++++++++++++++++++++++++++++++ tests/rest/client/v1/test_rooms.py | 159 +------------------ 5 files changed, 291 insertions(+), 160 deletions(-) create mode 100644 changelog.d/8157.feature create mode 100644 tests/rest/client/test_shadow_banned.py (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8157.feature b/changelog.d/8157.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8157.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 31a2e5ea18..96c9d6bab4 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import random from synapse.api.errors import ( AuthError, @@ -213,8 +214,14 @@ class BaseProfileHandler(BaseHandler): async def set_avatar_url( self, target_user, requester, new_avatar_url, by_admin=False ): - """target_user is the user whose avatar_url is to be changed; - auth_user is the user attempting to make this change.""" + """Set a new avatar URL for a user. + + Args: + target_user (UserID): the user whose avatar URL is to be changed. + requester (Requester): The user attempting to make this change. + new_avatar_url (str): The avatar URL to give this user. + by_admin (bool): Whether this change was made by an administrator. + """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -278,6 +285,12 @@ class BaseProfileHandler(BaseHandler): await self.ratelimit(requester) + # Do not actually update the room state for shadow-banned users. + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + return + room_ids = await self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 804463b1c0..cae4d013b8 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -380,7 +380,7 @@ class RoomMemberHandler(object): # later on. content = dict(content) - if not self.allow_per_room_profiles: + if not self.allow_per_room_profiles or requester.shadow_banned: # Strip profile data, knowing that new profile data will be added to the # event's content in event_creation_handler.create_event() using the target's # global profile. diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py new file mode 100644 index 0000000000..3eb9aeaa9e --- /dev/null +++ b/tests/rest/client/test_shadow_banned.py @@ -0,0 +1,272 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock, patch + +import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import directory, login, profile, room +from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet + +from tests import unittest + + +class _ShadowBannedBase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + # Create two users, one of which is shadow-banned. + self.banned_user_id = self.register_user("banned", "test") + self.banned_access_token = self.login("banned", "test") + + self.store = self.hs.get_datastore() + + self.get_success( + self.store.db_pool.simple_update( + table="users", + keyvalues={"name": self.banned_user_id}, + updatevalues={"shadow_banned": True}, + desc="shadow_ban", + ) + ) + + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) +class RoomTestCase(_ShadowBannedBase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def test_invite(self): + """Invites from shadow-banned users don't actually get sent.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + self.helper.invite( + room=room_id, + src=self.banned_user_id, + tok=self.banned_access_token, + targ=self.other_user_id, + ) + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + def test_invite_3pid(self): + """Ensure that a 3PID invite does not attempt to contact the identity server.""" + identity_handler = self.hs.get_handlers().identity_handler + identity_handler.lookup_3pid = Mock( + side_effect=AssertionError("This should not get called") + ) + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + request, channel = self.make_request( + "POST", + "/rooms/%s/invite" % (room_id,), + {"id_server": "test", "medium": "email", "address": "test@test.test"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # This should have raised an error earlier, but double check this wasn't called. + identity_handler.lookup_3pid.assert_not_called() + + def test_create_room(self): + """Invitations during a room creation should be discarded, but the room still gets created.""" + # The room creation is successful. + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + {"visibility": "public", "invite": [self.other_user_id]}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + room_id = channel.json_body["room_id"] + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + # Since a real room was created, the other user should be able to join it. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + # Both users should be in the room. + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) + + def test_message(self): + """Messages from shadow-banned users don't actually get sent.""" + + room_id = self.helper.create_room_as( + self.other_user_id, tok=self.other_access_token + ) + + # The user should be in the room. + self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token) + + # Sending a message should complete successfully. + result = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "with right label"}, + tok=self.banned_access_token, + ) + self.assertIn("event_id", result) + event_id = result["event_id"] + + latest_events = self.get_success( + self.store.get_latest_event_ids_in_room(room_id) + ) + self.assertNotIn(event_id, latest_events) + + def test_upgrade(self): + """A room upgrade should fail, but look like it succeeded.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), + {"new_version": "6"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + # A new room_id should be returned. + self.assertIn("replacement_room", channel.json_body) + + new_room_id = channel.json_body["replacement_room"] + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(new_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) + + +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) +class ProfileTestCase(_ShadowBannedBase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def test_displayname(self): + """Profile changes should succeed, but don't end up in a room.""" + original_display_name = "banned" + new_display_name = "new name" + + # Join a room. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # The update should succeed. + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,), + {"displayname": new_display_name}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertEqual(channel.json_body, {}) + + # The user's display name should be updated. + request, channel = self.make_request( + "GET", "/profile/%s/displayname" % (self.banned_user_id,) + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["displayname"], new_display_name) + + # But the display name in the room should not be. + message_handler = self.hs.get_message_handler() + event = self.get_success( + message_handler.get_room_data( + self.banned_user_id, + room_id, + "m.room.member", + self.banned_user_id, + False, + ) + ) + self.assertEqual( + event.content, {"membership": "join", "displayname": original_display_name} + ) + + def test_room_displayname(self): + """Changes to state events for a room should be processed, but not end up in the room.""" + original_display_name = "banned" + new_display_name = "new name" + + # Join a room. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # The update should succeed. + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (room_id, self.banned_user_id), + {"membership": "join", "displayname": new_display_name}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertIn("event_id", channel.json_body) + + # The display name in the room should not be changed. + message_handler = self.hs.get_message_handler() + event = self.get_success( + message_handler.get_room_data( + self.banned_user_id, + room_id, + "m.room.member", + self.banned_user_id, + False, + ) + ) + self.assertEqual( + event.content, {"membership": "join", "displayname": original_display_name} + ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 68c4a6a8f7..0a567b032f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -21,13 +21,13 @@ import json from urllib import parse as urlparse -from mock import Mock, patch +from mock import Mock import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet +from synapse.rest.client.v2_alpha import account from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string @@ -2060,158 +2060,3 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) - - -# To avoid the tests timing out don't add a delay to "annoy the requester". -@patch("random.randint", new=lambda a, b: 0) -class ShadowBannedTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - room_upgrade_rest_servlet.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.banned_user_id = self.register_user("banned", "test") - self.banned_access_token = self.login("banned", "test") - - self.store = self.hs.get_datastore() - - self.get_success( - self.store.db_pool.simple_update( - table="users", - keyvalues={"name": self.banned_user_id}, - updatevalues={"shadow_banned": True}, - desc="shadow_ban", - ) - ) - - self.other_user_id = self.register_user("otheruser", "pass") - self.other_access_token = self.login("otheruser", "pass") - - def test_invite(self): - """Invites from shadow-banned users don't actually get sent.""" - - # The create works fine. - room_id = self.helper.create_room_as( - self.banned_user_id, tok=self.banned_access_token - ) - - # Inviting the user completes successfully. - self.helper.invite( - room=room_id, - src=self.banned_user_id, - tok=self.banned_access_token, - targ=self.other_user_id, - ) - - # But the user wasn't actually invited. - invited_rooms = self.get_success( - self.store.get_invited_rooms_for_local_user(self.other_user_id) - ) - self.assertEqual(invited_rooms, []) - - def test_invite_3pid(self): - """Ensure that a 3PID invite does not attempt to contact the identity server.""" - identity_handler = self.hs.get_handlers().identity_handler - identity_handler.lookup_3pid = Mock( - side_effect=AssertionError("This should not get called") - ) - - # The create works fine. - room_id = self.helper.create_room_as( - self.banned_user_id, tok=self.banned_access_token - ) - - # Inviting the user completes successfully. - request, channel = self.make_request( - "POST", - "/rooms/%s/invite" % (room_id,), - {"id_server": "test", "medium": "email", "address": "test@test.test"}, - access_token=self.banned_access_token, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - - # This should have raised an error earlier, but double check this wasn't called. - identity_handler.lookup_3pid.assert_not_called() - - def test_create_room(self): - """Invitations during a room creation should be discarded, but the room still gets created.""" - # The room creation is successful. - request, channel = self.make_request( - "POST", - "/_matrix/client/r0/createRoom", - {"visibility": "public", "invite": [self.other_user_id]}, - access_token=self.banned_access_token, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - room_id = channel.json_body["room_id"] - - # But the user wasn't actually invited. - invited_rooms = self.get_success( - self.store.get_invited_rooms_for_local_user(self.other_user_id) - ) - self.assertEqual(invited_rooms, []) - - # Since a real room was created, the other user should be able to join it. - self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) - - # Both users should be in the room. - users = self.get_success(self.store.get_users_in_room(room_id)) - self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) - - def test_message(self): - """Messages from shadow-banned users don't actually get sent.""" - - room_id = self.helper.create_room_as( - self.other_user_id, tok=self.other_access_token - ) - - # The user should be in the room. - self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token) - - # Sending a message should complete successfully. - result = self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "with right label"}, - tok=self.banned_access_token, - ) - self.assertIn("event_id", result) - event_id = result["event_id"] - - latest_events = self.get_success( - self.store.get_latest_event_ids_in_room(room_id) - ) - self.assertNotIn(event_id, latest_events) - - def test_upgrade(self): - """A room upgrade should fail, but look like it succeeded.""" - - # The create works fine. - room_id = self.helper.create_room_as( - self.banned_user_id, tok=self.banned_access_token - ) - - request, channel = self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), - {"new_version": "6"}, - access_token=self.banned_access_token, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - # A new room_id should be returned. - self.assertIn("replacement_room", channel.json_body) - - new_room_id = channel.json_body["replacement_room"] - - # It doesn't really matter what API we use here, we just want to assert - # that the room doesn't exist. - summary = self.get_success(self.store.get_room_summary(new_room_id)) - # The summary should be empty since the room doesn't exist. - self.assertEqual(summary, {}) -- cgit 1.5.1 From da77520cd1c414c9341da287967feb1bab14cbec Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:39:04 -0400 Subject: Convert additional databases to async/await part 2 (#8200) --- changelog.d/8200.misc | 1 + synapse/events/builder.py | 19 +++++---- synapse/handlers/message.py | 13 ++---- synapse/handlers/room_member.py | 12 +----- synapse/storage/databases/main/client_ips.py | 4 +- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/filtering.py | 5 ++- synapse/storage/databases/main/openid.py | 8 +++- synapse/storage/databases/main/profile.py | 6 ++- synapse/storage/databases/main/push_rule.py | 10 ++--- synapse/storage/databases/main/room.py | 49 ++++++++++++---------- synapse/storage/databases/main/signatures.py | 40 ++++++++++++++---- synapse/storage/databases/main/ui_auth.py | 4 +- .../storage/databases/main/user_erasure_store.py | 8 ++-- tests/test_utils/event_injection.py | 7 ++-- 15 files changed, 111 insertions(+), 81 deletions(-) create mode 100644 changelog.d/8200.misc (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8200.misc b/changelog.d/8200.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8200.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 9ed24380dd..7878cd7044 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple, Union import attr from nacl.signing import SigningKey @@ -97,14 +97,14 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids): + async def build(self, prev_event_ids: List[str]) -> EventBase: """Transform into a fully signed and hashed event Args: - prev_event_ids (list[str]): The event IDs to use as the prev events + prev_event_ids: The event IDs to use as the prev events Returns: - FrozenEvent + The signed and hashed event. """ state_ids = await self._state.get_current_state_ids( @@ -114,8 +114,13 @@ class EventBuilder(object): format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = await self._store.add_event_hashes(auth_ids) - prev_events = await self._store.add_event_hashes(prev_event_ids) + # The types of auth/prev events changes between event versions. + auth_events = await self._store.add_event_hashes( + auth_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events = await self._store.add_event_hashes( + prev_event_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: auth_events = auth_ids prev_events = prev_event_ids @@ -138,7 +143,7 @@ class EventBuilder(object): "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } + } # type: Dict[str, Any] if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d0c38f4df..72bb638167 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import ( - Collection, - Requester, - RoomAlias, - StreamToken, - UserID, - create_requester, -) +from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -446,7 +439,7 @@ class EventCreationHandler(object): event_dict: dict, token_id: Optional[str] = None, txn_id: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -786,7 +779,7 @@ class EventCreationHandler(object): self, builder: EventBuilder, requester: Optional[Requester] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cae4d013b8..a7962b0ada 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import ( - Collection, - JsonDict, - Requester, - RoomAlias, - RoomID, - StateMap, - UserID, -) +from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -184,7 +176,7 @@ class RoomMemberHandler(object): target: UserID, room_id: str, membership: str, - prev_event_ids: Collection[str], + prev_event_ids: List[str], txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 216a5925fc..c2fc847fbc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): + async def _update_client_ips_batch(self) -> None: # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 405b5eafa5..e5060d4c46 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room( + async def update_aliases_for_room( self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): + ) -> None: """Repoint all of the aliases for a given room, to a different room. Args: @@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 45a1760170..d2f5b9a502 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - def add_user_filter(self, user_localpart, user_filter): + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.db_pool.runInteraction("add_user_filter", _do_txn) + return await self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 4db8949da7..2aac64901b 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,3 +1,5 @@ +from typing import Optional + from synapse.storage._base import SQLBaseStore @@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore): desc="insert_open_id_token", ) - def get_user_id_for_open_id_token(self, token, ts_now_ms): + async def get_user_id_for_open_id_token( + self, token: str, ts_now_ms: int + ) -> Optional[str]: def get_user_id_for_token_txn(txn): sql = ( "SELECT user_id FROM open_id_tokens" @@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 301875a672..d2e0685e9e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore): desc="delete_remote_profile_cache", ) - def get_remote_profile_cache_entries_that_expire(self, last_checked): + async def get_remote_profile_cache_entries_that_expire( + self, last_checked: int + ) -> Dict[str, str]: """Get all users who haven't been checked since `last_checked` """ @@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2fb5b02d7d..0de802a86b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from twisted.internet import defer - from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -149,9 +147,11 @@ class PushRulesWorkerStore( ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - def have_push_rules_changed_for_user(self, user_id, last_id): + async def have_push_rules_changed_for_user( + self, user_id: str, last_id: int + ) -> bool: if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) + return False else: def have_push_rules_changed_txn(txn): @@ -163,7 +163,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a92641c339..717df97301 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) - def get_room_with_stats(self, room_id: str): + async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve room with statistics. Args: @@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore): res["public"] = bool(res["public"]) return res - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) @@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore): desc="get_public_room_ids", ) - def count_public_rooms(self, network_tuple, ignore_non_federatable): + async def count_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + ignore_non_federatable: bool, + ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms + network_tuple + ignore_non_federatable: If true filters out non-federatable rooms """ def _count_public_rooms_txn(txn): @@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn ) @@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore): return row - def get_media_mxcs_in_room(self, room_id): + async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room Args: - room_id (str) + room_id Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. + The local and remote media as a lists of the media IDs. """ def _get_media_mxcs_in_room_txn(txn): @@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) - def quarantine_media_ids_in_room(self, room_id, quarantined_by): + async def quarantine_media_ids_in_room( + self, room_id: str, quarantined_by: str + ) -> int: """For a room loops through all events with media and quarantines the associated media """ @@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - def quarantine_media_by_id( + async def quarantine_media_by_id( self, server_name: str, media_id: str, quarantined_by: str, - ): + ) -> int: """quarantines a single local or remote media id Args: @@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + async def quarantine_media_ids_by_user( + self, user_id: str, quarantined_by: str + ) -> int: """quarantines all local media associated with a single user Args: @@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - def get_room_count(self): - """Retrieve a list of all rooms + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. """ def f(txn): @@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.db_pool.runInteraction("get_rooms", f) + return await self.db_pool.runInteraction("get_rooms", f) async def add_event_report( self, diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index be191dd870..c8c67953e4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Tuple + from unpaddedbase64 import encode_base64 from synapse.storage._base import SQLBaseStore +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 ) - def get_event_reference_hashes(self, event_ids): + async def get_event_reference_hashes( + self, event_ids: Iterable[str] + ) -> Dict[str, Dict[str, bytes]]: + """Get all hashes for given events. + + Args: + event_ids: The event IDs to get hashes for. + + Returns: + A mapping of event ID to a mapping of algorithm to hash. + """ + def f(txn): return { event_id: self._get_event_reference_hashes_txn(txn, event_id) for event_id in event_ids } - return self.db_pool.runInteraction("get_event_reference_hashes", f) + return await self.db_pool.runInteraction("get_event_reference_hashes", f) - async def add_event_hashes(self, event_ids): + async def add_event_hashes( + self, event_ids: Iterable[str] + ) -> List[Tuple[str, Dict[str, str]]]: + """ + + Args: + event_ids: The event IDs + + Returns: + A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. + """ hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} @@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore): return list(hashes.items()) - def _get_event_reference_hashes_txn(self, txn, event_id): + def _get_event_reference_hashes_txn( + self, txn: Cursor, event_id: str + ) -> Dict[str, bytes]: """Get all the hashes for a given PDU. Args: - txn (cursor): - event_id (str): Id for the Event. + txn: + event_id: Id for the Event. Returns: - A dict[unicode, bytes] of algorithm -> hash. + A mapping of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 9eef8e57c5..b89668d561 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore): class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): + async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore): This is an epoch time in milliseconds. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index e3547e53b3..2f7c95fc74 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id: str) -> None: + async def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: @@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_erased", f) + await self.db_pool.runInteraction("mark_user_erased", f) - def mark_user_not_erased(self, user_id: str) -> None: + async def mark_user_not_erased(self, user_id: str) -> None: """Indicate that user_id is no longer erased. Args: @@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_not_erased", f) + await self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8522c6fc09..fb1ca90336 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Collection """ Utility functions for poking events into the storage of the server under test. @@ -58,7 +57,7 @@ async def inject_member_event( async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> EventBase: """Inject a generic event into a room @@ -80,7 +79,7 @@ async def inject_event( async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: if room_version is None: -- cgit 1.5.1 From 82c1ee1c22a87b9e6e3179947014b0f11c0a1ac3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Sep 2020 15:48:37 +0100 Subject: Add experimental support for sharding event persister. (#8170) This is *not* ready for production yet. Caveats: 1. We should write some tests... 2. The stream token that we use for events can get stalled at the minimum position of all writers. This means that new events may not be processed and e.g. sent down sync streams if a writer isn't writing or is slow. --- changelog.d/8170.feature | 1 + synapse/config/_base.py | 21 ++++++- synapse/config/_base.pyi | 1 + synapse/config/workers.py | 37 ++++++++---- synapse/handlers/federation.py | 44 ++++++++++----- synapse/handlers/message.py | 14 +++-- synapse/handlers/room.py | 14 +++-- synapse/handlers/room_member.py | 7 --- synapse/replication/http/federation.py | 12 +++- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/events.py | 4 +- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events.py | 4 +- synapse/storage/databases/main/events_worker.py | 66 +++++++++++++++------- .../schema/delta/58/14events_instance_name.sql | 16 ++++++ .../delta/58/14events_instance_name.sql.postgres | 26 +++++++++ synapse/storage/util/id_generators.py | 10 ++-- 18 files changed, 206 insertions(+), 77 deletions(-) create mode 100644 changelog.d/8170.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8170.feature b/changelog.d/8170.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8170.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1417487427..73f0717b0d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - - # If multiple instances are not defined we always return true. + # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True + return self.get_instance(key) == instance_name + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key. + + Note: For things like federation sending the config for which instance + is sending is known only to the sender instance if there is only one. + Therefore `should_handle` should be used where possible. + """ + + if not self.instances: + return "master" + + if len(self.instances) == 1: + return self.instances[0] + # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name + return self.instances[remainder] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index eb911e8f9f..b8faafa9bd 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c784a71508..f23e42cdf9 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union + import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def +def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: + """Helper for allowing parsing a string or list of strings to a config + option expecting a list of strings. + """ + + if isinstance(obj, str): + return [obj] + return obj + + @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -33,11 +45,13 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instance that writes to the event and backfill streams. - events: The instance that writes to the typing stream. + events: The instances that write to the event and backfill streams. + typing: The instance that writes to the typing stream. """ - events = attr.ib(default="master", type=str) + events = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter + ) typing = attr.ib(default="master", type=str) @@ -105,15 +119,18 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events and typing also appears in + # Check that the configured writers for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instance = getattr(self.writers, stream) - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) + instances = _instance_to_list_converter(getattr(self.writers, stream)) + for instance in instances: + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 16389a0dca..bd8efbb768 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -923,7 +923,8 @@ class FederationHandler(BaseHandler): ) ) - await self._handle_new_events(dest, ev_infos, backfilled=True) + if ev_infos: + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1216,7 +1217,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, event_infos, + destination, room_id, event_infos, ) def _sanity_check_event(self, ev): @@ -1363,15 +1364,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, auth_chain, state, event, room_version_obj + origin, room_id, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.writers.events, "events", max_stream_id + self.config.worker.events_shard_config.get_instance(room_id), + "events", + max_stream_id, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1625,7 +1626,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + await self.persist_events_and_notify(event.room_id, [(event, context)]) return event @@ -1652,7 +1653,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event, stream_id @@ -1900,7 +1903,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - [(event, context)], backfilled=backfilled + event.room_id, [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1913,6 +1916,7 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, + room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1944,6 +1948,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( + room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1954,6 +1959,7 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, + room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1968,6 +1974,7 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from + room_id, auth_events state event @@ -2042,17 +2049,20 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( + room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ] + ], ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) async def _prep_event( self, @@ -2903,6 +2913,7 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, + room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2910,14 +2921,19 @@ class FederationHandler(BaseHandler): necessary. Args: - event_and_contexts: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker.writers.events != self._instance_name: + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: result = await self._send_events( - instance_name=self.config.worker.writers.events, + instance_name=instance, store=self.store, + room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72bb638167..6398774c02 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,9 +376,8 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._is_event_writer = ( - self.config.worker.writers.events == hs.get_instance_name() - ) + self._events_shard_config = self.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -904,9 +903,10 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. - if not self._is_event_writer: + writer_instance = self._events_shard_config.get_instance(event.room_id) + if writer_instance != self._instance_name: result = await self.send_event( - instance_name=self.config.worker.writers.events, + instance_name=writer_instance, event_id=event.event_id, store=self.store, requester=requester, @@ -974,7 +974,9 @@ class EventCreationHandler(object): This should only be run on the instance in charge of persisting events. """ - assert self._is_event_writer + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9d5b1828df..55794c3057 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", last_stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + last_stream_id, ) return result, last_stream_id @@ -1260,10 +1262,10 @@ class RoomShutdownHandler(object): # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(new_room_id), + "events", + stream_id, ) else: new_room_id = None @@ -1293,7 +1295,9 @@ class RoomShutdownHandler(object): # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + stream_id, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a7962b0ada..e5cb3c2d5c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,13 +82,6 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - self._event_stream_writer_instance = hs.config.worker.writers.events - self._is_on_event_persistence_instance = ( - self._event_stream_writer_instance == hs.get_instance_name() - ) - if self._is_on_event_persistence_instance: - self.persist_event_storage = hs.get_storage().persistence - self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 6b56315148..5c8be747e1 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, room_id, event_and_contexts, backfilled): """ Args: store + room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = {"events": event_payloads, "backfilled": backfilled} + payload = { + "events": event_payloads, + "backfilled": backfilled, + "room_id": room_id, + } return payload @@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) + room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled + room_id, event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1c303f3a46..b323841f73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 16c63ff4ec..3705618b4f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(self._store.get_current_events_token), + self._store._stream_id_gen.get_current_token_for_writer, self._update_function, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 0ac854aee2..c73d54fb67 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -68,7 +68,7 @@ class Databases(object): # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0b69aa6a94..4c3c162acf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") + raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6313b41eef..46b11e705b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -97,6 +97,7 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -108,7 +109,7 @@ class PersistEventsStore: # This should only exist on instances that are configured to write assert ( - hs.config.worker.writers.events == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -800,6 +801,7 @@ class PersistEventsStore: table="events", values=[ { + "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7a73cc3d8..17f5997b89 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + if isinstance(database.engine, PostgresEngine): + # If we're using Postgres than we can use `MultiWriterIdGenerator` + # regardless of whether this process writes to the streams or not. + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_stream_seq", ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_backfill_stream_seq", + positive=False, ) else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql new file mode 100644 index 0000000000..98ff76d709 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres new file mode 100644 index 0000000000..97c1e6a0c5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS events_stream_seq; + +SELECT setval('events_stream_seq', ( + SELECT COALESCE(MAX(stream_ordering), 1) FROM events +)); + +CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; + +SELECT setval('events_backfill_stream_seq', ( + SELECT COALESCE(-MIN(stream_ordering), 1) FROM events +)); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9f3d23f0a5..8fd21c2bf8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -231,8 +231,12 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] @@ -362,9 +366,7 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. -- cgit 1.5.1 From 9f8abdcc3828e942f09cc62e29dff93dbaa01ec7 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 4 Sep 2020 10:19:42 +0100 Subject: Revert "Add experimental support for sharding event persister. (#8170)" (#8242) * Revert "Add experimental support for sharding event persister. (#8170)" This reverts commit 82c1ee1c22a87b9e6e3179947014b0f11c0a1ac3. * Changelog --- changelog.d/8170.feature | 1 - changelog.d/8242.feature | 1 + synapse/config/_base.py | 21 +------ synapse/config/_base.pyi | 1 - synapse/config/workers.py | 37 ++++-------- synapse/handlers/federation.py | 44 +++++---------- synapse/handlers/message.py | 14 ++--- synapse/handlers/room.py | 14 ++--- synapse/handlers/room_member.py | 7 +++ synapse/replication/http/federation.py | 12 +--- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/events.py | 4 +- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events.py | 4 +- synapse/storage/databases/main/events_worker.py | 66 +++++++--------------- .../schema/delta/58/14events_instance_name.sql | 16 ------ .../delta/58/14events_instance_name.sql.postgres | 26 --------- synapse/storage/util/id_generators.py | 10 ++-- 19 files changed, 78 insertions(+), 206 deletions(-) delete mode 100644 changelog.d/8170.feature create mode 100644 changelog.d/8242.feature delete mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql delete mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8170.feature b/changelog.d/8170.feature deleted file mode 100644 index b363e929ea..0000000000 --- a/changelog.d/8170.feature +++ /dev/null @@ -1 +0,0 @@ -Add experimental support for sharding event persister. diff --git a/changelog.d/8242.feature b/changelog.d/8242.feature new file mode 100644 index 0000000000..f6891e360d --- /dev/null +++ b/changelog.d/8242.feature @@ -0,0 +1 @@ +Back out experimental support for sharding event persister. **PLEASE REMOVE THIS LINE FROM THE FINAL CHANGELOG** diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 73f0717b0d..1417487427 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -832,26 +832,11 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - # If multiple instances are not defined we always return true + + # If multiple instances are not defined we always return true. if not self.instances or len(self.instances) == 1: return True - return self.get_instance(key) == instance_name - - def get_instance(self, key: str) -> str: - """Get the instance responsible for handling the given key. - - Note: For things like federation sending the config for which instance - is sending is known only to the sender instance if there is only one. - Therefore `should_handle` should be used where possible. - """ - - if not self.instances: - return "master" - - if len(self.instances) == 1: - return self.instances[0] - # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -861,7 +846,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] + return self.instances[remainder] == instance_name __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index b8faafa9bd..eb911e8f9f 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,4 +142,3 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... - def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py index f23e42cdf9..c784a71508 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,24 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union - import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def -def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: - """Helper for allowing parsing a string or list of strings to a config - option expecting a list of strings. - """ - - if isinstance(obj, str): - return [obj] - return obj - - @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -45,13 +33,11 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instances that write to the event and backfill streams. - typing: The instance that writes to the typing stream. + events: The instance that writes to the event and backfill streams. + events: The instance that writes to the typing stream. """ - events = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter - ) + events = attr.ib(default="master", type=str) typing = attr.ib(default="master", type=str) @@ -119,18 +105,15 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writers for events and typing also appears in + # Check that the configured writer for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instances = _instance_to_list_converter(getattr(self.writers, stream)) - for instance in instances: - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) - - self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) + instance = getattr(self.writers, stream) + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 310c7f7138..43f2986f89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -923,8 +923,7 @@ class FederationHandler(BaseHandler): ) ) - if ev_infos: - await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) + await self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1217,7 +1216,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, room_id, event_infos, + destination, event_infos, ) def _sanity_check_event(self, ev): @@ -1364,15 +1363,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, room_id, auth_chain, state, event, room_version_obj + origin, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.events_shard_config.get_instance(room_id), - "events", - max_stream_id, + self.config.worker.writers.events, "events", max_stream_id ) # Check whether this room is the result of an upgrade of a room we already know @@ -1626,7 +1625,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify(event.room_id, [(event, context)]) + await self.persist_events_and_notify([(event, context)]) return event @@ -1653,9 +1652,7 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify( - event.room_id, [(event, context)] - ) + stream_id = await self.persist_events_and_notify([(event, context)]) return event, stream_id @@ -1903,7 +1900,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - event.room_id, [(event, context)], backfilled=backfilled + [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1916,7 +1913,6 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, - room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1948,7 +1944,6 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1959,7 +1954,6 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, - room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1974,7 +1968,6 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from - room_id, auth_events state event @@ -2049,20 +2042,17 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( - room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ], + ] ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify( - room_id, [(event, new_event_context)] - ) + return await self.persist_events_and_notify([(event, new_event_context)]) async def _prep_event( self, @@ -2913,7 +2903,6 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, - room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2921,19 +2910,14 @@ class FederationHandler(BaseHandler): necessary. Args: - room_id: The room ID of events being persisted. - event_and_contexts: Sequence of events with their associated - context that should be persisted. All events must belong to - the same room. + event_and_contexts: backfilled: Whether these events are a result of backfilling or not """ - instance = self.config.worker.events_shard_config.get_instance(room_id) - if instance != self._instance_name: + if self.config.worker.writers.events != self._instance_name: result = await self._send_events( - instance_name=instance, + instance_name=self.config.worker.writers.events, store=self.store, - room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6398774c02..72bb638167 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,8 +376,9 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._events_shard_config = self.config.worker.events_shard_config - self._instance_name = hs.get_instance_name() + self._is_event_writer = ( + self.config.worker.writers.events == hs.get_instance_name() + ) self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -903,10 +904,9 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. - writer_instance = self._events_shard_config.get_instance(event.room_id) - if writer_instance != self._instance_name: + if not self._is_event_writer: result = await self.send_event( - instance_name=writer_instance, + instance_name=self.config.worker.writers.events, event_id=event.event_id, store=self.store, requester=requester, @@ -974,9 +974,7 @@ class EventCreationHandler(object): This should only be run on the instance in charge of persisting events. """ - assert self._events_shard_config.should_handle( - self._instance_name, event.room_id - ) + assert self._is_event_writer if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 55794c3057..9d5b1828df 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -804,9 +804,7 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.events_shard_config.get_instance(room_id), - "events", - last_stream_id, + self.hs.config.worker.writers.events, "events", last_stream_id ) return result, last_stream_id @@ -1262,10 +1260,10 @@ class RoomShutdownHandler(object): # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. + # + # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.events_shard_config.get_instance(new_room_id), - "events", - stream_id, + self.hs.config.worker.writers.events, "events", stream_id ) else: new_room_id = None @@ -1295,9 +1293,7 @@ class RoomShutdownHandler(object): # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.events_shard_config.get_instance(room_id), - "events", - stream_id, + self.hs.config.worker.writers.events, "events", stream_id ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e5cb3c2d5c..a7962b0ada 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,6 +82,13 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles + self._event_stream_writer_instance = hs.config.worker.writers.events + self._is_on_event_persistence_instance = ( + self._event_stream_writer_instance == hs.get_instance_name() + ) + if self._is_on_event_persistence_instance: + self.persist_event_storage = hs.get_storage().persistence + self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5c8be747e1..6b56315148 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -65,11 +65,10 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, room_id, event_and_contexts, backfilled): + async def _serialize_payload(store, event_and_contexts, backfilled): """ Args: store - room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -89,11 +88,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = { - "events": event_payloads, - "backfilled": backfilled, - "room_id": room_id, - } + payload = {"events": event_payloads, "backfilled": backfilled} return payload @@ -101,7 +96,6 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) - room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -126,7 +120,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - room_id, event_and_contexts, backfilled + event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b323841f73..1c303f3a46 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.get_instance_name() in hs.config.worker.writers.events: + if hs.config.worker.writers.events == hs.get_instance_name(): self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 3705618b4f..16c63ff4ec 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self._store._stream_id_gen.get_current_token_for_writer, + current_token_without_instance(self._store.get_current_events_token), self._update_function, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index c73d54fb67..0ac854aee2 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -68,7 +68,7 @@ class Databases(object): # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.get_instance_name() in hs.config.worker.writers.events: + if hs.config.worker.writers.events == hs.get_instance_name(): persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4c3c162acf..0b69aa6a94 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) + raise StoreError(400, "stream_ordering too old") sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b94fe7ac17..b3d27a2ee7 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -97,7 +97,6 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() - self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -109,7 +108,7 @@ class PersistEventsStore: # This should only exist on instances that are configured to write assert ( - hs.get_instance_name() in hs.config.worker.writers.events + hs.config.worker.writers.events == hs.get_instance_name() ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -801,7 +800,6 @@ class PersistEventsStore: table="events", values=[ { - "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 17f5997b89..a7a73cc3d8 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,8 +42,7 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -79,54 +78,27 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if isinstance(database.engine, PostgresEngine): - # If we're using Postgres than we can use `MultiWriterIdGenerator` - # regardless of whether this process writes to the streams or not. - self._stream_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", - sequence_name="events_stream_seq", + if hs.config.worker.writers.events == hs.get_instance_name(): + # We are the process in charge of generating stream ids for events, + # so instantiate ID generators based on the database + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", ) - self._backfill_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", - sequence_name="events_backfill_stream_seq", - positive=False, + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], ) else: - # We shouldn't be running in worker mode with SQLite, but its useful - # to support it for unit tests. - # - # If this process is the writer than we need to use - # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets - # updated over replication. (Multiple writers are not supported for - # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # Another process is in charge of persisting events and generating + # stream IDs: rely on the replication streams to let us know which + # IDs we can process. + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql deleted file mode 100644 index 98ff76d709..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres deleted file mode 100644 index 97c1e6a0c5..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE SEQUENCE IF NOT EXISTS events_stream_seq; - -SELECT setval('events_stream_seq', ( - SELECT COALESCE(MAX(stream_ordering), 1) FROM events -)); - -CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; - -SELECT setval('events_backfill_stream_seq', ( - SELECT COALESCE(-MIN(stream_ordering), 1) FROM events -)); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 8fd21c2bf8..9f3d23f0a5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -231,12 +231,8 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. - # - # We start at 1 here as a) the first generated stream ID will be 2, and - # b) other parts of the code assume that stream IDs are strictly greater - # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 1 + min(self._current_positions.values()) if self._current_positions else 0 ) self._known_persisted_positions = [] # type: List[int] @@ -366,7 +362,9 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - return self.get_persisted_upto_position() + # Currently we don't support this operation, as it's not obvious how to + # condense the stream positions of multiple writers into a single int. + raise NotImplementedError() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. -- cgit 1.5.1 From c619253db80c8d1c606dc40756dd3c9e3a55a9fb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Sep 2020 06:54:56 -0400 Subject: Stop sub-classing object (#8249) --- changelog.d/8249.misc | 1 + contrib/cmdclient/http.py | 6 ++--- contrib/experiments/cursesio.py | 2 +- contrib/experiments/test_messaging.py | 4 ++-- scripts-dev/hash_history.py | 2 +- synapse/api/auth.py | 2 +- synapse/api/auth_blocking.py | 2 +- synapse/api/constants.py | 26 +++++++++++----------- synapse/api/errors.py | 2 +- synapse/api/filtering.py | 6 ++--- synapse/api/ratelimiting.py | 2 +- synapse/api/room_versions.py | 10 ++++----- synapse/api/urls.py | 2 +- synapse/app/_base.py | 4 ++-- synapse/app/generic_worker.py | 2 +- synapse/appservice/__init__.py | 6 ++--- synapse/appservice/scheduler.py | 8 +++---- synapse/config/_base.py | 4 ++-- synapse/config/cache.py | 2 +- synapse/config/key.py | 2 +- synapse/config/metrics.py | 2 +- synapse/config/ratelimiting.py | 4 ++-- synapse/config/room.py | 2 +- synapse/config/room_directory.py | 2 +- synapse/config/server.py | 2 +- synapse/crypto/context_factory.py | 8 +++---- synapse/crypto/keyring.py | 8 +++---- synapse/events/__init__.py | 2 +- synapse/events/builder.py | 4 ++-- synapse/events/spamcheck.py | 2 +- synapse/events/third_party_rules.py | 2 +- synapse/events/utils.py | 2 +- synapse/events/validator.py | 2 +- synapse/federation/federation_base.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/persistence.py | 2 +- synapse/federation/send_queue.py | 4 ++-- synapse/federation/sender/__init__.py | 2 +- synapse/federation/sender/per_destination_queue.py | 2 +- synapse/federation/sender/transaction_manager.py | 2 +- synapse/federation/transport/client.py | 2 +- synapse/federation/transport/server.py | 4 ++-- synapse/groups/attestations.py | 4 ++-- synapse/groups/groups_server.py | 2 +- synapse/handlers/__init__.py | 2 +- synapse/handlers/_base.py | 2 +- synapse/handlers/account_data.py | 2 +- synapse/handlers/account_validity.py | 2 +- synapse/handlers/acme.py | 2 +- synapse/handlers/acme_issuing_service.py | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 2 +- synapse/handlers/device.py | 2 +- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/e2e_keys.py | 4 ++-- synapse/handlers/e2e_room_keys.py | 2 +- synapse/handlers/groups_local.py | 2 +- synapse/handlers/message.py | 4 ++-- synapse/handlers/pagination.py | 4 ++-- synapse/handlers/password_policy.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/room.py | 6 ++--- synapse/handlers/room_member.py | 2 +- synapse/handlers/saml_handler.py | 4 ++-- synapse/handlers/state_deltas.py | 2 +- synapse/handlers/sync.py | 4 ++-- synapse/handlers/typing.py | 2 +- synapse/http/client.py | 8 +++---- synapse/http/connectproxyclient.py | 2 +- synapse/http/federation/matrix_federation_agent.py | 6 ++--- synapse/http/federation/srv_resolver.py | 4 ++-- synapse/http/federation/well_known_resolver.py | 4 ++-- synapse/http/matrixfederationclient.py | 6 ++--- synapse/http/request_metrics.py | 2 +- synapse/http/server.py | 2 +- synapse/http/servlet.py | 2 +- synapse/logging/_structured.py | 6 ++--- synapse/logging/_terse_json.py | 4 ++-- synapse/logging/context.py | 6 ++--- synapse/logging/opentracing.py | 2 +- synapse/metrics/__init__.py | 16 ++++++------- synapse/metrics/background_process_metrics.py | 4 ++-- synapse/module_api/__init__.py | 2 +- synapse/notifier.py | 6 ++--- synapse/push/action_generator.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 4 ++-- synapse/push/emailpusher.py | 2 +- synapse/push/httppusher.py | 2 +- synapse/push/mailer.py | 2 +- synapse/push/push_rule_evaluator.py | 2 +- synapse/push/pusher.py | 2 +- synapse/replication/http/_base.py | 2 +- .../slave/storage/_slaved_id_tracker.py | 2 +- synapse/replication/tcp/protocol.py | 2 +- synapse/replication/tcp/resource.py | 2 +- synapse/replication/tcp/streams/_base.py | 2 +- synapse/replication/tcp/streams/events.py | 4 ++-- synapse/rest/client/transactions.py | 2 +- synapse/rest/client/v2_alpha/register.py | 2 +- synapse/rest/media/v1/_base.py | 4 ++-- synapse/rest/media/v1/filepath.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/media_storage.py | 2 +- synapse/rest/media/v1/thumbnailer.py | 2 +- synapse/rest/well_known.py | 2 +- synapse/secrets.py | 2 +- synapse/server_notices/consent_server_notices.py | 2 +- .../resource_limits_server_notices.py | 2 +- synapse/server_notices/server_notices_manager.py | 2 +- synapse/server_notices/server_notices_sender.py | 2 +- .../server_notices/worker_server_notices_sender.py | 2 +- synapse/spam_checker_api/__init__.py | 2 +- synapse/state/__init__.py | 8 +++---- synapse/storage/__init__.py | 2 +- synapse/storage/background_updates.py | 4 ++-- synapse/storage/database.py | 4 ++-- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/roommember.py | 2 +- synapse/storage/keys.py | 2 +- synapse/storage/persist_events.py | 4 ++-- synapse/storage/prepare_database.py | 2 +- synapse/storage/purge_events.py | 2 +- synapse/storage/relations.py | 6 ++--- synapse/storage/state.py | 4 ++-- synapse/storage/util/id_generators.py | 4 ++-- synapse/streams/config.py | 4 ++-- synapse/streams/events.py | 2 +- synapse/types.py | 2 +- synapse/util/__init__.py | 2 +- synapse/util/async_helpers.py | 8 +++---- synapse/util/caches/__init__.py | 2 +- synapse/util/caches/descriptors.py | 8 +++---- synapse/util/caches/dictionary_cache.py | 4 ++-- synapse/util/caches/expiringcache.py | 4 ++-- synapse/util/caches/lrucache.py | 4 ++-- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/treecache.py | 4 ++-- synapse/util/caches/ttlcache.py | 4 ++-- synapse/util/distributor.py | 4 ++-- synapse/util/file_consumer.py | 2 +- synapse/util/jsonobject.py | 2 +- synapse/util/metrics.py | 2 +- synapse/util/ratelimitutils.py | 4 ++-- synapse/util/retryutils.py | 2 +- synapse/util/wheel_timer.py | 4 ++-- tests/api/test_auth.py | 2 +- tests/crypto/test_keyring.py | 2 +- tests/federation/transport/test_server.py | 2 +- tests/handlers/test_auth.py | 2 +- tests/handlers/test_profile.py | 2 +- tests/http/__init__.py | 2 +- .../federation/test_matrix_federation_agent.py | 2 +- tests/logging/test_structured.py | 4 ++-- tests/push/test_email.py | 2 +- tests/rest/client/third_party_rules.py | 2 +- tests/rest/client/v1/utils.py | 2 +- tests/rest/media/v1/test_url_preview.py | 6 ++--- tests/server.py | 6 ++--- tests/state/test_v2.py | 4 ++-- tests/storage/test__base.py | 18 +++++++-------- tests/test_state.py | 4 ++-- tests/test_visibility.py | 2 +- tests/unittest.py | 2 +- tests/util/caches/test_descriptors.py | 20 ++++++++--------- tests/util/test_file_consumer.py | 4 ++-- tests/utils.py | 6 ++--- 168 files changed, 293 insertions(+), 292 deletions(-) create mode 100644 changelog.d/8249.misc (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8249.misc b/changelog.d/8249.misc new file mode 100644 index 0000000000..6a42e8a4e6 --- /dev/null +++ b/changelog.d/8249.misc @@ -0,0 +1 @@ +Stop sub-classing from object. diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index e2534ee584..cd3260b27d 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -24,7 +24,7 @@ from twisted.web.client import Agent, readBody from twisted.web.http_headers import Headers -class HttpClient(object): +class HttpClient: """ Interface for talking json over http """ @@ -169,7 +169,7 @@ class TwistedHttpClient(HttpClient): return d -class _RawProducer(object): +class _RawProducer: def __init__(self, data): self.data = data self.body = data @@ -186,7 +186,7 @@ class _RawProducer(object): pass -class _JsonProducer(object): +class _JsonProducer: """ Used by the twisted http client to create the HTTP body from json """ diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py index ffefe3bb39..15a22c3a0e 100644 --- a/contrib/experiments/cursesio.py +++ b/contrib/experiments/cursesio.py @@ -141,7 +141,7 @@ class CursesStdIO: curses.endwin() -class Callback(object): +class Callback: def __init__(self, stdio): self.stdio = stdio diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index a84ec4ecae..d4c35ff2fc 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -55,7 +55,7 @@ def excpetion_errback(failure): logging.exception(failure) -class InputOutput(object): +class InputOutput: """ This is responsible for basic I/O so that a user can interact with the example app. """ @@ -132,7 +132,7 @@ class IOLoggerHandler(logging.Handler): self.io.print_log(msg) -class Room(object): +class Room: """ Used to store (in memory) the current membership state of a room, and which home servers we should send PDUs associated with the room to. """ diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index bf3862a386..89acb52e6a 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -15,7 +15,7 @@ from synapse.storage.pdu import PduStore from synapse.storage.signatures import SignatureStore -class Store(object): +class Store: _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 7aab764360..75388643ee 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -58,7 +58,7 @@ class _InvalidMacaroonException(Exception): pass -class Auth(object): +class Auth: """ FIXME: This class contains a mix of functions for authenticating users of our client-server API and authenticating events added to room graphs. diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 49093bf181..d8fafd7cb8 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -22,7 +22,7 @@ from synapse.config.server import is_threepid_reserved logger = logging.getLogger(__name__) -class AuthBlocking(object): +class AuthBlocking: def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6a6d32c302..46013cde15 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -28,7 +28,7 @@ MAX_ALIAS_LENGTH = 255 MAX_USERID_LENGTH = 255 -class Membership(object): +class Membership: """Represents the membership states of a user in a room.""" @@ -40,7 +40,7 @@ class Membership(object): LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) -class PresenceState(object): +class PresenceState: """Represents the presence state of a user.""" OFFLINE = "offline" @@ -48,14 +48,14 @@ class PresenceState(object): ONLINE = "online" -class JoinRules(object): +class JoinRules: PUBLIC = "public" KNOCK = "knock" INVITE = "invite" PRIVATE = "private" -class LoginType(object): +class LoginType: PASSWORD = "m.login.password" EMAIL_IDENTITY = "m.login.email.identity" MSISDN = "m.login.msisdn" @@ -65,7 +65,7 @@ class LoginType(object): DUMMY = "m.login.dummy" -class EventTypes(object): +class EventTypes: Member = "m.room.member" Create = "m.room.create" Tombstone = "m.room.tombstone" @@ -96,17 +96,17 @@ class EventTypes(object): Presence = "m.presence" -class RejectedReason(object): +class RejectedReason: AUTH_ERROR = "auth_error" -class RoomCreationPreset(object): +class RoomCreationPreset: PRIVATE_CHAT = "private_chat" PUBLIC_CHAT = "public_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat" -class ThirdPartyEntityKind(object): +class ThirdPartyEntityKind: USER = "user" LOCATION = "location" @@ -115,7 +115,7 @@ ServerNoticeMsgType = "m.server_notice" ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" -class UserTypes(object): +class UserTypes: """Allows for user type specific behaviour. With the benefit of hindsight 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ @@ -125,7 +125,7 @@ class UserTypes(object): ALL_USER_TYPES = (SUPPORT, BOT) -class RelationTypes(object): +class RelationTypes: """The types of relations known to this server. """ @@ -134,14 +134,14 @@ class RelationTypes(object): REFERENCE = "m.reference" -class LimitBlockingTypes(object): +class LimitBlockingTypes: """Reasons that a server may be blocked""" MONTHLY_ACTIVE_USER = "monthly_active_user" HS_DISABLED = "hs_disabled" -class EventContentFields(object): +class EventContentFields: """Fields found in events' content, regardless of type.""" # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 @@ -152,6 +152,6 @@ class EventContentFields(object): SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" -class RoomEncryptionAlgorithms(object): +class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" DEFAULT = MEGOLM_V1_AES_SHA2 diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 4888c0ec4d..94a9e58eae 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -31,7 +31,7 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) -class Codes(object): +class Codes: UNRECOGNIZED = "M_UNRECOGNIZED" UNAUTHORIZED = "M_UNAUTHORIZED" FORBIDDEN = "M_FORBIDDEN" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index a8937d2595..2a2c9e6f13 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -130,7 +130,7 @@ def matrix_user_id_validator(user_id_str): return UserID.from_string(user_id_str) -class Filtering(object): +class Filtering: def __init__(self, hs): super(Filtering, self).__init__() self.store = hs.get_datastore() @@ -168,7 +168,7 @@ class Filtering(object): raise SynapseError(400, str(e)) -class FilterCollection(object): +class FilterCollection: def __init__(self, filter_json): self._filter_json = filter_json @@ -249,7 +249,7 @@ class FilterCollection(object): ) -class Filter(object): +class Filter: def __init__(self, filter_json): self.filter_json = filter_json diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index e62ae50ac2..5d9d5a228f 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -21,7 +21,7 @@ from synapse.types import Requester from synapse.util import Clock -class Ratelimiter(object): +class Ratelimiter: """ Ratelimit actions marked by arbitrary keys. diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index d7baf2bc39..f3ecbf36b6 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -18,7 +18,7 @@ from typing import Dict import attr -class EventFormatVersions(object): +class EventFormatVersions: """This is an internal enum for tracking the version of the event format, independently from the room version. """ @@ -35,20 +35,20 @@ KNOWN_EVENT_FORMAT_VERSIONS = { } -class StateResolutionVersions(object): +class StateResolutionVersions: """Enum to identify the state resolution algorithms""" V1 = 1 # room v1 state res V2 = 2 # MSC1442 state res: room v2 and later -class RoomDisposition(object): +class RoomDisposition: STABLE = "stable" UNSTABLE = "unstable" @attr.s(slots=True, frozen=True) -class RoomVersion(object): +class RoomVersion: """An object which describes the unique attributes of a room version.""" identifier = attr.ib() # str; the identifier for this version @@ -69,7 +69,7 @@ class RoomVersion(object): limit_notifications_power_levels = attr.ib(type=bool) -class RoomVersions(object): +class RoomVersions: V1 = RoomVersion( "1", RoomDisposition.STABLE, diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bd03ebca5a..bbfccf955e 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -33,7 +33,7 @@ MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" -class ConsentURIBuilder(object): +class ConsentURIBuilder: def __init__(self, hs_config): """ Args: diff --git a/synapse/app/_base.py b/synapse/app/_base.py index a43dc5b2c9..fb476ddaf5 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -349,7 +349,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): reactor.installNameResolver(new_resolver) -class _LimitedHostnameResolver(object): +class _LimitedHostnameResolver: """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups. """ @@ -409,7 +409,7 @@ class _LimitedHostnameResolver(object): yield deferred -class _DeferredResolutionReceiver(object): +class _DeferredResolutionReceiver: """Wraps a IResolutionReceiver and simply resolves the given deferred when resolution is complete """ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 739b013d4c..f985810e88 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -745,7 +745,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): self.send_handler.wake_destination(server) -class FederationSenderHandler(object): +class FederationSenderHandler: """Processes the fedration replication stream This class is only instantiate on the worker responsible for sending outbound diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 69a7182ef4..13ec1f71a6 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -27,12 +27,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ApplicationServiceState(object): +class ApplicationServiceState: DOWN = "down" UP = "up" -class AppServiceTransaction(object): +class AppServiceTransaction: """Represents an application service transaction.""" def __init__(self, service, id, events): @@ -64,7 +64,7 @@ class AppServiceTransaction(object): await store.complete_appservice_txn(service=self.service, txn_id=self.id) -class ApplicationService(object): +class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index d5204b1314..8eb8c6f51c 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process logger = logging.getLogger(__name__) -class ApplicationServiceScheduler(object): +class ApplicationServiceScheduler: """ Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this case is a simple array. @@ -86,7 +86,7 @@ class ApplicationServiceScheduler(object): self.queuer.enqueue(service, event) -class _ServiceQueuer(object): +class _ServiceQueuer: """Queue of events waiting to be sent to appservices. Groups events into transactions per-appservice, and sends them on to the @@ -133,7 +133,7 @@ class _ServiceQueuer(object): self.requests_in_flight.discard(service.id) -class _TransactionController(object): +class _TransactionController: """Transaction manager. Builds AppServiceTransactions and runs their lifecycle. Also starts a Recoverer @@ -209,7 +209,7 @@ class _TransactionController(object): return state == ApplicationServiceState.UP or state is None -class _Recoverer(object): +class _Recoverer: """Manages retries and backoff for a DOWN appservice. We have one of these for each appservice which is currently considered DOWN. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1417487427..ad5ab6ad62 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -88,7 +88,7 @@ def path_exists(file_path): return False -class Config(object): +class Config: """ A configuration section, containing configuration keys and values. @@ -283,7 +283,7 @@ def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: return mxc_to_http_filter -class RootConfig(object): +class RootConfig: """ Holder of an application's configuration. diff --git a/synapse/config/cache.py b/synapse/config/cache.py index aff5b21ab2..8e03f14005 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -33,7 +33,7 @@ _DEFAULT_FACTOR_SIZE = 0.5 _DEFAULT_EVENT_CACHE_SIZE = "10K" -class CacheProperties(object): +class CacheProperties: def __init__(self): # The default factor size for all caches self.default_factor_size = float( diff --git a/synapse/config/key.py b/synapse/config/key.py index b529ea5da0..de964dff13 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -82,7 +82,7 @@ logger = logging.getLogger(__name__) @attr.s -class TrustedKeyServer(object): +class TrustedKeyServer: # string: name of the server. server_name = attr.ib() diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 6aad0d37c0..dfd27e1523 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -22,7 +22,7 @@ from ._base import Config, ConfigError @attr.s -class MetricsFlags(object): +class MetricsFlags: known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool)) @classmethod diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index b2c78ac40c..14b8836197 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -17,7 +17,7 @@ from typing import Dict from ._base import Config -class RateLimitConfig(object): +class RateLimitConfig: def __init__( self, config: Dict[str, float], @@ -27,7 +27,7 @@ class RateLimitConfig(object): self.burst_count = config.get("burst_count", defaults["burst_count"]) -class FederationRateLimitConfig(object): +class FederationRateLimitConfig: _items_and_default = { "window_size": 1000, "sleep_limit": 10, diff --git a/synapse/config/room.py b/synapse/config/room.py index 52cf0b62fc..692d7a1936 100644 --- a/synapse/config/room.py +++ b/synapse/config/room.py @@ -22,7 +22,7 @@ from ._base import Config, ConfigError logger = logging.Logger(__name__) -class RoomDefaultEncryptionTypes(object): +class RoomDefaultEncryptionTypes: """Possible values for the encryption_enabled_by_default_for_room_type config option""" ALL = "all" diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 7ac7699676..6de1f9d103 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -149,7 +149,7 @@ class RoomDirectoryConfig(Config): return False -class _RoomDirectoryRule(object): +class _RoomDirectoryRule: """Helper class to test whether a room directory action is allowed, like creating an alias or publishing a room. """ diff --git a/synapse/config/server.py b/synapse/config/server.py index 526a90b26a..e85c6a0840 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -424,7 +424,7 @@ class ServerConfig(Config): self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) @attr.s - class LimitRemoteRoomsConfig(object): + class LimitRemoteRoomsConfig: enabled = attr.ib( validator=attr.validators.instance_of(bool), default=False ) diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 777c0f00b1..2b03f5ac76 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -83,7 +83,7 @@ class ServerContextFactory(ContextFactory): @implementer(IPolicyForHTTPS) -class FederationPolicyForHTTPS(object): +class FederationPolicyForHTTPS: """Factory for Twisted SSLClientConnectionCreators that are used to make connections to remote servers for federation. @@ -152,7 +152,7 @@ class FederationPolicyForHTTPS(object): @implementer(IPolicyForHTTPS) -class RegularPolicyForHTTPS(object): +class RegularPolicyForHTTPS: """Factory for Twisted SSLClientConnectionCreators that are used to make connections to remote servers, for other than federation. @@ -189,7 +189,7 @@ def _context_info_cb(ssl_connection, where, ret): @implementer(IOpenSSLClientConnectionCreator) -class SSLClientConnectionCreator(object): +class SSLClientConnectionCreator: """Creates openssl connection objects for client connections. Replaces twisted.internet.ssl.ClientTLSOptions @@ -214,7 +214,7 @@ class SSLClientConnectionCreator(object): return connection -class ConnectionVerifier(object): +class ConnectionVerifier: """Set the SNI, and do cert verification This is a thing which is attached to the TLSMemoryBIOProtocol, and is called by diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 81c4b430b2..32c31b1cd1 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -57,7 +57,7 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, cmp=False) -class VerifyJsonRequest(object): +class VerifyJsonRequest: """ A request to verify a JSON object. @@ -96,7 +96,7 @@ class KeyLookupError(ValueError): pass -class Keyring(object): +class Keyring: def __init__(self, hs, key_fetchers=None): self.clock = hs.get_clock() @@ -420,7 +420,7 @@ class Keyring(object): remaining_requests.difference_update(completed) -class KeyFetcher(object): +class KeyFetcher: async def get_keys(self, keys_to_fetch): """ Args: @@ -456,7 +456,7 @@ class StoreKeyFetcher(KeyFetcher): return keys -class BaseV2KeyFetcher(object): +class BaseV2KeyFetcher: def __init__(self, hs): self.store = hs.get_datastore() self.config = hs.get_config() diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 62ea44fa49..bf800a3852 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -96,7 +96,7 @@ class DefaultDictProperty(DictProperty): return instance._dict.get(self.key, self.default) -class _EventInternalMetadata(object): +class _EventInternalMetadata: __slots__ = ["_dict"] def __init__(self, internal_metadata_dict: JsonDict): diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 7878cd7044..b6c47be646 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -36,7 +36,7 @@ from synapse.util.stringutils import random_string @attr.s(slots=True, cmp=False, frozen=True) -class EventBuilder(object): +class EventBuilder: """A format independent event builder used to build up the event content before signing the event. @@ -164,7 +164,7 @@ class EventBuilder(object): ) -class EventBuilderFactory(object): +class EventBuilderFactory: def __init__(self, hs): self.clock = hs.get_clock() self.hostname = hs.hostname diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index a7cddac974..b0fc859a47 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -25,7 +25,7 @@ if MYPY: import synapse.server -class SpamChecker(object): +class SpamChecker: def __init__(self, hs: "synapse.server.HomeServer"): self.spam_checkers = [] # type: List[Any] diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 2956a64234..9d5310851c 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -18,7 +18,7 @@ from synapse.events.snapshot import EventContext from synapse.types import Requester -class ThirdPartyEventRules(object): +class ThirdPartyEventRules: """Allows server admins to provide a Python module implementing an extra set of rules to apply when processing events. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 2d42e268c6..32c73d3413 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -322,7 +322,7 @@ def serialize_event( return d -class EventClientSerializer(object): +class EventClientSerializer: """Serializes events that are to be sent to clients. This is used for bundling extra information with any events to be sent to diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 5ce3874fba..9df35b54ba 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -20,7 +20,7 @@ from synapse.events.utils import validate_canonicaljson from synapse.types import EventID, RoomID, UserID -class EventValidator(object): +class EventValidator: def validate_new(self, event, config): """Validates the event has roughly the right format diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 420df2385f..38aa47963f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -39,7 +39,7 @@ from synapse.types import JsonDict, get_domain_from_id logger = logging.getLogger(__name__) -class FederationBase(object): +class FederationBase: def __init__(self, hs): self.hs = hs diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 630f571cd4..218df884b0 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -785,7 +785,7 @@ def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: return regex.match(server_name) -class FederationHandlerRegistry(object): +class FederationHandlerRegistry: """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. """ diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index de1fe7da38..079e2b2fe0 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -29,7 +29,7 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) -class TransactionActions(object): +class TransactionActions: """ Defines persistence actions that relate to handling Transactions. """ diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 4d65d4aeea..8e46957d15 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -46,7 +46,7 @@ from .units import Edu logger = logging.getLogger(__name__) -class FederationRemoteSendQueue(object): +class FederationRemoteSendQueue: """A drop in replacement for FederationSender""" def __init__(self, hs): @@ -365,7 +365,7 @@ class FederationRemoteSendQueue(object): ) -class BaseFederationRow(object): +class BaseFederationRow: """Base class for rows to be sent in the federation stream. Specifies how to identify, serialize and deserialize the different types. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 5276c1734f..552519e82c 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -56,7 +56,7 @@ sent_pdus_destination_dist_total = Counter( ) -class FederationSender(object): +class FederationSender: def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs self.server_name = hs.hostname diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index f1534d431d..defc228c23 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -53,7 +53,7 @@ sent_edus_by_type = Counter( ) -class PerDestinationQueue(object): +class PerDestinationQueue: """ Manages the per-destination transmission queues. diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 0ebc70d57d..c84072ab73 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class TransactionManager(object): +class TransactionManager: """Helper class which handles building and sending transactions shared between PerDestinationQueue objects diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9ea821dbb2..17a10f622e 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -30,7 +30,7 @@ from synapse.logging.utils import log_function logger = logging.getLogger(__name__) -class TransportLayerClient(object): +class TransportLayerClient: """Sends federation HTTP requests to other servers""" def __init__(self, hs): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5e111aa902..9325e0f857 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -100,7 +100,7 @@ class NoAuthenticationError(AuthenticationError): pass -class Authenticator(object): +class Authenticator: def __init__(self, hs: HomeServer): self._clock = hs.get_clock() self.keyring = hs.get_keyring() @@ -228,7 +228,7 @@ def _parse_auth_header(header_bytes): ) -class BaseFederationServlet(object): +class BaseFederationServlet: """Abstract base class for federation servlet classes. The servlet object should have a PATH attribute which takes the form of a regexp to diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index e674bf44a2..a86b3debc5 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -60,7 +60,7 @@ DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 -class GroupAttestationSigning(object): +class GroupAttestationSigning: """Creates and verifies group attestations. """ @@ -124,7 +124,7 @@ class GroupAttestationSigning(object): ) -class GroupAttestionRenewer(object): +class GroupAttestionRenewer: """Responsible for sending and receiving attestation updates. """ diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8cb922ddc7..1dd20ee4e1 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # TODO: Flairs -class GroupsServerWorkerHandler(object): +class GroupsServerWorkerHandler: def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 2dd183018a..286f0054be 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -20,7 +20,7 @@ from .identity import IdentityHandler from .search import SearchHandler -class Handlers(object): +class Handlers: """ Deprecated. A collection of handlers. diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index ba2bf99800..0206320e96 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -25,7 +25,7 @@ from synapse.types import UserID logger = logging.getLogger(__name__) -class BaseHandler(object): +class BaseHandler: """ Common base class for the event handlers. """ diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index a8d3fbc6de..9112a0ab86 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. -class AccountDataEventSource(object): +class AccountDataEventSource: def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index b865bf5b48..4caf6d591a 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -29,7 +29,7 @@ from synapse.util import stringutils logger = logging.getLogger(__name__) -class AccountValidityHandler(object): +class AccountValidityHandler: def __init__(self, hs): self.hs = hs self.config = hs.config diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 7666d3abcd..8476256a59 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -34,7 +34,7 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC --------------------------------------------------------------------------------""" -class AcmeHandler(object): +class AcmeHandler: def __init__(self, hs): self.hs = hs self.reactor = hs.get_reactor() diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py index e1d4224e74..69650ff221 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py @@ -78,7 +78,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou @attr.s @implementer(ICertificateStore) -class ErsatzStore(object): +class ErsatzStore: """ A store that only stores in memory. """ diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 506bb2b275..918d0e037c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -197,7 +197,7 @@ class AdminHandler(BaseHandler): return writer.finished() -class ExfiltrationWriter(object): +class ExfiltrationWriter: """Interface used to specify how to write exported data. """ diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index c9044a5019..9d4e87dad6 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") -class ApplicationServicesHandler(object): +class ApplicationServicesHandler: def __init__(self, hs): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f0b0a4d76a..90189869cc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1236,7 +1236,7 @@ class AuthHandler(BaseHandler): @attr.s -class MacaroonGenerator(object): +class MacaroonGenerator: hs = attr.ib() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ee4666337a..643d71a710 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -497,7 +497,7 @@ def _update_device_from_client_ips(device, client_ips): device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) -class DeviceListUpdater(object): +class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs, device_handler): diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index dcb4c82244..64ef7f63ab 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -31,7 +31,7 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) -class DeviceMessageHandler(object): +class DeviceMessageHandler: def __init__(self, hs): """ Args: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index dfd1c78549..d629c7c16c 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -43,7 +43,7 @@ from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) -class E2eKeysHandler(object): +class E2eKeysHandler: def __init__(self, hs): self.store = hs.get_datastore() self.federation = hs.get_federation_client() @@ -1212,7 +1212,7 @@ class SignatureListItem: signature = attr.ib() -class SigningKeyEduUpdater(object): +class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" def __init__(self, hs, e2e_keys_handler): diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 0bb983dc28..f01b090772 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -29,7 +29,7 @@ from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) -class E2eRoomKeysHandler(object): +class E2eRoomKeysHandler: """ Implements an optional realtime backup mechanism for encrypted E2E megolm room keys. This gives a way for users to store and recover their megolm keys if they lose all diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 0e2656ccb3..44df567983 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -52,7 +52,7 @@ def _create_rerouter(func_name): return f -class GroupsLocalWorkerHandler(object): +class GroupsLocalWorkerHandler: def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72bb638167..8a7b4916cd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -64,7 +64,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class MessageHandler(object): +class MessageHandler: """Contains some read only APIs to get state about a room """ @@ -361,7 +361,7 @@ class MessageHandler(object): _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000 -class EventCreationHandler(object): +class EventCreationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 63d7edff87..34ed0e2921 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class PurgeStatus(object): +class PurgeStatus: """Object tracking the status of a purge request This class contains information on the progress of a purge request, for @@ -65,7 +65,7 @@ class PurgeStatus(object): return {"status": PurgeStatus.STATUS_TEXT[self.status]} -class PaginationHandler(object): +class PaginationHandler: """Handles pagination and purge history requests. These are in the same handler due to the fact we need to block clients diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py index d06b110269..88e2f87200 100644 --- a/synapse/handlers/password_policy.py +++ b/synapse/handlers/password_policy.py @@ -22,7 +22,7 @@ from synapse.api.errors import Codes, PasswordRefusedError logger = logging.getLogger(__name__) -class PasswordPolicyHandler(object): +class PasswordPolicyHandler: def __init__(self, hs): self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1846068150..91a3aec1cc 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1010,7 +1010,7 @@ def format_user_presence_state(state, now, include_user_id=True): return content -class PresenceEventSource(object): +class PresenceEventSource: def __init__(self, hs): # We can't call get_presence_handler here because there's a cycle: # diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index f922d8a545..2cc6c2eb68 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -123,7 +123,7 @@ class ReceiptsHandler(BaseHandler): await self.federation.send_read_receipt(receipt) -class ReceiptEventSource(object): +class ReceiptEventSource: def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9d5b1828df..a29305f655 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -974,7 +974,7 @@ class RoomCreationHandler(BaseHandler): raise StoreError(500, "Couldn't generate a room ID.") -class RoomContextHandler(object): +class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() @@ -1084,7 +1084,7 @@ class RoomContextHandler(object): return results -class RoomEventSource(object): +class RoomEventSource: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -1146,7 +1146,7 @@ class RoomEventSource(object): return self.store.get_room_events_max_id(room_id) -class RoomShutdownHandler(object): +class RoomShutdownHandler: DEFAULT_MESSAGE = ( "Sharing illegal content on this server is not permitted and rooms in" diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a7962b0ada..32b7e323fa 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -51,7 +51,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RoomMemberHandler(object): +class RoomMemberHandler: # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index b426199aa6..66b063f991 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -360,12 +360,12 @@ MXID_MAPPER_MAP = { @attr.s -class SamlConfig(object): +class SamlConfig: mxid_source_attribute = attr.ib() mxid_mapper = attr.ib() -class DefaultSamlMappingProvider(object): +class DefaultSamlMappingProvider: __version__ = "0.0.1" def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi): diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index 8590c1eff4..7a4ae0727a 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) -class StateDeltasHandler(object): +class StateDeltasHandler: def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 8728403e62..e2ddb628ff 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -246,7 +246,7 @@ class SyncResult: __bool__ = __nonzero__ # python3 -class SyncHandler(object): +class SyncHandler: def __init__(self, hs: "HomeServer"): self.hs_config = hs.config self.store = hs.get_datastore() @@ -2075,7 +2075,7 @@ class SyncResultBuilder: @attr.s -class RoomSyncResultBuilder(object): +class RoomSyncResultBuilder: """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 1d828bd7be..3cbfc2d780 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -412,7 +412,7 @@ class TypingWriterHandler(FollowerTypingHandler): raise Exception("Typing writer instance got typing info over replication") -class TypingNotificationEventSource(object): +class TypingNotificationEventSource: def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() diff --git a/synapse/http/client.py b/synapse/http/client.py index dad01a8e56..13fcab3378 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -86,7 +86,7 @@ def _make_scheduler(reactor): return _scheduler -class IPBlacklistingResolver(object): +class IPBlacklistingResolver: """ A proxy for reactor.nameResolver which only produces non-blacklisted IP addresses, preventing DNS rebinding attacks on URL preview. @@ -133,7 +133,7 @@ class IPBlacklistingResolver(object): r.resolutionComplete() @provider(IResolutionReceiver) - class EndpointReceiver(object): + class EndpointReceiver: @staticmethod def resolutionBegan(resolutionInProgress): pass @@ -192,7 +192,7 @@ class BlacklistingAgentWrapper(Agent): ) -class SimpleHttpClient(object): +class SimpleHttpClient: """ A simple, no-frills HTTP client with methods that wrap up common ways of using HTTP in Matrix @@ -244,7 +244,7 @@ class SimpleHttpClient(object): ) @implementer(IReactorPluggableNameResolver) - class Reactor(object): + class Reactor: def __getattr__(_self, attr): if attr == "nameResolver": return nameResolver diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index be7b2ceb8e..856e28454f 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -31,7 +31,7 @@ class ProxyConnectError(ConnectError): @implementer(IStreamClientEndpoint) -class HTTPConnectProxyEndpoint(object): +class HTTPConnectProxyEndpoint: """An Endpoint implementation which will send a CONNECT request to an http proxy Wraps an existing HostnameEndpoint for the proxy. diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 782d39d4ca..83d6196d4a 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) @implementer(IAgent) -class MatrixFederationAgent(object): +class MatrixFederationAgent: """An Agent-like thing which provides a `request` method which correctly handles resolving matrix server names when using matrix://. Handles standard https URIs as normal. @@ -175,7 +175,7 @@ class MatrixFederationAgent(object): @implementer(IAgentEndpointFactory) -class MatrixHostnameEndpointFactory(object): +class MatrixHostnameEndpointFactory: """Factory for MatrixHostnameEndpoint for parsing to an Agent. """ @@ -198,7 +198,7 @@ class MatrixHostnameEndpointFactory(object): @implementer(IStreamClientEndpoint) -class MatrixHostnameEndpoint(object): +class MatrixHostnameEndpoint: """An endpoint that resolves matrix:// URLs using Matrix server name resolution (i.e. via SRV). Does not check for well-known delegation. diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 2ede90a9b1..d9620032d2 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -33,7 +33,7 @@ SERVER_CACHE = {} @attr.s(slots=True, frozen=True) -class Server(object): +class Server: """ Our record of an individual server which can be tried to reach a destination. @@ -96,7 +96,7 @@ def _sort_server_list(server_list): return results -class SrvResolver(object): +class SrvResolver: """Interface to the dns client to do SRV lookups, with result caching. The default resolver in twisted.names doesn't do any caching (it has a CacheResolver, diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index cdb6bec56e..e6f067ca29 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -71,11 +71,11 @@ _had_valid_well_known_cache = TTLCache("had-valid-well-known") @attr.s(slots=True, frozen=True) -class WellKnownLookupResult(object): +class WellKnownLookupResult: delegated_server = attr.ib() -class WellKnownResolver(object): +class WellKnownResolver: """Handles well-known lookups for matrix servers. """ diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 738be43f46..775fad3be4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -76,7 +76,7 @@ _next_id = 1 @attr.s(frozen=True) -class MatrixFederationRequest(object): +class MatrixFederationRequest: method = attr.ib() """HTTP method :type: str @@ -203,7 +203,7 @@ async def _handle_json_response( return body -class MatrixFederationHttpClient(object): +class MatrixFederationHttpClient: """HTTP client used to talk to other homeservers over the federation protocol. Send client certificates and signs requests. @@ -226,7 +226,7 @@ class MatrixFederationHttpClient(object): ) @implementer(IReactorPluggableNameResolver) - class Reactor(object): + class Reactor: def __getattr__(_self, attr): if attr == "nameResolver": return nameResolver diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index b58ae3d9db..cd94e789e8 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -145,7 +145,7 @@ LaterGauge( ) -class RequestMetrics(object): +class RequestMetrics: def start(self, time_sec, name, method): self.start = time_sec self.start_context = current_context() diff --git a/synapse/http/server.py b/synapse/http/server.py index 8d791bd2ca..996a31a9ec 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -174,7 +174,7 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer(object): +class HttpServer: """ Interface for registering callbacks on a HTTP server """ diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 53acba56cb..fd90ba7828 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -256,7 +256,7 @@ def assert_params_in_dict(body, required): raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) -class RestServlet(object): +class RestServlet: """ A Synapse REST Servlet. diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 7372450b45..144506c8f2 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -55,7 +55,7 @@ def stdlib_log_level_to_twisted(level: str) -> LogLevel: @attr.s @implementer(ILogObserver) -class LogContextObserver(object): +class LogContextObserver: """ An ILogObserver which adds Synapse-specific log context information. @@ -169,7 +169,7 @@ class OutputPipeType(Values): @attr.s -class DrainConfiguration(object): +class DrainConfiguration: name = attr.ib() type = attr.ib() location = attr.ib() @@ -177,7 +177,7 @@ class DrainConfiguration(object): @attr.s -class NetworkJSONTerseOptions(object): +class NetworkJSONTerseOptions: maximum_buffer = attr.ib(type=int) diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index c0b9384189..1b8916cfa2 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -152,7 +152,7 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb @attr.s @implementer(IPushProducer) -class LogProducer(object): +class LogProducer: """ An IPushProducer that writes logs from its buffer to its transport when it is resumed. @@ -190,7 +190,7 @@ class LogProducer(object): @attr.s @implementer(ILogObserver) -class TerseJSONToTCPLogObserver(object): +class TerseJSONToTCPLogObserver: """ An IObserver that writes JSON logs to a TCP target. diff --git a/synapse/logging/context.py b/synapse/logging/context.py index cbeeb870cb..22598e02d2 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -74,7 +74,7 @@ except Exception: get_thread_id = threading.get_ident -class ContextResourceUsage(object): +class ContextResourceUsage: """Object for tracking the resources used by a log context Attributes: @@ -179,7 +179,7 @@ class ContextResourceUsage(object): LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] -class _Sentinel(object): +class _Sentinel: """Sentinel to represent the root context""" __slots__ = ["previous_context", "finished", "request", "scope", "tag"] @@ -226,7 +226,7 @@ class _Sentinel(object): SENTINEL_CONTEXT = _Sentinel() -class LoggingContext(object): +class LoggingContext: """Additional context for log formatting. Contexts are scoped within a "with" block. diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index d39ac62168..7df0aa197d 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -185,7 +185,7 @@ if TYPE_CHECKING: # Helper class -class _DummyTagNames(object): +class _DummyTagNames: """wrapper of opentracings tags. We need to have them if we want to reference them without opentracing around. Clearly they should never actually show up in a trace. `set_tags` overwrites diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 6035672698..2643380d9e 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -51,7 +51,7 @@ all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollec HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") -class RegistryProxy(object): +class RegistryProxy: @staticmethod def collect(): for metric in REGISTRY.collect(): @@ -60,7 +60,7 @@ class RegistryProxy(object): @attr.s(hash=True) -class LaterGauge(object): +class LaterGauge: name = attr.ib(type=str) desc = attr.ib(type=str) @@ -100,7 +100,7 @@ class LaterGauge(object): all_gauges[self.name] = self -class InFlightGauge(object): +class InFlightGauge: """Tracks number of things (e.g. requests, Measure blocks, etc) in flight at any given time. @@ -206,7 +206,7 @@ class InFlightGauge(object): @attr.s(hash=True) -class BucketCollector(object): +class BucketCollector: """ Like a Histogram, but allows buckets to be point-in-time instead of incrementally added to. @@ -269,7 +269,7 @@ class BucketCollector(object): # -class CPUMetrics(object): +class CPUMetrics: def __init__(self): ticks_per_sec = 100 try: @@ -329,7 +329,7 @@ gc_time = Histogram( ) -class GCCounts(object): +class GCCounts: def collect(self): cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) for n, m in enumerate(gc.get_count()): @@ -347,7 +347,7 @@ if not running_on_pypy: # -class PyPyGCStats(object): +class PyPyGCStats: def collect(self): # @stats is a pretty-printer object with __str__() returning a nice table, @@ -482,7 +482,7 @@ build_info.labels( last_ticked = time.time() -class ReactorLastSeenMetric(object): +class ReactorLastSeenMetric: def collect(self): cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 4cd7932e5b..5b73463504 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -105,7 +105,7 @@ _background_processes_active_since_last_scrape = set() # type: Set[_BackgroundP _bg_metrics_lock = threading.Lock() -class _Collector(object): +class _Collector: """A custom metrics collector for the background process metrics. Ensures that all of the metrics are up-to-date with any in-flight processes @@ -140,7 +140,7 @@ class _Collector(object): REGISTRY.register(_Collector()) -class _BackgroundProcess(object): +class _BackgroundProcess: def __init__(self, desc, ctx): self.desc = desc self._context = ctx diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ae0e359a77..fcbd5378c4 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -31,7 +31,7 @@ __all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi" logger = logging.getLogger(__name__) -class ModuleApi(object): +class ModuleApi: """A proxy object that gets passed to various plugin modules so they can register new users etc if necessary. """ diff --git a/synapse/notifier.py b/synapse/notifier.py index dfb096e589..b7f4041306 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -68,7 +68,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int: return n -class _NotificationListener(object): +class _NotificationListener: """ This represents a single client connection to the events stream. The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. @@ -80,7 +80,7 @@ class _NotificationListener(object): self.deferred = deferred -class _NotifierUserStream(object): +class _NotifierUserStream: """This represents a user connected to the event stream. It tracks the most recent stream token for that user. At a given point a user may have a number of streams listening for @@ -168,7 +168,7 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): __bool__ = __nonzero__ # python3 -class Notifier(object): +class Notifier: """ This class is responsible for notifying any listeners when there are new events available for it. diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 0d23142653..fabc9ba126 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -22,7 +22,7 @@ from .bulk_push_rule_evaluator import BulkPushRuleEvaluator logger = logging.getLogger(__name__) -class ActionGenerator(object): +class ActionGenerator: def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index e7fa02b78b..1bb8e346b9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -95,7 +95,7 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: return False -class BulkPushRuleEvaluator(object): +class BulkPushRuleEvaluator: """Calculates the outcome of push rules for an event for all users in the room at once. """ @@ -263,7 +263,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): return True -class RulesForRoom(object): +class RulesForRoom: """Caches push rules for users in a room. This efficiently handles users joining/leaving the room by not invalidating diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 568c13eaea..b7ea4438e0 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -45,7 +45,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000 INCLUDE_ALL_UNREAD_NOTIFS = False -class EmailPusher(object): +class EmailPusher: """ A pusher that sends email notifications about events (approximately) when they happen. diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 4c469efb20..f21fa9b659 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -49,7 +49,7 @@ http_badges_failed_counter = Counter( ) -class HttpPusher(object): +class HttpPusher: INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes MAX_BACKOFF_SEC = 60 * 60 diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index c38e037281..6c57854018 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -92,7 +92,7 @@ ALLOWED_ATTRS = { # ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"] -class Mailer(object): +class Mailer: def __init__(self, hs, app_name, template_html, template_text): self.hs = hs self.template_html = template_html diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 2d79ada189..709ace01e5 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -105,7 +105,7 @@ def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: return tweaks -class PushRuleEvaluatorForEvent(object): +class PushRuleEvaluatorForEvent: def __init__( self, event: EventBase, diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index f626797133..2a52e226e3 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -23,7 +23,7 @@ from .httppusher import HttpPusher logger = logging.getLogger(__name__) -class PusherFactory(object): +class PusherFactory: def __init__(self, hs): self.hs = hs self.config = hs.config diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 6a28c2db9d..ba16f22c91 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -33,7 +33,7 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) -class ReplicationEndpoint(object): +class ReplicationEndpoint: """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 047f2c50f7..eb74903d68 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -16,7 +16,7 @@ from synapse.storage.util.id_generators import _load_current_id -class SlavedIdTracker(object): +class SlavedIdTracker: def __init__(self, db_conn, table, column, extra_tables=[], step=1): self.step = step self._current = _load_current_id(db_conn, table, column, step) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 0350923898..0b0d204e64 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -113,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5 PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER -class ConnectionStates(object): +class ConnectionStates: CONNECTING = "connecting" ESTABLISHED = "established" PAUSED = "paused" diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 41569305df..04d894fb3d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -58,7 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): ) -class ReplicationStreamer(object): +class ReplicationStreamer: """Handles replication connections. This needs to be poked when new replication data may be available. When new diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 8c3caf30c9..682d47f402 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -79,7 +79,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] -class Stream(object): +class Stream: """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 16c63ff4ec..f929fc3954 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -49,14 +49,14 @@ data part are: @attr.s(slots=True, frozen=True) -class EventsStreamRow(object): +class EventsStreamRow: """A parsed row from the events replication stream""" type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow -class BaseEventsStreamRow(object): +class BaseEventsStreamRow: """Base class for rows to be sent in the events stream. Specifies how to identify, serialize and deserialize the different types. diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 6da71dc46f..7be5c0fb88 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins -class HttpTransactionCache(object): +class HttpTransactionCache: def __init__(self, hs): self.hs = hs self.auth = self.hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 51372cdb5e..b6b90a8b30 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -658,7 +658,7 @@ class RegisterRestServlet(RestServlet): (object) params: registration parameters, from which we pull device_id, initial_device_name and inhibit_login Returns: - (object) dictionary for response from /register + dictionary for response from /register """ result = {"user_id": user_id, "home_server": self.hs.hostname} if not params.get("inhibit_login", False): diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 20ddb9550b..6568e61829 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -235,7 +235,7 @@ async def respond_with_responder( finish_request(request) -class Responder(object): +class Responder: """Represents a response that can be streamed to the requester. Responder is a context manager which *must* be used, so that any resources @@ -260,7 +260,7 @@ class Responder(object): pass -class FileInfo(object): +class FileInfo: """Details about a requested/uploaded file. Attributes: diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index e25c382c9c..d2826374a7 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -33,7 +33,7 @@ def _wrap_in_base_path(func): return _wrapped -class MediaFilePaths(object): +class MediaFilePaths: """Describes where files are stored on disk. Most of the functions have a `*_rel` variant which returns a file path that diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 6fb4039e98..9a1b7779f7 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -62,7 +62,7 @@ logger = logging.getLogger(__name__) UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 -class MediaRepository(object): +class MediaRepository: def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index ab1fa705bf..3a352b5631 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class MediaStorage(object): +class MediaStorage: """Responsible for storing/fetching files from local sources. Args: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 7126997134..d681bf7bf0 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -31,7 +31,7 @@ EXIF_TRANSPOSE_MAPPINGS = { } -class Thumbnailer(object): +class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index e15e13b756..f591cc6c5c 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -23,7 +23,7 @@ from synapse.util import json_encoder logger = logging.getLogger(__name__) -class WellKnownBuilder(object): +class WellKnownBuilder: """Utility to construct the well-known response Args: diff --git a/synapse/secrets.py b/synapse/secrets.py index ff86950a54..fb6d90a3b7 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -37,7 +37,7 @@ else: import binascii import os - class Secrets(object): + class Secrets: def token_bytes(self, nbytes=32): return os.urandom(nbytes) diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 089cfef0b3..3673e7f47e 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -23,7 +23,7 @@ from synapse.types import get_localpart_from_id logger = logging.getLogger(__name__) -class ConsentServerNotices(object): +class ConsentServerNotices: """Keeps track of whether we need to send users server_notices about privacy policy consent, and sends one if we do. """ diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index c2faef6eab..2258d306d9 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -27,7 +27,7 @@ from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG logger = logging.getLogger(__name__) -class ResourceLimitsServerNotices(object): +class ResourceLimitsServerNotices: """ Keeps track of whether the server has reached it's resource limit and ensures that the client is kept up to date. """ diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index ed96aa8571..0422d4c7ce 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) SERVER_NOTICE_ROOM_TAG = "m.server_notice" -class ServerNoticesManager(object): +class ServerNoticesManager: def __init__(self, hs): """ diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index a754f75db4..6870b67ca0 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -20,7 +20,7 @@ from synapse.server_notices.resource_limits_server_notices import ( ) -class ServerNoticesSender(object): +class ServerNoticesSender: """A centralised place which sends server notices automatically when Certain Events take place """ diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py index e9390b19da..9273e61895 100644 --- a/synapse/server_notices/worker_server_notices_sender.py +++ b/synapse/server_notices/worker_server_notices_sender.py @@ -14,7 +14,7 @@ # limitations under the License. -class WorkerServerNoticesSender(object): +class WorkerServerNoticesSender: """Stub impl of ServerNoticesSender which does nothing""" def __init__(self, hs): diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9be92e2565..395ac5ab02 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -36,7 +36,7 @@ class RegistrationBehaviour(Enum): DENY = "deny" -class SpamCheckerApi(object): +class SpamCheckerApi: """A proxy object that gets passed to spam checkers so they can get access to rooms and other relevant information. """ diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 9bf2ec368f..c7e3015b5d 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -77,7 +77,7 @@ def _gen_state_id(): return s -class _StateCacheEntry(object): +class _StateCacheEntry: __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] def __init__( @@ -113,7 +113,7 @@ class _StateCacheEntry(object): return len(self.state) -class StateHandler(object): +class StateHandler: """Fetches bits of state from the stores, and does state resolution where necessary """ @@ -462,7 +462,7 @@ class StateHandler(object): return {key: state_map[ev_id] for key, ev_id in new_state.items()} -class StateResolutionHandler(object): +class StateResolutionHandler: """Responsible for doing state conflict resolution. Note that the storage layer depends on this handler, so all functions must @@ -679,7 +679,7 @@ def resolve_events_with_store( @attr.s -class StateResolutionStore(object): +class StateResolutionStore: """Interface that allows state resolution algorithms to access the database in well defined way. diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 5ef3853559..8e5d78f6f7 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -37,7 +37,7 @@ from synapse.storage.state import StateGroupStorage __all__ = ["DataStores", "DataStore"] -class Storage(object): +class Storage: """The high level interfaces for talking to various storage layers. """ diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 67a89cd51a..810721ebe9 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -24,7 +24,7 @@ from . import engines logger = logging.getLogger(__name__) -class BackgroundUpdatePerformance(object): +class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" def __init__(self, name): @@ -71,7 +71,7 @@ class BackgroundUpdatePerformance(object): return float(self.total_item_count) / float(self.total_duration_ms) -class BackgroundUpdater(object): +class BackgroundUpdater: """ Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 78ca6d8346..8be943f589 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -248,7 +248,7 @@ class LoggingTransaction: self.txn.close() -class PerformanceCounters(object): +class PerformanceCounters: def __init__(self): self.current_counters = {} self.previous_counters = {} @@ -286,7 +286,7 @@ class PerformanceCounters(object): R = TypeVar("R") -class DatabasePool(object): +class DatabasePool: """Wraps a single physical database and connection pool. A single database may be used by multiple data stores. diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 0ac854aee2..7f08bd8285 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -24,7 +24,7 @@ from synapse.storage.prepare_database import prepare_database logger = logging.getLogger(__name__) -class Databases(object): +class Databases: """The various databases. These are low level interfaces to physical databases. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index c46f5cd524..91a8b43da3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -999,7 +999,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): await self.db_pool.runInteraction("forget_membership", f) -class _JoinedHostsCache(object): +class _JoinedHostsCache: """Cache for joined hosts in a room that is optimised to handle updates via state deltas. """ diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 4769b21529..afd10f7bae 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -22,6 +22,6 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True) -class FetchKeyResult(object): +class FetchKeyResult: verify_key = attr.ib() # VerifyKey: the key itself valid_until_ts = attr.ib() # int: how long we can use this key for diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index f15b95e633..dbaeef91dd 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -69,7 +69,7 @@ stale_forward_extremities_counter = Histogram( ) -class _EventPeristenceQueue(object): +class _EventPeristenceQueue: """Queues up events so that they can be persisted in bulk with only one concurrent transaction per room. """ @@ -172,7 +172,7 @@ class _EventPeristenceQueue(object): pass -class EventsPersistenceStorage(object): +class EventsPersistenceStorage: """High level interface for handling persisting newly received events. Takes care of batching up events by room, and calculating the necessary diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 1c5f305132..964d8d9eb8 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -569,7 +569,7 @@ def _get_or_create_schema_state(txn, database_engine): @attr.s() -class _DirectoryListing(object): +class _DirectoryListing: """Helper class to store schema file name and the absolute path to it. diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index 79d9f06e2e..bfa0a9fd06 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -20,7 +20,7 @@ from typing import Set logger = logging.getLogger(__name__) -class PurgeEventsStorage(object): +class PurgeEventsStorage: """High level interface for purging rooms and event history. """ diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index d471ec9860..d30e3f11e7 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) @attr.s -class PaginationChunk(object): +class PaginationChunk: """Returned by relation pagination APIs. Attributes: @@ -51,7 +51,7 @@ class PaginationChunk(object): @attr.s(frozen=True, slots=True) -class RelationPaginationToken(object): +class RelationPaginationToken: """Pagination token for relation pagination API. As the results are in topological order, we can use the @@ -82,7 +82,7 @@ class RelationPaginationToken(object): @attr.s(frozen=True, slots=True) -class AggregationPaginationToken(object): +class AggregationPaginationToken: """Pagination token for relation aggregation pagination API. As the results are order by count and then MAX(stream_ordering) of the diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 96a1b59d64..8f68d968f0 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -29,7 +29,7 @@ T = TypeVar("T") @attr.s(slots=True) -class StateFilter(object): +class StateFilter: """A filter used when querying for state. Attributes: @@ -326,7 +326,7 @@ class StateFilter(object): return member_filter, non_member_filter -class StateGroupStorage(object): +class StateGroupStorage: """High level interface to fetching state for event. """ diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9f3d23f0a5..76bc3afdfa 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator -class IdGenerator(object): +class IdGenerator: def __init__(self, db_conn, table, column): self._lock = threading.Lock() self._next_id = _load_current_id(db_conn, table, column) @@ -59,7 +59,7 @@ def _load_current_id(db_conn, table, column, step=1): return (max if step > 0 else min)(current_id, step) -class StreamIdGenerator(object): +class StreamIdGenerator: """Used to generate new stream ids when persisting events while keeping track of which transactions have been completed. diff --git a/synapse/streams/config.py b/synapse/streams/config.py index ca7c16ff65..d97dc4d101 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) MAX_LIMIT = 1000 -class SourcePaginationConfig(object): +class SourcePaginationConfig: """A configuration object which stores pagination parameters for a specific event source.""" @@ -45,7 +45,7 @@ class SourcePaginationConfig(object): ) -class PaginationConfig(object): +class PaginationConfig: """A configuration object which stores pagination parameters.""" diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 7ab46f42bf..92fd5d489f 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -23,7 +23,7 @@ from synapse.handlers.typing import TypingNotificationEventSource from synapse.types import StreamToken -class EventSources(object): +class EventSources: SOURCE_TYPES = { "room": RoomEventSource, "presence": PresenceEventSource, diff --git a/synapse/types.py b/synapse/types.py index f8b9b03850..f7de48f148 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -529,7 +529,7 @@ class ThirdPartyInstanceID( @attr.s(slots=True) -class ReadReceipt(object): +class ReadReceipt: """Information about a read-receipt""" room_id = attr.ib() diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b2a22dbd5c..3ad4b28fc7 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -46,7 +46,7 @@ def unwrapFirstError(failure): @attr.s -class Clock(object): +class Clock: """ A Clock wraps a Twisted reactor and provides utilities on top of it. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index dfefbd996d..bb57e27beb 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -36,7 +36,7 @@ from synapse.util import Clock, unwrapFirstError logger = logging.getLogger(__name__) -class ObservableDeferred(object): +class ObservableDeferred: """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original deferred. @@ -188,7 +188,7 @@ def yieldable_gather_results(func, iter, *args, **kwargs): ).addErrback(unwrapFirstError) -class Linearizer(object): +class Linearizer: """Limits concurrent access to resources based on a key. Useful to ensure only a few things happen at a time on a given resource. @@ -338,7 +338,7 @@ class Linearizer(object): return new_defer -class ReadWriteLock(object): +class ReadWriteLock: """An async read write lock. Example: @@ -502,7 +502,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): @attr.s(slots=True, frozen=True) -class DoneAwaitable(object): +class DoneAwaitable: """Simple awaitable that returns the provided value. """ diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index dd356bf156..237f588658 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -43,7 +43,7 @@ response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["n @attr.s -class CacheMetric(object): +class CacheMetric: _cache = attr.ib() _cache_type = attr.ib(type=str) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 825810eb16..98b34f2223 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -64,7 +64,7 @@ cache_pending_metric = Gauge( _CacheSentinel = object() -class CacheEntry(object): +class CacheEntry: __slots__ = ["deferred", "callbacks", "invalidated"] def __init__(self, deferred, callbacks): @@ -80,7 +80,7 @@ class CacheEntry(object): self.callbacks.clear() -class Cache(object): +class Cache: __slots__ = ( "cache", "name", @@ -288,7 +288,7 @@ class Cache(object): self._pending_deferred_cache.clear() -class _CacheDescriptorBase(object): +class _CacheDescriptorBase: def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig @@ -705,7 +705,7 @@ def cachedList( Example: - class Example(object): + class Example: @cached(num_args=2) def do_something(self, first_arg): ... diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 6834e6f3ae..8592b93689 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -40,7 +40,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va return len(self.value) -class DictionaryCache(object): +class DictionaryCache: """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ @@ -53,7 +53,7 @@ class DictionaryCache(object): self.thread = None # caches_by_name[name] = self.cache - class Sentinel(object): + class Sentinel: __slots__ = [] self.sentinel = Sentinel() diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 89a3420f92..e15f7ee698 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) SENTINEL = object() -class ExpiringCache(object): +class ExpiringCache: def __init__( self, cache_name, @@ -190,7 +190,7 @@ class ExpiringCache(object): return False -class _CacheEntry(object): +class _CacheEntry: __slots__ = ["time", "value"] def __init__(self, time, value): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index df4ea5901d..4bc1a67b58 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -30,7 +30,7 @@ def enumerate_leaves(node, depth): yield m -class _Node(object): +class _Node: __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] def __init__(self, prev_node, next_node, key, value, callbacks=set()): @@ -41,7 +41,7 @@ class _Node(object): self.callbacks = callbacks -class LruCache(object): +class LruCache: """ Least-recently-used cache. Supports del_multi only if cache_type=TreeCache diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index a6c60888e5..df1a721add 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -23,7 +23,7 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -class ResponseCache(object): +class ResponseCache: """ This caches a deferred response. Until the deferred completes it will be returned from the cache. This means that if the client retries the request diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index ecd9948e79..eb4d98f683 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -3,7 +3,7 @@ from typing import Dict SENTINEL = object() -class TreeCache(object): +class TreeCache: """ Tree-based backing store for LruCache. Allows subtrees of data to be deleted efficiently. @@ -89,7 +89,7 @@ def iterate_tree_cache_entry(d): yield d -class _Entry(object): +class _Entry: __slots__ = ["value"] def __init__(self, value): diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 6437aa907e..3e180cafd3 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) SENTINEL = object() -class TTLCache(object): +class TTLCache: """A key/value cache implementation where each entry has its own TTL""" def __init__(self, cache_name, timer=time.time): @@ -154,7 +154,7 @@ class TTLCache(object): @attr.s(frozen=True, slots=True) -class _CacheEntry(object): +class _CacheEntry: """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 22a857a306..a750261e77 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -34,7 +34,7 @@ def user_joined_room(distributor, user, room_id): distributor.fire("user_joined_room", user=user, room_id=room_id) -class Distributor(object): +class Distributor: """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. @@ -103,7 +103,7 @@ def maybeAwaitableDeferred(f, *args, **kw): return succeed(result) -class Signal(object): +class Signal: """A Signal is a dispatch point that stores a list of callables as observers of it. diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 6a3f6177b1..733f5e26e6 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -20,7 +20,7 @@ from twisted.internet import threads from synapse.logging.context import make_deferred_yieldable, run_in_background -class BackgroundFileConsumer(object): +class BackgroundFileConsumer: """A consumer that writes to a file like object. Supports both push and pull producers diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index 6dce03dd3a..50516926f3 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -14,7 +14,7 @@ # limitations under the License. -class JsonEncodedObject(object): +class JsonEncodedObject: """ A common base class for defining protocol units that are represented as JSON. diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 13775b43f9..6e57c1ee72 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -93,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]: return wrapper -class Measure(object): +class Measure: __slots__ = [ "clock", "name", diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index e5efdfcd02..70d11e1ec3 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -29,7 +29,7 @@ from synapse.logging.context import ( logger = logging.getLogger(__name__) -class FederationRateLimiter(object): +class FederationRateLimiter: def __init__(self, clock, config): """ Args: @@ -60,7 +60,7 @@ class FederationRateLimiter(object): return self.ratelimiters[host].ratelimit() -class _PerHostRatelimiter(object): +class _PerHostRatelimiter: def __init__(self, clock, config): """ Args: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 919988d3bc..79869aaa44 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -114,7 +114,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k ) -class RetryDestinationLimiter(object): +class RetryDestinationLimiter: def __init__( self, destination, diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 023beb5ede..be3b22469d 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -14,7 +14,7 @@ # limitations under the License. -class _Entry(object): +class _Entry: __slots__ = ["end_key", "queue"] def __init__(self, end_key): @@ -22,7 +22,7 @@ class _Entry(object): self.queue = [] -class WheelTimer(object): +class WheelTimer: """Stores arbitrary objects that will be returned after their timers have expired. """ diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 5d45689c8c..8ab56ec94c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -36,7 +36,7 @@ from tests import unittest from tests.utils import mock_getRawHeaders, setup_test_homeserver -class TestHandlers(object): +class TestHandlers: def __init__(self, hs): self.auth_handler = synapse.handlers.auth.AuthHandler(hs) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d264653e74..2e6e7abf1f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -43,7 +43,7 @@ from tests import unittest from tests.test_utils import make_awaitable -class MockPerspectiveServer(object): +class MockPerspectiveServer: def __init__(self): self.server_name = "mock_server" self.key = signedjson.key.generate_signing_key(0) diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index 27d83bb7d9..72e22d655f 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -26,7 +26,7 @@ from tests.unittest import override_config class RoomDirectoryFederationTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - class Authenticator(object): + class Authenticator: def authenticate_request(self, request, content): return defer.succeed("otherserver.nottld") diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 4b3fb018b1..c7efd3822d 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -28,7 +28,7 @@ from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver -class AuthHandlers(object): +class AuthHandlers: def __init__(self, hs): self.auth_handler = AuthHandler(hs) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 60ebc95f3e..8e95e53d9e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -28,7 +28,7 @@ from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver -class ProfileHandlers(object): +class ProfileHandlers: def __init__(self, hs): self.profile_handler = MasterProfileHandler(hs) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 2096ba3c91..5d41443293 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -133,7 +133,7 @@ def create_test_cert_file(sanlist): @implementer(IOpenSSLServerConnectionCreator) -class TestServerTLSConnectionFactory(object): +class TestServerTLSConnectionFactory: """An SSL connection creator which returns connections which present a certificate signed by our test CA.""" diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index eb78ab412a..8b5ad4574f 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -1264,7 +1264,7 @@ def _log_request(request): @implementer(IPolicyForHTTPS) -class TrustingTLSPolicyForHTTPS(object): +class TrustingTLSPolicyForHTTPS: """An IPolicyForHTTPS which checks that the certificate belongs to the right server, but doesn't check the certificate chain.""" diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py index 451d05c0f0..d36f5f426c 100644 --- a/tests/logging/test_structured.py +++ b/tests/logging/test_structured.py @@ -29,12 +29,12 @@ from synapse.logging.context import LoggingContext from tests.unittest import DEBUG, HomeserverTestCase -class FakeBeginner(object): +class FakeBeginner: def beginLoggingTo(self, observers, **kwargs): self.observers = observers -class StructuredLoggingTestBase(object): +class StructuredLoggingTestBase: """ Test base that registers a cleanup handler to reset the stdlib log handler to 'unset'. diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 227b0d32d0..3224568640 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase @attr.s -class _User(object): +class _User: "Helper wrapper for user ID and access token" id = attr.ib() token = attr.ib() diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py index 7167fc56b6..8c24add530 100644 --- a/tests/rest/client/third_party_rules.py +++ b/tests/rest/client/third_party_rules.py @@ -19,7 +19,7 @@ from synapse.rest.client.v1 import login, room from tests import unittest -class ThirdPartyRulesTestModule(object): +class ThirdPartyRulesTestModule: def __init__(self, config): pass diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index e66c9a4c4c..afaf9f7b85 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -30,7 +30,7 @@ from tests.server import make_request, render @attr.s -class RestHelper(object): +class RestHelper: """Contains extra helper functions to quickly and clearly perform a given REST action, which isn't the focus of the test. """ diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 74765a582b..c00a7b9114 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -32,7 +32,7 @@ from tests.server import FakeTransport @attr.s -class FakeResponse(object): +class FakeResponse: version = attr.ib() code = attr.ib() phrase = attr.ib() @@ -43,7 +43,7 @@ class FakeResponse(object): @property def request(self): @attr.s - class FakeTransport(object): + class FakeTransport: absoluteURI = self.absoluteURI return FakeTransport() @@ -111,7 +111,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups = {} - class Resolver(object): + class Resolver: def resolveHostName( _self, resolutionReceiver, diff --git a/tests/server.py b/tests/server.py index b6e0b14e78..48e45c6c8b 100644 --- a/tests/server.py +++ b/tests/server.py @@ -35,7 +35,7 @@ class TimedOutException(Exception): @attr.s -class FakeChannel(object): +class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). @@ -242,7 +242,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): lookups = self.lookups = {} @implementer(IResolverSimple) - class FakeResolver(object): + class FakeResolver: def getHostByName(self, name, timeout=None): if name not in lookups: return fail(DNSLookupError("OH NO: unknown %s" % (name,))) @@ -371,7 +371,7 @@ def get_clock(): @attr.s(cmp=False) -class FakeTransport(object): +class FakeTransport: """ A twisted.internet.interfaces.ITransport implementation which sends all its data straight into an IProtocol object: it exists to connect two IProtocols together. diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index f2955a9c69..ad9bbef9d2 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -49,7 +49,7 @@ class FakeClock: return defer.succeed(None) -class FakeEvent(object): +class FakeEvent: """A fake event we use as a convenience. NOTE: Again as a convenience we use "node_ids" rather than event_ids to @@ -595,7 +595,7 @@ def pairwise(iterable): @attr.s -class TestStateResolutionStore(object): +class TestStateResolutionStore: event_map = attr.ib() def get_events(self, event_ids, allow_rejected=False): diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 319e2c2325..f5afed017c 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -99,7 +99,7 @@ class CacheTestCase(unittest.HomeserverTestCase): class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): - class A(object): + class A: @cached() def func(self, key): return key @@ -113,7 +113,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_hit(self): callcount = [0] - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -131,7 +131,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_invalidate(self): callcount = [0] - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -149,7 +149,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEquals(callcount[0], 2) def test_invalidate_missing(self): - class A(object): + class A: @cached() def func(self, key): return key @@ -160,7 +160,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_max_entries(self): callcount = [0] - class A(object): + class A: @cached(max_entries=10) def func(self, key): callcount[0] += 1 @@ -187,7 +187,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): d = defer.succeed(123) - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -205,7 +205,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): callcount = [0] callcount2 = [0] - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -238,7 +238,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): callcount = [0] callcount2 = [0] - class A(object): + class A: @cached(max_entries=2) def func(self, key): callcount[0] += 1 @@ -275,7 +275,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): callcount = [0] callcount2 = [0] - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 diff --git a/tests/test_state.py b/tests/test_state.py index 56ba0fecf5..2d58467932 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -71,7 +71,7 @@ def create_event( return event -class StateGroupStore(object): +class StateGroupStore: def __init__(self): self._event_to_state_group = {} self._group_to_state = {} @@ -129,7 +129,7 @@ class DictObj(dict): self.__dict__ = self -class Graph(object): +class Graph: def __init__(self, nodes, edges): events = {} clobbered = set(events.keys()) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 4a4483ba12..510b630114 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -294,7 +294,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): test_large_room.skip = "Disabled by default because it's slow" -class _TestStore(object): +class _TestStore: """Implements a few methods of the DataStore, so that we can test filter_events_for_server diff --git a/tests/unittest.py b/tests/unittest.py index 7b80999a74..3cb55a7e96 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -614,7 +614,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): """ def prepare(self, reactor, clock, homeserver): - class Authenticator(object): + class Authenticator: def authenticate_request(self, request, content): return succeed("other.example.com") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 0363735d4f..677e925477 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -88,7 +88,7 @@ class CacheTestCase(unittest.TestCase): class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -122,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase): def test_cache_num_args(self): """Only the first num_args arguments should matter to the cache""" - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -156,7 +156,7 @@ class DescriptorTestCase(unittest.TestCase): """If the wrapped function throws synchronously, things should continue to work """ - class Cls(object): + class Cls: @cached() def fn(self, arg1): raise SynapseError(100, "mai spoon iz too big!!1") @@ -180,7 +180,7 @@ class DescriptorTestCase(unittest.TestCase): complete_lookup = defer.Deferred() - class Cls(object): + class Cls: @descriptors.cached() def fn(self, arg1): @defer.inlineCallbacks @@ -223,7 +223,7 @@ class DescriptorTestCase(unittest.TestCase): """Check that the cache sets and restores logcontexts correctly when the lookup function throws an exception""" - class Cls(object): + class Cls: @descriptors.cached() def fn(self, arg1): @defer.inlineCallbacks @@ -263,7 +263,7 @@ class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache_default_args(self): - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -300,7 +300,7 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.assert_not_called() def test_cache_iterable(self): - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -336,7 +336,7 @@ class DescriptorTestCase(unittest.TestCase): """If the wrapped function throws synchronously, things should continue to work """ - class Cls(object): + class Cls: @descriptors.cached(iterable=True) def fn(self, arg1): raise SynapseError(100, "mai spoon iz too big!!1") @@ -358,7 +358,7 @@ class DescriptorTestCase(unittest.TestCase): class CachedListDescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -408,7 +408,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def test_invalidate(self): """Make sure that invalidation callbacks are called.""" - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index 8d6627ec33..2012263184 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -112,7 +112,7 @@ class FileConsumerTests(unittest.TestCase): self.assertTrue(string_file.closed) -class DummyPullProducer(object): +class DummyPullProducer: def __init__(self): self.consumer = None self.deferred = defer.Deferred() @@ -134,7 +134,7 @@ class DummyPullProducer(object): return d -class BlockingStringWrite(object): +class BlockingStringWrite: def __init__(self): self.buffer = "" self.closed = False diff --git a/tests/utils.py b/tests/utils.py index a61cbdef44..4673872f88 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -472,7 +472,7 @@ class MockHttpResource(HttpServer): self.callbacks.append((method, path_pattern, callback)) -class MockKey(object): +class MockKey: alg = "mock_alg" version = "mock_version" signature = b"\x9a\x87$" @@ -491,7 +491,7 @@ class MockKey(object): return b"" -class MockClock(object): +class MockClock: now = 1000 def __init__(self): @@ -568,7 +568,7 @@ def _format_call(args, kwargs): ) -class DeferredMockCallable(object): +class DeferredMockCallable: """A callable instance that stores a set of pending call expectations and return values for them. It allows a unit test to assert that the given set of function calls are eventually made, by awaiting on them to be called. -- cgit 1.5.1 From 2ea1c682490a19e0e2df069702c1dbe419389fa4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Sep 2020 12:22:00 -0400 Subject: Remove some unused distributor signals (#8216) Removes the `user_joined_room` and stops calling it since there are no observers. Also cleans-up some other unused signals and related code. --- changelog.d/8216.misc | 1 + synapse/handlers/events.py | 4 --- synapse/handlers/federation.py | 43 +---------------------------- synapse/handlers/room_member.py | 42 +++------------------------- synapse/handlers/room_member_worker.py | 9 ------ synapse/replication/http/membership.py | 10 +++---- synapse/util/distributor.py | 50 ++++++---------------------------- 7 files changed, 18 insertions(+), 141 deletions(-) create mode 100644 changelog.d/8216.misc (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8216.misc b/changelog.d/8216.misc new file mode 100644 index 0000000000..b38911b0e5 --- /dev/null +++ b/changelog.d/8216.misc @@ -0,0 +1 @@ +Simplify the distributor code to avoid unnecessary work. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index b05e32f457..fdce54c5c3 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -39,10 +39,6 @@ class EventStreamHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super(EventStreamHandler, self).__init__(hs) - self.distributor = hs.get_distributor() - self.distributor.declare("started_user_eventstream") - self.distributor.declare("stopped_user_eventstream") - self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 74d7ac8a67..be9b0701a0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -69,7 +69,6 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ReplicationStoreRoomOnInviteRestServlet, ) -from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( @@ -80,7 +79,6 @@ from synapse.types import ( get_domain_from_id, ) from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server @@ -141,9 +139,6 @@ class FederationHandler(BaseHandler): self._replication = hs.get_replication_data_handler() self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) - self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( - hs - ) self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs ) @@ -704,31 +699,10 @@ class FederationHandler(BaseHandler): logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) try: - context = await self._handle_new_event(origin, event, state=state) + await self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - if event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has acutally - # joined the room. Don't bother if the user is just - # changing their profile info. - newly_joined = True - - prev_state_ids = await context.get_prev_state_ids() - - prev_state_id = prev_state_ids.get((event.type, event.state_key)) - if prev_state_id: - prev_state = await self.store.get_event( - prev_state_id, allow_none=True - ) - if prev_state and prev_state.membership == Membership.JOIN: - newly_joined = False - - if newly_joined: - user = UserID.from_string(event.state_key) - await self.user_joined_room(user, room_id) - # For encrypted messages we check that we know about the sending device, # if we don't then we mark the device cache for that user as stale. if event.type == EventTypes.Encrypted: @@ -1550,11 +1524,6 @@ class FederationHandler(BaseHandler): event.signatures, ) - if event.type == EventTypes.Member: - if event.content["membership"] == Membership.JOIN: - user = UserID.from_string(event.state_key) - await self.user_joined_room(user, event.room_id) - prev_state_ids = await context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) @@ -2984,16 +2953,6 @@ class FederationHandler(BaseHandler): else: await self.store.clean_room_for_join(room_id) - async def user_joined_room(self, user: UserID, room_id: str) -> None: - """Called when a new user has joined the room - """ - if self.config.worker_app: - await self._notify_user_membership_change( - room_id=room_id, user_id=user.to_string(), change="joined" - ) - else: - user_joined_room(self.distributor, user, room_id) - async def get_room_complexity( self, remote_room_hosts: List[str], room_id: str ) -> Optional[dict]: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 32b7e323fa..100f335b80 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -40,7 +40,7 @@ from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer -from synapse.util.distributor import user_joined_room, user_left_room +from synapse.util.distributor import user_left_room from ._base import BaseHandler @@ -148,17 +148,6 @@ class RoomMemberHandler: """ raise NotImplementedError() - @abc.abstractmethod - async def _user_joined_room(self, target: UserID, room_id: str) -> None: - """Notifies distributor on master process that the user has joined the - room. - - Args: - target - room_id - """ - raise NotImplementedError() - @abc.abstractmethod async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the @@ -221,7 +210,6 @@ class RoomMemberHandler: prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - newly_joined = False if event.membership == Membership.JOIN: newly_joined = True if prev_member_event_id: @@ -246,12 +234,7 @@ class RoomMemberHandler: requester, event, context, extra_users=[target], ratelimit=ratelimit, ) - if event.membership == Membership.JOIN and newly_joined: - # Only fire user_joined_room if the user has actually joined the - # room. Don't bother if the user is just changing their profile - # info. - await self._user_joined_room(target, room_id) - elif event.membership == Membership.LEAVE: + if event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: @@ -726,17 +709,7 @@ class RoomMemberHandler: (EventTypes.Member, event.state_key), None ) - if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has actually joined the - # room. Don't bother if the user is just changing their profile - # info. - newly_joined = True - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - newly_joined = prev_member_event.membership != Membership.JOIN - if newly_joined: - await self._user_joined_room(target_user, room_id) - elif event.membership == Membership.LEAVE: + if event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: @@ -1002,10 +975,9 @@ class RoomMemberHandler: class RoomMemberMasterHandler(RoomMemberHandler): def __init__(self, hs): - super(RoomMemberMasterHandler, self).__init__(hs) + super().__init__(hs) self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") async def _is_remote_room_too_complex( @@ -1085,7 +1057,6 @@ class RoomMemberMasterHandler(RoomMemberHandler): event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content ) - await self._user_joined_room(user, room_id) # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. @@ -1228,11 +1199,6 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) return event.event_id, stream_id - async def _user_joined_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_joined_room - """ - user_joined_room(self.distributor, target, room_id) - async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 897338fd54..e7f34737c6 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -57,8 +57,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler): content=content, ) - await self._user_joined_room(user, room_id) - return ret["event_id"], ret["stream_id"] async def remote_reject_invite( @@ -81,13 +79,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler): ) return ret["event_id"], ret["stream_id"] - async def _user_joined_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_joined_room - """ - await self._notify_change_client( - user_id=target.to_string(), room_id=room_id, change="joined" - ) - async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 741329ab5f..08095fdf7d 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID -from synapse.util.distributor import user_joined_room, user_left_room +from synapse.util.distributor import user_left_room if TYPE_CHECKING: from synapse.server import HomeServer @@ -181,9 +181,9 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): Args: room_id (str) user_id (str) - change (str): Either "joined" or "left" + change (str): "left" """ - assert change in ("joined", "left") + assert change == "left" return {} @@ -192,9 +192,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): user = UserID.from_string(user_id) - if change == "joined": - user_joined_room(self.distributor, user, room_id) - elif change == "left": + if change == "left": user_left_room(self.distributor, user, room_id) else: raise Exception("Unrecognized change: %r", change) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index a750261e77..f73e95393c 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -16,8 +16,6 @@ import inspect import logging from twisted.internet import defer -from twisted.internet.defer import Deferred, fail, succeed -from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -29,11 +27,6 @@ def user_left_room(distributor, user, room_id): distributor.fire("user_left_room", user=user, room_id=room_id) -# XXX: this is no longer used. We should probably kill it. -def user_joined_room(distributor, user, room_id): - distributor.fire("user_joined_room", user=user, room_id=room_id) - - class Distributor: """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. @@ -81,28 +74,6 @@ class Distributor: run_as_background_process(name, self.signals[name].fire, *args, **kwargs) -def maybeAwaitableDeferred(f, *args, **kw): - """ - Invoke a function that may or may not return a Deferred or an Awaitable. - - This is a modified version of twisted.internet.defer.maybeDeferred. - """ - try: - result = f(*args, **kw) - except Exception: - return fail(failure.Failure(captureVars=Deferred.debug)) - - if isinstance(result, Deferred): - return result - # Handle the additional case of an awaitable being returned. - elif inspect.isawaitable(result): - return defer.ensureDeferred(result) - elif isinstance(result, failure.Failure): - return fail(result) - else: - return succeed(result) - - class Signal: """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -132,22 +103,17 @@ class Signal: Returns a Deferred that will complete when all the observers have completed.""" - def do(observer): - def eb(failure): + async def do(observer): + try: + result = observer(*args, **kwargs) + if inspect.isawaitable(result): + result = await result + return result + except Exception as e: logger.warning( - "%s signal observer %s failed: %r", - self.name, - observer, - failure, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject(), - ), + "%s signal observer %s failed: %r", self.name, observer, e, ) - return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [run_in_background(do, o) for o in self.observers] return make_deferred_yieldable( -- cgit 1.5.1 From 04cc249b43e8716513f788b2a4eeb8ede24d19df Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 14 Sep 2020 10:16:41 +0100 Subject: Add experimental support for sharding event persister. Again. (#8294) This is *not* ready for production yet. Caveats: 1. We should write some tests... 2. The stream token that we use for events can get stalled at the minimum position of all writers. This means that new events may not be processed and e.g. sent down sync streams if a writer isn't writing or is slow. --- changelog.d/8294.feature | 1 + synapse/config/_base.py | 21 ++++++- synapse/config/_base.pyi | 1 + synapse/config/workers.py | 37 ++++++++---- synapse/handlers/federation.py | 44 ++++++++++----- synapse/handlers/message.py | 14 +++-- synapse/handlers/room.py | 14 +++-- synapse/handlers/room_member.py | 7 --- synapse/replication/http/federation.py | 12 +++- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/events.py | 4 +- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events.py | 12 ++-- synapse/storage/databases/main/events_worker.py | 66 +++++++++++++++------- .../schema/delta/58/14events_instance_name.sql | 16 ++++++ .../delta/58/14events_instance_name.sql.postgres | 26 +++++++++ synapse/storage/util/id_generators.py | 10 ++-- 18 files changed, 211 insertions(+), 80 deletions(-) create mode 100644 changelog.d/8294.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8294.feature b/changelog.d/8294.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8294.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index ad5ab6ad62..bb9bf8598d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - - # If multiple instances are not defined we always return true. + # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True + return self.get_instance(key) == instance_name + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key. + + Note: For things like federation sending the config for which instance + is sending is known only to the sender instance if there is only one. + Therefore `should_handle` should be used where possible. + """ + + if not self.instances: + return "master" + + if len(self.instances) == 1: + return self.instances[0] + # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name + return self.instances[remainder] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index eb911e8f9f..b8faafa9bd 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c784a71508..f23e42cdf9 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union + import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def +def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: + """Helper for allowing parsing a string or list of strings to a config + option expecting a list of strings. + """ + + if isinstance(obj, str): + return [obj] + return obj + + @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -33,11 +45,13 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instance that writes to the event and backfill streams. - events: The instance that writes to the typing stream. + events: The instances that write to the event and backfill streams. + typing: The instance that writes to the typing stream. """ - events = attr.ib(default="master", type=str) + events = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter + ) typing = attr.ib(default="master", type=str) @@ -105,15 +119,18 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events and typing also appears in + # Check that the configured writers for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instance = getattr(self.writers, stream) - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) + instances = _instance_to_list_converter(getattr(self.writers, stream)) + for instance in instances: + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c195eba830..a5734bebab 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -896,7 +896,8 @@ class FederationHandler(BaseHandler): ) ) - await self._handle_new_events(dest, ev_infos, backfilled=True) + if ev_infos: + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1189,7 +1190,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, event_infos, + destination, room_id, event_infos, ) def _sanity_check_event(self, ev): @@ -1336,15 +1337,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, auth_chain, state, event, room_version_obj + origin, room_id, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.writers.events, "events", max_stream_id + self.config.worker.events_shard_config.get_instance(room_id), + "events", + max_stream_id, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1593,7 +1594,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + await self.persist_events_and_notify(event.room_id, [(event, context)]) return event @@ -1620,7 +1621,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event, stream_id @@ -1868,7 +1871,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - [(event, context)], backfilled=backfilled + event.room_id, [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1881,6 +1884,7 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, + room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1912,6 +1916,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( + room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1922,6 +1927,7 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, + room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1936,6 +1942,7 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from + room_id, auth_events state event @@ -2010,17 +2017,20 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( + room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ] + ], ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) async def _prep_event( self, @@ -2871,6 +2881,7 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, + room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2878,14 +2889,19 @@ class FederationHandler(BaseHandler): necessary. Args: - event_and_contexts: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker.writers.events != self._instance_name: + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: result = await self._send_events( - instance_name=self.config.worker.writers.events, + instance_name=instance, store=self.store, + room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e54e2b322b..a8fe5cf4e2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,9 +376,8 @@ class EventCreationHandler: self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._is_event_writer = ( - self.config.worker.writers.events == hs.get_instance_name() - ) + self._events_shard_config = self.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -902,9 +901,10 @@ class EventCreationHandler: try: # If we're a worker we need to hit out to the master. - if not self._is_event_writer: + writer_instance = self._events_shard_config.get_instance(event.room_id) + if writer_instance != self._instance_name: result = await self.send_event( - instance_name=self.config.worker.writers.events, + instance_name=writer_instance, event_id=event.event_id, store=self.store, requester=requester, @@ -972,8 +972,10 @@ class EventCreationHandler: This should only be run on the instance in charge of persisting events. """ - assert self._is_event_writer assert self.storage.persistence is not None + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 53d85ab97d..eeade6ad3f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", last_stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + last_stream_id, ) return result, last_stream_id @@ -1259,10 +1261,10 @@ class RoomShutdownHandler: # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(new_room_id), + "events", + stream_id, ) else: new_room_id = None @@ -1292,7 +1294,9 @@ class RoomShutdownHandler: # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + stream_id, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 100f335b80..01a6e88262 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,13 +82,6 @@ class RoomMemberHandler: self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - self._event_stream_writer_instance = hs.config.worker.writers.events - self._is_on_event_persistence_instance = ( - self._event_stream_writer_instance == hs.get_instance_name() - ) - if self._is_on_event_persistence_instance: - self.persist_event_storage = hs.get_storage().persistence - self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 6b56315148..5c8be747e1 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, room_id, event_and_contexts, backfilled): """ Args: store + room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = {"events": event_payloads, "backfilled": backfilled} + payload = { + "events": event_payloads, + "backfilled": backfilled, + "room_id": room_id, + } return payload @@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) + room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled + room_id, event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1c303f3a46..b323841f73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f929fc3954..ccc7ca30d8 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(self._store.get_current_events_token), + self._store._stream_id_gen.get_current_token_for_writer, self._update_function, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 985b12df91..aa5d490624 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -75,7 +75,7 @@ class Databases: # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0b69aa6a94..4c3c162acf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") + raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9cd1403b38..9a80f419e3 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -32,7 +32,7 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter @@ -97,18 +97,21 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id # Ideally we'd move these ID gens here, unfortunately some other ID # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator - self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator + self._backfill_id_gen = ( + self.store._backfill_id_gen + ) # type: MultiWriterIdGenerator + self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator # This should only exist on instances that are configured to write assert ( - hs.config.worker.writers.events == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -809,6 +812,7 @@ class PersistEventsStore: table="events", values=[ { + "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7a73cc3d8..17f5997b89 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + if isinstance(database.engine, PostgresEngine): + # If we're using Postgres than we can use `MultiWriterIdGenerator` + # regardless of whether this process writes to the streams or not. + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_stream_seq", ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_backfill_stream_seq", + positive=False, ) else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql new file mode 100644 index 0000000000..98ff76d709 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres new file mode 100644 index 0000000000..97c1e6a0c5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS events_stream_seq; + +SELECT setval('events_stream_seq', ( + SELECT COALESCE(MAX(stream_ordering), 1) FROM events +)); + +CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; + +SELECT setval('events_backfill_stream_seq', ( + SELECT COALESCE(-MIN(stream_ordering), 1) FROM events +)); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2a66b3ad4e..1de2b91587 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -240,8 +240,12 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] @@ -398,9 +402,7 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. -- cgit 1.5.1 From a3f124b821f0faf53af9e6c890870ec8cbb47ce5 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 16 Sep 2020 21:15:55 +0200 Subject: Switch metaclass initialization to python 3-compatible syntax (#8326) --- changelog.d/8326.misc | 1 + synapse/handlers/room_member.py | 4 +--- synapse/replication/http/_base.py | 4 +--- synapse/storage/databases/main/account_data.py | 8 +++----- synapse/storage/databases/main/push_rule.py | 7 +++---- synapse/storage/databases/main/receipts.py | 8 +++----- synapse/storage/databases/main/stream.py | 4 +--- synapse/types.py | 6 +++--- 8 files changed, 16 insertions(+), 26 deletions(-) create mode 100644 changelog.d/8326.misc (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/8326.misc b/changelog.d/8326.misc new file mode 100644 index 0000000000..985d2c027a --- /dev/null +++ b/changelog.d/8326.misc @@ -0,0 +1 @@ +Update outdated usages of `metaclass` to python 3 syntax. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 01a6e88262..8feba8c90a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -51,14 +51,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RoomMemberHandler: +class RoomMemberHandler(metaclass=abc.ABCMeta): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. - __metaclass__ = abc.ABCMeta - def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index ba16f22c91..b448da6710 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -33,7 +33,7 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) -class ReplicationEndpoint: +class ReplicationEndpoint(metaclass=abc.ABCMeta): """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` @@ -72,8 +72,6 @@ class ReplicationEndpoint: is received. """ - __metaclass__ = abc.ABCMeta - NAME = abc.abstractproperty() # type: str # type: ignore PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore METHOD = "POST" diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 4436b1a83d..5f1a2b9aa6 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -29,15 +29,13 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -class AccountDataWorkerStore(SQLBaseStore): +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. +class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 9790a31998..b7a8d34ce1 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -61,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False): return rules +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, @@ -68,15 +70,12 @@ class PushRulesWorkerStore( RoomMemberWorkerStore, EventsWorkerStore, SQLBaseStore, + metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 4a0d5a320e..6568bddd81 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -31,15 +31,13 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -class ReceiptsWorkerStore(SQLBaseStore): +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. +class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_max_receipt_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 2e95518752..7dbe11513b 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -259,14 +259,12 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: return " AND ".join(clauses), args -class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super(StreamWorkerStore, self).__init__(database, db_conn, hs) diff --git a/synapse/types.py b/synapse/types.py index dc09448bdc..a6fc7df22c 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -165,7 +165,9 @@ def get_localpart_from_id(string): DS = TypeVar("DS", bound="DomainSpecificString") -class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): +class DomainSpecificString( + namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta +): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -175,8 +177,6 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 'domain' : The domain part of the name """ - __metaclass__ = abc.ABCMeta - SIGIL = abc.abstractproperty() # type: str # type: ignore # Deny iteration because it will bite you if you try to create a singleton -- cgit 1.5.1