From b9391c957572224c3a7c22870102fcbd24dea4e0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Feb 2020 18:05:44 +0000 Subject: Add typing to SyncHandler (#6821) Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- tests/storage/test_redaction.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index feb1c07cb2..b9ee6ec1ec 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -238,8 +238,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): built_event = yield self._base_builder.build(prev_event_ids) - built_event.event_id = self._event_id + + built_event._event_id = self._event_id built_event._event_dict["event_id"] = self._event_id + assert built_event.event_id == self._event_id + return built_event @property -- cgit 1.5.1 From 928edef9793bf10fa6156a42c4babbfaaaa17f88 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 31 Jan 2020 16:50:13 +0000 Subject: Pass room_version into `event_from_pdu_json` It's called from all over the shop, so this one's a bit messy. --- changelog.d/6856.misc | 1 + synapse/federation/federation_base.py | 28 ++++++++++++---------- synapse/federation/federation_client.py | 35 +++++++++++++--------------- synapse/federation/federation_server.py | 41 +++++++++++---------------------- tests/handlers/test_federation.py | 6 +++-- 5 files changed, 51 insertions(+), 60 deletions(-) create mode 100644 changelog.d/6856.misc (limited to 'tests') diff --git a/changelog.d/6856.misc b/changelog.d/6856.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6856.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 0e22183280..ebe8b8e9fe 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +23,13 @@ from twisted.internet.defer import DeferredList from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersion, +) from synapse.crypto.event_signing import check_event_content_hash -from synapse.events import event_type_from_format_version +from synapse.events import EventBase, event_type_from_format_version from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( @@ -33,7 +38,7 @@ from synapse.logging.context import ( make_deferred_yieldable, preserve_fn, ) -from synapse.types import get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError logger = logging.getLogger(__name__) @@ -342,16 +347,15 @@ def _is_invite_via_3pid(event): ) -def event_from_pdu_json(pdu_json, event_format_version, outlier=False): - """Construct a FrozenEvent from an event json received over federation +def event_from_pdu_json( + pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False +) -> EventBase: + """Construct an EventBase from an event json received over federation Args: - pdu_json (object): pdu as received over federation - event_format_version (int): The event format version - outlier (bool): True to mark this event as an outlier - - Returns: - FrozenEvent + pdu_json: pdu as received over federation + room_version: The version of the room this event belongs to + outlier: True to mark this event as an outlier Raises: SynapseError: if the pdu is missing required fields or is otherwise @@ -370,7 +374,7 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False): elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(event_format_version)(pdu_json) + event = event_type_from_format_version(room_version.event_format)(pdu_json) event.internal_metadata.outlier = outlier diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 5fb4bd414c..4870e39652 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -49,7 +49,7 @@ from synapse.api.room_versions import ( RoomVersion, RoomVersions, ) -from synapse.events import EventBase, builder, room_version_to_event_format +from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function @@ -209,18 +209,18 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%r", transaction_data) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) pdus = [ - event_from_pdu_json(p, format_ver, outlier=False) + event_from_pdu_json(p, room_version, outlier=False) for p in transaction_data["pdus"] ] # FIXME: We should handle signature failures more gracefully. pdus[:] = await make_deferred_yieldable( defer.gatherResults( - self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True + self._check_sigs_and_hashes(room_version.identifier, pdus), + consumeErrors=True, ).addErrback(unwrapFirstError) ) @@ -262,8 +262,6 @@ class FederationClient(FederationBase): pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) - format_ver = room_version.event_format - signed_pdu = None for destination in destinations: now = self._clock.time_msec() @@ -284,7 +282,7 @@ class FederationClient(FederationBase): ) pdu_list = [ - event_from_pdu_json(p, format_ver, outlier=outlier) + event_from_pdu_json(p, room_version, outlier=outlier) for p in transaction_data["pdus"] ] @@ -350,15 +348,15 @@ class FederationClient(FederationBase): async def get_event_auth(self, destination, room_id, event_id): res = await self.transport_layer.get_event_auth(destination, room_id, event_id) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] + event_from_pdu_json(p, room_version, outlier=True) + for p in res["auth_chain"] ] signed_auth = await self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version + destination, auth_chain, outlier=True, room_version=room_version.identifier ) signed_auth.sort(key=lambda e: e.depth) @@ -547,12 +545,12 @@ class FederationClient(FederationBase): logger.debug("Got content: %s", content) state = [ - event_from_pdu_json(p, room_version.event_format, outlier=True) + event_from_pdu_json(p, room_version, outlier=True) for p in content.get("state", []) ] auth_chain = [ - event_from_pdu_json(p, room_version.event_format, outlier=True) + event_from_pdu_json(p, room_version, outlier=True) for p in content.get("auth_chain", []) ] @@ -677,7 +675,7 @@ class FederationClient(FederationBase): logger.debug("Got response to send_invite: %s", pdu_dict) - pdu = event_from_pdu_json(pdu_dict, room_version.event_format) + pdu = event_from_pdu_json(pdu_dict, room_version) # Check signatures are correct. pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) @@ -865,15 +863,14 @@ class FederationClient(FederationBase): timeout=timeout, ) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) events = [ - event_from_pdu_json(e, format_ver) for e in content.get("events", []) + event_from_pdu_json(e, room_version) for e in content.get("events", []) ] signed_events = await self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version + destination, events, outlier=False, room_version=room_version.identifier ) except HttpResponseException as e: if not e.code == 400: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8e3933b6c5..2489832a11 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -38,7 +38,6 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction @@ -234,24 +233,17 @@ class FederationServer(FederationBase): continue try: - room_version = await self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue - - try: - format_ver = room_version_to_event_format(room_version) - except UnsupportedRoomVersionError: + except UnsupportedRoomVersionError as e: # this can happen if support for a given room version is withdrawn, # so that we still get events for said room. - logger.info( - "Ignoring PDU for room %s with unknown version %s", - room_id, - room_version, - ) + logger.info("Ignoring PDU: %s", e) continue - event = event_from_pdu_json(p, format_ver) + event = event_from_pdu_json(p, room_version) pdus_by_room.setdefault(room_id, []).append(event) pdu_results = {} @@ -407,9 +399,7 @@ class FederationServer(FederationBase): Codes.UNSUPPORTED_ROOM_VERSION, ) - format_ver = room_version.event_format - - pdu = event_from_pdu_json(content, format_ver) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) @@ -420,16 +410,15 @@ class FederationServer(FederationBase): async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) - pdu = event_from_pdu_json(content, format_ver) + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = await self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -451,16 +440,15 @@ class FederationServer(FederationBase): async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) - pdu = event_from_pdu_json(content, format_ver) + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = await self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) await self.handler.on_send_leave_request(origin, pdu) return {} @@ -498,15 +486,14 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) auth_chain = [ - event_from_pdu_json(e, format_ver) for e in content["auth_chain"] + event_from_pdu_json(e, room_version) for e in content["auth_chain"] ] signed_auth = await self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True, room_version=room_version + origin, auth_chain, outlier=True, room_version=room_version.identifier ) ret = await self.handler.on_query_auth( diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index b4d92cf732..132e35651d 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -99,6 +99,7 @@ class FederationTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) # pretend that another server has joined join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) @@ -120,7 +121,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "auth_events": [], "origin_server_ts": self.clock.time_msec(), }, - join_event.format_version, + room_version, ) with LoggingContext(request="send_rejected"): @@ -149,6 +150,7 @@ class FederationTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) # pretend that another server has joined join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) @@ -171,7 +173,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "auth_events": [], "origin_server_ts": self.clock.time_msec(), }, - join_event.format_version, + room_version, ) with LoggingContext(request="send_rejected"): -- cgit 1.5.1 From 56ca93ef5941b5dfcda368f373a6bcd80d177acd Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 7 Feb 2020 11:29:36 +0100 Subject: Admin api to add an email address (#6789) --- changelog.d/6769.feature | 1 + docs/admin_api/user_admin_api.rst | 11 +++++++++++ synapse/handlers/admin.py | 2 ++ synapse/handlers/auth.py | 8 ++++++++ synapse/rest/admin/users.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/rest/admin/test_user.py | 19 +++++++++++++++++-- 6 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6769.feature (limited to 'tests') diff --git a/changelog.d/6769.feature b/changelog.d/6769.feature new file mode 100644 index 0000000000..8a60e12907 --- /dev/null +++ b/changelog.d/6769.feature @@ -0,0 +1 @@ +Admin API to add or modify threepids of user accounts. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 0b3d09d694..eb146095de 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -15,6 +15,16 @@ with a body of: { "password": "user_password", "displayname": "User", + "threepids": [ + { + "medium": "email", + "address": "" + }, + { + "medium": "email", + "address": "" + } + ], "avatar_url": "", "admin": false, "deactivated": false @@ -23,6 +33,7 @@ with a body of: including an ``access_token`` of a server admin. The parameter ``displayname`` is optional and defaults to ``user_id``. +The parameter ``threepids`` is optional. The parameter ``avatar_url`` is optional. The parameter ``admin`` is optional and defaults to 'false'. The parameter ``deactivated`` is optional and defaults to 'false'. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 9205865231..f3c0aeceb6 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -58,8 +58,10 @@ class AdminHandler(BaseHandler): ret = await self.store.get_user_by_id(user.to_string()) if ret: profile = await self.store.get_profileinfo(user.localpart) + threepids = await self.store.user_get_threepids(user.to_string()) ret["displayname"] = profile.display_name ret["avatar_url"] = profile.avatar_url + ret["threepids"] = threepids return ret async def export_user_data(self, user_id, writer): diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 54a71c49d2..48a88d3c2a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -816,6 +816,14 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): + # check if medium has a valid value + if medium not in ["email", "msisdn"]: + raise SynapseError( + code=400, + msg=("'%s' is not a valid value for 'medium'" % (medium,)), + errcode=Codes.INVALID_PARAM, + ) + # 'Canonicalise' email addresses down to lower case. # We've now moving towards the homeserver being the entity that # is responsible for validating threepids used for resetting passwords diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f1c4434f5c..e75c5f1370 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -136,6 +136,8 @@ class UserRestServletV2(RestServlet): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_handlers().admin_handler + self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.set_password_handler = hs.get_set_password_handler() self.deactivate_account_handler = hs.get_deactivate_account_handler() @@ -163,6 +165,7 @@ class UserRestServletV2(RestServlet): raise SynapseError(400, "This endpoint can only be used with local users") user = await self.admin_handler.get_user(target_user) + user_id = target_user.to_string() if user: # modify user if "displayname" in body: @@ -170,6 +173,29 @@ class UserRestServletV2(RestServlet): target_user, requester, body["displayname"], True ) + if "threepids" in body: + # check for required parameters for each threepid + for threepid in body["threepids"]: + assert_params_in_dict(threepid, ["medium", "address"]) + + # remove old threepids from user + threepids = await self.store.user_get_threepids(user_id) + for threepid in threepids: + try: + await self.auth_handler.delete_threepid( + user_id, threepid["medium"], threepid["address"], None + ) + except Exception: + logger.exception("Failed to remove threepids") + raise SynapseError(500, "Failed to remove threepids") + + # add new threepids to user + current_time = self.hs.get_clock().time_msec() + for threepid in body["threepids"]: + await self.auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], current_time + ) + if "avatar_url" in body: await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True @@ -221,6 +247,7 @@ class UserRestServletV2(RestServlet): admin = body.get("admin", None) user_type = body.get("user_type", None) displayname = body.get("displayname", None) + threepids = body.get("threepids", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") @@ -232,6 +259,18 @@ class UserRestServletV2(RestServlet): default_display_name=displayname, user_type=user_type, ) + + if "threepids" in body: + # check for required parameters for each threepid + for threepid in body["threepids"]: + assert_params_in_dict(threepid, ["medium", "address"]) + + current_time = self.hs.get_clock().time_msec() + for threepid in body["threepids"]: + await self.auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], current_time + ) + if "avatar_url" in body: await self.profile_handler.set_avatar_url( user_id, requester, body["avatar_url"], True diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 8f09f51c61..3b5169b38d 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -407,7 +407,13 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ self.hs.config.registration_shared_secret = None - body = json.dumps({"password": "abc123", "admin": True}) + body = json.dumps( + { + "password": "abc123", + "admin": True, + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) # Create user request, channel = self.make_request( @@ -421,6 +427,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) # Get user request, channel = self.make_request( @@ -449,7 +457,13 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Modify user - body = json.dumps({"displayname": "foobar", "deactivated": True}) + body = json.dumps( + { + "displayname": "foobar", + "deactivated": True, + "threepids": [{"medium": "email", "address": "bob2@bob.bob"}], + } + ) request, channel = self.make_request( "PUT", @@ -463,6 +477,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) self.assertEqual(True, channel.json_body["deactivated"]) + # the user is deactivated, the threepid will be deleted # Get user request, channel = self.make_request( -- cgit 1.5.1 From b08b0a22d505b1555f511e3f38935a62930ea25d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Feb 2020 13:56:38 +0000 Subject: Add typing to synapse.federation.sender (#6871) --- changelog.d/6871.misc | 1 + synapse/federation/federation_server.py | 7 +- synapse/federation/sender/__init__.py | 99 +++++++++++----------- synapse/federation/sender/per_destination_queue.py | 88 +++++++++---------- synapse/federation/sender/transaction_manager.py | 16 ++-- synapse/federation/units.py | 23 ++++- synapse/server.pyi | 2 + tests/handlers/test_typing.py | 8 +- tox.ini | 1 + 9 files changed, 138 insertions(+), 107 deletions(-) create mode 100644 changelog.d/6871.misc (limited to 'tests') diff --git a/changelog.d/6871.misc b/changelog.d/6871.misc new file mode 100644 index 0000000000..5161af9983 --- /dev/null +++ b/changelog.d/6871.misc @@ -0,0 +1 @@ +Add typing to `synapse.federation.sender` and port to async/await. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2489832a11..a6c966a393 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -294,7 +294,12 @@ class FederationServer(FederationBase): async def _process_edu(edu_dict): received_edus_counter.inc() - edu = Edu(**edu_dict) + edu = Edu( + origin=origin, + destination=self.server_name, + edu_type=edu_dict["edu_type"], + content=edu_dict["content"], + ) await self.registry.on_edu(edu.edu_type, origin, edu.content) await concurrently_execute( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 36c83c3027..233cb33daf 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Hashable, Iterable, List, Optional, Set from six import itervalues @@ -23,6 +24,7 @@ from twisted.internet import defer import synapse import synapse.metrics +from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu @@ -39,6 +41,8 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.presence import UserPresenceState +from synapse.types import ReadReceipt from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) @@ -68,7 +72,7 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] + self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", @@ -84,7 +88,7 @@ class FederationSender(object): # Map of user_id -> UserPresenceState for all the pending presence # to be sent out by user_id. Entries here get processed and put in # pending_presence_by_dest - self.pending_presence = {} + self.pending_presence = {} # type: Dict[str, UserPresenceState] LaterGauge( "synapse_federation_transaction_queue_pending_pdus", @@ -116,20 +120,17 @@ class FederationSender(object): # and that there is a pending call to _flush_rrs_for_room in the system. self._queues_awaiting_rr_flush_by_room = ( {} - ) # type: dict[str, set[PerDestinationQueue]] + ) # type: Dict[str, Set[PerDestinationQueue]] self._rr_txn_interval_per_room_ms = ( - 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second + 1000.0 / hs.config.federation_rr_transactions_per_room_per_second ) - def _get_per_destination_queue(self, destination): + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination Args: - destination (str): server_name of remote server - - Returns: - PerDestinationQueue + destination: server_name of remote server """ queue = self._per_destination_queues.get(destination) if not queue: @@ -137,7 +138,7 @@ class FederationSender(object): self._per_destination_queues[destination] = queue return queue - def notify_new_events(self, current_id): + def notify_new_events(self, current_id: int) -> None: """This gets called when we have some new events we might want to send out to other servers. """ @@ -151,13 +152,12 @@ class FederationSender(object): "process_event_queue_for_federation", self._process_event_queue_loop ) - @defer.inlineCallbacks - def _process_event_queue_loop(self): + async def _process_event_queue_loop(self) -> None: try: self._is_processing = True while True: - last_token = yield self.store.get_federation_out_pos("events") - next_token, events = yield self.store.get_all_new_events_stream( + last_token = await self.store.get_federation_out_pos("events") + next_token, events = await self.store.get_all_new_events_stream( last_token, self._last_poked_id, limit=100 ) @@ -166,8 +166,7 @@ class FederationSender(object): if not events and next_token >= self._last_poked_id: break - @defer.inlineCallbacks - def handle_event(event): + async def handle_event(event: EventBase) -> None: # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.sender) @@ -184,7 +183,7 @@ class FederationSender(object): # Otherwise if the last member on a server in a room is # banned then it won't receive the event because it won't # be in the room after the ban. - destinations = yield self.state.get_hosts_in_room_at_events( + destinations = await self.state.get_hosts_in_room_at_events( event.room_id, event_ids=event.prev_event_ids() ) except Exception: @@ -206,17 +205,16 @@ class FederationSender(object): self._send_pdu(event, destinations) - @defer.inlineCallbacks - def handle_room_events(events): + async def handle_room_events(events: Iterable[EventBase]) -> None: with Measure(self.clock, "handle_room_events"): for event in events: - yield handle_event(event) + await handle_event(event) - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(handle_room_events, evs) @@ -226,11 +224,11 @@ class FederationSender(object): ) ) - yield self.store.update_federation_out_pos("events", next_token) + await self.store.update_federation_out_pos("events", next_token) if events: now = self.clock.time_msec() - ts = yield self.store.get_received_ts(events[-1].event_id) + ts = await self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_lag.labels( "federation_sender" @@ -254,7 +252,7 @@ class FederationSender(object): finally: self._is_processing = False - def _send_pdu(self, pdu, destinations): + def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. @@ -276,11 +274,11 @@ class FederationSender(object): self._get_per_destination_queue(destination).send_pdu(pdu, order) @defer.inlineCallbacks - def send_read_receipt(self, receipt): + def send_read_receipt(self, receipt: ReadReceipt): """Send a RR to any other servers in the room Args: - receipt (synapse.types.ReadReceipt): receipt to be sent + receipt: receipt to be sent """ # Some background on the rate-limiting going on here. @@ -343,7 +341,7 @@ class FederationSender(object): else: queue.flush_read_receipts_for_room(room_id) - def _schedule_rr_flush_for_room(self, room_id, n_domains): + def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None: # that is going to cause approximately len(domains) transactions, so now back # off for that multiplied by RR_TXN_INTERVAL_PER_ROOM backoff_ms = self._rr_txn_interval_per_room_ms * n_domains @@ -352,7 +350,7 @@ class FederationSender(object): self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id) self._queues_awaiting_rr_flush_by_room[room_id] = set() - def _flush_rrs_for_room(self, room_id): + def _flush_rrs_for_room(self, room_id: str) -> None: queues = self._queues_awaiting_rr_flush_by_room.pop(room_id) logger.debug("Flushing RRs in %s to %s", room_id, queues) @@ -368,14 +366,11 @@ class FederationSender(object): @preserve_fn # the caller should not yield on this @defer.inlineCallbacks - def send_presence(self, states): + def send_presence(self, states: List[UserPresenceState]): """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and triggers a background task to process them and send out the transactions. - - Args: - states (list(UserPresenceState)) """ if not self.hs.config.use_presence: # No-op if presence is disabled. @@ -412,11 +407,10 @@ class FederationSender(object): finally: self._processing_pending_presence = False - def send_presence_to_destinations(self, states, destinations): + def send_presence_to_destinations( + self, states: List[UserPresenceState], destinations: List[str] + ) -> None: """Send the given presence states to the given destinations. - - Args: - states (list[UserPresenceState]) destinations (list[str]) """ @@ -431,12 +425,9 @@ class FederationSender(object): @measure_func("txnqueue._process_presence") @defer.inlineCallbacks - def _process_presence_inner(self, states): + def _process_presence_inner(self, states: List[UserPresenceState]): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination - - Args: - states (list(UserPresenceState)) """ hosts_and_states = yield get_interested_remotes(self.store, states, self.state) @@ -446,14 +437,20 @@ class FederationSender(object): continue self._get_per_destination_queue(destination).send_presence(states) - def build_and_send_edu(self, destination, edu_type, content, key=None): + def build_and_send_edu( + self, + destination: str, + edu_type: str, + content: dict, + key: Optional[Hashable] = None, + ): """Construct an Edu object, and queue it for sending Args: - destination (str): name of server to send to - edu_type (str): type of EDU to send - content (dict): content of EDU - key (Any|None): clobbering key for this edu + destination: name of server to send to + edu_type: type of EDU to send + content: content of EDU + key: clobbering key for this edu """ if destination == self.server_name: logger.info("Not sending EDU to ourselves") @@ -468,12 +465,12 @@ class FederationSender(object): self.send_edu(edu, key) - def send_edu(self, edu, key): + def send_edu(self, edu: Edu, key: Optional[Hashable]): """Queue an EDU for sending Args: - edu (Edu): edu to send - key (Any|None): clobbering key for this edu + edu: edu to send + key: clobbering key for this edu """ queue = self._get_per_destination_queue(edu.destination) if key: @@ -481,7 +478,7 @@ class FederationSender(object): else: queue.send_edu(edu) - def send_device_messages(self, destination): + def send_device_messages(self, destination: str): if destination == self.server_name: logger.warning("Not sending device update to ourselves") return @@ -501,5 +498,5 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() - def get_current_token(self): + def get_current_token(self) -> int: return 0 diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 5012aaea35..e13cd20ffa 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,11 +15,11 @@ # limitations under the License. import datetime import logging +from typing import Dict, Hashable, Iterable, List, Tuple from prometheus_client import Counter -from twisted.internet import defer - +import synapse.server from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -31,7 +31,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState -from synapse.types import StateMap +from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. @@ -56,13 +56,18 @@ class PerDestinationQueue(object): Manages the per-destination transmission queues. Args: - hs (synapse.HomeServer): - transaction_sender (TransactionManager): - destination (str): the server_name of the destination that we are managing + hs + transaction_sender + destination: the server_name of the destination that we are managing transmission for. """ - def __init__(self, hs, transaction_manager, destination): + def __init__( + self, + hs: "synapse.server.HomeServer", + transaction_manager: "synapse.federation.sender.TransactionManager", + destination: str, + ): self._server_name = hs.hostname self._clock = hs.get_clock() self._store = hs.get_datastore() @@ -72,20 +77,20 @@ class PerDestinationQueue(object): self.transmission_loop_running = False # a list of tuples of (pending pdu, order) - self._pending_pdus = [] # type: list[tuple[EventBase, int]] - self._pending_edus = [] # type: list[Edu] + self._pending_pdus = [] # type: List[Tuple[EventBase, int]] + self._pending_edus = [] # type: List[Edu] # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: StateMap[Edu] + self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence = {} # type: dict[str, UserPresenceState] + self._pending_presence = {} # type: Dict[str, UserPresenceState] # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs = {} + self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]] self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -95,50 +100,50 @@ class PerDestinationQueue(object): # stream_id of last successfully sent device list update. self._last_device_list_stream_id = 0 - def __str__(self): + def __str__(self) -> str: return "PerDestinationQueue[%s]" % self._destination - def pending_pdu_count(self): + def pending_pdu_count(self) -> int: return len(self._pending_pdus) - def pending_edu_count(self): + def pending_edu_count(self) -> int: return ( len(self._pending_edus) + len(self._pending_presence) + len(self._pending_edus_keyed) ) - def send_pdu(self, pdu, order): + def send_pdu(self, pdu: EventBase, order: int) -> None: """Add a PDU to the queue, and start the transmission loop if neccessary Args: - pdu (EventBase): pdu to send - order (int): + pdu: pdu to send + order """ self._pending_pdus.append((pdu, order)) self.attempt_new_transaction() - def send_presence(self, states): + def send_presence(self, states: Iterable[UserPresenceState]) -> None: """Add presence updates to the queue. Start the transmission loop if neccessary. Args: - states (iterable[UserPresenceState]): presence to send + states: presence to send """ self._pending_presence.update({state.user_id: state for state in states}) self.attempt_new_transaction() - def queue_read_receipt(self, receipt): + def queue_read_receipt(self, receipt: ReadReceipt) -> None: """Add a RR to the list to be sent. Doesn't start the transmission loop yet (see flush_read_receipts_for_room) Args: - receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued + receipt: receipt to be queued """ self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( receipt.receipt_type, {} )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} - def flush_read_receipts_for_room(self, room_id): + def flush_read_receipts_for_room(self, room_id: str) -> None: # if we don't have any read-receipts for this room, it may be that we've already # sent them out, so we don't need to flush. if room_id not in self._pending_rrs: @@ -146,15 +151,15 @@ class PerDestinationQueue(object): self._rrs_pending_flush = True self.attempt_new_transaction() - def send_keyed_edu(self, edu, key): + def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu self.attempt_new_transaction() - def send_edu(self, edu): + def send_edu(self, edu) -> None: self._pending_edus.append(edu) self.attempt_new_transaction() - def attempt_new_transaction(self): + def attempt_new_transaction(self) -> None: """Try to start a new transaction to this destination If there is already a transaction in progress to this destination, @@ -177,23 +182,22 @@ class PerDestinationQueue(object): self._transaction_transmission_loop, ) - @defer.inlineCallbacks - def _transaction_transmission_loop(self): - pending_pdus = [] + async def _transaction_transmission_loop(self) -> None: + pending_pdus = [] # type: List[Tuple[EventBase, int]] try: self.transmission_loop_running = True # This will throw if we wouldn't retry. We do this here so we fail # quickly, but we will later check this again in the http client, # hence why we throw the result away. - yield get_retry_limiter(self._destination, self._clock, self._store) + await get_retry_limiter(self._destination, self._clock, self._store) pending_pdus = [] while True: # We have to keep 2 free slots for presence and rr_edus limit = MAX_EDUS_PER_TRANSACTION - 2 - device_update_edus, dev_list_id = yield self._get_device_update_edus( + device_update_edus, dev_list_id = await self._get_device_update_edus( limit ) @@ -202,7 +206,7 @@ class PerDestinationQueue(object): ( to_device_edus, device_stream_id, - ) = yield self._get_to_device_message_edus(limit) + ) = await self._get_to_device_message_edus(limit) pending_edus = device_update_edus + to_device_edus @@ -269,7 +273,7 @@ class PerDestinationQueue(object): # END CRITICAL SECTION - success = yield self._transaction_manager.send_new_transaction( + success = await self._transaction_manager.send_new_transaction( self._destination, pending_pdus, pending_edus ) if success: @@ -280,7 +284,7 @@ class PerDestinationQueue(object): # Remove the acknowledged device messages from the database # Only bother if we actually sent some device messages if to_device_edus: - yield self._store.delete_device_msgs_for_remote( + await self._store.delete_device_msgs_for_remote( self._destination, device_stream_id ) @@ -289,7 +293,7 @@ class PerDestinationQueue(object): logger.info( "Marking as sent %r %r", self._destination, dev_list_id ) - yield self._store.mark_as_sent_devices_by_remote( + await self._store.mark_as_sent_devices_by_remote( self._destination, dev_list_id ) @@ -334,7 +338,7 @@ class PerDestinationQueue(object): # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False - def _get_rr_edus(self, force_flush): + def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: return if not force_flush and not self._rrs_pending_flush: @@ -351,17 +355,16 @@ class PerDestinationQueue(object): self._rrs_pending_flush = False yield edu - def _pop_pending_edus(self, limit): + def _pop_pending_edus(self, limit: int) -> List[Edu]: pending_edus = self._pending_edus pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:] return pending_edus - @defer.inlineCallbacks - def _get_device_update_edus(self, limit): + async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination - now_stream_id, results = yield self._store.get_device_updates_by_remote( + now_stream_id, results = await self._store.get_device_updates_by_remote( self._destination, last_device_list, limit=limit ) edus = [ @@ -378,11 +381,10 @@ class PerDestinationQueue(object): return (edus, now_stream_id) - @defer.inlineCallbacks - def _get_to_device_message_edus(self, limit): + async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() - contents, stream_id = yield self._store.get_new_device_msgs_for_remote( + contents, stream_id = await self._store.get_new_device_msgs_for_remote( self._destination, last_device_stream_id, to_device_stream_id, limit ) edus = [ diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 5fed626d5b..3c2a02a3b3 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List from canonicaljson import json -from twisted.internet import defer - +import synapse.server from synapse.api.errors import HttpResponseException +from synapse.events import EventBase from synapse.federation.persistence import TransactionActions -from synapse.federation.units import Transaction +from synapse.federation.units import Edu, Transaction from synapse.logging.opentracing import ( extract_text_map, set_tag, @@ -39,7 +40,7 @@ class TransactionManager(object): shared between PerDestinationQueue objects """ - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func self._store = hs.get_datastore() @@ -50,8 +51,9 @@ class TransactionManager(object): self._next_txn_id = int(self.clock.time_msec()) @measure_func("_send_new_transaction") - @defer.inlineCallbacks - def send_new_transaction(self, destination, pending_pdus, pending_edus): + async def send_new_transaction( + self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] + ): # Make a transaction-sending opentracing span. This span follows on from # all the edus in that transaction. This needs to be done since there is @@ -127,7 +129,7 @@ class TransactionManager(object): return data try: - response = yield self._transport_layer.send_transaction( + response = await self._transport_layer.send_transaction( transaction, json_data_cb ) code = 200 diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b4d743cde7..6b32e0dcbf 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -19,11 +19,15 @@ server protocol. import logging +import attr + +from synapse.types import JsonDict from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) +@attr.s(slots=True) class Edu(JsonEncodedObject): """ An Edu represents a piece of data sent from one homeserver to another. @@ -32,11 +36,24 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - valid_keys = ["origin", "destination", "edu_type", "content"] + edu_type = attr.ib(type=str) + content = attr.ib(type=dict) + origin = attr.ib(type=str) + destination = attr.ib(type=str) - required_keys = ["edu_type"] + def get_dict(self) -> JsonDict: + return { + "edu_type": self.edu_type, + "content": self.content, + } - internal_keys = ["origin", "destination"] + def get_internal_dict(self) -> JsonDict: + return { + "edu_type": self.edu_type, + "content": self.content, + "origin": self.origin, + "destination": self.destination, + } def get_context(self): return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") diff --git a/synapse/server.pyi b/synapse/server.pyi index 90347ac23e..40eabfe5d9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -107,3 +107,5 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def is_mine_id(self, domain_id: str) -> bool: + pass diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 68b9847bd2..2767b0497a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -111,7 +111,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.return_value = (0, []) + self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + (0, []) + ) def get_received_txn_response(*args): return defer.succeed(None) @@ -144,7 +146,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + ([], 0) + ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( None diff --git a/tox.ini b/tox.ini index ef22368cf1..f8229eba88 100644 --- a/tox.ini +++ b/tox.ini @@ -179,6 +179,7 @@ extras = all commands = mypy \ synapse/api \ synapse/config/ \ + synapse/federation/sender \ synapse/federation/transport \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ -- cgit 1.5.1 From 799001f2c0b31d72b95a252a3808da25987e1ed3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 7 Feb 2020 15:30:04 +0000 Subject: Add a `make_event_from_dict` method (#6858) ... and use it in places where it's trivial to do so. This will make it easier to pass room versions into the FrozenEvent constructors. --- changelog.d/6858.misc | 1 + synapse/events/__init__.py | 16 ++++++++++++++-- synapse/events/builder.py | 10 +++------- synapse/federation/federation_base.py | 5 ++--- tests/api/test_filtering.py | 4 ++-- tests/crypto/test_event_signing.py | 6 +++--- tests/events/test_utils.py | 9 +++++---- tests/federation/test_federation_server.py | 4 ++-- tests/replication/slave/storage/test_events.py | 12 ++++++++---- tests/state/test_v2.py | 4 ++-- tests/test_event_auth.py | 10 +++++----- tests/test_federation.py | 6 +++--- tests/test_state.py | 4 ++-- 13 files changed, 52 insertions(+), 39 deletions(-) create mode 100644 changelog.d/6858.misc (limited to 'tests') diff --git a/changelog.d/6858.misc b/changelog.d/6858.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6858.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 89d41d82b6..a842661a90 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -16,12 +16,13 @@ import os from distutils.util import strtobool +from typing import Optional, Type import six from unpaddedbase64 import encode_base64 -from synapse.api.room_versions import EventFormatVersions +from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.types import JsonDict from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -407,7 +408,7 @@ class FrozenEventV3(FrozenEventV2): return self._event_id -def event_type_from_format_version(format_version): +def event_type_from_format_version(format_version: int) -> Type[EventBase]: """Returns the python type to use to construct an Event object for the given event format version. @@ -427,3 +428,14 @@ def event_type_from_format_version(format_version): return FrozenEventV3 else: raise Exception("No event format %r" % (format_version,)) + + +def make_event_from_dict( + event_dict: JsonDict, + room_version: RoomVersion = RoomVersions.V1, + internal_metadata_dict: JsonDict = {}, + rejected_reason: Optional[str] = None, +) -> EventBase: + """Construct an EventBase from the given event dict""" + event_type = event_type_from_format_version(room_version.event_format) + return event_type(event_dict, internal_metadata_dict, rejected_reason) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 8d63ad6dc3..a0c4a40c27 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -28,11 +28,7 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import ( - EventBase, - _EventInternalMetadata, - event_type_from_format_version, -) +from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string @@ -256,8 +252,8 @@ def create_local_event_from_event_dict( event_dict.setdefault("signatures", {}) add_hashes_and_signatures(room_version, event_dict, hostname, signing_key) - return event_type_from_format_version(format_version)( - event_dict, internal_metadata_dict=internal_metadata_dict + return make_event_from_dict( + event_dict, room_version, internal_metadata_dict=internal_metadata_dict ) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index ebe8b8e9fe..eea64c1c9f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -29,7 +29,7 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.crypto.event_signing import check_event_content_hash -from synapse.events import EventBase, event_type_from_format_version +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( @@ -374,8 +374,7 @@ def event_from_pdu_json( elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(room_version.event_format)(pdu_json) - + event = make_event_from_dict(pdu_json, room_version) event.internal_metadata.outlier = outlier return event diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 63d8633582..4e67503cf0 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from tests import unittest from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver @@ -38,7 +38,7 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" - return FrozenEvent(kwargs) + return make_event_from_dict(kwargs) class FilteringTestCase(unittest.TestCase): diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 6143a50ab2..62f639a18d 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -19,7 +19,7 @@ from unpaddedbase64 import decode_base64 from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from tests import unittest @@ -54,7 +54,7 @@ class EventSigningTestCase(unittest.TestCase): RoomVersions.V1, event_dict, HOSTNAME, self.signing_key ) - event = FrozenEvent(event_dict) + event = make_event_from_dict(event_dict) self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) @@ -88,7 +88,7 @@ class EventSigningTestCase(unittest.TestCase): RoomVersions.V1, event_dict, HOSTNAME, self.signing_key ) - event = FrozenEvent(event_dict) + event = make_event_from_dict(event_dict) self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 2b13980dfd..45d55b9e94 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.events.utils import ( copy_power_levels_contents, prune_event, @@ -30,7 +29,7 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" - return FrozenEvent(kwargs) + return make_event_from_dict(kwargs) class PruneEventTestCase(unittest.TestCase): @@ -38,7 +37,9 @@ class PruneEventTestCase(unittest.TestCase): `matchdict` when it is redacted. """ def run_test(self, evdict, matchdict): - self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict) + self.assertEquals( + prune_event(make_event_from_dict(evdict)).get_dict(), matchdict + ) def test_minimal(self): self.run_test( diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 1ec8c40901..e7d8699040 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -105,7 +105,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): def _create_acl_event(content): - return FrozenEvent( + return make_event_from_dict( { "room_id": "!a:b", "event_id": "$a:b", diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index b1b037006d..d31210fbe4 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -15,7 +15,7 @@ import logging from canonicaljson import encode_canonical_json -from synapse.events import FrozenEvent, _EventInternalMetadata +from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore @@ -90,7 +90,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction - redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + redacted = make_event_from_dict( + msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() + ) self.check("get_event", [msg.event_id], redacted) def test_backfilled_redactions(self): @@ -110,7 +112,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction - redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + redacted = make_event_from_dict( + msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() + ) self.check("get_event", [msg.event_id], redacted) def test_invites(self): @@ -345,7 +349,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if redacts is not None: event_dict["redacts"] = redacts - event = FrozenEvent(event_dict, internal_metadata_dict=internal) + event = make_event_from_dict(event_dict, internal_metadata_dict=internal) self.event_id += 1 diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 0f341d3ac3..5bafad9f19 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -22,7 +22,7 @@ import attr from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store from synapse.types import EventID @@ -89,7 +89,7 @@ class FakeEvent(object): if self.state_key is not None: event_dict["state_key"] = self.state_key - return FrozenEvent(event_dict) + return make_event_from_dict(event_dict) # All graphs start with this set of events diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index ca20b085a2..bfa5d6f510 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -18,7 +18,7 @@ import unittest from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict class EventAuthTestCase(unittest.TestCase): @@ -94,7 +94,7 @@ TEST_ROOM_ID = "!test:room" def _create_event(user_id): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -106,7 +106,7 @@ def _create_event(user_id): def _join_event(user_id): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -119,7 +119,7 @@ def _join_event(user_id): def _power_levels_event(sender, content): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -132,7 +132,7 @@ def _power_levels_event(sender, content): def _random_state_event(sender): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), diff --git a/tests/test_federation.py b/tests/test_federation.py index 68684460c6..9b5cf562f3 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -2,7 +2,7 @@ from mock import Mock from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID from synapse.util import Clock @@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.TestCase): ) )[0] - join_event = FrozenEvent( + join_event = make_event_from_dict( { "room_id": self.room_id, "sender": "@baduser:test.serv", @@ -105,7 +105,7 @@ class MessageAcceptTests(unittest.TestCase): )[0] # Now lie about an event - lying_event = FrozenEvent( + lying_event = make_event_from_dict( { "room_id": self.room_id, "sender": "@baduser:test.serv", diff --git a/tests/test_state.py b/tests/test_state.py index 1e4449fa1c..d1578fe581 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext from synapse.state import StateHandler, StateResolutionHandler @@ -66,7 +66,7 @@ def create_event( d.update(kwargs) - event = FrozenEvent(d) + event = make_event_from_dict(d) return event -- cgit 1.5.1 From a92e703ab9d78aecc062e797f941bb7e206650a5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 10 Feb 2020 16:35:26 -0500 Subject: Reject device display names that are too long (#6882) * Reject device display names that are too long. Too long is currently defined as 100 characters in length. * Add a regression test for rejecting a too long device display name. --- changelog.d/6882.misc | 1 + synapse/handlers/device.py | 14 +++++++++++++- tests/handlers/test_device.py | 18 ++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6882.misc (limited to 'tests') diff --git a/changelog.d/6882.misc b/changelog.d/6882.misc new file mode 100644 index 0000000000..e8382e36ae --- /dev/null +++ b/changelog.d/6882.misc @@ -0,0 +1 @@ +Reject device display names over 100 characters in length. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6d8e48ed39..50cea3f378 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -26,6 +26,7 @@ from synapse.api.errors import ( FederationDeniedError, HttpResponseException, RequestSendFailed, + SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.types import RoomStreamToken, get_domain_from_id @@ -39,6 +40,8 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +MAX_DEVICE_DISPLAY_NAME_LEN = 100 + class DeviceWorkerHandler(BaseHandler): def __init__(self, hs): @@ -404,9 +407,18 @@ class DeviceHandler(DeviceWorkerHandler): defer.Deferred: """ + # Reject a new displayname which is too long. + new_display_name = content.get("display_name") + if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN: + raise SynapseError( + 400, + "Device display name is too long (max %i)" + % (MAX_DEVICE_DISPLAY_NAME_LEN,), + ) + try: yield self.store.update_device( - user_id, device_id, new_display_name=content.get("display_name") + user_id, device_id, new_display_name=new_display_name ) yield self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a3aa0a1cf2..62b47f6574 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -160,6 +160,24 @@ class DeviceTestCase(unittest.HomeserverTestCase): res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") + def test_update_device_too_long_display_name(self): + """Update a device with a display name that is invalid (too long).""" + self._record_users() + + # Request to update a device display name with a new value that is longer than allowed. + update = { + "display_name": "a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) + } + self.get_failure( + self.handler.update_device(user1, "abc", update), + synapse.api.errors.SynapseError, + ) + + # Ensure the display name was not updated. + res = self.get_success(self.handler.get_device(user1, "abc")) + self.assertEqual(res["display_name"], "display 2") + def test_update_unknown_device(self): update = {"display_name": "new_display"} res = self.handler.update_device("user_id", "unknown_device_id", update) -- cgit 1.5.1 From d8994942f28f5028e560f6aba52512fae3ca1a6a Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 12 Feb 2020 18:14:10 +0000 Subject: Return a 404 for admin api user lookup if user not found (#6901) --- changelog.d/6901.misc | 1 + synapse/rest/admin/users.py | 5 ++++- tests/rest/admin/test_user.py | 16 ++++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6901.misc (limited to 'tests') diff --git a/changelog.d/6901.misc b/changelog.d/6901.misc new file mode 100644 index 0000000000..b2f12bbe86 --- /dev/null +++ b/changelog.d/6901.misc @@ -0,0 +1 @@ +Return a 404 instead of 200 for querying information of a non-existant user through the admin API. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index e75c5f1370..2107b5dc56 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -21,7 +21,7 @@ from six import text_type from six.moves import http_client from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -152,6 +152,9 @@ class UserRestServletV2(RestServlet): ret = await self.admin_handler.get_user(target_user) + if not ret: + raise NotFoundError("User not found") + return 200, ret async def on_PUT(self, request, user_id): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 3b5169b38d..490ce8f55d 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -401,6 +401,22 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + self.hs.config.registration_shared_secret = None + + request, channel = self.make_request( + "GET", + "/_synapse/admin/v2/users/@unknown_person:test", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) + def test_requester_is_admin(self): """ If the user is a server admin, a new user is created. -- cgit 1.5.1 From 49f877d32efc79cb40b2766cb052cf35bad31de5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Feb 2020 07:17:54 -0500 Subject: Filter the results of user directory searching via the spam checker (#6888) Add a method to the spam checker to filter the user directory results. --- changelog.d/6888.feature | 1 + docs/spam_checker.md | 3 ++ synapse/events/spamcheck.py | 27 ++++++++++ synapse/handlers/user_directory.py | 14 +++++- tests/handlers/test_user_directory.py | 92 +++++++++++++++++++++++++++++++++++ 5 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6888.feature (limited to 'tests') diff --git a/changelog.d/6888.feature b/changelog.d/6888.feature new file mode 100644 index 0000000000..1b7ac0c823 --- /dev/null +++ b/changelog.d/6888.feature @@ -0,0 +1 @@ +The result of a user directory search can now be filtered via the spam checker. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 97ff17f952..5b5f5000b7 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -54,6 +54,9 @@ class ExampleSpamChecker: def user_may_publish_room(self, userid, room_id): return True # allow publishing of all rooms + + def check_username_for_spam(self, user_profile): + return False # allow all usernames ``` ## Configuration diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 5a907718d6..0a13fca9a4 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,6 +15,7 @@ # limitations under the License. import inspect +from typing import Dict from synapse.spam_checker_api import SpamCheckerApi @@ -125,3 +126,29 @@ class SpamChecker(object): return True return self.spam_checker.user_may_publish_room(userid, room_id) + + def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + """Checks if a user ID or display name are considered "spammy" by this server. + + If the server considers a username spammy, then it will not be included in + user directory results. + + Args: + user_profile: The user information to check, it contains the keys: + * user_id + * display_name + * avatar_url + + Returns: + True if the user is spammy. + """ + if self.spam_checker is None: + return False + + # For backwards compatibility, if the method does not exist on the spam checker, fallback to not interfering. + checker = getattr(self.spam_checker, "check_username_for_spam", None) + if not checker: + return False + # Make a copy of the user profile object to ensure the spam checker + # cannot modify it. + return checker(user_profile.copy()) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 81aa58dc8c..722760c59d 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -52,6 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory self.search_all_users = hs.config.user_directory_search_all_users + self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream self.pos = None @@ -65,7 +66,7 @@ class UserDirectoryHandler(StateDeltasHandler): # we start populating the user directory self.clock.call_later(0, self.notify_new_event) - def search_users(self, user_id, search_term, limit): + async def search_users(self, user_id, search_term, limit): """Searches for users in directory Returns: @@ -82,7 +83,16 @@ class UserDirectoryHandler(StateDeltasHandler): ] } """ - return self.store.search_user_dir(user_id, search_term, limit) + results = await self.store.search_user_dir(user_id, search_term, limit) + + # Remove any spammy users from the results. + results["results"] = [ + user + for user in results["results"] + if not self.spam_checker.check_username_for_spam(user) + ] + + return results def notify_new_event(self): """Called when there may be more deltas to process diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 26071059d2..0a4765fff4 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -147,6 +147,98 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) + def test_spam_checker(self): + """ + A user which fails to the spam checks will not appear in search results. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + ) + self.assertEqual(public_users, []) + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that does not filter any users. + spam_checker = self.hs.get_spam_checker() + + class AllowAll(object): + def check_username_for_spam(self, user_profile): + # Allow all users. + return False + + spam_checker.spam_checker = AllowAll() + + # The results do not change: + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that filters all users. + class BlockAll(object): + def check_username_for_spam(self, user_profile): + # All users are spammy. + return True + + spam_checker.spam_checker = BlockAll() + + # User1 now gets no search results for any of the other users. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + def test_legacy_spam_checker(self): + """ + A spam checker without the expected method should be ignored. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + ) + self.assertEqual(public_users, []) + + # Configure a spam checker. + spam_checker = self.hs.get_spam_checker() + # The spam checker doesn't need any methods, so create a bare object. + spam_checker.spam_checker = object() + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + def _compress_shared(self, shared): """ Compress a list of users who share rooms dicts to a list of tuples. -- cgit 1.5.1 From 02e89021f58f931068ab0337de039181cc7f6569 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Feb 2020 09:05:43 -0500 Subject: Convert the directory handler tests to use HomeserverTestCase (#6919) Convert directory handler tests to use HomeserverTestCase. --- changelog.d/6919.misc | 1 + tests/handlers/test_directory.py | 41 +++++++++++++++++----------------------- 2 files changed, 18 insertions(+), 24 deletions(-) create mode 100644 changelog.d/6919.misc (limited to 'tests') diff --git a/changelog.d/6919.misc b/changelog.d/6919.misc new file mode 100644 index 0000000000..aa2cd89998 --- /dev/null +++ b/changelog.d/6919.misc @@ -0,0 +1 @@ +Convert the directory handler tests to use HomeserverTestCase. diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 91c7a17070..ee88cf5a4b 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -19,24 +19,16 @@ from mock import Mock from twisted.internet import defer from synapse.config.room_directory import RoomDirectoryConfig -from synapse.handlers.directory import DirectoryHandler from synapse.rest.client.v1 import directory, room from synapse.types import RoomAlias from tests import unittest -from tests.utils import setup_test_homeserver -class DirectoryHandlers(object): - def __init__(self, hs): - self.directory_handler = DirectoryHandler(hs) - - -class DirectoryTestCase(unittest.TestCase): +class DirectoryTestCase(unittest.HomeserverTestCase): """ Tests the directory service. """ - @defer.inlineCallbacks - def setUp(self): + def make_homeserver(self, reactor, clock): self.mock_federation = Mock() self.mock_registry = Mock() @@ -47,14 +39,12 @@ class DirectoryTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler - hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_registry=self.mock_registry, ) - hs.handlers = DirectoryHandlers(hs) self.handler = hs.get_handlers().directory_handler @@ -64,23 +54,25 @@ class DirectoryTestCase(unittest.TestCase): self.your_room = RoomAlias.from_string("#your-room:test") self.remote_room = RoomAlias.from_string("#another:remote") - @defer.inlineCallbacks + return hs + def test_get_local_association(self): - yield self.store.create_room_alias_association( - self.my_room, "!8765qwer:test", ["test"] + self.get_success( + self.store.create_room_alias_association( + self.my_room, "!8765qwer:test", ["test"] + ) ) - result = yield self.handler.get_association(self.my_room) + result = self.get_success(self.handler.get_association(self.my_room)) self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) - @defer.inlineCallbacks def test_get_remote_association(self): self.mock_federation.make_query.return_value = defer.succeed( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) - result = yield self.handler.get_association(self.remote_room) + result = self.get_success(self.handler.get_association(self.remote_room)) self.assertEquals( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result @@ -93,14 +85,15 @@ class DirectoryTestCase(unittest.TestCase): ignore_backoff=True, ) - @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_room_alias_association( - self.your_room, "!8765asdf:test", ["test"] + self.get_success( + self.store.create_room_alias_association( + self.your_room, "!8765asdf:test", ["test"] + ) ) - response = yield self.query_handlers["directory"]( - {"room_alias": "#your-room:test"} + response = self.get_success( + self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) -- cgit 1.5.1 From 3404ad289b1d2e5bc5c7f277f519b9698dbdaa15 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 17 Feb 2020 13:23:37 +0000 Subject: Raise the default power levels for invites, tombstones and server acls (#6834) --- changelog.d/6834.misc | 1 + synapse/handlers/room.py | 10 +++++++++- tests/rest/client/v1/test_rooms.py | 4 +++- 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6834.misc (limited to 'tests') diff --git a/changelog.d/6834.misc b/changelog.d/6834.misc new file mode 100644 index 0000000000..79acebe516 --- /dev/null +++ b/changelog.d/6834.misc @@ -0,0 +1 @@ +Change the default power levels of invites, tombstones and server ACLs for new rooms. \ No newline at end of file diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ab07edd2fc..033083acac 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -64,18 +64,21 @@ class RoomCreationHandler(BaseHandler): "history_visibility": "shared", "original_invitees_have_ops": False, "guest_can_join": True, + "power_level_content_override": {"invite": 0}, }, RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, "history_visibility": "shared", "original_invitees_have_ops": True, "guest_can_join": True, + "power_level_content_override": {"invite": 0}, }, RoomCreationPreset.PUBLIC_CHAT: { "join_rules": JoinRules.PUBLIC, "history_visibility": "shared", "original_invitees_have_ops": False, "guest_can_join": False, + "power_level_content_override": {}, }, } @@ -829,19 +832,24 @@ class RoomCreationHandler(BaseHandler): # This will be reudundant on pre-MSC2260 rooms, since the # aliases event is special-cased. EventTypes.Aliases: 0, + EventTypes.Tombstone: 100, + EventTypes.ServerACL: 100, }, "events_default": 0, "state_default": 50, "ban": 50, "kick": 50, "redact": 50, - "invite": 0, + "invite": 50, } if config["original_invitees_have_ops"]: for invitee in invite_list: power_level_content["users"][invitee] = 100 + # Power levels overrides are defined per chat preset + power_level_content.update(config["power_level_content_override"]) + if power_level_content_override: power_level_content.update(power_level_content_override) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e3af280ba6..fb681a1db9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1612,7 +1612,9 @@ class ContextTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("user", "password") self.tok = self.login("user", "password") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + self.room_id = self.helper.create_room_as( + self.user_id, tok=self.tok, is_public=False + ) self.other_user_id = self.register_user("user2", "password") self.other_tok = self.login("user2", "password") -- cgit 1.5.1 From fe3941f6e33a17fa7cdf209a4370f4e805341db4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Feb 2020 07:29:44 -0500 Subject: Stop sending events when creating or deleting aliases (#6904) Stop sending events when creating or deleting associations (room aliases). Send an updated canonical alias event if one of the alt_aliases is deleted. --- changelog.d/6904.removal | 1 + synapse/handlers/directory.py | 75 ++++++++++--------- synapse/handlers/room.py | 6 +- tests/handlers/test_directory.py | 154 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 194 insertions(+), 42 deletions(-) create mode 100644 changelog.d/6904.removal (limited to 'tests') diff --git a/changelog.d/6904.removal b/changelog.d/6904.removal new file mode 100644 index 0000000000..a5cc0c3605 --- /dev/null +++ b/changelog.d/6904.removal @@ -0,0 +1 @@ +Stop sending alias events during adding / removing aliases. Check alt_aliases in the latest canonical aliases event when deleting an alias. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8c5980cb0c..f718388884 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -81,13 +81,7 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def create_association( - self, - requester, - room_alias, - room_id, - servers=None, - send_event=True, - check_membership=True, + self, requester, room_alias, room_id, servers=None, check_membership=True, ): """Attempt to create a new alias @@ -97,7 +91,6 @@ class DirectoryHandler(BaseHandler): room_id (str) servers (list[str]|None): List of servers that others servers should try and join via - send_event (bool): Whether to send an updated m.room.aliases event check_membership (bool): Whether to check if the user is in the room before the alias can be set (if the server's config requires it). @@ -150,16 +143,9 @@ class DirectoryHandler(BaseHandler): ) yield self._create_association(room_alias, room_id, servers, creator=user_id) - if send_event: - try: - yield self.send_room_alias_update_event(requester, room_id) - except AuthError as e: - # sending the aliases event may fail due to the user not having - # permission in the room; this is permitted. - logger.info("Skipping updating aliases event due to auth error %s", e) @defer.inlineCallbacks - def delete_association(self, requester, room_alias, send_event=True): + def delete_association(self, requester, room_alias): """Remove an alias from the directory (this is only meant for human users; AS users should call @@ -168,9 +154,6 @@ class DirectoryHandler(BaseHandler): Args: requester (Requester): room_alias (RoomAlias): - send_event (bool): Whether to send an updated m.room.aliases event. - Note that, if we delete the canonical alias, we will always attempt - to send an m.room.canonical_alias event Returns: Deferred[unicode]: room id that the alias used to point to @@ -206,9 +189,6 @@ class DirectoryHandler(BaseHandler): room_id = yield self._delete_association(room_alias) try: - if send_event: - yield self.send_room_alias_update_event(requester, room_id) - yield self._update_canonical_alias( requester, requester.user.to_string(), room_id, room_alias ) @@ -319,25 +299,50 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + """ + Send an updated canonical alias event if the removed alias was set as + the canonical alias or listed in the alt_aliases field. + """ alias_event = yield self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" ) - alias_str = room_alias.to_string() - if not alias_event or alias_event.content.get("alias", "") != alias_str: + # There is no canonical alias, nothing to do. + if not alias_event: return - yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.CanonicalAlias, - "state_key": "", - "room_id": room_id, - "sender": user_id, - "content": {}, - }, - ratelimit=False, - ) + # Obtain a mutable version of the event content. + content = dict(alias_event.content) + send_update = False + + # Remove the alias property if it matches the removed alias. + alias_str = room_alias.to_string() + if alias_event.content.get("alias", "") == alias_str: + send_update = True + content.pop("alias", "") + + # Filter alt_aliases for the removed alias. + alt_aliases = content.pop("alt_aliases", None) + # If the aliases are not a list (or not found) do not attempt to modify + # the list. + if isinstance(alt_aliases, list): + send_update = True + alt_aliases = [alias for alias in alt_aliases if alias != alias_str] + if alt_aliases: + content["alt_aliases"] = alt_aliases + + if send_update: + yield self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": room_id, + "sender": user_id, + "content": content, + }, + ratelimit=False, + ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 033083acac..49ec2f48bc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -478,9 +478,7 @@ class RoomCreationHandler(BaseHandler): for alias_str in aliases: alias = RoomAlias.from_string(alias_str) try: - yield directory_handler.delete_association( - requester, alias, send_event=False - ) + yield directory_handler.delete_association(requester, alias) removed_aliases.append(alias_str) except SynapseError as e: logger.warning("Unable to remove alias %s from old room: %s", alias, e) @@ -511,7 +509,6 @@ class RoomCreationHandler(BaseHandler): RoomAlias.from_string(alias), new_room_id, servers=(self.hs.hostname,), - send_event=False, check_membership=False, ) logger.info("Moved alias %s to new room", alias) @@ -664,7 +661,6 @@ class RoomCreationHandler(BaseHandler): room_id=room_id, room_alias=room_alias, servers=[self.hs.hostname], - send_event=False, check_membership=False, ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index ee88cf5a4b..27b916aed4 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,9 +18,11 @@ from mock import Mock from twisted.internet import defer +import synapse.api.errors +from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig -from synapse.rest.client.v1 import directory, room -from synapse.types import RoomAlias +from synapse.rest.client.v1 import directory, login, room +from synapse.types import RoomAlias, create_requester from tests import unittest @@ -85,6 +87,38 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) + def test_delete_alias_not_allowed(self): + room_id = "!8765qwer:test" + self.get_success( + self.store.create_room_alias_association(self.my_room, room_id, ["test"]) + ) + + self.get_failure( + self.handler.delete_association( + create_requester("@user:test"), self.my_room + ), + synapse.api.errors.AuthError, + ) + + def test_delete_alias(self): + room_id = "!8765qwer:test" + user_id = "@user:test" + self.get_success( + self.store.create_room_alias_association( + self.my_room, room_id, ["test"], user_id + ) + ) + + result = self.get_success( + self.handler.delete_association(create_requester(user_id), self.my_room) + ) + self.assertEquals(room_id, result) + + # The alias should not be found. + self.get_failure( + self.handler.get_association(self.my_room), synapse.api.errors.SynapseError + ) + def test_incoming_fed_query(self): self.get_success( self.store.create_room_alias_association( @@ -99,6 +133,122 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) +class CanonicalAliasTestCase(unittest.HomeserverTestCase): + """Test modifications of the canonical alias when delete aliases. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a new alias to this room. + self.get_success( + self.store.create_room_alias_association( + self.room_alias, self.room_id, ["test"], self.admin_user + ) + ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self.helper.send_state( + self.room_id, + "m.room.canonical_alias", + {"alias": self.test_alias, "alt_aliases": [self.test_alias]}, + tok=self.admin_user_tok, + ) + + data = self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + + # Finally, delete the alias. + self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) + ) + + data = self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + self.assertNotIn("alias", data["content"]) + self.assertNotIn("alt_aliases", data["content"]) + + def test_remove_other_alias(self): + """Removing an alias listed as in alt_aliases should remove it there too.""" + # Create a second alias. + other_test_alias = "#test2:test" + other_room_alias = RoomAlias.from_string(other_test_alias) + self.get_success( + self.store.create_room_alias_association( + other_room_alias, self.room_id, ["test"], self.admin_user + ) + ) + + # Set the alias as the canonical alias for this room. + self.helper.send_state( + self.room_id, + "m.room.canonical_alias", + { + "alias": self.test_alias, + "alt_aliases": [self.test_alias, other_test_alias], + }, + tok=self.admin_user_tok, + ) + + data = self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual( + data["content"]["alt_aliases"], [self.test_alias, other_test_alias] + ) + + # Delete the second alias. + self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), other_room_alias + ) + ) + + data = self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + + class TestCreateAliasACL(unittest.HomeserverTestCase): user_id = "@test:test" -- cgit 1.5.1