From b449af0379db871945f32a572883d47b5c9018a3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 17 Mar 2021 07:14:39 -0400 Subject: Add type hints to the room member handler. (#9631) --- synapse/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/server.py') diff --git a/synapse/server.py b/synapse/server.py index 48ac87a124..dd4ee7dd3c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -96,7 +96,7 @@ from synapse.handlers.room import ( RoomShutdownHandler, ) from synapse.handlers.room_list import RoomListHandler -from synapse.handlers.room_member import RoomMemberMasterHandler +from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.search import SearchHandler from synapse.handlers.set_password import SetPasswordHandler @@ -630,7 +630,7 @@ class HomeServer(metaclass=abc.ABCMeta): return ThirdPartyEventRules(self) @cache_in_self - def get_room_member_handler(self): + def get_room_member_handler(self) -> RoomMemberHandler: if self.config.worker_app: return RoomMemberWorkerHandler(self) return RoomMemberMasterHandler(self) -- cgit 1.5.1 From cc324d53fe531d002aca28a9d8e5b85768cdef23 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 17 Mar 2021 11:30:21 -0400 Subject: Fix up types for the typing handler. (#9638) By splitting this to two separate methods the callers know what methods they can expect on the handler. --- changelog.d/9638.misc | 1 + synapse/replication/tcp/streams/_base.py | 17 ++++++++++------- synapse/rest/client/v1/room.py | 15 +++++++++------ synapse/server.py | 11 ++++++++++- 4 files changed, 30 insertions(+), 14 deletions(-) create mode 100644 changelog.d/9638.misc (limited to 'synapse/server.py') diff --git a/changelog.d/9638.misc b/changelog.d/9638.misc new file mode 100644 index 0000000000..35338cd332 --- /dev/null +++ b/changelog.d/9638.misc @@ -0,0 +1 @@ +Add additional type hints to the Homeserver object. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f45e7a8c89..7e8e64d61c 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -33,7 +33,7 @@ import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates if TYPE_CHECKING: - import synapse.server + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -299,20 +299,23 @@ class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow - def __init__(self, hs): - typing_handler = hs.get_typing_handler() - + def __init__(self, hs: "HomeServer"): writer_instance = hs.config.worker.writers.typing if writer_instance == hs.get_instance_name(): # On the writer, query the typing handler - update_function = typing_handler.get_all_typing_updates + typing_writer_handler = hs.get_typing_writer_handler() + update_function = ( + typing_writer_handler.get_all_typing_updates + ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] + current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) + current_token_function = hs.get_typing_handler().get_current_token super().__init__( hs.get_instance_name(), - current_token_without_instance(typing_handler.get_current_token), + current_token_without_instance(current_token_function), update_function, ) @@ -509,7 +512,7 @@ class AccountDataStream(Stream): NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5884daea6d..e7a8207eb1 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -49,7 +49,7 @@ from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: - import synapse.server + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -846,10 +846,10 @@ class RoomTypingRestServlet(RestServlet): "/rooms/(?P[^/]*)/typing/(?P[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() + self.hs = hs self.presence_handler = hs.get_presence_handler() - self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() # If we're not on the typing writer instance we should scream if we get @@ -874,16 +874,19 @@ class RoomTypingRestServlet(RestServlet): # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) + # Defer getting the typing handler since it will raise on workers. + typing_handler = self.hs.get_typing_writer_handler() + try: if content["typing"]: - await self.typing_handler.started_typing( + await typing_handler.started_typing( target_user=target_user, requester=requester, room_id=room_id, timeout=timeout, ) else: - await self.typing_handler.stopped_typing( + await typing_handler.stopped_typing( target_user=target_user, requester=requester, room_id=room_id ) except ShadowBanError: @@ -901,7 +904,7 @@ class RoomAliasListServlet(RestServlet): ), ] - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.directory_handler = hs.get_directory_handler() diff --git a/synapse/server.py b/synapse/server.py index dd4ee7dd3c..d11d08c573 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -417,9 +417,18 @@ class HomeServer(metaclass=abc.ABCMeta): return PresenceHandler(self) @cache_in_self - def get_typing_handler(self): + def get_typing_writer_handler(self) -> TypingWriterHandler: if self.config.worker.writers.typing == self.get_instance_name(): return TypingWriterHandler(self) + else: + raise Exception("Workers cannot write typing") + + @cache_in_self + def get_typing_handler(self) -> FollowerTypingHandler: + if self.config.worker.writers.typing == self.get_instance_name(): + # Use get_typing_writer_handler to ensure that we use the same + # cached version. + return self.get_typing_writer_handler() else: return FollowerTypingHandler(self) -- cgit 1.5.1 From 004234f03abad58b2de28abff406391a6d182e48 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 18 Mar 2021 18:24:16 +0000 Subject: Initial spaces summary API (#9643) This is very bare-bones for now: federation will come soon, while pagination is descoped for now but will come later. --- changelog.d/9643.feature | 1 + synapse/api/constants.py | 6 ++ synapse/config/experimental.py | 3 + synapse/handlers/space_summary.py | 199 ++++++++++++++++++++++++++++++++++++++ synapse/rest/client/v1/room.py | 66 ++++++++++++- synapse/server.py | 5 + 6 files changed, 277 insertions(+), 3 deletions(-) create mode 100644 changelog.d/9643.feature create mode 100644 synapse/handlers/space_summary.py (limited to 'synapse/server.py') diff --git a/changelog.d/9643.feature b/changelog.d/9643.feature new file mode 100644 index 0000000000..2f7ccedcfb --- /dev/null +++ b/changelog.d/9643.feature @@ -0,0 +1 @@ +Add initial experimental support for a "space summary" API. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 691f8f9adf..ed050c8104 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -100,6 +100,9 @@ class EventTypes: Dummy = "org.matrix.dummy_event" + MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child" + MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent" + class EduTypes: Presence = "m.presence" @@ -160,6 +163,9 @@ class EventContentFields: # cf https://github.com/matrix-org/matrix-doc/pull/2228 SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" + # cf https://github.com/matrix-org/matrix-doc/pull/1772 + MSC1772_ROOM_TYPE = "org.matrix.msc1772.type" + class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b1c1c51e4d..5554bcea8a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -27,3 +27,6 @@ class ExperimentalConfig(Config): # MSC2858 (multiple SSO identity providers) self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool + + # Spaces (MSC1772, MSC2946, etc) + self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py new file mode 100644 index 0000000000..513dc0c71a --- /dev/null +++ b/synapse/handlers/space_summary.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from collections import deque +from typing import TYPE_CHECKING, Iterable, List, Optional, Set + +from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility +from synapse.api.errors import AuthError +from synapse.events import EventBase +from synapse.events.utils import format_event_for_client_v2 +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# number of rooms to return. We'll stop once we hit this limit. +# TODO: allow clients to reduce this with a request param. +MAX_ROOMS = 50 + +# max number of events to return per room. +MAX_ROOMS_PER_SPACE = 50 + + +class SpaceSummaryHandler: + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._auth = hs.get_auth() + self._room_list_handler = hs.get_room_list_handler() + self._state_handler = hs.get_state_handler() + self._store = hs.get_datastore() + self._event_serializer = hs.get_event_client_serializer() + + async def get_space_summary( + self, + requester: str, + room_id: str, + suggested_only: bool = False, + max_rooms_per_space: Optional[int] = None, + ) -> JsonDict: + """ + Implementation of the space summary API + + Args: + requester: user id of the user making this request + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. This does not apply to the root room (ie, room_id), and + is overridden by ROOMS_PER_SPACE_LIMIT. + + Returns: + summary dict to return + """ + # first of all, check that the user is in the room in question (or it's + # world-readable) + await self._auth.check_user_in_room_or_world_readable(room_id, requester) + + # the queue of rooms to process + room_queue = deque((room_id,)) + + processed_rooms = set() # type: Set[str] + + rooms_result = [] # type: List[JsonDict] + events_result = [] # type: List[JsonDict] + + now = self._clock.time_msec() + + while room_queue and len(rooms_result) < MAX_ROOMS: + room_id = room_queue.popleft() + logger.debug("Processing room %s", room_id) + processed_rooms.add(room_id) + + try: + await self._auth.check_user_in_room_or_world_readable( + room_id, requester + ) + except AuthError: + logger.info( + "user %s cannot view room %s, omitting from summary", + requester, + room_id, + ) + continue + + room_entry = await self._build_room_entry(room_id) + rooms_result.append(room_entry) + + # look for child rooms/spaces. + child_events = await self._get_child_events(room_id) + + if suggested_only: + # we only care about suggested children + child_events = filter(_is_suggested_child_event, child_events) + + # The client-specified max_rooms_per_space limit doesn't apply to the + # room_id specified in the request, so we ignore it if this is the + # first room we are processing. Otherwise, apply any client-specified + # limit, capping to our built-in limit. + if max_rooms_per_space is not None and len(processed_rooms) > 1: + max_rooms = min(MAX_ROOMS_PER_SPACE, max_rooms_per_space) + else: + max_rooms = MAX_ROOMS_PER_SPACE + + for edge_event in itertools.islice(child_events, max_rooms): + edge_room_id = edge_event.state_key + + events_result.append( + await self._event_serializer.serialize_event( + edge_event, + time_now=now, + event_format=format_event_for_client_v2, + ) + ) + + # if we haven't yet visited the target of this link, add it to the queue + if edge_room_id not in processed_rooms: + room_queue.append(edge_room_id) + + return {"rooms": rooms_result, "events": events_result} + + async def _build_room_entry(self, room_id: str) -> JsonDict: + """Generate en entry suitable for the 'rooms' list in the summary response""" + stats = await self._store.get_room_with_stats(room_id) + + # currently this should be impossible because we call + # check_user_in_room_or_world_readable on the room before we get here, so + # there should always be an entry + assert stats is not None, "unable to retrieve stats for %s" % (room_id,) + + current_state_ids = await self._store.get_current_state_ids(room_id) + create_event = await self._store.get_event( + current_state_ids[(EventTypes.Create, "")] + ) + + # TODO: update once MSC1772 lands + room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE) + + entry = { + "room_id": stats["room_id"], + "name": stats["name"], + "topic": stats["topic"], + "canonical_alias": stats["canonical_alias"], + "num_joined_members": stats["joined_members"], + "avatar_url": stats["avatar"], + "world_readable": ( + stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + ), + "guest_can_join": stats["guest_access"] == "can_join", + "room_type": room_type, + } + + # Filter out Nones – rather omit the field altogether + room_entry = {k: v for k, v in entry.items() if v is not None} + + return room_entry + + async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: + # look for child rooms/spaces. + current_state_ids = await self._store.get_current_state_ids(room_id) + + events = await self._store.get_events_as_list( + [ + event_id + for key, event_id in current_state_ids.items() + # TODO: update once MSC1772 lands + if key[0] == EventTypes.MSC1772_SPACE_CHILD + ] + ) + + # filter out any events without a "via" (which implies it has been redacted) + return (e for e in events if e.content.get("via")) + + +def _is_suggested_child_event(edge_event: EventBase) -> bool: + suggested = edge_event.content.get("suggested") + if isinstance(suggested, bool) and suggested: + return True + logger.debug("Ignorning not-suggested child %s", edge_event.state_key) + return False diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e7a8207eb1..f78d792264 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -18,9 +18,11 @@ import logging import re -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from urllib import parse as urlparse +from twisted.web.server import Request + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -35,6 +37,7 @@ from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_integer, parse_json_object_from_request, parse_string, @@ -44,7 +47,14 @@ from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types import ( + JsonDict, + RoomAlias, + RoomID, + StreamToken, + ThirdPartyInstanceID, + UserID, +) from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -987,7 +997,54 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): ) -def register_servlets(hs, http_server, is_worker=False): +class RoomSpaceSummaryRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P[^/]*)/spaces$" + ), + ) + + def __init__(self, hs: "synapse.server.HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._space_summary_handler = hs.get_space_summary_handler() + + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + return 200, await self._space_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_rooms_per_space=parse_integer(request, "max_rooms_per_space"), + ) + + async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + content = parse_json_object_from_request(request) + + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON + ) + + return 200, await self._space_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_rooms_per_space, + ) + + +def register_servlets(hs: "synapse.server.HomeServer", http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1001,6 +1058,9 @@ def register_servlets(hs, http_server, is_worker=False): RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) + if hs.config.experimental.spaces_enabled: + RoomSpaceSummaryRestServlet(hs).register(http_server) + # Some servlets only get registered for the main process. if not is_worker: RoomCreateRestServlet(hs).register(http_server) diff --git a/synapse/server.py b/synapse/server.py index d11d08c573..98822d8e2f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -100,6 +100,7 @@ from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHand from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.search import SearchHandler from synapse.handlers.set_password import SetPasswordHandler +from synapse.handlers.space_summary import SpaceSummaryHandler from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler @@ -732,6 +733,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) + @cache_in_self + def get_space_summary_handler(self) -> SpaceSummaryHandler: + return SpaceSummaryHandler(self) + @cache_in_self def get_external_cache(self) -> ExternalCache: return ExternalCache(self) -- cgit 1.5.1 From 7e8dc9934e29ebd5f30f42b4b6e6b4491569373a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 24 Mar 2021 06:48:46 -0400 Subject: Add a type hints for service notices to the HomeServer object. (#9675) --- changelog.d/9675.misc | 1 + synapse/handlers/sync.py | 6 ++++-- synapse/rest/client/v2_alpha/sync.py | 11 ++++++++--- synapse/server.py | 4 ++-- synapse/server_notices/consent_server_notices.py | 18 ++++++++++-------- .../server_notices/resource_limits_server_notices.py | 11 +++++------ synapse/server_notices/server_notices_manager.py | 2 +- synapse/server_notices/server_notices_sender.py | 18 ++++++++++-------- synapse/server_notices/worker_server_notices_sender.py | 11 ++++++----- synapse/storage/databases/main/deviceinbox.py | 6 +++--- synapse/storage/databases/main/monthly_active_users.py | 4 ++-- 11 files changed, 52 insertions(+), 40 deletions(-) create mode 100644 changelog.d/9675.misc (limited to 'synapse/server.py') diff --git a/changelog.d/9675.misc b/changelog.d/9675.misc new file mode 100644 index 0000000000..35338cd332 --- /dev/null +++ b/changelog.d/9675.misc @@ -0,0 +1 @@ +Add additional type hints to the Homeserver object. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7b723ead58..ee607e6e65 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -80,7 +80,7 @@ class SyncConfig: filter_collection = attr.ib(type=FilterCollection) is_guest = attr.ib(type=bool) request_key = attr.ib(type=Tuple[Any, ...]) - device_id = attr.ib(type=str) + device_id = attr.ib(type=Optional[str]) @attr.s(slots=True, frozen=True) @@ -723,7 +723,9 @@ class SyncHandler: return summary - def get_lazy_loaded_members_cache(self, cache_key: Tuple[str, str]) -> LruCache: + def get_lazy_loaded_members_cache( + self, cache_key: Tuple[str, Optional[str]] + ) -> LruCache: cache = self.lazy_loaded_members_cache.get(cache_key) if cache is None: logger.debug("creating LruCache for %r", cache_key) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 8e52e4cca4..a0db0a054b 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -15,6 +15,7 @@ import itertools import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError @@ -26,11 +27,15 @@ from synapse.events.utils import ( from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string -from synapse.types import StreamToken +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -73,7 +78,7 @@ class SyncRestServlet(RestServlet): PATTERNS = client_patterns("/sync$") ALLOWED_PRESENCE = {"online", "offline", "unavailable"} - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -85,7 +90,7 @@ class SyncRestServlet(RestServlet): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. diff --git a/synapse/server.py b/synapse/server.py index 98822d8e2f..5e787e2281 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -650,13 +650,13 @@ class HomeServer(metaclass=abc.ABCMeta): return FederationHandlerRegistry(self) @cache_in_self - def get_server_notices_manager(self): + def get_server_notices_manager(self) -> ServerNoticesManager: if self.config.worker_app: raise Exception("Workers cannot send server notices") return ServerNoticesManager(self) @cache_in_self - def get_server_notices_sender(self): + def get_server_notices_sender(self) -> WorkerServerNoticesSender: if self.config.worker_app: return WorkerServerNoticesSender(self) return ServerNoticesSender(self) diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 9137c4edb1..a9349bf9a1 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -13,13 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any +from typing import TYPE_CHECKING, Any, Set from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError from synapse.types import get_localpart_from_id +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -28,16 +31,11 @@ class ConsentServerNotices: privacy policy consent, and sends one if we do. """ - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastore() - self._users_in_progress = set() + self._users_in_progress = set() # type: Set[str] self._current_consent_version = hs.config.user_consent_version self._server_notice_content = hs.config.user_consent_server_notice_content @@ -73,6 +71,10 @@ class ConsentServerNotices: try: u = await self._store.get_user_by_id(user_id) + # The user doesn't exist. + if u is None: + return + if u["is_guest"] and not self._send_to_guests: # don't send to guests return diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 6652451346..a18a2e76c9 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Tuple from synapse.api.constants import ( EventTypes, @@ -24,6 +24,9 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError, ResourceLimitError, SynapseError from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -32,11 +35,7 @@ class ResourceLimitsServerNotices: ensures that the client is kept up to date. """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastore() self._auth = hs.get_auth() diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index c46b2f047d..144e1da78e 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -58,7 +58,7 @@ class ServerNoticesManager: user_id: str, event_content: dict, type: str = EventTypes.Message, - state_key: Optional[bool] = None, + state_key: Optional[str] = None, ) -> EventBase: """Send a notice to the given user diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 6870b67ca0..965c645889 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,25 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Union +from typing import TYPE_CHECKING, Iterable, Union from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) +from synapse.server_notices.worker_server_notices_sender import ( + WorkerServerNoticesSender, +) + +if TYPE_CHECKING: + from synapse.server import HomeServer -class ServerNoticesSender: +class ServerNoticesSender(WorkerServerNoticesSender): """A centralised place which sends server notices automatically when Certain Events take place """ - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): + super().__init__(hs) self._server_notices = ( ConsentServerNotices(hs), ResourceLimitsServerNotices(hs), diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py index 9273e61895..c76bd57460 100644 --- a/synapse/server_notices/worker_server_notices_sender.py +++ b/synapse/server_notices/worker_server_notices_sender.py @@ -12,16 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from synapse.server import HomeServer class WorkerServerNoticesSender: """Stub impl of ServerNoticesSender which does nothing""" - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): + pass async def on_user_syncing(self, user_id: str) -> None: """Called when the user performs a sync operation. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 45ca6620a8..691080ce74 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Tuple +from typing import List, Optional, Tuple from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.replication.tcp.streams import ToDeviceStream @@ -115,7 +115,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): async def get_new_messages_for_device( self, user_id: str, - device_id: str, + device_id: Optional[str], last_stream_id: int, current_stream_id: int, limit: int = 100, @@ -163,7 +163,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): @trace async def delete_messages_for_device( - self, user_id: str, device_id: str, up_to_stream_id: int + self, user_id: str, device_id: Optional[str], up_to_stream_id: int ) -> int: """ Args: diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index d788dc0fc6..757da3d55d 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List +from typing import Dict, List, Optional from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -109,7 +109,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return users @cached(num_args=1) - async def user_last_seen_monthly_active(self, user_id: str) -> int: + async def user_last_seen_monthly_active(self, user_id: str) -> Optional[int]: """ Checks if a given user is part of the monthly active user group -- cgit 1.5.1 From da75d2ea1f2784791399dbeba16be401e2bb37d2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 29 Mar 2021 11:43:20 -0400 Subject: Add type hints for the federation sender. (#9681) Includes an abstract base class which both the FederationSender and the FederationRemoteSendQueue must implement. --- changelog.d/9681.misc | 1 + synapse/app/generic_worker.py | 7 -- synapse/federation/send_queue.py | 88 ++++++++++++------- synapse/federation/sender/__init__.py | 116 +++++++++++++++++++++++--- synapse/replication/tcp/commands.py | 6 +- synapse/replication/tcp/streams/federation.py | 14 +++- synapse/server.py | 4 +- 7 files changed, 177 insertions(+), 59 deletions(-) create mode 100644 changelog.d/9681.misc (limited to 'synapse/server.py') diff --git a/changelog.d/9681.misc b/changelog.d/9681.misc new file mode 100644 index 0000000000..35338cd332 --- /dev/null +++ b/changelog.d/9681.misc @@ -0,0 +1 @@ +Add additional type hints to the Homeserver object. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 6139881dbb..3df2aa5c2b 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -787,13 +787,6 @@ class FederationSenderHandler: self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - def on_start(self): - # There may be some events that are persisted but haven't been sent, - # so send them now. - self.federation_sender.notify_new_events( - self.store.get_room_max_stream_ordering() - ) - def wake_destination(self, server: str): self.federation_sender.wake_destination(server) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 3e993b428b..0c18c49abb 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,25 +31,39 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple -from typing import Dict, List, Tuple, Type +from typing import ( + TYPE_CHECKING, + Dict, + Hashable, + Iterable, + List, + Optional, + Sized, + Tuple, + Type, +) from sortedcontainers import SortedDict -from twisted.internet import defer - from synapse.api.presence import UserPresenceState +from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.metrics import LaterGauge +from synapse.replication.tcp.streams.federation import FederationStream +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken from synapse.util.metrics import Measure from .units import Edu +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -class FederationRemoteSendQueue: +class FederationRemoteSendQueue(AbstractFederationSender): """A drop in replacement for FederationSender""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -58,7 +72,7 @@ class FederationRemoteSendQueue: # We may have multiple federation sender instances, so we need to track # their positions separately. self._sender_instances = hs.config.worker.federation_shard_config.instances - self._sender_positions = {} + self._sender_positions = {} # type: Dict[str, int] # Pending presence map user_id -> UserPresenceState self.presence_map = {} # type: Dict[str, UserPresenceState] @@ -71,7 +85,7 @@ class FederationRemoteSendQueue: # Stream position -> (user_id, destinations) self.presence_destinations = ( SortedDict() - ) # type: SortedDict[int, Tuple[str, List[str]]] + ) # type: SortedDict[int, Tuple[str, Iterable[str]]] # (destination, key) -> EDU self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] @@ -94,7 +108,7 @@ class FederationRemoteSendQueue: # we make a new function, so we need to make a new function so the inner # lambda binds to the queue rather than to the name of the queue which # changes. ARGH. - def register(name, queue): + def register(name: str, queue: Sized) -> None: LaterGauge( "synapse_federation_send_queue_%s_size" % (queue_name,), "", @@ -115,13 +129,13 @@ class FederationRemoteSendQueue: self.clock.looping_call(self._clear_queue, 30 * 1000) - def _next_pos(self): + def _next_pos(self) -> int: pos = self.pos self.pos += 1 self.pos_time[self.clock.time_msec()] = pos return pos - def _clear_queue(self): + def _clear_queue(self) -> None: """Clear the queues for anything older than N minutes""" FIVE_MINUTES_AGO = 5 * 60 * 1000 @@ -138,7 +152,7 @@ class FederationRemoteSendQueue: self._clear_queue_before_pos(position_to_delete) - def _clear_queue_before_pos(self, position_to_delete): + def _clear_queue_before_pos(self, position_to_delete: int) -> None: """Clear all the queues from before a given position""" with Measure(self.clock, "send_queue._clear"): # Delete things out of presence maps @@ -188,13 +202,18 @@ class FederationRemoteSendQueue: for key in keys[:i]: del self.edus[key] - def notify_new_events(self, max_token): + def notify_new_events(self, max_token: RoomStreamToken) -> None: """As per FederationSender""" - # We don't need to replicate this as it gets sent down a different - # stream. - pass + # This should never get called. + raise NotImplementedError() - def build_and_send_edu(self, destination, edu_type, content, key=None): + def build_and_send_edu( + self, + destination: str, + edu_type: str, + content: JsonDict, + key: Optional[Hashable] = None, + ) -> None: """As per FederationSender""" if destination == self.server_name: logger.info("Not sending EDU to ourselves") @@ -218,38 +237,39 @@ class FederationRemoteSendQueue: self.notifier.on_new_replication_data() - def send_read_receipt(self, receipt): + async def send_read_receipt(self, receipt: ReadReceipt) -> None: """As per FederationSender Args: - receipt (synapse.types.ReadReceipt): + receipt: """ # nothing to do here: the replication listener will handle it. - return defer.succeed(None) - def send_presence(self, states): + def send_presence(self, states: List[UserPresenceState]) -> None: """As per FederationSender Args: - states (list(UserPresenceState)) + states """ pos = self._next_pos() # We only want to send presence for our own users, so lets always just # filter here just in case. - local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states)) + local_states = [s for s in states if self.is_mine_id(s.user_id)] self.presence_map.update({state.user_id: state for state in local_states}) self.presence_changed[pos] = [state.user_id for state in local_states] self.notifier.on_new_replication_data() - def send_presence_to_destinations(self, states, destinations): + def send_presence_to_destinations( + self, states: Iterable[UserPresenceState], destinations: Iterable[str] + ) -> None: """As per FederationSender Args: - states (list[UserPresenceState]) - destinations (list[str]) + states + destinations """ for state in states: pos = self._next_pos() @@ -258,15 +278,18 @@ class FederationRemoteSendQueue: self.notifier.on_new_replication_data() - def send_device_messages(self, destination): + def send_device_messages(self, destination: str) -> None: """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. - def get_current_token(self): + def wake_destination(self, server: str) -> None: + pass + + def get_current_token(self) -> int: return self.pos - 1 - def federation_ack(self, instance_name, token): + def federation_ack(self, instance_name: str, token: int) -> None: if self._sender_instances: # If we have configured multiple federation sender instances we need # to track their positions separately, and only clear the queue up @@ -504,13 +527,16 @@ ParsedFederationStreamData = namedtuple( ) -def process_rows_for_federation(transaction_queue, rows): +def process_rows_for_federation( + transaction_queue: FederationSender, + rows: List[FederationStream.FederationStreamRow], +) -> None: """Parse a list of rows from the federation stream and put them in the transaction queue ready for sending to the relevant homeservers. Args: - transaction_queue (FederationSender) - rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow)) + transaction_queue + rows """ # The federation stream contains a bunch of different types of diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 24ebc4b803..8babb1ebbe 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging -from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter from twisted.internet import defer -import synapse import synapse.metrics from synapse.api.presence import UserPresenceState from synapse.events import EventBase @@ -40,9 +40,12 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import ReadReceipt, RoomStreamToken +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken from synapse.util.metrics import Measure, measure_func +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) sent_pdus_destination_dist_count = Counter( @@ -65,8 +68,91 @@ CATCH_UP_STARTUP_DELAY_SEC = 15 CATCH_UP_STARTUP_INTERVAL_SEC = 5 -class FederationSender: - def __init__(self, hs: "synapse.server.HomeServer"): +class AbstractFederationSender(metaclass=abc.ABCMeta): + @abc.abstractmethod + def notify_new_events(self, max_token: RoomStreamToken) -> None: + """This gets called when we have some new events we might want to + send out to other servers. + """ + raise NotImplementedError() + + @abc.abstractmethod + async def send_read_receipt(self, receipt: ReadReceipt) -> None: + """Send a RR to any other servers in the room + + Args: + receipt: receipt to be sent + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_presence(self, states: List[UserPresenceState]) -> None: + """Send the new presence states to the appropriate destinations. + + This actually queues up the presence states ready for sending and + triggers a background task to process them and send out the transactions. + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_presence_to_destinations( + self, states: Iterable[UserPresenceState], destinations: Iterable[str] + ) -> None: + """Send the given presence states to the given destinations. + + Args: + destinations: + """ + raise NotImplementedError() + + @abc.abstractmethod + def build_and_send_edu( + self, + destination: str, + edu_type: str, + content: JsonDict, + key: Optional[Hashable] = None, + ) -> None: + """Construct an Edu object, and queue it for sending + + Args: + destination: name of server to send to + edu_type: type of EDU to send + content: content of EDU + key: clobbering key for this edu + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_device_messages(self, destination: str) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def wake_destination(self, destination: str) -> None: + """Called when we want to retry sending transactions to a remote. + + This is mainly useful if the remote server has been down and we think it + might have come back. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_current_token(self) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def federation_ack(self, instance_name: str, token: int) -> None: + raise NotImplementedError() + + @abc.abstractmethod + async def get_replication_rows( + self, instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: + raise NotImplementedError() + + +class FederationSender(AbstractFederationSender): + def __init__(self, hs: "HomeServer"): self.hs = hs self.server_name = hs.hostname @@ -432,7 +518,7 @@ class FederationSender: queue.flush_read_receipts_for_room(room_id) @preserve_fn # the caller should not yield on this - async def send_presence(self, states: List[UserPresenceState]): + async def send_presence(self, states: List[UserPresenceState]) -> None: """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and @@ -494,7 +580,7 @@ class FederationSender: self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") - async def _process_presence_inner(self, states: List[UserPresenceState]): + async def _process_presence_inner(self, states: List[UserPresenceState]) -> None: """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ @@ -516,9 +602,9 @@ class FederationSender: self, destination: str, edu_type: str, - content: dict, + content: JsonDict, key: Optional[Hashable] = None, - ): + ) -> None: """Construct an Edu object, and queue it for sending Args: @@ -545,7 +631,7 @@ class FederationSender: self.send_edu(edu, key) - def send_edu(self, edu: Edu, key: Optional[Hashable]): + def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None: """Queue an EDU for sending Args: @@ -563,7 +649,7 @@ class FederationSender: else: queue.send_edu(edu) - def send_device_messages(self, destination: str): + def send_device_messages(self, destination: str) -> None: if destination == self.server_name: logger.warning("Not sending device update to ourselves") return @@ -575,7 +661,7 @@ class FederationSender: self._get_per_destination_queue(destination).attempt_new_transaction() - def wake_destination(self, destination: str): + def wake_destination(self, destination: str) -> None: """Called when we want to retry sending transactions to a remote. This is mainly useful if the remote server has been down and we think it @@ -599,6 +685,10 @@ class FederationSender: # to a worker. return 0 + def federation_ack(self, instance_name: str, token: int) -> None: + # It is not expected that this gets called on FederationSender. + raise NotImplementedError() + @staticmethod async def get_replication_rows( instance_name: str, from_token: int, to_token: int, target_row_count: int @@ -607,7 +697,7 @@ class FederationSender: # to a worker. return [], 0, False - async def _wake_destinations_needing_catchup(self): + async def _wake_destinations_needing_catchup(self) -> None: """ Wakes up destinations that need catch-up and are not currently being backed off from. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index bb447f75b4..8abed1f52d 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -312,16 +312,16 @@ class FederationAckCommand(Command): NAME = "FEDERATION_ACK" - def __init__(self, instance_name, token): + def __init__(self, instance_name: str, token: int): self.instance_name = instance_name self.token = token @classmethod - def from_line(cls, line): + def from_line(cls, line: str) -> "FederationAckCommand": instance_name, token = line.split(" ") return cls(instance_name, int(token)) - def to_line(self): + def to_line(self) -> str: return "%s %s" % (self.instance_name, self.token) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 9bcd13b009..9bb8e9e177 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple from synapse.replication.tcp.streams._base import ( Stream, @@ -21,6 +22,9 @@ from synapse.replication.tcp.streams._base import ( make_http_update_function, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation @@ -38,7 +42,7 @@ class FederationStream(Stream): NAME = "federation" ROW_TYPE = FederationStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): if hs.config.worker_app is None: # master process: get updates from the FederationRemoteSendQueue. # (if the master is configured to send federation itself, federation_sender @@ -48,7 +52,9 @@ class FederationStream(Stream): current_token = current_token_without_instance( federation_sender.get_current_token ) - update_function = federation_sender.get_replication_rows + update_function = ( + federation_sender.get_replication_rows + ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] elif hs.should_send_federation(): # federation sender: Query master process @@ -69,5 +75,7 @@ class FederationStream(Stream): return 0 @staticmethod - async def _stub_update_function(instance_name, from_token, upto_token, limit): + async def _stub_update_function( + instance_name: str, from_token: int, upto_token: int, limit: int + ) -> Tuple[list, int, bool]: return [], upto_token, False diff --git a/synapse/server.py b/synapse/server.py index 5e787e2281..e85b9391fa 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -60,7 +60,7 @@ from synapse.federation.federation_server import ( FederationServer, ) from synapse.federation.send_queue import FederationRemoteSendQueue -from synapse.federation.sender import FederationSender +from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler @@ -571,7 +571,7 @@ class HomeServer(metaclass=abc.ABCMeta): return TransportLayerClient(self) @cache_in_self - def get_federation_sender(self): + def get_federation_sender(self) -> AbstractFederationSender: if self.should_send_federation(): return FederationSender(self) elif not self.config.worker_app: -- cgit 1.5.1 From 963f4309fe29206f3ba92b493e922280feea30ed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Mar 2021 12:06:09 +0100 Subject: Make RateLimiter class check for ratelimit overrides (#9711) This should fix a class of bug where we forget to check if e.g. the appservice shouldn't be ratelimited. We also check the `ratelimit_override` table to check if the user has ratelimiting disabled. That table is really only meant to override the event sender ratelimiting, so we don't use any values from it (as they might not make sense for different rate limits), but we do infer that if ratelimiting is disabled for the user we should disabled all ratelimits. Fixes #9663 --- changelog.d/9711.bugfix | 1 + synapse/api/ratelimiting.py | 100 +++++++++--------- synapse/federation/federation_server.py | 5 +- synapse/handlers/_base.py | 14 +-- synapse/handlers/auth.py | 24 +++-- synapse/handlers/devicemessage.py | 5 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 12 ++- synapse/handlers/register.py | 6 +- synapse/handlers/room_member.py | 23 +++-- synapse/replication/http/register.py | 2 +- synapse/rest/client/v1/login.py | 14 ++- synapse/rest/client/v2_alpha/account.py | 10 +- synapse/rest/client/v2_alpha/register.py | 8 +- synapse/server.py | 1 + tests/api/test_ratelimiting.py | 168 ++++++++++++++++++++----------- 16 files changed, 241 insertions(+), 154 deletions(-) create mode 100644 changelog.d/9711.bugfix (limited to 'synapse/server.py') diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix new file mode 100644 index 0000000000..4ca3438d46 --- /dev/null +++ b/changelog.d/9711.bugfix @@ -0,0 +1 @@ +Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index c3f07bc1a3..2244b8a340 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -31,10 +32,13 @@ class Ratelimiter: burst_count: How many actions that can be performed before being limited. """ - def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + def __init__( + self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + ): self.clock = clock self.rate_hz = rate_hz self.burst_count = burst_count + self.store = store # A ordered dictionary keeping track of actions, when they were last # performed and how often. Each entry is a mapping from a key of arbitrary type @@ -46,45 +50,10 @@ class Ratelimiter: OrderedDict() ) # type: OrderedDict[Hashable, Tuple[float, int, float]] - def can_requester_do_action( - self, - requester: Requester, - rate_hz: Optional[float] = None, - burst_count: Optional[int] = None, - update: bool = True, - _time_now_s: Optional[int] = None, - ) -> Tuple[bool, float]: - """Can the requester perform the action? - - Args: - requester: The requester to key off when rate limiting. The user property - will be used. - rate_hz: The long term number of actions that can be performed in a second. - Overrides the value set during instantiation if set. - burst_count: How many actions that can be performed before being limited. - Overrides the value set during instantiation if set. - update: Whether to count this check as performing the action - _time_now_s: The current time. Optional, defaults to the current time according - to self.clock. Only used by tests. - - Returns: - A tuple containing: - * A bool indicating if they can perform the action now - * The reactor timestamp for when the action can be performed next. - -1 if rate_hz is less than or equal to zero - """ - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return True, -1.0 - - return self.can_do_action( - requester.user.to_string(), rate_hz, burst_count, update, _time_now_s - ) - - def can_do_action( + async def can_do_action( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -92,9 +61,16 @@ class Ratelimiter: ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: The key we should use when rate limiting. Can be a user ID - (when sending events), an IP address, etc. + requester: The requester that is doing the action, if any. Used to check + if the user has ratelimits disabled in the database. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -109,6 +85,30 @@ class Ratelimiter: * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + + if requester: + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + # Check if ratelimiting has been disabled for the user. + # + # Note that we don't use the returned rate/burst count, as the table + # is specifically for the event sending ratelimiter. Instead, we + # only use it to (somewhat cheekily) infer whether the user should + # be subject to any rate limiting or not. + override = await self.store.get_ratelimit_for_user( + requester.authenticated_entity + ) + if override and not override.messages_per_second: + return True, -1.0 + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz @@ -175,9 +175,10 @@ class Ratelimiter: else: del self.actions[key] - def ratelimit( + async def ratelimit( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -185,8 +186,16 @@ class Ratelimiter: ): """Checks if an action can be performed. If not, raises a LimitExceededError + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: An arbitrary key used to classify an action + requester: The requester that is doing the action, if any. Used to check for + if the user has ratelimits disabled. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -201,7 +210,8 @@ class Ratelimiter: """ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - allowed, time_allowed = self.can_do_action( + allowed, time_allowed = await self.can_do_action( + requester, key, rate_hz=rate_hz, burst_count=burst_count, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d84e362070..71cb120ef7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -870,6 +870,7 @@ class FederationHandlerRegistry: # A rate limiter for incoming room key requests per origin. self._room_key_request_rate_limiter = Ratelimiter( + store=hs.get_datastore(), clock=self.clock, rate_hz=self.config.rc_key_requests.per_second, burst_count=self.config.rc_key_requests.burst_count, @@ -930,7 +931,9 @@ class FederationHandlerRegistry: # the limit, drop them. if ( edu_type == EduTypes.RoomKeyRequest - and not self._room_key_request_rate_limiter.can_do_action(origin) + and not await self._room_key_request_rate_limiter.can_do_action( + None, origin + ) ): return diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index aade2c4a3a..fb899aa90d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -49,7 +49,7 @@ class BaseHandler: # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 ) self._rc_message = self.hs.config.rc_message @@ -57,6 +57,7 @@ class BaseHandler: # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: self.admin_redaction_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, @@ -91,11 +92,6 @@ class BaseHandler: if app_service is not None: return # do not ratelimit app service senders - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return - messages_per_second = self._rc_message.per_second burst_count = self._rc_message.burst_count @@ -113,11 +109,11 @@ class BaseHandler: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) else: # Override rate and burst count per-user - self.request_ratelimiter.ratelimit( - user_id, + await self.request_ratelimiter.ratelimit( + requester, rate_hz=messages_per_second, burst_count=burst_count, update=update, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d537ea8137..08e413bc98 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -238,6 +238,7 @@ class AuthHandler(BaseHandler): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -248,6 +249,7 @@ class AuthHandler(BaseHandler): # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -352,7 +354,7 @@ class AuthHandler(BaseHandler): requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) + await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -373,7 +375,9 @@ class AuthHandler(BaseHandler): ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) + await self._failed_uia_attempts_ratelimiter.can_do_action( + requester, + ) raise # find the completed login type @@ -982,8 +986,8 @@ class AuthHandler(BaseHandler): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - (medium, address), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, (medium, address), update=False ) # Check for login providers that support 3pid login types @@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler): # this code path, which is fine as then the per-user ratelimit # will kick in below. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - (medium, address) + await self._failed_login_attempts_ratelimiter.can_do_action( + None, (medium, address) ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler): # Check if we've hit the failed ratelimit (but don't update it) if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, qualified_user_id.lower(), update=False ) try: @@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler): # exception and masking the LoginError. The actual ratelimiting # should have happened above. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - qualified_user_id.lower() + await self._failed_login_attempts_ratelimiter.can_do_action( + None, qualified_user_id.lower() ) raise diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index eb547743be..5ee48be6ff 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -81,6 +81,7 @@ class DeviceMessageHandler: ) self._ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.rc_key_requests.per_second, burst_count=hs.config.rc_key_requests.burst_count, @@ -191,8 +192,8 @@ class DeviceMessageHandler: if ( message_type == EduTypes.RoomKeyRequest and user_id != sender_user_id - and self._ratelimiter.can_do_action( - (sender_user_id, requester.device_id) + and await self._ratelimiter.can_do_action( + requester, (sender_user_id, requester.device_id) ) ): continue diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 598a66f74c..3ebee38ebe 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1711,7 +1711,7 @@ class FederationHandler(BaseHandler): member_handler = self.hs.get_room_member_handler() # We don't rate limit based on room ID, as that should be done by # sending server. - member_handler.ratelimit_invite(None, event.state_key) + await member_handler.ratelimit_invite(None, None, event.state_key) # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 5f346f6d6d..d89fa5fb30 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler): # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) - def ratelimit_request_token_requests( + async def ratelimit_request_token_requests( self, request: SynapseRequest, medium: str, @@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler): address: The actual threepid ID, e.g. the phone number or email address """ - self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) - self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + await self._3pid_validation_ratelimiter_ip.ratelimit( + None, (medium, request.getClientIP()) + ) + await self._3pid_validation_ratelimiter_address.ratelimit( + None, (medium, address) + ) async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0fc2bf15d5..9701b76d0f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler): Raises: SynapseError if there was a problem registering. """ - self.check_registration_ratelimit(address) + await self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( threepid, @@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address: Optional[str]) -> None: + async def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler): if not address: return - self.ratelimiter.ratelimit(address) + await self.ratelimiter.ratelimit(None, address) async def register_with_store( self, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4d20ed8357..1cf12f3255 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -75,22 +75,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.allow_per_room_profiles = self.config.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) self._join_rate_limiter_remote = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) self._invites_per_room_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, ) self._invites_per_user_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, @@ -159,15 +163,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def forget(self, user: UserID, room_id: str) -> None: raise NotImplementedError() - def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): + async def ratelimit_invite( + self, + requester: Optional[Requester], + room_id: Optional[str], + invitee_user_id: str, + ): """Ratelimit invites by room and by target user. If room ID is missing then we just rate limit by target user. """ if room_id: - self._invites_per_room_limiter.ratelimit(room_id) + await self._invites_per_room_limiter.ratelimit(requester, room_id) - self._invites_per_user_limiter.ratelimit(invitee_user_id) + await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) async def _local_membership_update( self, @@ -237,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ( allowed, time_allowed, - ) = self._join_rate_limiter_local.can_requester_do_action(requester) + ) = await self._join_rate_limiter_local.can_do_action(requester) if not allowed: raise LimitExceededError( @@ -421,9 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if effective_membership_state == Membership.INVITE: target_id = target.to_string() if ratelimit: - # Don't ratelimit application services. - if not requester.app_service or requester.app_service.is_rate_limited(): - self.ratelimit_invite(room_id, target_id) + await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: @@ -534,7 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ( allowed, time_allowed, - ) = self._join_rate_limiter_remote.can_requester_do_action( + ) = await self._join_rate_limiter_remote.can_do_action( requester, ) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index d005f38767..73d7477854 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - self.registration_handler.check_registration_ratelimit(content["address"]) + await self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.register_with_store( user_id=user_id, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e4c352f572..3151e72d4f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, @@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientIP() + ) result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_token_login(login_submission) else: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet): # too often. This happens here rather than before as we don't # necessarily know the user before now. if ratelimit: - self._account_ratelimiter.ratelimit(user_id.lower()) + await self._account_ratelimiter.ratelimit(None, user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c2ba790bab..411fb57c47 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to @@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) if next_link: # Raise if the provided next_link value isn't valid @@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 8f68d8dfc8..c212da0cb2 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email @@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) @@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - self.ratelimiter.ratelimit(client_addr, update=False) + await self.ratelimiter.ratelimit(None, client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index e85b9391fa..e42f7b1a18 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -329,6 +329,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( + store=self.get_datastore(), clock=self.get_clock(), rate_hz=self.config.rc_registration.per_second, burst_count=self.config.rc_registration.burst_count, diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483418192c..fa96ba07a5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -5,38 +5,25 @@ from synapse.types import create_requester from tests import unittest -class TestRatelimiter(unittest.TestCase): +class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) - self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) - self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) - self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) - - def test_allowed_user_via_can_requester_do_action(self): - user_requester = create_requester("@user:example.com") - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) def test_allowed_via_ratelimit(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=0) + self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) # Should raise with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key="test_id", _time_now_s=5) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=5) + ) self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=10) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=10) + ) def test_allowed_via_can_do_action_and_overriding_parameters(self): """Test that we can override options of can_do_action that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=0, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=0, + ) ) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=1, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=1, + ) ) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, rate_hz=10.0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0) ) self.assertTrue(allowed) self.assertEqual(1.1, time_allowed) # Similarly if we allow a burst of 10 actions - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, burst_count=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10) ) self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) @@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase): fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - limiter.ratelimit(key=("test_id",), _time_now_s=0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=0) + ) # Second attempt, 1s later, will fail with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1) + ) self.assertEqual(context.exception.retry_after_ms, 9000) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0) + ) # Similarly if we allow a burst of 10 actions - limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) + ) def test_pruning(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - limiter.can_do_action(key="test_id_1", _time_now_s=0) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_1", _time_now_s=0) + ) self.assertIn("test_id_1", limiter.actions) - limiter.can_do_action(key="test_id_2", _time_now_s=10) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_2", _time_now_s=10) + ) self.assertNotIn("test_id_1", limiter.actions) + + def test_db_user_override(self): + """Test that users that have ratelimiting disabled in the DB aren't + ratelimited. + """ + store = self.hs.get_datastore() + + user_id = "@user:test" + requester = create_requester(user_id) + + self.get_success( + store.db_pool.simple_insert( + table="ratelimit_override", + values={ + "user_id": user_id, + "messages_per_second": None, + "burst_count": None, + }, + desc="test_db_user_override", + ) + ) + + limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + for _ in range(20): + self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) -- cgit 1.5.1 From 04819239bae2b39ee42bfdb6f9b83c6d9fe34169 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 6 Apr 2021 14:38:30 +0100 Subject: Add a Synapse Module for configuring presence update routing (#9491) At the moment, if you'd like to share presence between local or remote users, those users must be sharing a room together. This isn't always the most convenient or useful situation though. This PR adds a module to Synapse that will allow deployments to set up extra logic on where presence updates should be routed. The module must implement two methods, `get_users_for_states` and `get_interested_users`. These methods are given presence updates or user IDs and must return information that Synapse will use to grant passing presence updates around. A method is additionally added to `ModuleApi` which allows triggering a set of users to receive the current, online presence information for all users they are considered interested in. This is the equivalent of that user receiving presence information during an initial sync. The goal of this module is to be fairly generic and useful for a variety of applications, with hard requirements being: * Sending state for a specific set or all known users to a defined set of local and remote users. * The ability to trigger an initial sync for specific users, so they receive all current state. --- README.rst | 7 +- changelog.d/9491.feature | 1 + docs/presence_router_module.md | 235 +++++++++++++++++++++ docs/sample_config.yaml | 23 +- synapse/app/generic_worker.py | 3 +- synapse/config/server.py | 39 +++- synapse/events/presence_router.py | 104 +++++++++ synapse/federation/sender/__init__.py | 19 +- synapse/handlers/presence.py | 278 ++++++++++++++++++++---- synapse/module_api/__init__.py | 50 +++++ synapse/server.py | 5 + tests/events/test_presence_router.py | 386 ++++++++++++++++++++++++++++++++++ tests/handlers/test_sync.py | 21 +- tests/module_api/test_api.py | 175 ++++++++++++++- 14 files changed, 1282 insertions(+), 64 deletions(-) create mode 100644 changelog.d/9491.feature create mode 100644 docs/presence_router_module.md create mode 100644 synapse/events/presence_router.py create mode 100644 tests/events/test_presence_router.py (limited to 'synapse/server.py') diff --git a/README.rst b/README.rst index 655a2bf3be..1a5503572e 100644 --- a/README.rst +++ b/README.rst @@ -393,7 +393,12 @@ massive excess of outgoing federation requests (see `discussion indicate that your server is also issuing far more outgoing federation requests than can be accounted for by your users' activity, this is a likely cause. The misbehavior can be worked around by setting -``use_presence: false`` in the Synapse config file. +the following in the Synapse config file: + +.. code-block:: yaml + + presence: + enabled: false People can't accept room invitations from me -------------------------------------------- diff --git a/changelog.d/9491.feature b/changelog.d/9491.feature new file mode 100644 index 0000000000..8b56a95a44 --- /dev/null +++ b/changelog.d/9491.feature @@ -0,0 +1 @@ +Add a Synapse module for routing presence updates between users. diff --git a/docs/presence_router_module.md b/docs/presence_router_module.md new file mode 100644 index 0000000000..d6566d978d --- /dev/null +++ b/docs/presence_router_module.md @@ -0,0 +1,235 @@ +# Presence Router Module + +Synapse supports configuring a module that can specify additional users +(local or remote) to should receive certain presence updates from local +users. + +Note that routing presence via Application Service transactions is not +currently supported. + +The presence routing module is implemented as a Python class, which will +be imported by the running Synapse. + +## Python Presence Router Class + +The Python class is instantiated with two objects: + +* A configuration object of some type (see below). +* An instance of `synapse.module_api.ModuleApi`. + +It then implements methods related to presence routing. + +Note that one method of `ModuleApi` that may be useful is: + +```python +async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None +``` + +which can be given a list of local or remote MXIDs to broadcast known, online user +presence to (for those users that the receiving user is considered interested in). +It does not include state for users who are currently offline, and it can only be +called on workers that support sending federation. + +### Module structure + +Below is a list of possible methods that can be implemented, and whether they are +required. + +#### `parse_config` + +```python +def parse_config(config_dict: dict) -> Any +``` + +**Required.** A static method that is passed a dictionary of config options, and + should return a validated config object. This method is described further in + [Configuration](#configuration). + +#### `get_users_for_states` + +```python +async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], +) -> Dict[str, Set[UserPresenceState]]: +``` + +**Required.** An asynchronous method that is passed an iterable of user presence +state. This method can determine whether a given presence update should be sent to certain +users. It does this by returning a dictionary with keys representing local or remote +Matrix User IDs, and values being a python set +of `synapse.handlers.presence.UserPresenceState` instances. + +Synapse will then attempt to send the specified presence updates to each user when +possible. + +#### `get_interested_users` + +```python +async def get_interested_users(self, user_id: str) -> Union[Set[str], str] +``` + +**Required.** An asynchronous method that is passed a single Matrix User ID. This +method is expected to return the users that the passed in user may be interested in the +presence of. Returned users may be local or remote. The presence routed as a result of +what this method returns is sent in addition to the updates already sent between users +that share a room together. Presence updates are deduplicated. + +This method should return a python set of Matrix User IDs, or the object +`synapse.events.presence_router.PresenceRouter.ALL_USERS` to indicate that the passed +user should receive presence information for *all* known users. + +For clarity, if the user `@alice:example.org` is passed to this method, and the Set +`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice +should receive presence updates sent by Bob and Charlie, regardless of whether these +users share a room. + +### Example + +Below is an example implementation of a presence router class. + +```python +from typing import Dict, Iterable, Set, Union +from synapse.events.presence_router import PresenceRouter +from synapse.handlers.presence import UserPresenceState +from synapse.module_api import ModuleApi + +class PresenceRouterConfig: + def __init__(self): + # Config options with their defaults + # A list of users to always send all user presence updates to + self.always_send_to_users = [] # type: List[str] + + # A list of users to ignore presence updates for. Does not affect + # shared-room presence relationships + self.blacklisted_users = [] # type: List[str] + +class ExamplePresenceRouter: + """An example implementation of synapse.presence_router.PresenceRouter. + Supports routing all presence to a configured set of users, or a subset + of presence from certain users to members of certain rooms. + + Args: + config: A configuration object. + module_api: An instance of Synapse's ModuleApi. + """ + def __init__(self, config: PresenceRouterConfig, module_api: ModuleApi): + self._config = config + self._module_api = module_api + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterConfig() + always_send_to_users = config_dict.get("always_send_to_users") + blacklisted_users = config_dict.get("blacklisted_users") + + # Do some validation of config options... otherwise raise a + # synapse.config.ConfigError. + config.always_send_to_users = always_send_to_users + config.blacklisted_users = blacklisted_users + + return config + + async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], + ) -> Dict[str, Set[UserPresenceState]]: + """Given an iterable of user presence updates, determine where each one + needs to go. Returned results will not affect presence updates that are + sent between users who share a room. + + Args: + state_updates: An iterable of user presence state updates. + + Returns: + A dictionary of user_id -> set of UserPresenceState that the user should + receive. + """ + destination_users = {} # type: Dict[str, Set[UserPresenceState] + + # Ignore any updates for blacklisted users + desired_updates = set() + for update in state_updates: + if update.state_key not in self._config.blacklisted_users: + desired_updates.add(update) + + # Send all presence updates to specific users + for user_id in self._config.always_send_to_users: + destination_users[user_id] = desired_updates + + return destination_users + + async def get_interested_users( + self, + user_id: str, + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + """ + Retrieve a list of users that `user_id` is interested in receiving the + presence of. This will be in addition to those they share a room with. + Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate + that this user should receive all incoming local and remote presence updates. + + Note that this method will only be called for local users. + + Args: + user_id: A user requesting presence updates. + + Returns: + A set of user IDs to return additional presence updates for, or + PresenceRouter.ALL_USERS to return presence updates for all other users. + """ + if user_id in self._config.always_send_to_users: + return PresenceRouter.ALL_USERS + + return set() +``` + +#### A note on `get_users_for_states` and `get_interested_users` + +Both of these methods are effectively two different sides of the same coin. The logic +regarding which users should receive updates for other users should be the same +between them. + +`get_users_for_states` is called when presence updates come in from either federation +or local users, and is used to either direct local presence to remote users, or to +wake up the sync streams of local users to collect remote presence. + +In contrast, `get_interested_users` is used to determine the users that presence should +be fetched for when a local user is syncing. This presence is then retrieved, before +being fed through `get_users_for_states` once again, with only the syncing user's +routing information pulled from the resulting dictionary. + +Their routing logic should thus line up, else you may run into unintended behaviour. + +## Configuration + +Once you've crafted your module and installed it into the same Python environment as +Synapse, amend your homeserver config file with the following. + +```yaml +presence: + routing_module: + module: my_module.ExamplePresenceRouter + config: + # Any configuration options for your module. The below is an example. + # of setting options for ExamplePresenceRouter. + always_send_to_users: ["@presence_gobbler:example.org"] + blacklisted_users: + - "@alice:example.com" + - "@bob:example.com" + ... +``` + +The contents of `config` will be passed as a Python dictionary to the static +`parse_config` method of your class. The object returned by this method will +then be passed to the `__init__` method of your module as `config`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b0bf987740..9182dcd987 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -82,9 +82,28 @@ pid_file: DATADIR/homeserver.pid # #soft_file_limit: 0 -# Set to false to disable presence tracking on this homeserver. +# Presence tracking allows users to see the state (e.g online/offline) +# of other local and remote users. # -#use_presence: false +presence: + # Uncomment to disable presence tracking on this homeserver. This option + # replaces the previous top-level 'use_presence' option. + # + #enabled: false + + # Presence routers are third-party modules that can specify additional logic + # to where presence updates from users are routed. + # + presence_router: + # The custom module's class. Uncomment to use a custom presence router module. + # + #module: "my_custom_router.PresenceRouter" + + # Configuration options of the custom module. Refer to your module's + # documentation for available options. + # + #config: + # example_option: 'something' # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 3df2aa5c2b..d1c2079233 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -281,6 +281,7 @@ class GenericWorkerPresence(BasePresenceHandler): self.hs = hs self.is_mine_id = hs.is_mine_id + self.presence_router = hs.get_presence_router() self._presence_enabled = hs.config.use_presence # The number of ongoing syncs on this process, by user id. @@ -395,7 +396,7 @@ class GenericWorkerPresence(BasePresenceHandler): return _user_syncing() async def notify_from_replication(self, states, stream_id): - parties = await get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( diff --git a/synapse/config/server.py b/synapse/config/server.py index 5f8910b6e1..8decc9d10d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -27,6 +27,7 @@ import yaml from netaddr import AddrFormatError, IPNetwork, IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError @@ -238,7 +239,20 @@ class ServerConfig(Config): self.public_baseurl = config.get("public_baseurl") # Whether to enable user presence. - self.use_presence = config.get("use_presence", True) + presence_config = config.get("presence") or {} + self.use_presence = presence_config.get("enabled") + if self.use_presence is None: + self.use_presence = config.get("use_presence", True) + + # Custom presence router module + self.presence_router_module_class = None + self.presence_router_config = None + presence_router_config = presence_config.get("presence_router") + if presence_router_config: + ( + self.presence_router_module_class, + self.presence_router_config, + ) = load_module(presence_router_config, ("presence", "presence_router")) # Whether to update the user directory or not. This should be set to # false only if we are updating the user directory in a worker @@ -834,9 +848,28 @@ class ServerConfig(Config): # #soft_file_limit: 0 - # Set to false to disable presence tracking on this homeserver. + # Presence tracking allows users to see the state (e.g online/offline) + # of other local and remote users. # - #use_presence: false + presence: + # Uncomment to disable presence tracking on this homeserver. This option + # replaces the previous top-level 'use_presence' option. + # + #enabled: false + + # Presence routers are third-party modules that can specify additional logic + # to where presence updates from users are routed. + # + presence_router: + # The custom module's class. Uncomment to use a custom presence router module. + # + #module: "my_custom_router.PresenceRouter" + + # Configuration options of the custom module. Refer to your module's + # documentation for available options. + # + #config: + # example_option: 'something' # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py new file mode 100644 index 0000000000..24cd389d80 --- /dev/null +++ b/synapse/events/presence_router.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Iterable, Set, Union + +from synapse.api.presence import UserPresenceState + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class PresenceRouter: + """ + A module that the homeserver will call upon to help route user presence updates to + additional destinations. If a custom presence router is configured, calls will be + passed to that instead. + """ + + ALL_USERS = "ALL" + + def __init__(self, hs: "HomeServer"): + self.custom_presence_router = None + + # Check whether a custom presence router module has been configured + if hs.config.presence_router_module_class: + # Initialise the module + self.custom_presence_router = hs.config.presence_router_module_class( + config=hs.config.presence_router_config, module_api=hs.get_module_api() + ) + + # Ensure the module has implemented the required methods + required_methods = ["get_users_for_states", "get_interested_users"] + for method_name in required_methods: + if not hasattr(self.custom_presence_router, method_name): + raise Exception( + "PresenceRouter module '%s' must implement all required methods: %s" + % ( + hs.config.presence_router_module_class.__name__, + ", ".join(required_methods), + ) + ) + + async def get_users_for_states( + self, + state_updates: Iterable[UserPresenceState], + ) -> Dict[str, Set[UserPresenceState]]: + """ + Given an iterable of user presence updates, determine where each one + needs to go. + + Args: + state_updates: An iterable of user presence state updates. + + Returns: + A dictionary of user_id -> set of UserPresenceState, indicating which + presence updates each user should receive. + """ + if self.custom_presence_router is not None: + # Ask the custom module + return await self.custom_presence_router.get_users_for_states( + state_updates=state_updates + ) + + # Don't include any extra destinations for presence updates + return {} + + async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]: + """ + Retrieve a list of users that `user_id` is interested in receiving the + presence of. This will be in addition to those they share a room with. + Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate + that this user should receive all incoming local and remote presence updates. + + Note that this method will only be called for local users, but can return users + that are local or remote. + + Args: + user_id: A user requesting presence updates. + + Returns: + A set of user IDs to return presence updates for, or ALL_USERS to return all + known updates. + """ + if self.custom_presence_router is not None: + # Ask the custom module for interested users + return await self.custom_presence_router.get_interested_users( + user_id=user_id + ) + + # A custom presence router is not defined. + # Don't report any additional interested users + return set() diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 8babb1ebbe..98bfce22ff 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -44,6 +44,7 @@ from synapse.types import JsonDict, ReadReceipt, RoomStreamToken from synapse.util.metrics import Measure, measure_func if TYPE_CHECKING: + from synapse.events.presence_router import PresenceRouter from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -162,6 +163,7 @@ class FederationSender(AbstractFederationSender): self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id + self._presence_router = None # type: Optional[PresenceRouter] self._transaction_manager = TransactionManager(hs) self._instance_name = hs.get_instance_name() @@ -584,7 +586,22 @@ class FederationSender(AbstractFederationSender): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ - hosts_and_states = await get_interested_remotes(self.store, states, self.state) + # We pull the presence router here instead of __init__ + # to prevent a dependency cycle: + # + # AuthHandler -> Notifier -> FederationSender + # -> PresenceRouter -> ModuleApi -> AuthHandler + if self._presence_router is None: + self._presence_router = self.hs.get_presence_router() + + assert self._presence_router is not None + + hosts_and_states = await get_interested_remotes( + self.store, + self._presence_router, + states, + self.state, + ) for destinations, states in hosts_and_states: for destination in destinations: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index da92feacc9..c817f2952d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -25,7 +25,17 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import ( + TYPE_CHECKING, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) from prometheus_client import Counter from typing_extensions import ContextManager @@ -34,6 +44,7 @@ import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState +from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge @@ -42,7 +53,7 @@ from synapse.state import StateHandler from synapse.storage.databases.main import DataStore from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler): self.notifier = hs.get_notifier() self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() + self.presence_router = hs.get_presence_router() self._presence_enabled = hs.config.use_presence federation_registry = hs.get_federation_registry() @@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler): """ stream_id, max_token = await self.store.update_presence(states) - parties = await get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -1041,7 +1053,12 @@ class PresenceEventSource: # # Presence -> Notifier -> PresenceEventSource -> Presence # + # Same with get_module_api, get_presence_router + # + # AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler self.get_presence_handler = hs.get_presence_handler + self.get_module_api = hs.get_module_api + self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -1055,7 +1072,7 @@ class PresenceEventSource: include_offline=True, explicit_room_id=None, **kwargs - ): + ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1068,7 +1085,17 @@ class PresenceEventSource: # We don't try and limit the presence updates by the current token, as # sending down the rare duplicate is not a concern. + user_id = user.to_string() + stream_change_cache = self.store.presence_stream_cache + with Measure(self.clock, "presence.get_new_events"): + if user_id in self.get_module_api()._send_full_presence_to_local_users: + # This user has been specified by a module to receive all current, online + # user presence. Removing from_key and setting include_offline to false + # will do effectively this. + from_key = None + include_offline = False + if from_key is not None: from_key = int(from_key) @@ -1091,59 +1118,209 @@ class PresenceEventSource: # doesn't return. C.f. #5503. return [], max_token - presence = self.get_presence_handler() - stream_change_cache = self.store.presence_stream_cache - + # Figure out which other users this user should receive updates for users_interested_in = await self._get_interested_in(user, explicit_room_id) - user_ids_changed = set() # type: Collection[str] - changed = None - if from_key: - changed = stream_change_cache.get_all_entities_changed(from_key) + # We have a set of users that we're interested in the presence of. We want to + # cross-reference that with the users that have actually changed their presence. - if changed is not None and len(changed) < 500: - assert isinstance(user_ids_changed, set) + # Check whether this user should see all user updates - # For small deltas, its quicker to get all changes and then - # work out if we share a room or they're in our presence list - get_updates_counter.labels("stream").inc() - for other_user_id in changed: - if other_user_id in users_interested_in: - user_ids_changed.add(other_user_id) - else: - # Too many possible updates. Find all users we can see and check - # if any of them have changed. - get_updates_counter.labels("full").inc() + if users_interested_in == PresenceRouter.ALL_USERS: + # Provide presence state for all users + presence_updates = await self._filter_all_presence_updates_for_user( + user_id, include_offline, from_key + ) - if from_key: - user_ids_changed = stream_change_cache.get_entities_changed( - users_interested_in, from_key + # Remove the user from the list of users to receive all presence + if user_id in self.get_module_api()._send_full_presence_to_local_users: + self.get_module_api()._send_full_presence_to_local_users.remove( + user_id ) + + return presence_updates, max_token + + # Make mypy happy. users_interested_in should now be a set + assert not isinstance(users_interested_in, str) + + # The set of users that we're interested in and that have had a presence update. + # We'll actually pull the presence updates for these users at the end. + interested_and_updated_users = ( + set() + ) # type: Union[Set[str], FrozenSet[str]] + + if from_key: + # First get all users that have had a presence update + updated_users = stream_change_cache.get_all_entities_changed(from_key) + + # Cross-reference users we're interested in with those that have had updates. + # Use a slightly-optimised method for processing smaller sets of updates. + if updated_users is not None and len(updated_users) < 500: + # For small deltas, it's quicker to get all changes and then + # cross-reference with the users we're interested in + get_updates_counter.labels("stream").inc() + for other_user_id in updated_users: + if other_user_id in users_interested_in: + # mypy thinks this variable could be a FrozenSet as it's possibly set + # to one in the `get_entities_changed` call below, and `add()` is not + # method on a FrozenSet. That doesn't affect us here though, as + # `interested_and_updated_users` is clearly a set() above. + interested_and_updated_users.add(other_user_id) # type: ignore else: - user_ids_changed = users_interested_in + # Too many possible updates. Find all users we can see and check + # if any of them have changed. + get_updates_counter.labels("full").inc() - updates = await presence.current_state_for_users(user_ids_changed) + interested_and_updated_users = ( + stream_change_cache.get_entities_changed( + users_interested_in, from_key + ) + ) + else: + # No from_key has been specified. Return the presence for all users + # this user is interested in + interested_and_updated_users = users_interested_in + + # Retrieve the current presence state for each user + users_to_state = await self.get_presence_handler().current_state_for_users( + interested_and_updated_users + ) + presence_updates = list(users_to_state.values()) - if include_offline: - return (list(updates.values()), max_token) + # Remove the user from the list of users to receive all presence + if user_id in self.get_module_api()._send_full_presence_to_local_users: + self.get_module_api()._send_full_presence_to_local_users.remove(user_id) + + if not include_offline: + # Filter out offline presence states + presence_updates = self._filter_offline_presence_state(presence_updates) + + return presence_updates, max_token + + async def _filter_all_presence_updates_for_user( + self, + user_id: str, + include_offline: bool, + from_key: Optional[int] = None, + ) -> List[UserPresenceState]: + """ + Computes the presence updates a user should receive. + + First pulls presence updates from the database. Then consults PresenceRouter + for whether any updates should be excluded by user ID. + + Args: + user_id: The User ID of the user to compute presence updates for. + include_offline: Whether to include offline presence states from the results. + from_key: The minimum stream ID of updates to pull from the database + before filtering. + + Returns: + A list of presence states for the given user to receive. + """ + if from_key: + # Only return updates since the last sync + updated_users = self.store.presence_stream_cache.get_all_entities_changed( + from_key + ) + if not updated_users: + updated_users = [] + + # Get the actual presence update for each change + users_to_state = await self.get_presence_handler().current_state_for_users( + updated_users + ) + presence_updates = list(users_to_state.values()) + + if not include_offline: + # Filter out offline states + presence_updates = self._filter_offline_presence_state(presence_updates) else: - return ( - [s for s in updates.values() if s.state != PresenceState.OFFLINE], - max_token, + users_to_state = await self.store.get_presence_for_all_users( + include_offline=include_offline ) + presence_updates = list(users_to_state.values()) + + # TODO: This feels wildly inefficient, and it's unfortunate we need to ask the + # module for information on a number of users when we then only take the info + # for a single user + + # Filter through the presence router + users_to_state_set = await self.get_presence_router().get_users_for_states( + presence_updates + ) + + # We only want the mapping for the syncing user + presence_updates = list(users_to_state_set[user_id]) + + # Return presence information for all users + return presence_updates + + def _filter_offline_presence_state( + self, presence_updates: Iterable[UserPresenceState] + ) -> List[UserPresenceState]: + """Given an iterable containing user presence updates, return a list with any offline + presence states removed. + + Args: + presence_updates: Presence states to filter + + Returns: + A new list with any offline presence states removed. + """ + return [ + update + for update in presence_updates + if update.state != PresenceState.OFFLINE + ] + def get_current_key(self): return self.store.get_current_presence_token() @cached(num_args=2, cache_context=True) - async def _get_interested_in(self, user, explicit_room_id, cache_context): + async def _get_interested_in( + self, + user: UserID, + explicit_room_id: Optional[str] = None, + cache_context: Optional[_CacheContext] = None, + ) -> Union[Set[str], str]: """Returns the set of users that the given user should see presence - updates for + updates for. + + Args: + user: The user to retrieve presence updates for. + explicit_room_id: The users that are in the room will be returned. + + Returns: + A set of user IDs to return presence updates for, or "ALL" to return all + known updates. """ user_id = user.to_string() users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence + # cache_context isn't likely to ever be None due to the @cached decorator, + # but we can't have a non-optional argument after the optional argument + # explicit_room_id either. Assert cache_context is not None so we can use it + # without mypy complaining. + assert cache_context + + # Check with the presence router whether we should poll additional users for + # their presence information + additional_users = await self.get_presence_router().get_interested_users( + user.to_string() + ) + if additional_users == PresenceRouter.ALL_USERS: + # If the module requested that this user see the presence updates of *all* + # users, then simply return that instead of calculating what rooms this + # user shares + return PresenceRouter.ALL_USERS + + # Add the additional users from the router + users_interested_in.update(additional_users) + + # Find the users who share a room with this user users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) @@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): async def get_interested_parties( - store: DataStore, states: List[UserPresenceState] + store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState] ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: """Given a list of states return which entities (rooms, users) are interested in the given states. Args: - store - states + store: The homeserver's data store. + presence_router: A module for augmenting the destinations for presence updates. + states: A list of incoming user presence updates. Returns: A 2-tuple of `(room_ids_to_states, users_to_states)`, @@ -1337,11 +1515,22 @@ async def get_interested_parties( # Always notify self users_to_states.setdefault(state.user_id, []).append(state) + # Ask a presence routing module for any additional parties if one + # is loaded. + router_users_to_states = await presence_router.get_users_for_states(states) + + # Update the dictionaries with additional destinations and state to send + for user_id, user_states in router_users_to_states.items(): + users_to_states.setdefault(user_id, []).extend(user_states) + return room_ids_to_states, users_to_states async def get_interested_remotes( - store: DataStore, states: List[UserPresenceState], state_handler: StateHandler + store: DataStore, + presence_router: PresenceRouter, + states: List[UserPresenceState], + state_handler: StateHandler, ) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1349,9 +1538,10 @@ async def get_interested_remotes( All the presence states should be for local users only. Args: - store - states - state_handler + store: The homeserver's data store. + presence_router: A module for augmenting the destinations for presence updates. + states: A list of incoming user presence updates. + state_handler: Returns: A list of 2-tuples of destinations and states, where for @@ -1363,7 +1553,9 @@ async def get_interested_remotes( # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = await get_interested_parties(store, states) + room_ids_to_states, users_to_states = await get_interested_parties( + store, presence_router, states + ) for room_id, states in room_ids_to_states.items(): hosts = await state_handler.get_current_hosts_in_room(room_id) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 781e02fbbb..3ecd46c038 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -50,11 +50,20 @@ class ModuleApi: self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname + self._presence_stream = hs.get_event_sources().sources["presence"] # We expose these as properties below in order to attach a helpful docstring. self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._public_room_list_manager = PublicRoomListManager(hs) + # The next time these users sync, they will receive the current presence + # state of all local users. Users are added by send_local_online_presence_to, + # and removed after a successful sync. + # + # We make this a private variable to deter modules from accessing it directly, + # though other classes in Synapse will still do so. + self._send_full_presence_to_local_users = set() + @property def http_client(self): """Allows making outbound HTTP requests to remote resources. @@ -385,6 +394,47 @@ class ModuleApi: return event + async def send_local_online_presence_to(self, users: Iterable[str]) -> None: + """ + Forces the equivalent of a presence initial_sync for a set of local or remote + users. The users will receive presence for all currently online users that they + are considered interested in. + + Updates to remote users will be sent immediately, whereas local users will receive + them on their next sync attempt. + + Note that this method can only be run on the main or federation_sender worker + processes. + """ + if not self._hs.should_send_federation(): + raise Exception( + "send_local_online_presence_to can only be run " + "on processes that send federation", + ) + + for user in users: + if self._hs.is_mine_id(user): + # Modify SyncHandler._generate_sync_entry_for_presence to call + # presence_source.get_new_events with an empty `from_key` if + # that user's ID were in a list modified by ModuleApi somewhere. + # That user would then get all presence state on next incremental sync. + + # Force a presence initial_sync for this user next time + self._send_full_presence_to_local_users.add(user) + else: + # Retrieve presence state for currently online users that this user + # is considered interested in + presence_events, _ = await self._presence_stream.get_new_events( + UserID.from_string(user), from_key=None, include_offline=False + ) + + # Send to remote destinations + await make_deferred_yieldable( + # We pull the federation sender here as we can only do so on workers + # that support sending presence + self._hs.get_federation_sender().send_presence(presence_events) + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/server.py b/synapse/server.py index e42f7b1a18..cfb55c230d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -51,6 +51,7 @@ from synapse.crypto import context_factory from synapse.crypto.context_factory import RegularPolicyForHTTPS from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory +from synapse.events.presence_router import PresenceRouter from synapse.events.spamcheck import SpamChecker from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.events.utils import EventClientSerializer @@ -425,6 +426,10 @@ class HomeServer(metaclass=abc.ABCMeta): else: raise Exception("Workers cannot write typing") + @cache_in_self + def get_presence_router(self) -> PresenceRouter: + return PresenceRouter(self) + @cache_in_self def get_typing_handler(self) -> FollowerTypingHandler: if self.config.worker.writers.typing == self.get_instance_name(): diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py new file mode 100644 index 0000000000..c6e547f11c --- /dev/null +++ b/tests/events/test_presence_router.py @@ -0,0 +1,386 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +from mock import Mock + +import attr + +from synapse.api.constants import EduTypes +from synapse.events.presence_router import PresenceRouter +from synapse.federation.units import Transaction +from synapse.handlers.presence import UserPresenceState +from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client.v1 import login, presence, room +from synapse.types import JsonDict, StreamToken, create_requester + +from tests.handlers.test_sync import generate_sync_config +from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config + + +@attr.s +class PresenceRouterTestConfig: + users_who_should_receive_all_presence = attr.ib(type=List[str], default=[]) + + +class PresenceRouterTestModule: + def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi): + self._config = config + self._module_api = module_api + + async def get_users_for_states( + self, state_updates: Iterable[UserPresenceState] + ) -> Dict[str, Set[UserPresenceState]]: + users_to_state = { + user_id: set(state_updates) + for user_id in self._config.users_who_should_receive_all_presence + } + return users_to_state + + async def get_interested_users( + self, user_id: str + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + if user_id in self._config.users_who_should_receive_all_presence: + return PresenceRouter.ALL_USERS + + return set() + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterTestConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterTestConfig() + + config.users_who_should_receive_all_presence = config_dict.get( + "users_who_should_receive_all_presence" + ) + + return config + + +class PresenceRouterTestCase(FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + presence.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def prepare(self, reactor, clock, homeserver): + self.sync_handler = self.hs.get_sync_handler() + self.module_api = homeserver.get_module_api() + + @override_config( + { + "presence": { + "presence_router": { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler:test", + ] + }, + } + }, + "send_federation": True, + } + ) + def test_receiving_all_presence(self): + """Test that a user that does not share a room with another other can receive + presence for them, due to presence routing. + """ + # Create a user who should receive all presence of others + self.presence_receiving_user_id = self.register_user( + "presence_gobbler", "monkey" + ) + self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey") + + # And two users who should not have any special routing + self.other_user_one_id = self.register_user("other_user_one", "monkey") + self.other_user_one_tok = self.login("other_user_one", "monkey") + self.other_user_two_id = self.register_user("other_user_two", "monkey") + self.other_user_two_tok = self.login("other_user_two", "monkey") + + # Put the other two users in a room with each other + room_id = self.helper.create_room_as( + self.other_user_one_id, tok=self.other_user_one_tok + ) + + self.helper.invite( + room_id, + self.other_user_one_id, + self.other_user_two_id, + tok=self.other_user_one_tok, + ) + self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok) + # User one sends some presence + send_presence_update( + self, + self.other_user_one_id, + self.other_user_one_tok, + "online", + "boop", + ) + + # Check that the presence receiving user gets user one's presence when syncing + presence_updates, sync_token = sync_presence( + self, self.presence_receiving_user_id + ) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.other_user_one_id) + self.assertEqual(presence_update.state, "online") + self.assertEqual(presence_update.status_msg, "boop") + + # Have all three users send presence + send_presence_update( + self, + self.other_user_one_id, + self.other_user_one_tok, + "online", + "user_one", + ) + send_presence_update( + self, + self.other_user_two_id, + self.other_user_two_tok, + "online", + "user_two", + ) + send_presence_update( + self, + self.presence_receiving_user_id, + self.presence_receiving_user_tok, + "online", + "presence_gobbler", + ) + + # Check that the presence receiving user gets everyone's presence + presence_updates, _ = sync_presence( + self, self.presence_receiving_user_id, sync_token + ) + self.assertEqual(len(presence_updates), 3) + + # But that User One only get itself and User Two's presence + presence_updates, _ = sync_presence(self, self.other_user_one_id) + self.assertEqual(len(presence_updates), 2) + + found = False + for update in presence_updates: + if update.user_id == self.other_user_two_id: + self.assertEqual(update.state, "online") + self.assertEqual(update.status_msg, "user_two") + found = True + + self.assertTrue(found) + + @override_config( + { + "presence": { + "presence_router": { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler1:test", + "@presence_gobbler2:test", + "@far_away_person:island", + ] + }, + } + }, + "send_federation": True, + } + ) + def test_send_local_online_presence_to_with_module(self): + """Tests that send_local_presence_to_users sends local online presence to a set + of specified local and remote users, with a custom PresenceRouter module enabled. + """ + # Create a user who will send presence updates + self.other_user_id = self.register_user("other_user", "monkey") + self.other_user_tok = self.login("other_user", "monkey") + + # And another two users that will also send out presence updates, as well as receive + # theirs and everyone else's + self.presence_receiving_user_one_id = self.register_user( + "presence_gobbler1", "monkey" + ) + self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey") + self.presence_receiving_user_two_id = self.register_user( + "presence_gobbler2", "monkey" + ) + self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey") + + # Have all three users send some presence updates + send_presence_update( + self, + self.other_user_id, + self.other_user_tok, + "online", + "I'm online!", + ) + send_presence_update( + self, + self.presence_receiving_user_one_id, + self.presence_receiving_user_one_tok, + "online", + "I'm also online!", + ) + send_presence_update( + self, + self.presence_receiving_user_two_id, + self.presence_receiving_user_two_tok, + "unavailable", + "I'm in a meeting!", + ) + + # Mark each presence-receiving user for receiving all user presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiving_user_one_id, + self.presence_receiving_user_two_id, + ] + ) + ) + + # Perform a sync for each user + + # The other user should only receive their own presence + presence_updates, _ = sync_presence(self, self.other_user_id) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.other_user_id) + self.assertEqual(presence_update.state, "online") + self.assertEqual(presence_update.status_msg, "I'm online!") + + # Whereas both presence receiving users should receive everyone's presence updates + presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id) + self.assertEqual(len(presence_updates), 3) + presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id) + self.assertEqual(len(presence_updates), 3) + + # Test that sending to a remote user works + remote_user_id = "@far_away_person:island" + + # Note that due to the remote user being in our module's + # users_who_should_receive_all_presence config, they would have + # received user presence updates already. + # + # Thus we reset the mock, and try sending all online local user + # presence again + self.hs.get_federation_transport_client().send_transaction.reset_mock() + + # Broadcast local user online presence + self.get_success( + self.module_api.send_local_online_presence_to([remote_user_id]) + ) + + # Check that the expected presence updates were sent + expected_users = [ + self.other_user_id, + self.presence_receiving_user_one_id, + self.presence_receiving_user_two_id, + ] + + calls = ( + self.hs.get_federation_transport_client().send_transaction.call_args_list + ) + for call in calls: + federation_transaction = call.args[0] # type: Transaction + + # Get the sent EDUs in this transaction + edus = federation_transaction.get_dict()["edus"] + + for edu in edus: + # Make sure we're only checking presence-type EDUs + if edu["edu_type"] != EduTypes.Presence: + continue + + # EDUs can contain multiple presence updates + for presence_update in edu["content"]["push"]: + # Check for presence updates that contain the user IDs we're after + expected_users.remove(presence_update["user_id"]) + + # Ensure that no offline states are being sent out + self.assertNotEqual(presence_update["presence"], "offline") + + self.assertEqual(len(expected_users), 0) + + +def send_presence_update( + testcase: TestCase, + user_id: str, + access_token: str, + presence_state: str, + status_message: Optional[str] = None, +) -> JsonDict: + # Build the presence body + body = {"presence": presence_state} + if status_message: + body["status_msg"] = status_message + + # Update the user's presence state + channel = testcase.make_request( + "PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token + ) + testcase.assertEqual(channel.code, 200) + + return channel.json_body + + +def sync_presence( + testcase: TestCase, + user_id: str, + since_token: Optional[StreamToken] = None, +) -> Tuple[List[UserPresenceState], StreamToken]: + """Perform a sync request for the given user and return the user presence updates + they've received, as well as the next_batch token. + + This method assumes testcase.sync_handler points to the homeserver's sync handler. + + Args: + testcase: The testcase that is currently being run. + user_id: The ID of the user to generate a sync response for. + since_token: An optional token to indicate from at what point to sync from. + + Returns: + A tuple containing a list of presence updates, and the sync response's + next_batch token. + """ + requester = create_requester(user_id) + sync_config = generate_sync_config(requester.user.to_string()) + sync_result = testcase.get_success( + testcase.sync_handler.wait_for_sync_for_user( + requester, sync_config, since_token + ) + ) + + return sync_result.presence, sync_result.next_batch diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index e62586142e..8e950f25c5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" user_id2 = "@user2:test" - sync_config = self._generate_sync_config(user_id1) + sync_config = generate_sync_config(user_id1) requester = create_requester(user_id1) self.reactor.advance(100) # So we get not 0 time @@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = False - sync_config = self._generate_sync_config(user_id2) + sync_config = generate_sync_config(user_id2) requester = create_requester(user_id2) e = self.get_failure( @@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def _generate_sync_config(self, user_id): - return SyncConfig( - user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), - filter_collection=DEFAULT_FILTER_COLLECTION, - is_guest=False, - request_key="request_key", - device_id="device_id", - ) + +def generate_sync_config(user_id: str) -> SyncConfig: + return SyncConfig( + user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), + filter_collection=DEFAULT_FILTER_COLLECTION, + is_guest=False, + request_key="request_key", + device_id="device_id", + ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index edacd1b566..1d1fceeecf 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -14,25 +14,37 @@ # limitations under the License. from mock import Mock +from synapse.api.constants import EduTypes from synapse.events import EventBase +from synapse.federation.units import Transaction +from synapse.handlers.presence import UserPresenceState from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client.v1 import login, presence, room from synapse.types import create_requester -from tests.unittest import HomeserverTestCase +from tests.events.test_presence_router import send_presence_update, sync_presence +from tests.test_utils.event_injection import inject_member_event +from tests.unittest import FederatingHomeserverTestCase, override_config -class ModuleApiTestCase(HomeserverTestCase): +class ModuleApiTestCase(FederatingHomeserverTestCase): servlets = [ admin.register_servlets, login.register_servlets, room.register_servlets, + presence.register_servlets, ] def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() + self.sync_handler = homeserver.get_sync_handler() + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) def test_can_register_user(self): """Tests that an external module can register a user""" @@ -205,3 +217,160 @@ class ModuleApiTestCase(HomeserverTestCase): ) ) self.assertFalse(is_in_public_rooms) + + # The ability to send federation is required by send_local_online_presence_to. + @override_config({"send_federation": True}) + def test_send_local_online_presence_to(self): + """Tests that send_local_presence_to_users sends local online presence to local users.""" + # Create a user who will send presence updates + self.presence_receiver_id = self.register_user("presence_receiver", "monkey") + self.presence_receiver_tok = self.login("presence_receiver", "monkey") + + # And another user that will send presence updates out + self.presence_sender_id = self.register_user("presence_sender", "monkey") + self.presence_sender_tok = self.login("presence_sender", "monkey") + + # Put them in a room together so they will receive each other's presence updates + room_id = self.helper.create_room_as( + self.presence_receiver_id, + tok=self.presence_receiver_tok, + ) + self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok) + + # Presence sender comes online + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "online", + "I'm online!", + ) + + # Presence receiver should have received it + presence_updates, sync_token = sync_presence(self, self.presence_receiver_id) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.presence_sender_id) + self.assertEqual(presence_update.state, "online") + + # Syncing again should result in no presence updates + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 0) + + # Trigger sending local online presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiver_id, + ] + ) + ) + + # Presence receiver should have received online presence again + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.presence_sender_id) + self.assertEqual(presence_update.state, "online") + + # Presence sender goes offline + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "offline", + "I slink back into the darkness.", + ) + + # Trigger sending local online presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiver_id, + ] + ) + ) + + # Presence receiver should *not* have received offline state + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 0) + + @override_config({"send_federation": True}) + def test_send_local_online_presence_to_federation(self): + """Tests that send_local_presence_to_users sends local online presence to remote users.""" + # Create a user who will send presence updates + self.presence_sender_id = self.register_user("presence_sender", "monkey") + self.presence_sender_tok = self.login("presence_sender", "monkey") + + # And a room they're a part of + room_id = self.helper.create_room_as( + self.presence_sender_id, + tok=self.presence_sender_tok, + ) + + # Mark them as online + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "online", + "I'm online!", + ) + + # Make up a remote user to send presence to + remote_user_id = "@far_away_person:island" + + # Create a join membership event for the remote user into the room. + # This allows presence information to flow from one user to the other. + self.get_success( + inject_member_event( + self.hs, + room_id, + sender=remote_user_id, + target=remote_user_id, + membership="join", + ) + ) + + # The remote user would have received the existing room members' presence + # when they joined the room. + # + # Thus we reset the mock, and try sending online local user + # presence again + self.hs.get_federation_transport_client().send_transaction.reset_mock() + + # Broadcast local user online presence + self.get_success( + self.module_api.send_local_online_presence_to([remote_user_id]) + ) + + # Check that a presence update was sent as part of a federation transaction + found_update = False + calls = ( + self.hs.get_federation_transport_client().send_transaction.call_args_list + ) + for call in calls: + federation_transaction = call.args[0] # type: Transaction + + # Get the sent EDUs in this transaction + edus = federation_transaction.get_dict()["edus"] + + for edu in edus: + # Make sure we're only checking presence-type EDUs + if edu["edu_type"] != EduTypes.Presence: + continue + + # EDUs can contain multiple presence updates + for presence_update in edu["content"]["push"]: + if presence_update["user_id"] == self.presence_sender_id: + found_update = True + + self.assertTrue(found_update) -- cgit 1.5.1