From f02663c4ddfa259c96aebde848a83156540c9fb3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 30 Mar 2021 12:12:44 +0100 Subject: Replace `room_invite_state_types` with `room_prejoin_state` (#9700) `room_invite_state_types` was inconvenient as a configuration setting, because anyone that ever set it would not receive any new types that were added to the defaults. Here, we deprecate the old setting, and replace it with a couple of new settings under `room_prejoin_state`. --- synapse/config/api.py | 135 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 112 insertions(+), 23 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/api.py b/synapse/config/api.py index 74cd53a8ed..91387a7f0e 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -1,4 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,38 +12,127 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import Iterable + from synapse.api.constants import EventTypes +from synapse.config._base import Config, ConfigError +from synapse.config._util import validate_config +from synapse.types import JsonDict -from ._base import Config +logger = logging.getLogger(__name__) class ApiConfig(Config): section = "api" - def read_config(self, config, **kwargs): - self.room_invite_state_types = config.get( - "room_invite_state_types", - [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, - ], + def read_config(self, config: JsonDict, **kwargs): + validate_config(_MAIN_SCHEMA, config, ()) + self.room_prejoin_state = list(self._get_prejoin_state_types(config)) + + def generate_config_section(cls, **kwargs) -> str: + formatted_default_state_types = "\n".join( + " # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES ) - def generate_config_section(cls, **kwargs): return """\ ## API Configuration ## - # A list of event types that will be included in the room_invite_state + # Controls for the state that is shared with users who receive an invite + # to a room # - #room_invite_state_types: - # - "{JoinRules}" - # - "{CanonicalAlias}" - # - "{RoomAvatar}" - # - "{RoomEncryption}" - # - "{Name}" - """.format( - **vars(EventTypes) - ) + room_prejoin_state: + # By default, the following state event types are shared with users who + # receive invites to the room: + # +%(formatted_default_state_types)s + # + # Uncomment the following to disable these defaults (so that only the event + # types listed in 'additional_event_types' are shared). Defaults to 'false'. + # + #disable_default_event_types: true + + # Additional state event types to share with users when they are invited + # to a room. + # + # By default, this list is empty (so only the default event types are shared). + # + #additional_event_types: + # - org.example.custom.event.type + """ % { + "formatted_default_state_types": formatted_default_state_types + } + + def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: + """Get the event types to include in the prejoin state + + Parses the config and returns an iterable of the event types to be included. + """ + room_prejoin_state_config = config.get("room_prejoin_state") or {} + + # backwards-compatibility support for room_invite_state_types + if "room_invite_state_types" in config: + # if both "room_invite_state_types" and "room_prejoin_state" are set, then + # we don't really know what to do. + if room_prejoin_state_config: + raise ConfigError( + "Can't specify both 'room_invite_state_types' and 'room_prejoin_state' " + "in config" + ) + + logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING) + + yield from config["room_invite_state_types"] + return + + if not room_prejoin_state_config.get("disable_default_event_types"): + yield from _DEFAULT_PREJOIN_STATE_TYPES + + yield from room_prejoin_state_config.get("additional_event_types", []) + + +_ROOM_INVITE_STATE_TYPES_WARNING = """\ +WARNING: The 'room_invite_state_types' configuration setting is now deprecated, +and replaced with 'room_prejoin_state'. New features may not work correctly +unless 'room_invite_state_types' is removed. See the sample configuration file for +details of 'room_prejoin_state'. +-------------------------------------------------------------------------------- +""" + +_DEFAULT_PREJOIN_STATE_TYPES = [ + EventTypes.JoinRules, + EventTypes.CanonicalAlias, + EventTypes.RoomAvatar, + EventTypes.RoomEncryption, + EventTypes.Name, +] + + +# room_prejoin_state can either be None (as it is in the default config), or +# an object containing other config settings +_ROOM_PREJOIN_STATE_CONFIG_SCHEMA = { + "oneOf": [ + { + "type": "object", + "properties": { + "disable_default_event_types": {"type": "boolean"}, + "additional_event_types": { + "type": "array", + "items": {"type": "string"}, + }, + }, + }, + {"type": "null"}, + ] +} + +# the legacy room_invite_state_types setting +_ROOM_INVITE_STATE_TYPES_SCHEMA = {"type": "array", "items": {"type": "string"}} + +_MAIN_SCHEMA = { + "type": "object", + "properties": { + "room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA, + "room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA, + }, +} -- cgit 1.5.1 From 4dabcf026e3a8f480268451332b6bf3a7e671480 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 30 Mar 2021 14:03:17 +0100 Subject: Include m.room.create in invite_room_state for Spaces (#9710) --- changelog.d/9710.feature | 1 + synapse/config/api.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/9710.feature (limited to 'synapse/config') diff --git a/changelog.d/9710.feature b/changelog.d/9710.feature new file mode 100644 index 0000000000..fce308cc41 --- /dev/null +++ b/changelog.d/9710.feature @@ -0,0 +1 @@ +Experimental Spaces support: include `m.room.create` in the room state sent with room-invites. diff --git a/synapse/config/api.py b/synapse/config/api.py index 91387a7f0e..55c038c0c4 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -88,6 +88,10 @@ class ApiConfig(Config): if not room_prejoin_state_config.get("disable_default_event_types"): yield from _DEFAULT_PREJOIN_STATE_TYPES + if self.spaces_enabled: + # MSC1772 suggests adding m.room.create to the prejoin state + yield EventTypes.Create + yield from room_prejoin_state_config.get("additional_event_types", []) -- cgit 1.5.1 From 5ff8eb97c646f9f8de74915e4b2926789695d4af Mon Sep 17 00:00:00 2001 From: Denis Kasak Date: Wed, 31 Mar 2021 12:27:20 +0000 Subject: Make sample config allowed_local_3pids regex stricter. (#9719) The regex should be terminated so that subdomain matches of another domain are not accepted. Just ensuring that someone doesn't shoot themselves in the foot by copying our example. Signed-off-by: Denis Kasak --- changelog.d/9719.doc | 1 + docs/sample_config.yaml | 4 ++-- synapse/config/registration.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/9719.doc (limited to 'synapse/config') diff --git a/changelog.d/9719.doc b/changelog.d/9719.doc new file mode 100644 index 0000000000..f018606dd6 --- /dev/null +++ b/changelog.d/9719.doc @@ -0,0 +1 @@ +Make the allowed_local_3pids regex example in the sample config stricter. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c73ea6b161..b0bf987740 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1246,9 +1246,9 @@ account_validity: # #allowed_local_3pids: # - medium: email -# pattern: '.*@matrix\.org' +# pattern: '^[^@]+@matrix\.org$' # - medium: email -# pattern: '.*@vector\.im' +# pattern: '^[^@]+@vector\.im$' # - medium: msisdn # pattern: '\+44' diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ead007ba5a..f27d1e14ac 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -298,9 +298,9 @@ class RegistrationConfig(Config): # #allowed_local_3pids: # - medium: email - # pattern: '.*@matrix\\.org' + # pattern: '^[^@]+@matrix\\.org$' # - medium: email - # pattern: '.*@vector\\.im' + # pattern: '^[^@]+@vector\\.im$' # - medium: msisdn # pattern: '\\+44' -- cgit 1.5.1 From 35c5ef2d24734889a20a0cf334bb971a9329806f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 31 Mar 2021 16:39:08 -0400 Subject: Add an experimental room version to support restricted join rules. (#9717) Per MSC3083. --- changelog.d/9717.feature | 1 + synapse/api/constants.py | 2 + synapse/api/room_versions.py | 24 +++- synapse/config/experimental.py | 7 +- synapse/event_auth.py | 28 ++++- tests/test_event_auth.py | 246 ++++++++++++++++++++++++++++++++++++++++- 6 files changed, 297 insertions(+), 11 deletions(-) create mode 100644 changelog.d/9717.feature (limited to 'synapse/config') diff --git a/changelog.d/9717.feature b/changelog.d/9717.feature new file mode 100644 index 0000000000..c2c74f13d5 --- /dev/null +++ b/changelog.d/9717.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8f37d2cf3b..6856dab06c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -59,6 +59,8 @@ class JoinRules: KNOCK = "knock" INVITE = "invite" PRIVATE = "private" + # As defined for MSC3083. + MSC3083_RESTRICTED = "restricted" class LoginType: diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index de2cc15d33..87038d436d 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -57,7 +57,7 @@ class RoomVersion: state_res = attr.ib(type=int) # one of the StateResolutionVersions enforce_key_validity = attr.ib(type=bool) - # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules + # Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules special_case_aliases_auth = attr.ib(type=bool) # Strictly enforce canonicaljson, do not allow: # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] @@ -69,6 +69,8 @@ class RoomVersion: limit_notifications_power_levels = attr.ib(type=bool) # MSC2174/MSC2176: Apply updated redaction rules algorithm. msc2176_redaction_rules = attr.ib(type=bool) + # MSC3083: Support the 'restricted' join_rule. + msc3083_join_rules = attr.ib(type=bool) class RoomVersions: @@ -82,6 +84,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V2 = RoomVersion( "2", @@ -93,6 +96,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V3 = RoomVersion( "3", @@ -104,6 +108,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V4 = RoomVersion( "4", @@ -115,6 +120,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V5 = RoomVersion( "5", @@ -126,6 +132,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) V6 = RoomVersion( "6", @@ -137,6 +144,7 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, + msc3083_join_rules=False, ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -148,6 +156,19 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=True, + msc3083_join_rules=False, + ) + MSC3083 = RoomVersion( + "org.matrix.msc3083", + RoomDisposition.UNSTABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, ) @@ -162,4 +183,5 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V6, RoomVersions.MSC2176, ) + # Note that we do not include MSC3083 here unless it is enabled in the config. } # type: Dict[str, RoomVersion] diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 86f4d9af9d..eb96ecda74 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config._base import Config from synapse.types import JsonDict @@ -27,7 +28,11 @@ class ExperimentalConfig(Config): # MSC2858 (multiple SSO identity providers) self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool - # Spaces (MSC1772, MSC2946, etc) + + # Spaces (MSC1772, MSC2946, MSC3083, etc) self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool + if self.spaces_enabled: + KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083 + # MSC3026 (busy presence state) self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 91ad5b3d3c..9863953f5c 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -162,7 +162,7 @@ def check( logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) if event.type == EventTypes.Member: - _is_membership_change_allowed(event, auth_events) + _is_membership_change_allowed(room_version_obj, event, auth_events) logger.debug("Allowing! %s", event) return @@ -220,8 +220,19 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool: def _is_membership_change_allowed( - event: EventBase, auth_events: StateMap[EventBase] + room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase] ) -> None: + """ + Confirms that the event which changes membership is an allowed change. + + Args: + room_version: The version of the room. + event: The event to check. + auth_events: The current auth events of the room. + + Raises: + AuthError if the event is not allowed. + """ membership = event.content["membership"] # Check if this is the room creator joining: @@ -315,14 +326,19 @@ def _is_membership_change_allowed( if user_level < invite_level: raise AuthError(403, "You don't have permission to invite users") elif Membership.JOIN == membership: - # Joins are valid iff caller == target and they were: - # invited: They are accepting the invitation - # joined: It's a NOOP + # Joins are valid iff caller == target and: + # * They are not banned. + # * They are accepting a previously sent invitation. + # * They are already joined (it's a NOOP). + # * The room is public or restricted. if event.user_id != target_user_id: raise AuthError(403, "Cannot force another user to join.") elif target_banned: raise AuthError(403, "You are banned from this room") - elif join_rule == JoinRules.PUBLIC: + elif join_rule == JoinRules.PUBLIC or ( + room_version.msc3083_join_rules + and join_rule == JoinRules.MSC3083_RESTRICTED + ): pass elif join_rule == JoinRules.INVITE: if not caller_in_room and not caller_invited: diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 3f2691ee6b..b5f18344dc 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -207,6 +207,226 @@ class EventAuthTestCase(unittest.TestCase): do_sig_check=False, ) + def test_join_rules_public(self): + """ + Test joining a public room. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "public"), + } + + # Check join. + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left can re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + def test_join_rules_invite(self): + """ + Test joining an invite only room. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "invite"), + } + + # A join without an invite is rejected. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left cannot re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + def test_join_rules_msc3083_restricted(self): + """ + Test joining a restricted room from MSC3083. + + This is pretty much the same test as public. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"), + } + + # Older room versions don't understand this join rule + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # Check join. + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC3083, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left can re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + # helpers for making events @@ -225,19 +445,24 @@ def _create_event(user_id): ) -def _join_event(user_id): +def _member_event(user_id, membership, sender=None): return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), "type": "m.room.member", - "sender": user_id, + "sender": sender or user_id, "state_key": user_id, - "content": {"membership": "join"}, + "content": {"membership": membership}, + "prev_events": [], } ) +def _join_event(user_id): + return _member_event(user_id, "join") + + def _power_levels_event(sender, content): return make_event_from_dict( { @@ -277,6 +502,21 @@ def _random_state_event(sender): ) +def _join_rules_event(sender, join_rule): + return make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.join_rules", + "sender": sender, + "state_key": "", + "content": { + "join_rule": join_rule, + }, + } + ) + + event_count = 0 -- cgit 1.5.1 From 04819239bae2b39ee42bfdb6f9b83c6d9fe34169 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 6 Apr 2021 14:38:30 +0100 Subject: Add a Synapse Module for configuring presence update routing (#9491) At the moment, if you'd like to share presence between local or remote users, those users must be sharing a room together. This isn't always the most convenient or useful situation though. This PR adds a module to Synapse that will allow deployments to set up extra logic on where presence updates should be routed. The module must implement two methods, `get_users_for_states` and `get_interested_users`. These methods are given presence updates or user IDs and must return information that Synapse will use to grant passing presence updates around. A method is additionally added to `ModuleApi` which allows triggering a set of users to receive the current, online presence information for all users they are considered interested in. This is the equivalent of that user receiving presence information during an initial sync. The goal of this module is to be fairly generic and useful for a variety of applications, with hard requirements being: * Sending state for a specific set or all known users to a defined set of local and remote users. * The ability to trigger an initial sync for specific users, so they receive all current state. --- README.rst | 7 +- changelog.d/9491.feature | 1 + docs/presence_router_module.md | 235 +++++++++++++++++++++ docs/sample_config.yaml | 23 +- synapse/app/generic_worker.py | 3 +- synapse/config/server.py | 39 +++- synapse/events/presence_router.py | 104 +++++++++ synapse/federation/sender/__init__.py | 19 +- synapse/handlers/presence.py | 278 ++++++++++++++++++++---- synapse/module_api/__init__.py | 50 +++++ synapse/server.py | 5 + tests/events/test_presence_router.py | 386 ++++++++++++++++++++++++++++++++++ tests/handlers/test_sync.py | 21 +- tests/module_api/test_api.py | 175 ++++++++++++++- 14 files changed, 1282 insertions(+), 64 deletions(-) create mode 100644 changelog.d/9491.feature create mode 100644 docs/presence_router_module.md create mode 100644 synapse/events/presence_router.py create mode 100644 tests/events/test_presence_router.py (limited to 'synapse/config') diff --git a/README.rst b/README.rst index 655a2bf3be..1a5503572e 100644 --- a/README.rst +++ b/README.rst @@ -393,7 +393,12 @@ massive excess of outgoing federation requests (see `discussion indicate that your server is also issuing far more outgoing federation requests than can be accounted for by your users' activity, this is a likely cause. The misbehavior can be worked around by setting -``use_presence: false`` in the Synapse config file. +the following in the Synapse config file: + +.. code-block:: yaml + + presence: + enabled: false People can't accept room invitations from me -------------------------------------------- diff --git a/changelog.d/9491.feature b/changelog.d/9491.feature new file mode 100644 index 0000000000..8b56a95a44 --- /dev/null +++ b/changelog.d/9491.feature @@ -0,0 +1 @@ +Add a Synapse module for routing presence updates between users. diff --git a/docs/presence_router_module.md b/docs/presence_router_module.md new file mode 100644 index 0000000000..d6566d978d --- /dev/null +++ b/docs/presence_router_module.md @@ -0,0 +1,235 @@ +# Presence Router Module + +Synapse supports configuring a module that can specify additional users +(local or remote) to should receive certain presence updates from local +users. + +Note that routing presence via Application Service transactions is not +currently supported. + +The presence routing module is implemented as a Python class, which will +be imported by the running Synapse. + +## Python Presence Router Class + +The Python class is instantiated with two objects: + +* A configuration object of some type (see below). +* An instance of `synapse.module_api.ModuleApi`. + +It then implements methods related to presence routing. + +Note that one method of `ModuleApi` that may be useful is: + +```python +async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None +``` + +which can be given a list of local or remote MXIDs to broadcast known, online user +presence to (for those users that the receiving user is considered interested in). +It does not include state for users who are currently offline, and it can only be +called on workers that support sending federation. + +### Module structure + +Below is a list of possible methods that can be implemented, and whether they are +required. + +#### `parse_config` + +```python +def parse_config(config_dict: dict) -> Any +``` + +**Required.** A static method that is passed a dictionary of config options, and + should return a validated config object. This method is described further in + [Configuration](#configuration). + +#### `get_users_for_states` + +```python +async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], +) -> Dict[str, Set[UserPresenceState]]: +``` + +**Required.** An asynchronous method that is passed an iterable of user presence +state. This method can determine whether a given presence update should be sent to certain +users. It does this by returning a dictionary with keys representing local or remote +Matrix User IDs, and values being a python set +of `synapse.handlers.presence.UserPresenceState` instances. + +Synapse will then attempt to send the specified presence updates to each user when +possible. + +#### `get_interested_users` + +```python +async def get_interested_users(self, user_id: str) -> Union[Set[str], str] +``` + +**Required.** An asynchronous method that is passed a single Matrix User ID. This +method is expected to return the users that the passed in user may be interested in the +presence of. Returned users may be local or remote. The presence routed as a result of +what this method returns is sent in addition to the updates already sent between users +that share a room together. Presence updates are deduplicated. + +This method should return a python set of Matrix User IDs, or the object +`synapse.events.presence_router.PresenceRouter.ALL_USERS` to indicate that the passed +user should receive presence information for *all* known users. + +For clarity, if the user `@alice:example.org` is passed to this method, and the Set +`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice +should receive presence updates sent by Bob and Charlie, regardless of whether these +users share a room. + +### Example + +Below is an example implementation of a presence router class. + +```python +from typing import Dict, Iterable, Set, Union +from synapse.events.presence_router import PresenceRouter +from synapse.handlers.presence import UserPresenceState +from synapse.module_api import ModuleApi + +class PresenceRouterConfig: + def __init__(self): + # Config options with their defaults + # A list of users to always send all user presence updates to + self.always_send_to_users = [] # type: List[str] + + # A list of users to ignore presence updates for. Does not affect + # shared-room presence relationships + self.blacklisted_users = [] # type: List[str] + +class ExamplePresenceRouter: + """An example implementation of synapse.presence_router.PresenceRouter. + Supports routing all presence to a configured set of users, or a subset + of presence from certain users to members of certain rooms. + + Args: + config: A configuration object. + module_api: An instance of Synapse's ModuleApi. + """ + def __init__(self, config: PresenceRouterConfig, module_api: ModuleApi): + self._config = config + self._module_api = module_api + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterConfig() + always_send_to_users = config_dict.get("always_send_to_users") + blacklisted_users = config_dict.get("blacklisted_users") + + # Do some validation of config options... otherwise raise a + # synapse.config.ConfigError. + config.always_send_to_users = always_send_to_users + config.blacklisted_users = blacklisted_users + + return config + + async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], + ) -> Dict[str, Set[UserPresenceState]]: + """Given an iterable of user presence updates, determine where each one + needs to go. Returned results will not affect presence updates that are + sent between users who share a room. + + Args: + state_updates: An iterable of user presence state updates. + + Returns: + A dictionary of user_id -> set of UserPresenceState that the user should + receive. + """ + destination_users = {} # type: Dict[str, Set[UserPresenceState] + + # Ignore any updates for blacklisted users + desired_updates = set() + for update in state_updates: + if update.state_key not in self._config.blacklisted_users: + desired_updates.add(update) + + # Send all presence updates to specific users + for user_id in self._config.always_send_to_users: + destination_users[user_id] = desired_updates + + return destination_users + + async def get_interested_users( + self, + user_id: str, + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + """ + Retrieve a list of users that `user_id` is interested in receiving the + presence of. This will be in addition to those they share a room with. + Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate + that this user should receive all incoming local and remote presence updates. + + Note that this method will only be called for local users. + + Args: + user_id: A user requesting presence updates. + + Returns: + A set of user IDs to return additional presence updates for, or + PresenceRouter.ALL_USERS to return presence updates for all other users. + """ + if user_id in self._config.always_send_to_users: + return PresenceRouter.ALL_USERS + + return set() +``` + +#### A note on `get_users_for_states` and `get_interested_users` + +Both of these methods are effectively two different sides of the same coin. The logic +regarding which users should receive updates for other users should be the same +between them. + +`get_users_for_states` is called when presence updates come in from either federation +or local users, and is used to either direct local presence to remote users, or to +wake up the sync streams of local users to collect remote presence. + +In contrast, `get_interested_users` is used to determine the users that presence should +be fetched for when a local user is syncing. This presence is then retrieved, before +being fed through `get_users_for_states` once again, with only the syncing user's +routing information pulled from the resulting dictionary. + +Their routing logic should thus line up, else you may run into unintended behaviour. + +## Configuration + +Once you've crafted your module and installed it into the same Python environment as +Synapse, amend your homeserver config file with the following. + +```yaml +presence: + routing_module: + module: my_module.ExamplePresenceRouter + config: + # Any configuration options for your module. The below is an example. + # of setting options for ExamplePresenceRouter. + always_send_to_users: ["@presence_gobbler:example.org"] + blacklisted_users: + - "@alice:example.com" + - "@bob:example.com" + ... +``` + +The contents of `config` will be passed as a Python dictionary to the static +`parse_config` method of your class. The object returned by this method will +then be passed to the `__init__` method of your module as `config`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b0bf987740..9182dcd987 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -82,9 +82,28 @@ pid_file: DATADIR/homeserver.pid # #soft_file_limit: 0 -# Set to false to disable presence tracking on this homeserver. +# Presence tracking allows users to see the state (e.g online/offline) +# of other local and remote users. # -#use_presence: false +presence: + # Uncomment to disable presence tracking on this homeserver. This option + # replaces the previous top-level 'use_presence' option. + # + #enabled: false + + # Presence routers are third-party modules that can specify additional logic + # to where presence updates from users are routed. + # + presence_router: + # The custom module's class. Uncomment to use a custom presence router module. + # + #module: "my_custom_router.PresenceRouter" + + # Configuration options of the custom module. Refer to your module's + # documentation for available options. + # + #config: + # example_option: 'something' # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 3df2aa5c2b..d1c2079233 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -281,6 +281,7 @@ class GenericWorkerPresence(BasePresenceHandler): self.hs = hs self.is_mine_id = hs.is_mine_id + self.presence_router = hs.get_presence_router() self._presence_enabled = hs.config.use_presence # The number of ongoing syncs on this process, by user id. @@ -395,7 +396,7 @@ class GenericWorkerPresence(BasePresenceHandler): return _user_syncing() async def notify_from_replication(self, states, stream_id): - parties = await get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( diff --git a/synapse/config/server.py b/synapse/config/server.py index 5f8910b6e1..8decc9d10d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -27,6 +27,7 @@ import yaml from netaddr import AddrFormatError, IPNetwork, IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError @@ -238,7 +239,20 @@ class ServerConfig(Config): self.public_baseurl = config.get("public_baseurl") # Whether to enable user presence. - self.use_presence = config.get("use_presence", True) + presence_config = config.get("presence") or {} + self.use_presence = presence_config.get("enabled") + if self.use_presence is None: + self.use_presence = config.get("use_presence", True) + + # Custom presence router module + self.presence_router_module_class = None + self.presence_router_config = None + presence_router_config = presence_config.get("presence_router") + if presence_router_config: + ( + self.presence_router_module_class, + self.presence_router_config, + ) = load_module(presence_router_config, ("presence", "presence_router")) # Whether to update the user directory or not. This should be set to # false only if we are updating the user directory in a worker @@ -834,9 +848,28 @@ class ServerConfig(Config): # #soft_file_limit: 0 - # Set to false to disable presence tracking on this homeserver. + # Presence tracking allows users to see the state (e.g online/offline) + # of other local and remote users. # - #use_presence: false + presence: + # Uncomment to disable presence tracking on this homeserver. This option + # replaces the previous top-level 'use_presence' option. + # + #enabled: false + + # Presence routers are third-party modules that can specify additional logic + # to where presence updates from users are routed. + # + presence_router: + # The custom module's class. Uncomment to use a custom presence router module. + # + #module: "my_custom_router.PresenceRouter" + + # Configuration options of the custom module. Refer to your module's + # documentation for available options. + # + #config: + # example_option: 'something' # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py new file mode 100644 index 0000000000..24cd389d80 --- /dev/null +++ b/synapse/events/presence_router.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Iterable, Set, Union + +from synapse.api.presence import UserPresenceState + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class PresenceRouter: + """ + A module that the homeserver will call upon to help route user presence updates to + additional destinations. If a custom presence router is configured, calls will be + passed to that instead. + """ + + ALL_USERS = "ALL" + + def __init__(self, hs: "HomeServer"): + self.custom_presence_router = None + + # Check whether a custom presence router module has been configured + if hs.config.presence_router_module_class: + # Initialise the module + self.custom_presence_router = hs.config.presence_router_module_class( + config=hs.config.presence_router_config, module_api=hs.get_module_api() + ) + + # Ensure the module has implemented the required methods + required_methods = ["get_users_for_states", "get_interested_users"] + for method_name in required_methods: + if not hasattr(self.custom_presence_router, method_name): + raise Exception( + "PresenceRouter module '%s' must implement all required methods: %s" + % ( + hs.config.presence_router_module_class.__name__, + ", ".join(required_methods), + ) + ) + + async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], + ) -> Dict[str, Set[UserPresenceState]]: + """ + Given an iterable of user presence updates, determine where each one + needs to go. + + Args: + state_updates: An iterable of user presence state updates. + + Returns: + A dictionary of user_id -> set of UserPresenceState, indicating which + presence updates each user should receive. + """ + if self.custom_presence_router is not None: + # Ask the custom module + return await self.custom_presence_router.get_users_for_states( + state_updates=state_updates + ) + + # Don't include any extra destinations for presence updates + return {} + + async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]: + """ + Retrieve a list of users that `user_id` is interested in receiving the + presence of. This will be in addition to those they share a room with. + Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate + that this user should receive all incoming local and remote presence updates. + + Note that this method will only be called for local users, but can return users + that are local or remote. + + Args: + user_id: A user requesting presence updates. + + Returns: + A set of user IDs to return presence updates for, or ALL_USERS to return all + known updates. + """ + if self.custom_presence_router is not None: + # Ask the custom module for interested users + return await self.custom_presence_router.get_interested_users( + user_id=user_id + ) + + # A custom presence router is not defined. + # Don't report any additional interested users + return set() diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 8babb1ebbe..98bfce22ff 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -44,6 +44,7 @@ from synapse.types import JsonDict, ReadReceipt, RoomStreamToken from synapse.util.metrics import Measure, measure_func if TYPE_CHECKING: + from synapse.events.presence_router import PresenceRouter from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -162,6 +163,7 @@ class FederationSender(AbstractFederationSender): self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id + self._presence_router = None # type: Optional[PresenceRouter] self._transaction_manager = TransactionManager(hs) self._instance_name = hs.get_instance_name() @@ -584,7 +586,22 @@ class FederationSender(AbstractFederationSender): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ - hosts_and_states = await get_interested_remotes(self.store, states, self.state) + # We pull the presence router here instead of __init__ + # to prevent a dependency cycle: + # + # AuthHandler -> Notifier -> FederationSender + # -> PresenceRouter -> ModuleApi -> AuthHandler + if self._presence_router is None: + self._presence_router = self.hs.get_presence_router() + + assert self._presence_router is not None + + hosts_and_states = await get_interested_remotes( + self.store, + self._presence_router, + states, + self.state, + ) for destinations, states in hosts_and_states: for destination in destinations: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index da92feacc9..c817f2952d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -25,7 +25,17 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import ( + TYPE_CHECKING, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) from prometheus_client import Counter from typing_extensions import ContextManager @@ -34,6 +44,7 @@ import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState +from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge @@ -42,7 +53,7 @@ from synapse.state import StateHandler from synapse.storage.databases.main import DataStore from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler): self.notifier = hs.get_notifier() self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() + self.presence_router = hs.get_presence_router() self._presence_enabled = hs.config.use_presence federation_registry = hs.get_federation_registry() @@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler): """ stream_id, max_token = await self.store.update_presence(states) - parties = await get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -1041,7 +1053,12 @@ class PresenceEventSource: # # Presence -> Notifier -> PresenceEventSource -> Presence # + # Same with get_module_api, get_presence_router + # + # AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler self.get_presence_handler = hs.get_presence_handler + self.get_module_api = hs.get_module_api + self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -1055,7 +1072,7 @@ class PresenceEventSource: include_offline=True, explicit_room_id=None, **kwargs - ): + ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1068,7 +1085,17 @@ class PresenceEventSource: # We don't try and limit the presence updates by the current token, as # sending down the rare duplicate is not a concern. + user_id = user.to_string() + stream_change_cache = self.store.presence_stream_cache + with Measure(self.clock, "presence.get_new_events"): + if user_id in self.get_module_api()._send_full_presence_to_local_users: + # This user has been specified by a module to receive all current, online + # user presence. Removing from_key and setting include_offline to false + # will do effectively this. + from_key = None + include_offline = False + if from_key is not None: from_key = int(from_key) @@ -1091,59 +1118,209 @@ class PresenceEventSource: # doesn't return. C.f. #5503. return [], max_token - presence = self.get_presence_handler() - stream_change_cache = self.store.presence_stream_cache - + # Figure out which other users this user should receive updates for users_interested_in = await self._get_interested_in(user, explicit_room_id) - user_ids_changed = set() # type: Collection[str] - changed = None - if from_key: - changed = stream_change_cache.get_all_entities_changed(from_key) + # We have a set of users that we're interested in the presence of. We want to + # cross-reference that with the users that have actually changed their presence. - if changed is not None and len(changed) < 500: - assert isinstance(user_ids_changed, set) + # Check whether this user should see all user updates - # For small deltas, its quicker to get all changes and then - # work out if we share a room or they're in our presence list - get_updates_counter.labels("stream").inc() - for other_user_id in changed: - if other_user_id in users_interested_in: - user_ids_changed.add(other_user_id) - else: - # Too many possible updates. Find all users we can see and check - # if any of them have changed. - get_updates_counter.labels("full").inc() + if users_interested_in == PresenceRouter.ALL_USERS: + # Provide presence state for all users + presence_updates = await self._filter_all_presence_updates_for_user( + user_id, include_offline, from_key + ) - if from_key: - user_ids_changed = stream_change_cache.get_entities_changed( - users_interested_in, from_key + # Remove the user from the list of users to receive all presence + if user_id in self.get_module_api()._send_full_presence_to_local_users: + self.get_module_api()._send_full_presence_to_local_users.remove( + user_id ) + + return presence_updates, max_token + + # Make mypy happy. users_interested_in should now be a set + assert not isinstance(users_interested_in, str) + + # The set of users that we're interested in and that have had a presence update. + # We'll actually pull the presence updates for these users at the end. + interested_and_updated_users = ( + set() + ) # type: Union[Set[str], FrozenSet[str]] + + if from_key: + # First get all users that have had a presence update + updated_users = stream_change_cache.get_all_entities_changed(from_key) + + # Cross-reference users we're interested in with those that have had updates. + # Use a slightly-optimised method for processing smaller sets of updates. + if updated_users is not None and len(updated_users) < 500: + # For small deltas, it's quicker to get all changes and then + # cross-reference with the users we're interested in + get_updates_counter.labels("stream").inc() + for other_user_id in updated_users: + if other_user_id in users_interested_in: + # mypy thinks this variable could be a FrozenSet as it's possibly set + # to one in the `get_entities_changed` call below, and `add()` is not + # method on a FrozenSet. That doesn't affect us here though, as + # `interested_and_updated_users` is clearly a set() above. + interested_and_updated_users.add(other_user_id) # type: ignore else: - user_ids_changed = users_interested_in + # Too many possible updates. Find all users we can see and check + # if any of them have changed. + get_updates_counter.labels("full").inc() - updates = await presence.current_state_for_users(user_ids_changed) + interested_and_updated_users = ( + stream_change_cache.get_entities_changed( + users_interested_in, from_key + ) + ) + else: + # No from_key has been specified. Return the presence for all users + # this user is interested in + interested_and_updated_users = users_interested_in + + # Retrieve the current presence state for each user + users_to_state = await self.get_presence_handler().current_state_for_users( + interested_and_updated_users + ) + presence_updates = list(users_to_state.values()) - if include_offline: - return (list(updates.values()), max_token) + # Remove the user from the list of users to receive all presence + if user_id in self.get_module_api()._send_full_presence_to_local_users: + self.get_module_api()._send_full_presence_to_local_users.remove(user_id) + + if not include_offline: + # Filter out offline presence states + presence_updates = self._filter_offline_presence_state(presence_updates) + + return presence_updates, max_token + + async def _filter_all_presence_updates_for_user( + self, + user_id: str, + include_offline: bool, + from_key: Optional[int] = None, + ) -> List[UserPresenceState]: + """ + Computes the presence updates a user should receive. + + First pulls presence updates from the database. Then consults PresenceRouter + for whether any updates should be excluded by user ID. + + Args: + user_id: The User ID of the user to compute presence updates for. + include_offline: Whether to include offline presence states from the results. + from_key: The minimum stream ID of updates to pull from the database + before filtering. + + Returns: + A list of presence states for the given user to receive. + """ + if from_key: + # Only return updates since the last sync + updated_users = self.store.presence_stream_cache.get_all_entities_changed( + from_key + ) + if not updated_users: + updated_users = [] + + # Get the actual presence update for each change + users_to_state = await self.get_presence_handler().current_state_for_users( + updated_users + ) + presence_updates = list(users_to_state.values()) + + if not include_offline: + # Filter out offline states + presence_updates = self._filter_offline_presence_state(presence_updates) else: - return ( - [s for s in updates.values() if s.state != PresenceState.OFFLINE], - max_token, + users_to_state = await self.store.get_presence_for_all_users( + include_offline=include_offline ) + presence_updates = list(users_to_state.values()) + + # TODO: This feels wildly inefficient, and it's unfortunate we need to ask the + # module for information on a number of users when we then only take the info + # for a single user + + # Filter through the presence router + users_to_state_set = await self.get_presence_router().get_users_for_states( + presence_updates + ) + + # We only want the mapping for the syncing user + presence_updates = list(users_to_state_set[user_id]) + + # Return presence information for all users + return presence_updates + + def _filter_offline_presence_state( + self, presence_updates: Iterable[UserPresenceState] + ) -> List[UserPresenceState]: + """Given an iterable containing user presence updates, return a list with any offline + presence states removed. + + Args: + presence_updates: Presence states to filter + + Returns: + A new list with any offline presence states removed. + """ + return [ + update + for update in presence_updates + if update.state != PresenceState.OFFLINE + ] + def get_current_key(self): return self.store.get_current_presence_token() @cached(num_args=2, cache_context=True) - async def _get_interested_in(self, user, explicit_room_id, cache_context): + async def _get_interested_in( + self, + user: UserID, + explicit_room_id: Optional[str] = None, + cache_context: Optional[_CacheContext] = None, + ) -> Union[Set[str], str]: """Returns the set of users that the given user should see presence - updates for + updates for. + + Args: + user: The user to retrieve presence updates for. + explicit_room_id: The users that are in the room will be returned. + + Returns: + A set of user IDs to return presence updates for, or "ALL" to return all + known updates. """ user_id = user.to_string() users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence + # cache_context isn't likely to ever be None due to the @cached decorator, + # but we can't have a non-optional argument after the optional argument + # explicit_room_id either. Assert cache_context is not None so we can use it + # without mypy complaining. + assert cache_context + + # Check with the presence router whether we should poll additional users for + # their presence information + additional_users = await self.get_presence_router().get_interested_users( + user.to_string() + ) + if additional_users == PresenceRouter.ALL_USERS: + # If the module requested that this user see the presence updates of *all* + # users, then simply return that instead of calculating what rooms this + # user shares + return PresenceRouter.ALL_USERS + + # Add the additional users from the router + users_interested_in.update(additional_users) + + # Find the users who share a room with this user users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) @@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): async def get_interested_parties( - store: DataStore, states: List[UserPresenceState] + store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState] ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: """Given a list of states return which entities (rooms, users) are interested in the given states. Args: - store - states + store: The homeserver's data store. + presence_router: A module for augmenting the destinations for presence updates. + states: A list of incoming user presence updates. Returns: A 2-tuple of `(room_ids_to_states, users_to_states)`, @@ -1337,11 +1515,22 @@ async def get_interested_parties( # Always notify self users_to_states.setdefault(state.user_id, []).append(state) + # Ask a presence routing module for any additional parties if one + # is loaded. + router_users_to_states = await presence_router.get_users_for_states(states) + + # Update the dictionaries with additional destinations and state to send + for user_id, user_states in router_users_to_states.items(): + users_to_states.setdefault(user_id, []).extend(user_states) + return room_ids_to_states, users_to_states async def get_interested_remotes( - store: DataStore, states: List[UserPresenceState], state_handler: StateHandler + store: DataStore, + presence_router: PresenceRouter, + states: List[UserPresenceState], + state_handler: StateHandler, ) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1349,9 +1538,10 @@ async def get_interested_remotes( All the presence states should be for local users only. Args: - store - states - state_handler + store: The homeserver's data store. + presence_router: A module for augmenting the destinations for presence updates. + states: A list of incoming user presence updates. + state_handler: Returns: A list of 2-tuples of destinations and states, where for @@ -1363,7 +1553,9 @@ async def get_interested_remotes( # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = await get_interested_parties(store, states) + room_ids_to_states, users_to_states = await get_interested_parties( + store, presence_router, states + ) for room_id, states in room_ids_to_states.items(): hosts = await state_handler.get_current_hosts_in_room(room_id) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 781e02fbbb..3ecd46c038 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -50,11 +50,20 @@ class ModuleApi: self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname + self._presence_stream = hs.get_event_sources().sources["presence"] # We expose these as properties below in order to attach a helpful docstring. self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._public_room_list_manager = PublicRoomListManager(hs) + # The next time these users sync, they will receive the current presence + # state of all local users. Users are added by send_local_online_presence_to, + # and removed after a successful sync. + # + # We make this a private variable to deter modules from accessing it directly, + # though other classes in Synapse will still do so. + self._send_full_presence_to_local_users = set() + @property def http_client(self): """Allows making outbound HTTP requests to remote resources. @@ -385,6 +394,47 @@ class ModuleApi: return event + async def send_local_online_presence_to(self, users: Iterable[str]) -> None: + """ + Forces the equivalent of a presence initial_sync for a set of local or remote + users. The users will receive presence for all currently online users that they + are considered interested in. + + Updates to remote users will be sent immediately, whereas local users will receive + them on their next sync attempt. + + Note that this method can only be run on the main or federation_sender worker + processes. + """ + if not self._hs.should_send_federation(): + raise Exception( + "send_local_online_presence_to can only be run " + "on processes that send federation", + ) + + for user in users: + if self._hs.is_mine_id(user): + # Modify SyncHandler._generate_sync_entry_for_presence to call + # presence_source.get_new_events with an empty `from_key` if + # that user's ID were in a list modified by ModuleApi somewhere. + # That user would then get all presence state on next incremental sync. + + # Force a presence initial_sync for this user next time + self._send_full_presence_to_local_users.add(user) + else: + # Retrieve presence state for currently online users that this user + # is considered interested in + presence_events, _ = await self._presence_stream.get_new_events( + UserID.from_string(user), from_key=None, include_offline=False + ) + + # Send to remote destinations + await make_deferred_yieldable( + # We pull the federation sender here as we can only do so on workers + # that support sending presence + self._hs.get_federation_sender().send_presence(presence_events) + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/server.py b/synapse/server.py index e42f7b1a18..cfb55c230d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -51,6 +51,7 @@ from synapse.crypto import context_factory from synapse.crypto.context_factory import RegularPolicyForHTTPS from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory +from synapse.events.presence_router import PresenceRouter from synapse.events.spamcheck import SpamChecker from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.events.utils import EventClientSerializer @@ -425,6 +426,10 @@ class HomeServer(metaclass=abc.ABCMeta): else: raise Exception("Workers cannot write typing") + @cache_in_self + def get_presence_router(self) -> PresenceRouter: + return PresenceRouter(self) + @cache_in_self def get_typing_handler(self) -> FollowerTypingHandler: if self.config.worker.writers.typing == self.get_instance_name(): diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py new file mode 100644 index 0000000000..c6e547f11c --- /dev/null +++ b/tests/events/test_presence_router.py @@ -0,0 +1,386 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +from mock import Mock + +import attr + +from synapse.api.constants import EduTypes +from synapse.events.presence_router import PresenceRouter +from synapse.federation.units import Transaction +from synapse.handlers.presence import UserPresenceState +from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client.v1 import login, presence, room +from synapse.types import JsonDict, StreamToken, create_requester + +from tests.handlers.test_sync import generate_sync_config +from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config + + +@attr.s +class PresenceRouterTestConfig: + users_who_should_receive_all_presence = attr.ib(type=List[str], default=[]) + + +class PresenceRouterTestModule: + def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi): + self._config = config + self._module_api = module_api + + async def get_users_for_states( + self, state_updates: Iterable[UserPresenceState] + ) -> Dict[str, Set[UserPresenceState]]: + users_to_state = { + user_id: set(state_updates) + for user_id in self._config.users_who_should_receive_all_presence + } + return users_to_state + + async def get_interested_users( + self, user_id: str + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + if user_id in self._config.users_who_should_receive_all_presence: + return PresenceRouter.ALL_USERS + + return set() + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterTestConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterTestConfig() + + config.users_who_should_receive_all_presence = config_dict.get( + "users_who_should_receive_all_presence" + ) + + return config + + +class PresenceRouterTestCase(FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + presence.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def prepare(self, reactor, clock, homeserver): + self.sync_handler = self.hs.get_sync_handler() + self.module_api = homeserver.get_module_api() + + @override_config( + { + "presence": { + "presence_router": { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler:test", + ] + }, + } + }, + "send_federation": True, + } + ) + def test_receiving_all_presence(self): + """Test that a user that does not share a room with another other can receive + presence for them, due to presence routing. + """ + # Create a user who should receive all presence of others + self.presence_receiving_user_id = self.register_user( + "presence_gobbler", "monkey" + ) + self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey") + + # And two users who should not have any special routing + self.other_user_one_id = self.register_user("other_user_one", "monkey") + self.other_user_one_tok = self.login("other_user_one", "monkey") + self.other_user_two_id = self.register_user("other_user_two", "monkey") + self.other_user_two_tok = self.login("other_user_two", "monkey") + + # Put the other two users in a room with each other + room_id = self.helper.create_room_as( + self.other_user_one_id, tok=self.other_user_one_tok + ) + + self.helper.invite( + room_id, + self.other_user_one_id, + self.other_user_two_id, + tok=self.other_user_one_tok, + ) + self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok) + # User one sends some presence + send_presence_update( + self, + self.other_user_one_id, + self.other_user_one_tok, + "online", + "boop", + ) + + # Check that the presence receiving user gets user one's presence when syncing + presence_updates, sync_token = sync_presence( + self, self.presence_receiving_user_id + ) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.other_user_one_id) + self.assertEqual(presence_update.state, "online") + self.assertEqual(presence_update.status_msg, "boop") + + # Have all three users send presence + send_presence_update( + self, + self.other_user_one_id, + self.other_user_one_tok, + "online", + "user_one", + ) + send_presence_update( + self, + self.other_user_two_id, + self.other_user_two_tok, + "online", + "user_two", + ) + send_presence_update( + self, + self.presence_receiving_user_id, + self.presence_receiving_user_tok, + "online", + "presence_gobbler", + ) + + # Check that the presence receiving user gets everyone's presence + presence_updates, _ = sync_presence( + self, self.presence_receiving_user_id, sync_token + ) + self.assertEqual(len(presence_updates), 3) + + # But that User One only get itself and User Two's presence + presence_updates, _ = sync_presence(self, self.other_user_one_id) + self.assertEqual(len(presence_updates), 2) + + found = False + for update in presence_updates: + if update.user_id == self.other_user_two_id: + self.assertEqual(update.state, "online") + self.assertEqual(update.status_msg, "user_two") + found = True + + self.assertTrue(found) + + @override_config( + { + "presence": { + "presence_router": { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler1:test", + "@presence_gobbler2:test", + "@far_away_person:island", + ] + }, + } + }, + "send_federation": True, + } + ) + def test_send_local_online_presence_to_with_module(self): + """Tests that send_local_presence_to_users sends local online presence to a set + of specified local and remote users, with a custom PresenceRouter module enabled. + """ + # Create a user who will send presence updates + self.other_user_id = self.register_user("other_user", "monkey") + self.other_user_tok = self.login("other_user", "monkey") + + # And another two users that will also send out presence updates, as well as receive + # theirs and everyone else's + self.presence_receiving_user_one_id = self.register_user( + "presence_gobbler1", "monkey" + ) + self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey") + self.presence_receiving_user_two_id = self.register_user( + "presence_gobbler2", "monkey" + ) + self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey") + + # Have all three users send some presence updates + send_presence_update( + self, + self.other_user_id, + self.other_user_tok, + "online", + "I'm online!", + ) + send_presence_update( + self, + self.presence_receiving_user_one_id, + self.presence_receiving_user_one_tok, + "online", + "I'm also online!", + ) + send_presence_update( + self, + self.presence_receiving_user_two_id, + self.presence_receiving_user_two_tok, + "unavailable", + "I'm in a meeting!", + ) + + # Mark each presence-receiving user for receiving all user presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiving_user_one_id, + self.presence_receiving_user_two_id, + ] + ) + ) + + # Perform a sync for each user + + # The other user should only receive their own presence + presence_updates, _ = sync_presence(self, self.other_user_id) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.other_user_id) + self.assertEqual(presence_update.state, "online") + self.assertEqual(presence_update.status_msg, "I'm online!") + + # Whereas both presence receiving users should receive everyone's presence updates + presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id) + self.assertEqual(len(presence_updates), 3) + presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id) + self.assertEqual(len(presence_updates), 3) + + # Test that sending to a remote user works + remote_user_id = "@far_away_person:island" + + # Note that due to the remote user being in our module's + # users_who_should_receive_all_presence config, they would have + # received user presence updates already. + # + # Thus we reset the mock, and try sending all online local user + # presence again + self.hs.get_federation_transport_client().send_transaction.reset_mock() + + # Broadcast local user online presence + self.get_success( + self.module_api.send_local_online_presence_to([remote_user_id]) + ) + + # Check that the expected presence updates were sent + expected_users = [ + self.other_user_id, + self.presence_receiving_user_one_id, + self.presence_receiving_user_two_id, + ] + + calls = ( + self.hs.get_federation_transport_client().send_transaction.call_args_list + ) + for call in calls: + federation_transaction = call.args[0] # type: Transaction + + # Get the sent EDUs in this transaction + edus = federation_transaction.get_dict()["edus"] + + for edu in edus: + # Make sure we're only checking presence-type EDUs + if edu["edu_type"] != EduTypes.Presence: + continue + + # EDUs can contain multiple presence updates + for presence_update in edu["content"]["push"]: + # Check for presence updates that contain the user IDs we're after + expected_users.remove(presence_update["user_id"]) + + # Ensure that no offline states are being sent out + self.assertNotEqual(presence_update["presence"], "offline") + + self.assertEqual(len(expected_users), 0) + + +def send_presence_update( + testcase: TestCase, + user_id: str, + access_token: str, + presence_state: str, + status_message: Optional[str] = None, +) -> JsonDict: + # Build the presence body + body = {"presence": presence_state} + if status_message: + body["status_msg"] = status_message + + # Update the user's presence state + channel = testcase.make_request( + "PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token + ) + testcase.assertEqual(channel.code, 200) + + return channel.json_body + + +def sync_presence( + testcase: TestCase, + user_id: str, + since_token: Optional[StreamToken] = None, +) -> Tuple[List[UserPresenceState], StreamToken]: + """Perform a sync request for the given user and return the user presence updates + they've received, as well as the next_batch token. + + This method assumes testcase.sync_handler points to the homeserver's sync handler. + + Args: + testcase: The testcase that is currently being run. + user_id: The ID of the user to generate a sync response for. + since_token: An optional token to indicate from at what point to sync from. + + Returns: + A tuple containing a list of presence updates, and the sync response's + next_batch token. + """ + requester = create_requester(user_id) + sync_config = generate_sync_config(requester.user.to_string()) + sync_result = testcase.get_success( + testcase.sync_handler.wait_for_sync_for_user( + requester, sync_config, since_token + ) + ) + + return sync_result.presence, sync_result.next_batch diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index e62586142e..8e950f25c5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" user_id2 = "@user2:test" - sync_config = self._generate_sync_config(user_id1) + sync_config = generate_sync_config(user_id1) requester = create_requester(user_id1) self.reactor.advance(100) # So we get not 0 time @@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = False - sync_config = self._generate_sync_config(user_id2) + sync_config = generate_sync_config(user_id2) requester = create_requester(user_id2) e = self.get_failure( @@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def _generate_sync_config(self, user_id): - return SyncConfig( - user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), - filter_collection=DEFAULT_FILTER_COLLECTION, - is_guest=False, - request_key="request_key", - device_id="device_id", - ) + +def generate_sync_config(user_id: str) -> SyncConfig: + return SyncConfig( + user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), + filter_collection=DEFAULT_FILTER_COLLECTION, + is_guest=False, + request_key="request_key", + device_id="device_id", + ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index edacd1b566..1d1fceeecf 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -14,25 +14,37 @@ # limitations under the License. from mock import Mock +from synapse.api.constants import EduTypes from synapse.events import EventBase +from synapse.federation.units import Transaction +from synapse.handlers.presence import UserPresenceState from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client.v1 import login, presence, room from synapse.types import create_requester -from tests.unittest import HomeserverTestCase +from tests.events.test_presence_router import send_presence_update, sync_presence +from tests.test_utils.event_injection import inject_member_event +from tests.unittest import FederatingHomeserverTestCase, override_config -class ModuleApiTestCase(HomeserverTestCase): +class ModuleApiTestCase(FederatingHomeserverTestCase): servlets = [ admin.register_servlets, login.register_servlets, room.register_servlets, + presence.register_servlets, ] def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() + self.sync_handler = homeserver.get_sync_handler() + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) def test_can_register_user(self): """Tests that an external module can register a user""" @@ -205,3 +217,160 @@ class ModuleApiTestCase(HomeserverTestCase): ) ) self.assertFalse(is_in_public_rooms) + + # The ability to send federation is required by send_local_online_presence_to. + @override_config({"send_federation": True}) + def test_send_local_online_presence_to(self): + """Tests that send_local_presence_to_users sends local online presence to local users.""" + # Create a user who will send presence updates + self.presence_receiver_id = self.register_user("presence_receiver", "monkey") + self.presence_receiver_tok = self.login("presence_receiver", "monkey") + + # And another user that will send presence updates out + self.presence_sender_id = self.register_user("presence_sender", "monkey") + self.presence_sender_tok = self.login("presence_sender", "monkey") + + # Put them in a room together so they will receive each other's presence updates + room_id = self.helper.create_room_as( + self.presence_receiver_id, + tok=self.presence_receiver_tok, + ) + self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok) + + # Presence sender comes online + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "online", + "I'm online!", + ) + + # Presence receiver should have received it + presence_updates, sync_token = sync_presence(self, self.presence_receiver_id) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.presence_sender_id) + self.assertEqual(presence_update.state, "online") + + # Syncing again should result in no presence updates + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 0) + + # Trigger sending local online presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiver_id, + ] + ) + ) + + # Presence receiver should have received online presence again + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.presence_sender_id) + self.assertEqual(presence_update.state, "online") + + # Presence sender goes offline + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "offline", + "I slink back into the darkness.", + ) + + # Trigger sending local online presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiver_id, + ] + ) + ) + + # Presence receiver should *not* have received offline state + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 0) + + @override_config({"send_federation": True}) + def test_send_local_online_presence_to_federation(self): + """Tests that send_local_presence_to_users sends local online presence to remote users.""" + # Create a user who will send presence updates + self.presence_sender_id = self.register_user("presence_sender", "monkey") + self.presence_sender_tok = self.login("presence_sender", "monkey") + + # And a room they're a part of + room_id = self.helper.create_room_as( + self.presence_sender_id, + tok=self.presence_sender_tok, + ) + + # Mark them as online + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "online", + "I'm online!", + ) + + # Make up a remote user to send presence to + remote_user_id = "@far_away_person:island" + + # Create a join membership event for the remote user into the room. + # This allows presence information to flow from one user to the other. + self.get_success( + inject_member_event( + self.hs, + room_id, + sender=remote_user_id, + target=remote_user_id, + membership="join", + ) + ) + + # The remote user would have received the existing room members' presence + # when they joined the room. + # + # Thus we reset the mock, and try sending online local user + # presence again + self.hs.get_federation_transport_client().send_transaction.reset_mock() + + # Broadcast local user online presence + self.get_success( + self.module_api.send_local_online_presence_to([remote_user_id]) + ) + + # Check that a presence update was sent as part of a federation transaction + found_update = False + calls = ( + self.hs.get_federation_transport_client().send_transaction.call_args_list + ) + for call in calls: + federation_transaction = call.args[0] # type: Transaction + + # Get the sent EDUs in this transaction + edus = federation_transaction.get_dict()["edus"] + + for edu in edus: + # Make sure we're only checking presence-type EDUs + if edu["edu_type"] != EduTypes.Presence: + continue + + # EDUs can contain multiple presence updates + for presence_update in edu["content"]["push"]: + if presence_update["user_id"] == self.presence_sender_id: + found_update = True + + self.assertTrue(found_update) -- cgit 1.5.1 From 2ca4e349e9d0c606d802ae15c06089080fa4f27e Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 8 Apr 2021 23:38:54 +0200 Subject: Bugbear: Add Mutable Parameter fixes (#9682) Part of #9366 Adds in fixes for B006 and B008, both relating to mutable parameter lint errors. Signed-off-by: Jonathan de Jong --- changelog.d/9682.misc | 1 + contrib/cmdclient/console.py | 5 ++++- contrib/cmdclient/http.py | 24 +++++++++++++++----- setup.cfg | 4 ++-- synapse/appservice/scheduler.py | 6 ++--- synapse/config/ratelimiting.py | 6 +++-- synapse/events/__init__.py | 14 ++++++++---- synapse/federation/units.py | 5 +++-- synapse/handlers/appservice.py | 4 ++-- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 11 ++++++--- synapse/handlers/register.py | 4 +++- synapse/handlers/sync.py | 8 +++---- synapse/http/client.py | 4 ++-- synapse/http/proxyagent.py | 6 +++-- synapse/logging/opentracing.py | 3 ++- synapse/module_api/__init__.py | 12 +++++----- synapse/notifier.py | 19 +++++++++------- synapse/storage/database.py | 20 ++++++++++++----- synapse/storage/databases/main/events.py | 7 ++++-- synapse/storage/databases/main/group_server.py | 4 +++- synapse/storage/databases/main/state.py | 6 +++-- synapse/storage/databases/state/bg_updates.py | 5 ++++- synapse/storage/databases/state/store.py | 5 +++-- synapse/storage/state.py | 26 +++++++++++++--------- synapse/storage/util/id_generators.py | 11 +++++++-- synapse/util/caches/lrucache.py | 14 +++++++----- .../federation/test_matrix_federation_agent.py | 15 +++++++++---- tests/replication/_base.py | 4 ++-- tests/replication/slave/storage/test_events.py | 16 ++++++++----- tests/rest/client/v1/test_rooms.py | 5 ++++- tests/rest/client/v1/utils.py | 14 ++++++++---- tests/rest/client/v2_alpha/test_relations.py | 5 +++-- tests/storage/test_id_generators.py | 14 +++++++----- tests/storage/test_redaction.py | 10 +++++++-- tests/test_state.py | 5 +++-- tests/test_visibility.py | 7 ++++-- tests/util/test_ratelimitutils.py | 6 +++-- 38 files changed, 224 insertions(+), 113 deletions(-) create mode 100644 changelog.d/9682.misc (limited to 'synapse/config') diff --git a/changelog.d/9682.misc b/changelog.d/9682.misc new file mode 100644 index 0000000000..428a466fac --- /dev/null +++ b/changelog.d/9682.misc @@ -0,0 +1 @@ +Introduce flake8-bugbear to the test suite and fix some of its lint violations. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 67e032244e..856dd437db 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -24,6 +24,7 @@ import sys import time import urllib from http import TwistedHttpClient +from typing import Optional import nacl.encoding import nacl.signing @@ -718,7 +719,7 @@ class SynapseCmd(cmd.Cmd): method, path, data=None, - query_params={"access_token": None}, + query_params: Optional[dict] = None, alt_text=None, ): """Runs an HTTP request and pretty prints the output. @@ -729,6 +730,8 @@ class SynapseCmd(cmd.Cmd): data: Raw JSON data if any query_params: dict of query parameters to add to the url """ + query_params = query_params or {"access_token": None} + url = self._url() + path if "access_token" in query_params: query_params["access_token"] = self._tok() diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 851e80c25b..1cf913756e 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -16,6 +16,7 @@ import json import urllib from pprint import pformat +from typing import Optional from twisted.internet import defer, reactor from twisted.web.client import Agent, readBody @@ -85,8 +86,9 @@ class TwistedHttpClient(HttpClient): body = yield readBody(response) defer.returnValue(json.loads(body)) - def _create_put_request(self, url, json_data, headers_dict={}): + def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None): """Wrapper of _create_request to issue a PUT request""" + headers_dict = headers_dict or {} if "Content-Type" not in headers_dict: raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) @@ -95,14 +97,22 @@ class TwistedHttpClient(HttpClient): "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict ) - def _create_get_request(self, url, headers_dict={}): + def _create_get_request(self, url, headers_dict: Optional[dict] = None): """Wrapper of _create_request to issue a GET request""" - return self._create_request("GET", url, headers_dict=headers_dict) + return self._create_request("GET", url, headers_dict=headers_dict or {}) @defer.inlineCallbacks def do_request( - self, method, url, data=None, qparams=None, jsonreq=True, headers={} + self, + method, + url, + data=None, + qparams=None, + jsonreq=True, + headers: Optional[dict] = None, ): + headers = headers or {} + if qparams: url = "%s?%s" % (url, urllib.urlencode(qparams, True)) @@ -123,8 +133,12 @@ class TwistedHttpClient(HttpClient): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def _create_request(self, method, url, producer=None, headers_dict={}): + def _create_request( + self, method, url, producer=None, headers_dict: Optional[dict] = None + ): """Creates and sends a request to the given url""" + headers_dict = headers_dict or {} + headers_dict["User-Agent"] = ["Synapse Cmd Client"] retries_left = 5 diff --git a/setup.cfg b/setup.cfg index 7329eed213..5fdb51ac73 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,8 +18,8 @@ ignore = # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes) -ignore=W503,W504,E203,E731,E501,B006,B007,B008 +# B007: Subsection of the bugbear suite (TODO: add in remaining fixes) +ignore=W503,W504,E203,E731,E501,B007 [isort] line_length = 88 diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 366c476f80..5203ffe90f 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -49,7 +49,7 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import List +from typing import List, Optional from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.events import EventBase @@ -191,11 +191,11 @@ class _TransactionController: self, service: ApplicationService, events: List[EventBase], - ephemeral: List[JsonDict] = [], + ephemeral: Optional[List[JsonDict]] = None, ): try: txn = await self.store.create_appservice_txn( - service=service, events=events, ephemeral=ephemeral + service=service, events=events, ephemeral=ephemeral or [] ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 3f3997f4e5..7a8d5851c4 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from typing import Dict, Optional from ._base import Config @@ -21,8 +21,10 @@ class RateLimitConfig: def __init__( self, config: Dict[str, float], - defaults={"per_second": 0.17, "burst_count": 3.0}, + defaults: Optional[Dict[str, float]] = None, ): + defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} + self.per_second = config.get("per_second", defaults["per_second"]) self.burst_count = int(config.get("burst_count", defaults["burst_count"])) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8f6b955d17..f9032e3697 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -330,9 +330,11 @@ class FrozenEvent(EventBase): self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: JsonDict = {}, + internal_metadata_dict: Optional[JsonDict] = None, rejected_reason: Optional[str] = None, ): + internal_metadata_dict = internal_metadata_dict or {} + event_dict = dict(event_dict) # Signatures is a dict of dicts, and this is faster than doing a @@ -386,9 +388,11 @@ class FrozenEventV2(EventBase): self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: JsonDict = {}, + internal_metadata_dict: Optional[JsonDict] = None, rejected_reason: Optional[str] = None, ): + internal_metadata_dict = internal_metadata_dict or {} + event_dict = dict(event_dict) # Signatures is a dict of dicts, and this is faster than doing a @@ -507,9 +511,11 @@ def _event_type_from_format_version(format_version: int) -> Type[EventBase]: def make_event_from_dict( event_dict: JsonDict, room_version: RoomVersion = RoomVersions.V1, - internal_metadata_dict: JsonDict = {}, + internal_metadata_dict: Optional[JsonDict] = None, rejected_reason: Optional[str] = None, ) -> EventBase: """Construct an EventBase from the given event dict""" event_type = _event_type_from_format_version(room_version.event_format) - return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason) + return event_type( + event_dict, room_version, internal_metadata_dict or {}, rejected_reason + ) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b662c42621..0f8bf000ac 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -18,6 +18,7 @@ server protocol. """ import logging +from typing import Optional import attr @@ -98,7 +99,7 @@ class Transaction(JsonEncodedObject): "pdus", ] - def __init__(self, transaction_id=None, pdus=[], **kwargs): + def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs): """If we include a list of pdus then we decode then as PDU's automatically. """ @@ -107,7 +108,7 @@ class Transaction(JsonEncodedObject): if "edus" in kwargs and not kwargs["edus"]: del kwargs["edus"] - super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) + super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs) @staticmethod def create_new(pdus, **kwargs): diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 996f9e5deb..9fb7ee335d 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -182,7 +182,7 @@ class ApplicationServicesHandler: self, stream_key: str, new_token: Optional[int], - users: Collection[Union[str, UserID]] = [], + users: Optional[Collection[Union[str, UserID]]] = None, ): """This is called by the notifier in the background when a ephemeral event handled by the homeserver. @@ -215,7 +215,7 @@ class ApplicationServicesHandler: # We only start a new background process if necessary rather than # optimistically (to cut down on overhead). self._notify_interested_services_ephemeral( - services, stream_key, new_token, users + services, stream_key, new_token, users or [] ) @wrap_as_background_process("notify_interested_services_ephemeral") diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5ea8a7b603..67888898ff 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1790,7 +1790,7 @@ class FederationHandler(BaseHandler): room_id: str, user_id: str, membership: str, - content: JsonDict = {}, + content: JsonDict, params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, ) -> Tuple[str, EventBase, RoomVersion]: ( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6069968f7f..125dae6d25 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -137,7 +137,7 @@ class MessageHandler: self, user_id: str, room_id: str, - state_filter: StateFilter = StateFilter.all(), + state_filter: Optional[StateFilter] = None, at_token: Optional[StreamToken] = None, is_guest: bool = False, ) -> List[dict]: @@ -164,6 +164,8 @@ class MessageHandler: AuthError (403) if the user doesn't have permission to view members of this room. """ + state_filter = state_filter or StateFilter.all() + if at_token: # FIXME this claims to get the state at a stream position, but # get_recent_events_for_room operates by topo ordering. This therefore @@ -874,7 +876,7 @@ class EventCreationHandler: event: EventBase, context: EventContext, ratelimit: bool = True, - extra_users: List[UserID] = [], + extra_users: Optional[List[UserID]] = None, ignore_shadow_ban: bool = False, ) -> EventBase: """Processes a new event. @@ -902,6 +904,7 @@ class EventCreationHandler: Raises: ShadowBanError if the requester has been shadow-banned. """ + extra_users = extra_users or [] # we don't apply shadow-banning to membership events here. Invites are blocked # higher up the stack, and we allow shadow-banned users to send join and leave @@ -1071,7 +1074,7 @@ class EventCreationHandler: event: EventBase, context: EventContext, ratelimit: bool = True, - extra_users: List[UserID] = [], + extra_users: Optional[List[UserID]] = None, ) -> EventBase: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1083,6 +1086,8 @@ class EventCreationHandler: it was de-duplicated (e.g. because we had already persisted an event with the same transaction ID.) """ + extra_users = extra_users or [] + assert self.storage.persistence is not None assert self._events_shard_config.should_handle( self._instance_name, event.room_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 9701b76d0f..3b6660c873 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -169,7 +169,7 @@ class RegistrationHandler(BaseHandler): user_type: Optional[str] = None, default_display_name: Optional[str] = None, address: Optional[str] = None, - bind_emails: Iterable[str] = [], + bind_emails: Optional[Iterable[str]] = None, by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, auth_provider_id: Optional[str] = None, @@ -204,6 +204,8 @@ class RegistrationHandler(BaseHandler): Raises: SynapseError if there was a problem registering. """ + bind_emails = bind_emails or [] + await self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ff11266c67..f8d88ef77b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -548,7 +548,7 @@ class SyncHandler: ) async def get_state_after_event( - self, event: EventBase, state_filter: StateFilter = StateFilter.all() + self, event: EventBase, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the room state after the given event @@ -558,7 +558,7 @@ class SyncHandler: state_filter: The state filter used to fetch state from the database. """ state_ids = await self.state_store.get_state_ids_for_event( - event.event_id, state_filter=state_filter + event.event_id, state_filter=state_filter or StateFilter.all() ) if event.is_state(): state_ids = dict(state_ids) @@ -569,7 +569,7 @@ class SyncHandler: self, room_id: str, stream_position: StreamToken, - state_filter: StateFilter = StateFilter.all(), + state_filter: Optional[StateFilter] = None, ) -> StateMap[str]: """Get the room state at a particular stream position @@ -589,7 +589,7 @@ class SyncHandler: if last_events: last_event = last_events[-1] state = await self.get_state_after_event( - last_event, state_filter=state_filter + last_event, state_filter=state_filter or StateFilter.all() ) else: diff --git a/synapse/http/client.py b/synapse/http/client.py index e691ba6d88..f7a07f0466 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -297,7 +297,7 @@ class SimpleHttpClient: def __init__( self, hs: "HomeServer", - treq_args: Dict[str, Any] = {}, + treq_args: Optional[Dict[str, Any]] = None, ip_whitelist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None, use_proxy: bool = False, @@ -317,7 +317,7 @@ class SimpleHttpClient: self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist - self._extra_treq_args = treq_args + self._extra_treq_args = treq_args or {} self.user_agent = hs.version_string self.clock = hs.get_clock() diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 16ec850064..ea5ad14cb0 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -27,7 +27,7 @@ from twisted.python.failure import Failure from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent +from twisted.web.iweb import IAgent, IPolicyForHTTPS from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint @@ -88,12 +88,14 @@ class ProxyAgent(_AgentBase): self, reactor, proxy_reactor=None, - contextFactory=BrowserLikePolicyForHTTPS(), + contextFactory: Optional[IPolicyForHTTPS] = None, connectTimeout=None, bindAddress=None, pool=None, use_proxy=False, ): + contextFactory = contextFactory or BrowserLikePolicyForHTTPS() + _AgentBase.__init__(self, reactor, pool) if proxy_reactor is None: diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index b8081f197e..bfe9136fd8 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -486,7 +486,7 @@ def start_active_span_from_request( def start_active_span_from_edu( edu_content, operation_name, - references=[], + references: Optional[list] = None, tags=None, start_time=None, ignore_active_span=False, @@ -501,6 +501,7 @@ def start_active_span_from_edu( For the other args see opentracing.tracer """ + references = references or [] if opentracing is None: return noop_context_manager() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 3ecd46c038..ca1bd4cdc9 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple +from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple from twisted.internet import defer @@ -127,7 +127,7 @@ class ModuleApi: return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks - def register(self, localpart, displayname=None, emails=[]): + def register(self, localpart, displayname=None, emails: Optional[List[str]] = None): """Registers a new user with given localpart and optional displayname, emails. Also returns an access token for the new user. @@ -147,11 +147,13 @@ class ModuleApi: logger.warning( "Using deprecated ModuleApi.register which creates a dummy user device." ) - user_id = yield self.register_user(localpart, displayname, emails) + user_id = yield self.register_user(localpart, displayname, emails or []) _, access_token = yield self.register_device(user_id) return user_id, access_token - def register_user(self, localpart, displayname=None, emails=[]): + def register_user( + self, localpart, displayname=None, emails: Optional[List[str]] = None + ): """Registers a new user with given localpart and optional displayname, emails. Args: @@ -170,7 +172,7 @@ class ModuleApi: self._hs.get_registration_handler().register_user( localpart=localpart, default_display_name=displayname, - bind_emails=emails, + bind_emails=emails or [], ) ) diff --git a/synapse/notifier.py b/synapse/notifier.py index c178db57e3..7ce34380af 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -276,7 +276,7 @@ class Notifier: event: EventBase, event_pos: PersistedEventPosition, max_room_stream_token: RoomStreamToken, - extra_users: Collection[UserID] = [], + extra_users: Optional[Collection[UserID]] = None, ): """Unwraps event and calls `on_new_room_event_args`.""" self.on_new_room_event_args( @@ -286,7 +286,7 @@ class Notifier: state_key=event.get("state_key"), membership=event.content.get("membership"), max_room_stream_token=max_room_stream_token, - extra_users=extra_users, + extra_users=extra_users or [], ) def on_new_room_event_args( @@ -297,7 +297,7 @@ class Notifier: membership: Optional[str], event_pos: PersistedEventPosition, max_room_stream_token: RoomStreamToken, - extra_users: Collection[UserID] = [], + extra_users: Optional[Collection[UserID]] = None, ): """Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -313,7 +313,7 @@ class Notifier: self.pending_new_room_events.append( _PendingRoomEventEntry( event_pos=event_pos, - extra_users=extra_users, + extra_users=extra_users or [], room_id=room_id, type=event_type, state_key=state_key, @@ -382,14 +382,14 @@ class Notifier: self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[Union[str, UserID]] = [], + users: Optional[Collection[Union[str, UserID]]] = None, ): try: stream_token = None if isinstance(new_token, int): stream_token = new_token self.appservice_handler.notify_interested_services_ephemeral( - stream_key, stream_token, users + stream_key, stream_token, users or [] ) except Exception: logger.exception("Error notifying application services of event") @@ -404,13 +404,16 @@ class Notifier: self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[Union[str, UserID]] = [], - rooms: Collection[str] = [], + users: Optional[Collection[Union[str, UserID]]] = None, + rooms: Optional[Collection[str]] = None, ): """Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. """ + users = users or [] + rooms = rooms or [] + with Measure(self.clock, "on_new_event"): user_streams = set() diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b302cd5786..fa15b0ce5b 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -900,7 +900,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, desc: str = "simple_upsert", lock: bool = True, ) -> Optional[bool]: @@ -927,6 +927,8 @@ class DatabasePool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + attempts = 0 while True: try: @@ -964,7 +966,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, ) -> Optional[bool]: """ @@ -982,6 +984,8 @@ class DatabasePool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values @@ -1003,7 +1007,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, ) -> bool: """ @@ -1017,6 +1021,8 @@ class DatabasePool: Returns True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + # We need to lock the table :(, unless we're *really* careful if lock: self.engine.lock_table(txn, table) @@ -1077,7 +1083,7 @@ class DatabasePool: table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Optional[Dict[str, Any]] = None, ) -> None: """ Use the native UPSERT functionality in recent PostgreSQL versions. @@ -1090,7 +1096,7 @@ class DatabasePool: """ allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) - allvalues.update(insertion_values) + allvalues.update(insertion_values or {}) if not values: latter = "NOTHING" @@ -1513,7 +1519,7 @@ class DatabasePool: column: str, iterable: Iterable[Any], retcols: Iterable[str], - keyvalues: Dict[str, Any] = {}, + keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, ) -> List[Any]: @@ -1531,6 +1537,8 @@ class DatabasePool: desc: description of the transaction, for logging and metrics batch_size: the number of rows for each select query """ + keyvalues = keyvalues or {} + results = [] # type: List[Dict[str, Any]] if not iterable: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 98dac19a95..ad17123915 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -320,8 +320,8 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - state_delta_for_room: Dict[str, DeltaState] = {}, - new_forward_extremeties: Dict[str, List[str]] = {}, + state_delta_for_room: Optional[Dict[str, DeltaState]] = None, + new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): """Insert some number of room events into the necessary database tables. @@ -342,6 +342,9 @@ class PersistEventsStore: extremities. """ + state_delta_for_room = state_delta_for_room or {} + new_forward_extremeties = new_forward_extremeties or {} + all_events_and_contexts = events_and_contexts min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 8f462dfc31..bd7826f4e9 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore): user_id: str, membership: str, is_admin: bool = False, - content: JsonDict = {}, + content: Optional[JsonDict] = None, local_attestation: Optional[dict] = None, remote_attestation: Optional[dict] = None, is_publicised: bool = False, @@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore): is_publicised: Whether this should be publicised. """ + content = content or {} + def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? self.db_pool.simple_delete_txn( diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a7f371732f..93431efe00 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # FIXME: how should this be cached? async def get_filtered_current_state_ids( - self, room_id: str, state_filter: StateFilter = StateFilter.all() + self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result @@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Map from type/state_key to event ID. """ - where_clause, where_args = state_filter.make_sql_filter_clause() + where_clause, where_args = ( + state_filter or StateFilter.all() + ).make_sql_filter_clause() if not where_clause: # We delegate to the cached version diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 1fd333b707..75c09b3687 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): return count def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() + self, txn, groups, state_filter: Optional[StateFilter] = None ): + state_filter = state_filter or StateFilter.all() + results = {group: {} for group in groups} where_clause, where_args = state_filter.make_sql_filter_clause() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 97ec65f757..dfcf89d91c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore @@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types async def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Returns: Dict of state group to state map. """ + state_filter = state_filter or StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 2e277a21c4..c1c147c62a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -449,7 +449,7 @@ class StateGroupStorage: return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -465,7 +465,7 @@ class StateGroupStorage: groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter + groups, state_filter or StateFilter.all() ) state_event_map = await self.stores.main.get_events( @@ -485,7 +485,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -502,7 +502,7 @@ class StateGroupStorage: groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter + groups, state_filter or StateFilter.all() ) event_to_state = { @@ -513,7 +513,7 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_for_event( - self, event_id: str, state_filter: StateFilter = StateFilter.all() + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -525,11 +525,13 @@ class StateGroupStorage: Returns: A dict from (type, state_key) -> state_event """ - state_map = await self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events( + [event_id], state_filter or StateFilter.all() + ) return state_map[event_id] async def get_state_ids_for_event( - self, event_id: str, state_filter: StateFilter = StateFilter.all() + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -541,11 +543,13 @@ class StateGroupStorage: Returns: A dict from (type, state_key) -> state_event """ - state_map = await self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events( + [event_id], state_filter or StateFilter.all() + ) return state_map[event_id] def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -558,7 +562,9 @@ class StateGroupStorage: Returns: Dict of state group to state map. """ - return self.stores.state._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) async def store_state_group( self, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index d4643c4fdf..32d6cc16b9 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -17,7 +17,7 @@ import logging import threading from collections import OrderedDict from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import attr @@ -91,7 +91,14 @@ class StreamIdGenerator: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[], step=1): + def __init__( + self, + db_conn, + table, + column, + extra_tables: Iterable[Tuple[str, str]] = (), + step=1, + ): assert step != 0 self._lock = threading.Lock() self._step = step diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 60bb6ff642..20c8e2d9f5 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -57,12 +57,14 @@ def enumerate_leaves(node, depth): class _Node: __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] - def __init__(self, prev_node, next_node, key, value, callbacks=set()): + def __init__( + self, prev_node, next_node, key, value, callbacks: Optional[set] = None + ): self.prev_node = prev_node self.next_node = next_node self.key = key self.value = value - self.callbacks = callbacks + self.callbacks = callbacks or set() class LruCache(Generic[KT, VT]): @@ -176,10 +178,10 @@ class LruCache(Generic[KT, VT]): self.len = synchronized(cache_len) - def add_node(key, value, callbacks=set()): + def add_node(key, value, callbacks: Optional[set] = None): prev_node = list_root next_node = prev_node.next_node - node = _Node(prev_node, next_node, key, value, callbacks) + node = _Node(prev_node, next_node, key, value, callbacks or set()) prev_node.next_node = node next_node.prev_node = node cache[key] = node @@ -237,7 +239,7 @@ class LruCache(Generic[KT, VT]): def cache_get( key: KT, default: Optional[T] = None, - callbacks: Iterable[Callable[[], None]] = [], + callbacks: Iterable[Callable[[], None]] = (), update_metrics: bool = True, ): node = cache.get(key, None) @@ -253,7 +255,7 @@ class LruCache(Generic[KT, VT]): return default @synchronized - def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []): + def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()): node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 4c56253da5..73e12ea6c3 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from mock import Mock @@ -180,7 +181,11 @@ class MatrixFederationAgentTests(unittest.TestCase): _check_logcontext(context) def _handle_well_known_connection( - self, client_factory, expected_sni, content, response_headers={} + self, + client_factory, + expected_sni, + content, + response_headers: Optional[dict] = None, ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -202,10 +207,12 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual( request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] ) - self._send_well_known_response(request, content, headers=response_headers) + self._send_well_known_response(request, content, headers=response_headers or {}) return well_known_server - def _send_well_known_response(self, request, content, headers={}): + def _send_well_known_response( + self, request, content, headers: Optional[dict] = None + ): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ @@ -213,7 +220,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(request.path, b"/.well-known/matrix/server") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # send back a response - for k, v in headers.items(): + for k, v in (headers or {}).items(): request.setHeader(k, v) request.write(content) request.finish() diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 1d4a592862..aff19d9fb3 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): return resource def make_worker_hs( - self, worker_app: str, extra_config: dict = {}, **kwargs + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs ) -> HomeServer: """Make a new worker HS instance, correctly connecting replcation stream to the master HS. @@ -283,7 +283,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config = self._get_worker_hs_config() config["worker_app"] = worker_app - config.update(extra_config) + config.update(extra_config or {}) worker_hs = self.setup_test_homeserver( homeserver_to_use=GenericWorkerServer, diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 0ceb0f935c..333374b183 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Iterable, Optional from canonicaljson import encode_canonical_json @@ -332,15 +333,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): room_id=ROOM_ID, type="m.room.message", key=None, - internal={}, + internal: Optional[dict] = None, depth=None, - prev_events=[], - auth_events=[], - prev_state=[], + prev_events: Optional[list] = None, + auth_events: Optional[list] = None, + prev_state: Optional[list] = None, redacts=None, - push_actions=[], + push_actions: Iterable = frozenset(), **content ): + prev_events = prev_events or [] + auth_events = auth_events or [] + prev_state = prev_state or [] if depth is None: depth = self.event_id @@ -369,7 +373,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if redacts is not None: event_dict["redacts"] = redacts - event = make_event_from_dict(event_dict, internal_metadata_dict=internal) + event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {}) self.event_id += 1 state_handler = self.hs.get_state_handler() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ed65f645fc..715414a310 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,6 +19,7 @@ """Tests REST events for /rooms paths.""" import json +from typing import Iterable from urllib import parse as urlparse from mock import Mock @@ -207,7 +208,9 @@ class RoomPermissionsTestCase(RoomBase): ) self.assertEquals(403, channel.code, msg=channel.result["body"]) - def _test_get_membership(self, room=None, members=[], expect_code=None): + def _test_get_membership( + self, room=None, members: Iterable = frozenset(), expect_code=None + ): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 946740aa5d..8a4dddae2b 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -132,7 +132,7 @@ class RestHelper: src: str, targ: str, membership: str, - extra_data: dict = {}, + extra_data: Optional[dict] = None, tok: Optional[str] = None, expect_code: int = 200, ) -> None: @@ -156,7 +156,7 @@ class RestHelper: path = path + "?access_token=%s" % tok data = {"membership": membership} - data.update(extra_data) + data.update(extra_data or {}) channel = make_request( self.hs.get_reactor(), @@ -187,7 +187,13 @@ class RestHelper: ) def send_event( - self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200 + self, + room_id, + type, + content: Optional[dict] = None, + txn_id=None, + tok=None, + expect_code=200, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -201,7 +207,7 @@ class RestHelper: self.site, "PUT", path, - json.dumps(content).encode("utf8"), + json.dumps(content or {}).encode("utf8"), ) assert ( diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index e7bb5583fc..21ee436b91 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -16,6 +16,7 @@ import itertools import json import urllib +from typing import Optional from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -681,7 +682,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): relation_type, event_type, key=None, - content={}, + content: Optional[dict] = None, access_token=None, parent_id=None, ): @@ -713,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, original_id, relation_type, event_type, query), - json.dumps(content).encode("utf-8"), + json.dumps(content or {}).encode("utf-8"), access_token=access_token, ) return channel diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index aad6bc907e..6c389fe9ac 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional + from synapse.storage.database import DatabasePool from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -43,7 +45,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -53,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], ) return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) @@ -476,7 +478,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -486,7 +488,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], positive=False, ) @@ -612,7 +614,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -625,7 +627,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ("foobar2", "instance_name", "stream_id"), ], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], ) return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 2622207639..2d2f58903c 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from canonicaljson import json @@ -47,10 +48,15 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.depth = 1 def inject_room_member( - self, room, user, membership, replaces_state=None, extra_content={} + self, + room, + user, + membership, + replaces_state=None, + extra_content: Optional[dict] = None, ): content = {"membership": membership} - content.update(extra_content) + content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { diff --git a/tests/test_state.py b/tests/test_state.py index 6227a3ba95..1d2019699d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional from mock import Mock @@ -37,7 +38,7 @@ def create_event( state_key=None, depth=2, event_id=None, - prev_events=[], + prev_events: Optional[List[str]] = None, **kwargs ): global _next_event_id @@ -58,7 +59,7 @@ def create_event( "sender": "@user_id:example.com", "room_id": "!room_id:example.com", "depth": depth, - "prev_events": prev_events, + "prev_events": prev_events or [], } if state_key is not None: diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 510b630114..1b4dd47a82 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from mock import Mock @@ -147,9 +148,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): return event @defer.inlineCallbacks - def inject_room_member(self, user_id, membership="join", extra_content={}): + def inject_room_member( + self, user_id, membership="join", extra_content: Optional[dict] = None + ): content = {"membership": membership} - content.update(extra_content) + content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 4d1aee91d5..3fed55090a 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from synapse.config.homeserver import HomeServerConfig from synapse.util.ratelimitutils import FederationRateLimiter @@ -89,9 +91,9 @@ def _await_resolution(reactor, d): return (reactor.seconds() - start_time) * 1000 -def build_rc_config(settings={}): +def build_rc_config(settings: Optional[dict] = None): config_dict = default_config("test") - config_dict.update(settings) + config_dict.update(settings or {}) config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") return config.rc_federation -- cgit 1.5.1 From 1d5f0e3529ec5acd889037c8ebcca2820ad003d5 Mon Sep 17 00:00:00 2001 From: Dan Callahan Date: Tue, 13 Apr 2021 10:41:34 +0100 Subject: Bump black configuration to target py36 (#9781) Signed-off-by: Dan Callahan --- changelog.d/9781.misc | 1 + pyproject.toml | 2 +- synapse/config/tls.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/http/site.py | 2 +- synapse/storage/database.py | 8 ++++---- tests/replication/slave/storage/test_events.py | 2 +- tests/test_state.py | 2 +- tests/test_utils/event_injection.py | 6 +++--- tests/utils.py | 2 +- 11 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/9781.misc (limited to 'synapse/config') diff --git a/changelog.d/9781.misc b/changelog.d/9781.misc new file mode 100644 index 0000000000..d1c73fc741 --- /dev/null +++ b/changelog.d/9781.misc @@ -0,0 +1 @@ +Update Black configuration to target Python 3.6. diff --git a/pyproject.toml b/pyproject.toml index cd880d4e39..8bca1fa4ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ showcontent = true [tool.black] -target-version = ['py35'] +target-version = ['py36'] exclude = ''' ( diff --git a/synapse/config/tls.py b/synapse/config/tls.py index ad37b93c02..85b5db4c40 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -270,7 +270,7 @@ class TlsConfig(Config): tls_certificate_path, tls_private_key_path, acme_domain, - **kwargs + **kwargs, ): """If the acme_domain is specified acme will be enabled. If the TLS paths are not specified the default will be certs in the diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c817f2952d..0047907cd9 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1071,7 +1071,7 @@ class PresenceEventSource: room_ids=None, include_offline=True, explicit_room_id=None, - **kwargs + **kwargs, ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 5f01ebd3d4..ab47dec8f2 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -272,7 +272,7 @@ class MatrixFederationHttpClient: self, request: MatrixFederationRequest, try_trailing_slash_on_400: bool = False, - **send_request_args + **send_request_args, ) -> IResponse: """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a diff --git a/synapse/http/site.py b/synapse/http/site.py index c0c873ce32..32b5e19c09 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -497,7 +497,7 @@ class SynapseSite(Site): resource, server_version_string, *args, - **kwargs + **kwargs, ): Site.__init__(self, resource, *args, **kwargs) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index fa15b0ce5b..77ef29ec71 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -488,7 +488,7 @@ class DatabasePool: exception_callbacks: List[_CallbackListEntry], func: "Callable[..., R]", *args: Any, - **kwargs: Any + **kwargs: Any, ) -> R: """Start a new database transaction with the given connection. @@ -622,7 +622,7 @@ class DatabasePool: func: "Callable[..., R]", *args: Any, db_autocommit: bool = False, - **kwargs: Any + **kwargs: Any, ) -> R: """Starts a transaction on the database and runs a given function @@ -682,7 +682,7 @@ class DatabasePool: func: "Callable[..., R]", *args: Any, db_autocommit: bool = False, - **kwargs: Any + **kwargs: Any, ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. @@ -775,7 +775,7 @@ class DatabasePool: desc: str, decoder: Optional[Callable[[Cursor], R]], query: str, - *args: Any + *args: Any, ) -> R: """Runs a single query for a result set. diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 333374b183..db80a0bdbd 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -340,7 +340,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): prev_state: Optional[list] = None, redacts=None, push_actions: Iterable = frozenset(), - **content + **content, ): prev_events = prev_events or [] auth_events = auth_events or [] diff --git a/tests/test_state.py b/tests/test_state.py index 83383d8872..0d626f49f6 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -38,7 +38,7 @@ def create_event( depth=2, event_id=None, prev_events: Optional[List[str]] = None, - **kwargs + **kwargs, ): global _next_event_id diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index c3c4a93e1f..3dfbf8f8a9 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -33,7 +33,7 @@ async def inject_member_event( membership: str, target: Optional[str] = None, extra_content: Optional[dict] = None, - **kwargs + **kwargs, ) -> EventBase: """Inject a membership event into a room.""" if target is None: @@ -58,7 +58,7 @@ async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs + **kwargs, ) -> EventBase: """Inject a generic event into a room @@ -83,7 +83,7 @@ async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs + **kwargs, ) -> Tuple[EventBase, EventContext]: if room_version is None: room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) diff --git a/tests/utils.py b/tests/utils.py index 2e34fad11c..c78d3e5ba7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -190,7 +190,7 @@ def setup_test_homeserver( config=None, reactor=None, homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs + **kwargs, ): """ Setup a homeserver suitable for running tests against. Keyword arguments -- cgit 1.5.1