From c7a5e49664ab0bd18a57336e282fa6c3b9a17ca0 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 26 Oct 2021 15:17:36 +0200 Subject: Implement an `on_new_event` callback (#11126) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- synapse/handlers/federation_event.py | 2 +- synapse/handlers/message.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9584d5bd46..bd1fa08cef 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1916,7 +1916,7 @@ class FederationEventHandler: event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) - self._notifier.on_new_room_event( + await self._notifier.on_new_room_event( event, event_pos, max_stream_token, extra_users=extra_users ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 2e024b551f..4a0fccfcc6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1537,13 +1537,16 @@ class EventCreationHandler: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - def _notify() -> None: + async def _notify() -> None: try: - self.notifier.on_new_room_event( + await self.notifier.on_new_room_event( event, event_pos, max_stream_token, extra_users=extra_users ) except Exception: - logger.exception("Error notifying about new room event") + logger.exception( + "Error notifying about new room event %s", + event.event_id, + ) run_in_background(_notify) -- cgit 1.4.1 From a930da3291b5b1d2375c3bd7c4a34f1588704292 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 27 Oct 2021 10:19:19 -0400 Subject: Include the stable identifier for MSC3288. (#11187) Includes both the stable and unstable identifier to store-invite calls to the identity server. In the future we should remove the unstable identifier. --- changelog.d/11187.feature | 1 + synapse/handlers/identity.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/11187.feature (limited to 'synapse/handlers') diff --git a/changelog.d/11187.feature b/changelog.d/11187.feature new file mode 100644 index 0000000000..dd28109030 --- /dev/null +++ b/changelog.d/11187.feature @@ -0,0 +1 @@ +Support the stable room type field for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288). diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 7ef8698a5e..6a315117ba 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -879,6 +879,8 @@ class IdentityHandler: } if room_type is not None: + invite_config["room_type"] = room_type + # TODO The unstable field is deprecated and should be removed in the future. invite_config["org.matrix.msc3288.room_type"] = room_type # If a custom web client location is available, include it in the request. -- cgit 1.4.1 From 19d5dc69316a28035caf6a6519ad8a116023de81 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 27 Oct 2021 11:26:30 -0400 Subject: Refactor `Filter` to handle fields according to data being filtered. (#11194) This avoids filtering against fields which cannot exist on an event source. E.g. presence updates don't have a room. --- changelog.d/11194.misc | 1 + synapse/api/filtering.py | 139 +++++++++++++++++++++++------------------ synapse/handlers/pagination.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/search.py | 12 ++-- 5 files changed, 87 insertions(+), 69 deletions(-) create mode 100644 changelog.d/11194.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11194.misc b/changelog.d/11194.misc new file mode 100644 index 0000000000..fc1d06ba89 --- /dev/null +++ b/changelog.d/11194.misc @@ -0,0 +1 @@ +Refactor `Filter` to check different fields depending on the data type. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bc550ae646..4b0a9b2974 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -18,7 +18,8 @@ import json from typing import ( TYPE_CHECKING, Awaitable, - Container, + Callable, + Dict, Iterable, List, Optional, @@ -217,19 +218,19 @@ class FilterCollection: return self._filter_json def timeline_limit(self) -> int: - return self._room_timeline_filter.limit() + return self._room_timeline_filter.limit def presence_limit(self) -> int: - return self._presence_filter.limit() + return self._presence_filter.limit def ephemeral_limit(self) -> int: - return self._room_ephemeral_filter.limit() + return self._room_ephemeral_filter.limit def lazy_load_members(self) -> bool: - return self._room_state_filter.lazy_load_members() + return self._room_state_filter.lazy_load_members def include_redundant_members(self) -> bool: - return self._room_state_filter.include_redundant_members() + return self._room_state_filter.include_redundant_members def filter_presence( self, events: Iterable[UserPresenceState] @@ -276,19 +277,25 @@ class Filter: def __init__(self, filter_json: JsonDict): self.filter_json = filter_json - self.types = self.filter_json.get("types", None) - self.not_types = self.filter_json.get("not_types", []) + self.limit = filter_json.get("limit", 10) + self.lazy_load_members = filter_json.get("lazy_load_members", False) + self.include_redundant_members = filter_json.get( + "include_redundant_members", False + ) + + self.types = filter_json.get("types", None) + self.not_types = filter_json.get("not_types", []) - self.rooms = self.filter_json.get("rooms", None) - self.not_rooms = self.filter_json.get("not_rooms", []) + self.rooms = filter_json.get("rooms", None) + self.not_rooms = filter_json.get("not_rooms", []) - self.senders = self.filter_json.get("senders", None) - self.not_senders = self.filter_json.get("not_senders", []) + self.senders = filter_json.get("senders", None) + self.not_senders = filter_json.get("not_senders", []) - self.contains_url = self.filter_json.get("contains_url", None) + self.contains_url = filter_json.get("contains_url", None) - self.labels = self.filter_json.get("org.matrix.labels", None) - self.not_labels = self.filter_json.get("org.matrix.not_labels", []) + self.labels = filter_json.get("org.matrix.labels", None) + self.not_labels = filter_json.get("org.matrix.not_labels", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -302,76 +309,95 @@ class Filter: def check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. + Args: + event: The event, account data, or presence to check against this + filter. + Returns: - True if the event matches + True if the event matches the filter. """ # We usually get the full "events" as dictionaries coming through, # except for presence which actually gets passed around as its own # namedtuple type. if isinstance(event, UserPresenceState): - sender: Optional[str] = event.user_id - room_id = None - ev_type = "m.presence" - contains_url = False - labels: List[str] = [] + user_id = event.user_id + field_matchers = { + "senders": lambda v: user_id == v, + "types": lambda v: "m.presence" == v, + } + return self._check_fields(field_matchers) else: + content = event.get("content") + # Content is assumed to be a dict below, so ensure it is. This should + # always be true for events, but account_data has been allowed to + # have non-dict content. + if not isinstance(content, dict): + content = {} + sender = event.get("sender", None) if not sender: # Presence events had their 'sender' in content.user_id, but are # now handled above. We don't know if anything else uses this # form. TODO: Check this and probably remove it. - content = event.get("content") - # account_data has been allowed to have non-dict content, so - # check type first - if isinstance(content, dict): - sender = content.get("user_id") + sender = content.get("user_id") room_id = event.get("room_id", None) ev_type = event.get("type", None) - content = event.get("content") or {} # check if there is a string url field in the content for filtering purposes - contains_url = isinstance(content.get("url"), str) labels = content.get(EventContentFields.LABELS, []) - return self.check_fields(room_id, sender, ev_type, labels, contains_url) + field_matchers = { + "rooms": lambda v: room_id == v, + "senders": lambda v: sender == v, + "types": lambda v: _matches_wildcard(ev_type, v), + "labels": lambda v: v in labels, + } + + result = self._check_fields(field_matchers) + if not result: + return result + + contains_url_filter = self.contains_url + if contains_url_filter is not None: + contains_url = isinstance(content.get("url"), str) + if contains_url_filter != contains_url: + return False + + return True - def check_fields( - self, - room_id: Optional[str], - sender: Optional[str], - event_type: Optional[str], - labels: Container[str], - contains_url: bool, - ) -> bool: + def _check_fields(self, field_matchers: Dict[str, Callable[[str], bool]]) -> bool: """Checks whether the filter matches the given event fields. + Args: + field_matchers: A map of attribute name to callable to use for checking + particular fields. + + The attribute name and an inverse (not_) must + exist on the Filter. + + The callable should return true if the event's value matches the + filter's value. + Returns: True if the event fields match """ - literal_keys = { - "rooms": lambda v: room_id == v, - "senders": lambda v: sender == v, - "types": lambda v: _matches_wildcard(event_type, v), - "labels": lambda v: v in labels, - } - - for name, match_func in literal_keys.items(): + + for name, match_func in field_matchers.items(): + # If the event matches one of the disallowed values, reject it. not_name = "not_%s" % (name,) disallowed_values = getattr(self, not_name) if any(map(match_func, disallowed_values)): return False + # Other the event does not match at least one of the allowed values, + # reject it. allowed_values = getattr(self, name) if allowed_values is not None: if not any(map(match_func, allowed_values)): return False - contains_url_filter = self.filter_json.get("contains_url") - if contains_url_filter is not None: - if contains_url_filter != contains_url: - return False - + # Otherwise, accept it. return True def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]: @@ -385,10 +411,10 @@ class Filter: """ room_ids = set(room_ids) - disallowed_rooms = set(self.filter_json.get("not_rooms", [])) + disallowed_rooms = set(self.not_rooms) room_ids -= disallowed_rooms - allowed_rooms = self.filter_json.get("rooms", None) + allowed_rooms = self.rooms if allowed_rooms is not None: room_ids &= set(allowed_rooms) @@ -397,15 +423,6 @@ class Filter: def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: return list(filter(self.check, events)) - def limit(self) -> int: - return self.filter_json.get("limit", 10) - - def lazy_load_members(self) -> bool: - return self.filter_json.get("lazy_load_members", False) - - def include_redundant_members(self) -> bool: - return self.filter_json.get("include_redundant_members", False) - def with_room_ids(self, room_ids: Iterable[str]) -> "Filter": """Returns a new filter with the given room IDs appended. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 60ff896386..abfe7be0e3 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -438,7 +438,7 @@ class PaginationHandler: } state = None - if event_filter and event_filter.lazy_load_members() and len(events) > 0: + if event_filter and event_filter.lazy_load_members and len(events) > 0: # TODO: remove redundant members # FIXME: we also care about invite targets etc. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index cf01d58ea1..99e9b37344 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1173,7 +1173,7 @@ class RoomContextHandler: else: last_event_id = event_id - if event_filter and event_filter.lazy_load_members(): + if event_filter and event_filter.lazy_load_members: state_filter = StateFilter.from_lazy_load_member_list( ev.sender for ev in itertools.chain( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index a3ffa26be8..6e4dff8056 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -249,7 +249,7 @@ class SearchHandler: ) events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[: search_filter.limit()] + allowed_events = events[: search_filter.limit] for e in allowed_events: rm = room_groups.setdefault( @@ -271,13 +271,13 @@ class SearchHandler: # We keep looping and we keep filtering until we reach the limit # or we run out of things. # But only go around 5 times since otherwise synapse will be sad. - while len(room_events) < search_filter.limit() and i < 5: + while len(room_events) < search_filter.limit and i < 5: i += 1 search_result = await self.store.search_rooms( room_ids, search_term, keys, - search_filter.limit() * 2, + search_filter.limit * 2, pagination_token=pagination_token, ) @@ -299,9 +299,9 @@ class SearchHandler: ) room_events.extend(events) - room_events = room_events[: search_filter.limit()] + room_events = room_events[: search_filter.limit] - if len(results) < search_filter.limit() * 2: + if len(results) < search_filter.limit * 2: pagination_token = None break else: @@ -311,7 +311,7 @@ class SearchHandler: group = room_groups.setdefault(event.room_id, {"results": []}) group["results"].append(event.event_id) - if room_events and len(room_events) >= search_filter.limit(): + if room_events and len(room_events) >= search_filter.limit: last_event_id = room_events[-1].event_id pagination_token = results_map[last_event_id]["pagination_token"] -- cgit 1.4.1 From 75ca0a6168f92dab3255839cf85fb0df3a0076c3 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 27 Oct 2021 17:27:23 +0100 Subject: Annotate `log_function` decorator (#10943) Co-authored-by: Patrick Cloke --- changelog.d/10943.misc | 1 + synapse/federation/federation_client.py | 17 +++++++++++++++-- synapse/federation/federation_server.py | 10 ++++++---- synapse/federation/sender/transaction_manager.py | 1 - synapse/federation/transport/client.py | 22 ++++++++++++++++++---- synapse/handlers/directory.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/presence.py | 2 ++ synapse/handlers/profile.py | 4 ++++ synapse/logging/utils.py | 8 ++++++-- synapse/state/__init__.py | 5 +++-- synapse/storage/databases/main/profile.py | 2 +- 12 files changed, 58 insertions(+), 18 deletions(-) create mode 100644 changelog.d/10943.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10943.misc b/changelog.d/10943.misc new file mode 100644 index 0000000000..3ce28d1a67 --- /dev/null +++ b/changelog.d/10943.misc @@ -0,0 +1 @@ +Add type annotations for the `log_function` decorator. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 2ab4dec88f..670186f548 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -227,7 +227,7 @@ class FederationClient(FederationBase): ) async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Iterable[str] + self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> Optional[List[EventBase]]: """Requests some more historic PDUs for the given room from the given destination server. @@ -237,6 +237,8 @@ class FederationClient(FederationBase): room_id: The room_id to backfill. limit: The maximum number of events to return. extremities: our current backwards extremities, to backfill from + Must be a Collection that is falsy when empty. + (Iterable is not enough here!) """ logger.debug("backfill extrem=%s", extremities) @@ -250,11 +252,22 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%r", transaction_data) + if not isinstance(transaction_data, dict): + # TODO we probably want an exception type specific to federation + # client validation. + raise TypeError("Backfill transaction_data is not a dict.") + + transaction_data_pdus = transaction_data.get("pdus") + if not isinstance(transaction_data_pdus, list): + # TODO we probably want an exception type specific to federation + # client validation. + raise TypeError("transaction_data.pdus is not a list.") + room_version = await self.store.get_room_version(room_id) pdus = [ event_from_pdu_json(p, room_version, outlier=False) - for p in transaction_data["pdus"] + for p in transaction_data_pdus ] # Check signatures and hash of pdus, removing any from the list that fail checks diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 0d66034f44..32a75993d9 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -295,14 +295,16 @@ class FederationServer(FederationBase): Returns: HTTP response code and body """ - response = await self.transaction_actions.have_responded(origin, transaction) + existing_response = await self.transaction_actions.have_responded( + origin, transaction + ) - if response: + if existing_response: logger.debug( "[%s] We've already responded to this request", transaction.transaction_id, ) - return response + return existing_response logger.debug("[%s] Transaction is new", transaction.transaction_id) @@ -632,7 +634,7 @@ class FederationServer(FederationBase): async def on_make_knock_request( self, origin: str, room_id: str, user_id: str, supported_versions: List[str] - ) -> Dict[str, Union[EventBase, str]]: + ) -> JsonDict: """We've received a /make_knock/ request, so we create a partial knock event for the room and hand that back, along with the room version, to the knocking homeserver. We do *not* persist or process this event until the other server has diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index dc555cca0b..ab935e5a7e 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -149,7 +149,6 @@ class TransactionManager: ) except HttpResponseException as e: code = e.code - response = e.response set_tag(tags.ERROR, True) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8b247fe206..d963178838 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,7 +15,19 @@ import logging import urllib -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import attr import ijson @@ -100,7 +112,7 @@ class TransportLayerClient: @log_function async def backfill( - self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int + self, destination: str, room_id: str, event_tuples: Collection[str], limit: int ) -> Optional[JsonDict]: """Requests `limit` previous PDUs in a given context before list of PDUs. @@ -108,7 +120,9 @@ class TransportLayerClient: Args: destination room_id - event_tuples + event_tuples: + Must be a Collection that is falsy when empty. + (Iterable is not enough here!) limit Returns: @@ -786,7 +800,7 @@ class TransportLayerClient: @log_function def join_group( self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: + ) -> Awaitable[JsonDict]: """Attempts to join a group""" path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8567cb0e00..8ca5f60b1c 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -245,7 +245,7 @@ class DirectoryHandler: servers = result.servers else: try: - fed_result = await self.federation.make_query( + fed_result: Optional[JsonDict] = await self.federation.make_query( destination=room_alias.domain, query_type="directory", args={"room_alias": room_alias.to_string()}, diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index bd1fa08cef..e617db4c0d 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -477,7 +477,7 @@ class FederationEventHandler: @log_function async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Iterable[str] + self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> None: """Trigger a backfill request to `dest` for the given `room_id` diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fdab50da37..3df872c578 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -52,6 +52,7 @@ import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState +from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.logging.utils import log_function @@ -1551,6 +1552,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, + service: Optional[ApplicationService] = None, ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index e6c3cf585b..6b5a6ded8b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -456,7 +456,11 @@ class ProfileHandler: continue new_name = profile.get("displayname") + if not isinstance(new_name, str): + new_name = None new_avatar = profile.get("avatar_url") + if not isinstance(new_avatar, str): + new_avatar = None # We always hit update to update the last_check timestamp await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 08895e72ee..4a01b902c2 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -16,6 +16,7 @@ import logging from functools import wraps from inspect import getcallargs +from typing import Callable, TypeVar, cast _TIME_FUNC_ID = 0 @@ -41,7 +42,10 @@ def _log_debug_as_f(f, msg, msg_args): logger.handle(record) -def log_function(f): +F = TypeVar("F", bound=Callable) + + +def log_function(f: F) -> F: """Function decorator that logs every call to that function.""" func_name = f.__name__ @@ -69,4 +73,4 @@ def log_function(f): return f(*args, **kwargs) wrapped.__name__ = func_name - return wrapped + return cast(F, wrapped) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5cf2e12575..98a0239759 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -26,6 +26,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -519,7 +520,7 @@ class StateResolutionHandler: self, room_id: str, room_version: str, - state_groups_ids: Dict[int, StateMap[str]], + state_groups_ids: Mapping[int, StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ) -> _StateCacheEntry: @@ -703,7 +704,7 @@ class StateResolutionHandler: def _make_state_cache_entry( - new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] + new_state: StateMap[str], state_groups_ids: Mapping[int, StateMap[str]] ) -> _StateCacheEntry: """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index ba7075caa5..dd8e27e226 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def update_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str + self, user_id: str, displayname: Optional[str], avatar_url: Optional[str] ) -> int: return await self.db_pool.simple_update( table="remote_profile_cache", -- cgit 1.4.1 From 0e16b418f6835c7a2a9aae4637b0a9f2ca47f518 Mon Sep 17 00:00:00 2001 From: Rafael Gonçalves <8217676+RafaelGoncalves8@users.noreply.github.com> Date: Thu, 28 Oct 2021 14:54:38 -0300 Subject: Add knock information in admin exported data (#11171) Signed-off-by: Rafael Goncalves --- changelog.d/11171.misc | 1 + synapse/app/admin_cmd.py | 14 ++++++++++++++ synapse/handlers/admin.py | 22 ++++++++++++++++++++++ tests/handlers/test_admin.py | 35 +++++++++++++++++++++++++++++++++-- tests/rest/client/utils.py | 29 +++++++++++++++++++++++++++++ 5 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11171.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11171.misc b/changelog.d/11171.misc new file mode 100644 index 0000000000..b6a41a96da --- /dev/null +++ b/changelog.d/11171.misc @@ -0,0 +1 @@ +Add knock information in admin export. Contributed by Rafael Gonçalves. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 2fc848596d..ad20b1d6aa 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -145,6 +145,20 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event), file=f) + def write_knock(self, room_id, event, state): + self.write_events(room_id, [event]) + + # We write the knock state somewhere else as they aren't full events + # and are only a subset of the state at the event. + room_directory = os.path.join(self.base_directory, "rooms", room_id) + os.makedirs(room_directory, exist_ok=True) + + knock_state = os.path.join(room_directory, "knock_state") + + with open(knock_state, "a") as f: + for event in state.values(): + print(json.dumps(event), file=f) + def finished(self): return self.base_directory diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index a53cd62d3c..be3203ac80 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -90,6 +90,7 @@ class AdminHandler: Membership.LEAVE, Membership.BAN, Membership.INVITE, + Membership.KNOCK, ), ) @@ -122,6 +123,13 @@ class AdminHandler: invited_state = invite.unsigned["invite_room_state"] writer.write_invite(room_id, invite, invited_state) + if room.membership == Membership.KNOCK: + event_id = room.event_id + knock = await self.store.get_event(event_id, allow_none=True) + if knock: + knock_state = knock.unsigned["knock_room_state"] + writer.write_knock(room_id, knock, knock_state) + continue # We only want to bother fetching events up to the last time they @@ -238,6 +246,20 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): """ raise NotImplementedError() + @abc.abstractmethod + def write_knock( + self, room_id: str, event: EventBase, state: StateMap[dict] + ) -> None: + """Write a knock for the room, with associated knock state. + + Args: + room_id: The room ID the knock is for. + event: The knock event. + state: A subset of the state at the knock, with a subset of the + event keys (type, state_key content and sender). + """ + raise NotImplementedError() + @abc.abstractmethod def finished(self) -> Any: """Called when all data has successfully been exported and written. diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 59de1142b1..abf2a0fe0d 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -17,8 +17,9 @@ from unittest.mock import Mock import synapse.rest.admin import synapse.storage -from synapse.api.constants import EventTypes -from synapse.rest.client import login, room +from synapse.api.constants import EventTypes, JoinRules +from synapse.api.room_versions import RoomVersions +from synapse.rest.client import knock, login, room from tests import unittest @@ -28,6 +29,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, room.register_servlets, + knock.register_servlets, ] def prepare(self, reactor, clock, hs): @@ -201,3 +203,32 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(args[0], room_id) self.assertEqual(args[1].content["membership"], "invite") self.assertTrue(args[2]) # Assert there is at least one bit of state + + def test_knock(self): + """Tests that knock get handled correctly.""" + # create a knockable v7 room + room_id = self.helper.create_room_as( + self.user1, room_version=RoomVersions.V7.identifier, tok=self.token1 + ) + self.helper.send_state( + room_id, + EventTypes.JoinRules, + {"join_rule": JoinRules.KNOCK}, + tok=self.token1, + ) + + self.helper.send(room_id, body="Hello!", tok=self.token1) + self.helper.knock(room_id, self.user2, tok=self.token2) + + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_events.assert_not_called() + writer.write_state.assert_not_called() + writer.write_knock.assert_called_once() + + args = writer.write_knock.call_args[0] + self.assertEqual(args[0], room_id) + self.assertEqual(args[1].content["membership"], "knock") + self.assertTrue(args[2]) # Assert there is at least one bit of state diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 71fa87ce92..ec0979850b 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -120,6 +120,35 @@ class RestHelper: expect_code=expect_code, ) + def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None): + temp_id = self.auth_user_id + self.auth_user_id = user + path = "/knock/%s" % room + if tok: + path = path + "?access_token=%s" % tok + + data = {} + if reason: + data["reason"] = reason + + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + path, + json.dumps(data).encode("utf8"), + ) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + self.auth_user_id = temp_id + def leave(self, room=None, user=None, expect_code=200, tok=None): self.change_membership( room=room, -- cgit 1.4.1 From c9c3aea9b189cb606d7ec2905dad2c87acc039ef Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 2 Nov 2021 10:39:02 +0000 Subject: Fix providing a `RoomStreamToken` instance to `_notify_app_services_ephemeral` (#11137) Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/11137.misc | 1 + synapse/handlers/appservice.py | 22 +++++++++++++---- synapse/notifier.py | 38 +++++++----------------------- synapse/storage/databases/main/devices.py | 4 ++-- synapse/storage/databases/main/presence.py | 2 +- 5 files changed, 30 insertions(+), 37 deletions(-) create mode 100644 changelog.d/11137.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11137.misc b/changelog.d/11137.misc new file mode 100644 index 0000000000..f0d6476f48 --- /dev/null +++ b/changelog.d/11137.misc @@ -0,0 +1 @@ +Remove and document unnecessary `RoomStreamToken` checks in application service ephemeral event code. \ No newline at end of file diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 36c206dae6..67f8ffcaff 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -182,7 +182,7 @@ class ApplicationServicesHandler: def notify_interested_services_ephemeral( self, stream_key: str, - new_token: Optional[int], + new_token: Union[int, RoomStreamToken], users: Optional[Collection[Union[str, UserID]]] = None, ) -> None: """ @@ -203,7 +203,7 @@ class ApplicationServicesHandler: Appservices will only receive ephemeral events that fall within their registered user and room namespaces. - new_token: The latest stream token. + new_token: The stream token of the event. users: The users that should be informed of the new event, if any. """ if not self.notify_appservices: @@ -212,6 +212,19 @@ class ApplicationServicesHandler: if stream_key not in ("typing_key", "receipt_key", "presence_key"): return + # Assert that new_token is an integer (and not a RoomStreamToken). + # All of the supported streams that this function handles use an + # integer to track progress (rather than a RoomStreamToken - a + # vector clock implementation) as they don't support multiple + # stream writers. + # + # As a result, we simply assert that new_token is an integer. + # If we do end up needing to pass a RoomStreamToken down here + # in the future, using RoomStreamToken.stream (the minimum stream + # position) to convert to an ascending integer value should work. + # Additional context: https://github.com/matrix-org/synapse/pull/11137 + assert isinstance(new_token, int) + services = [ service for service in self.store.get_app_services() @@ -231,14 +244,13 @@ class ApplicationServicesHandler: self, services: List[ApplicationService], stream_key: str, - new_token: Optional[int], + new_token: int, users: Collection[Union[str, UserID]], ) -> None: logger.debug("Checking interested services for %s" % (stream_key)) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: - # Only handle typing if we have the latest token - if stream_key == "typing_key" and new_token is not None: + if stream_key == "typing_key": # Note that we don't persist the token (via set_type_stream_id_for_appservice) # for typing_key due to performance reasons and due to their highly # ephemeral nature. diff --git a/synapse/notifier.py b/synapse/notifier.py index 1882fffd2a..60e5409895 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -383,29 +383,6 @@ class Notifier: except Exception: logger.exception("Error notifying application services of event") - def _notify_app_services_ephemeral( - self, - stream_key: str, - new_token: Union[int, RoomStreamToken], - users: Optional[Collection[Union[str, UserID]]] = None, - ) -> None: - """Notify application services of ephemeral event activity. - - Args: - stream_key: The stream the event came from. - new_token: The value of the new stream token. - users: The users that should be informed of the new event, if any. - """ - try: - stream_token = None - if isinstance(new_token, int): - stream_token = new_token - self.appservice_handler.notify_interested_services_ephemeral( - stream_key, stream_token, users or [] - ) - except Exception: - logger.exception("Error notifying application services of event") - def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: self._pusher_pool.on_new_notifications(max_room_stream_token) @@ -467,12 +444,15 @@ class Notifier: self.notify_replication() - # Notify appservices - self._notify_app_services_ephemeral( - stream_key, - new_token, - users, - ) + # Notify appservices. + try: + self.appservice_handler.notify_interested_services_ephemeral( + stream_key, + new_token, + users, + ) + except Exception: + logger.exception("Error notifying application services of event") def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happened diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index b15cd030e0..9ccc66e589 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore): user_ids: the users who were signed Returns: - THe new stream ID. + The new stream ID. """ async with self._device_list_id_gen.get_next() as stream_id: @@ -1322,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): async def add_device_change_to_streams( self, user_id: str, device_ids: Collection[str], hosts: List[str] - ): + ) -> int: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 12cf6995eb..cc0eebdb46 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): prefilled_cache=presence_cache_prefill, ) - async def update_presence(self, presence_states): + async def update_presence(self, presence_states) -> Tuple[int, int]: assert self._can_persist_presence stream_ordering_manager = self._presence_id_gen.get_next_mult( -- cgit 1.4.1 From c01bc5f43d1c7d0a25f397b542ced57894395519 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 2 Nov 2021 09:55:52 -0400 Subject: Add remaining type hints to `synapse.events`. (#11098) --- changelog.d/11098.misc | 1 + mypy.ini | 8 +- synapse/events/__init__.py | 227 +++++++++++++++++---------- synapse/events/validator.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/message.py | 14 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_batch.py | 2 +- synapse/handlers/room_member.py | 4 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_rule_evaluator.py | 10 +- synapse/rest/client/room_batch.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/events.py | 7 +- synapse/storage/databases/main/roommember.py | 8 +- 15 files changed, 185 insertions(+), 110 deletions(-) create mode 100644 changelog.d/11098.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11098.misc b/changelog.d/11098.misc new file mode 100644 index 0000000000..1e337bee54 --- /dev/null +++ b/changelog.d/11098.misc @@ -0,0 +1 @@ +Add type hints to `synapse.events`. diff --git a/mypy.ini b/mypy.ini index 119a7d8c91..600402a5d3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -22,13 +22,7 @@ files = synapse/config, synapse/crypto, synapse/event_auth.py, - synapse/events/builder.py, - synapse/events/presence_router.py, - synapse/events/snapshot.py, - synapse/events/spamcheck.py, - synapse/events/third_party_rules.py, - synapse/events/utils.py, - synapse/events/validator.py, + synapse/events, synapse/federation, synapse/groups, synapse/handlers, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 157669ea88..38f3cf4d33 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -16,8 +16,23 @@ import abc import os -from typing import Dict, Optional, Tuple, Type - +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +from typing_extensions import Literal from unpaddedbase64 import encode_base64 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions @@ -26,6 +41,9 @@ from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze from synapse.util.stringutils import strtobool +if TYPE_CHECKING: + from synapse.events.builder import EventBuilder + # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # bugs where we accidentally share e.g. signature dicts. However, converting a # dict to frozen_dicts is expensive. @@ -37,7 +55,23 @@ from synapse.util.stringutils import strtobool USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) -class DictProperty: +T = TypeVar("T") + + +# DictProperty (and DefaultDictProperty) require the classes they're used with to +# have a _dict property to pull properties from. +# +# TODO _DictPropertyInstance should not include EventBuilder but due to +# https://github.com/python/mypy/issues/5570 it thinks the DictProperty and +# DefaultDictProperty get applied to EventBuilder when it is in a Union with +# EventBase. This is the least invasive hack to get mypy to comply. +# +# Note that DictProperty/DefaultDictProperty cannot actually be used with +# EventBuilder as it lacks a _dict property. +_DictPropertyInstance = Union["_EventInternalMetadata", "EventBase", "EventBuilder"] + + +class DictProperty(Generic[T]): """An object property which delegates to the `_dict` within its parent object.""" __slots__ = ["key"] @@ -45,12 +79,33 @@ class DictProperty: def __init__(self, key: str): self.key = key - def __get__(self, instance, owner=None): + @overload + def __get__( + self, + instance: Literal[None], + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> "DictProperty": + ... + + @overload + def __get__( + self, + instance: _DictPropertyInstance, + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> T: + ... + + def __get__( + self, + instance: Optional[_DictPropertyInstance], + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> Union[T, "DictProperty"]: # if the property is accessed as a class property rather than an instance # property, return the property itself rather than the value if instance is None: return self try: + assert isinstance(instance, (EventBase, _EventInternalMetadata)) return instance._dict[self.key] except KeyError as e1: # We want this to look like a regular attribute error (mostly so that @@ -65,10 +120,12 @@ class DictProperty: "'%s' has no '%s' property" % (type(instance), self.key) ) from e1.__context__ - def __set__(self, instance, v): + def __set__(self, instance: _DictPropertyInstance, v: T) -> None: + assert isinstance(instance, (EventBase, _EventInternalMetadata)) instance._dict[self.key] = v - def __delete__(self, instance): + def __delete__(self, instance: _DictPropertyInstance) -> None: + assert isinstance(instance, (EventBase, _EventInternalMetadata)) try: del instance._dict[self.key] except KeyError as e1: @@ -77,7 +134,7 @@ class DictProperty: ) from e1.__context__ -class DefaultDictProperty(DictProperty): +class DefaultDictProperty(DictProperty, Generic[T]): """An extension of DictProperty which provides a default if the property is not present in the parent's _dict. @@ -86,13 +143,34 @@ class DefaultDictProperty(DictProperty): __slots__ = ["default"] - def __init__(self, key, default): + def __init__(self, key: str, default: T): super().__init__(key) self.default = default - def __get__(self, instance, owner=None): + @overload + def __get__( + self, + instance: Literal[None], + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> "DefaultDictProperty": + ... + + @overload + def __get__( + self, + instance: _DictPropertyInstance, + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> T: + ... + + def __get__( + self, + instance: Optional[_DictPropertyInstance], + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> Union[T, "DefaultDictProperty"]: if instance is None: return self + assert isinstance(instance, (EventBase, _EventInternalMetadata)) return instance._dict.get(self.key, self.default) @@ -111,22 +189,22 @@ class _EventInternalMetadata: # in the DAG) self.outlier = False - out_of_band_membership: bool = DictProperty("out_of_band_membership") - send_on_behalf_of: str = DictProperty("send_on_behalf_of") - recheck_redaction: bool = DictProperty("recheck_redaction") - soft_failed: bool = DictProperty("soft_failed") - proactively_send: bool = DictProperty("proactively_send") - redacted: bool = DictProperty("redacted") - txn_id: str = DictProperty("txn_id") - token_id: int = DictProperty("token_id") - historical: bool = DictProperty("historical") + out_of_band_membership: DictProperty[bool] = DictProperty("out_of_band_membership") + send_on_behalf_of: DictProperty[str] = DictProperty("send_on_behalf_of") + recheck_redaction: DictProperty[bool] = DictProperty("recheck_redaction") + soft_failed: DictProperty[bool] = DictProperty("soft_failed") + proactively_send: DictProperty[bool] = DictProperty("proactively_send") + redacted: DictProperty[bool] = DictProperty("redacted") + txn_id: DictProperty[str] = DictProperty("txn_id") + token_id: DictProperty[int] = DictProperty("token_id") + historical: DictProperty[bool] = DictProperty("historical") # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't # be here - before: RoomStreamToken = DictProperty("before") - after: RoomStreamToken = DictProperty("after") - order: Tuple[int, int] = DictProperty("order") + before: DictProperty[RoomStreamToken] = DictProperty("before") + after: DictProperty[RoomStreamToken] = DictProperty("after") + order: DictProperty[Tuple[int, int]] = DictProperty("order") def get_dict(self) -> JsonDict: return dict(self._dict) @@ -162,9 +240,6 @@ class _EventInternalMetadata: If the sender of the redaction event is allowed to redact any event due to auth rules, then this will always return false. - - Returns: - bool """ return self._dict.get("recheck_redaction", False) @@ -176,32 +251,23 @@ class _EventInternalMetadata: sent to clients. 2. They should not be added to the forward extremities (and therefore not to current state). - - Returns: - bool """ return self._dict.get("soft_failed", False) - def should_proactively_send(self): + def should_proactively_send(self) -> bool: """Whether the event, if ours, should be sent to other clients and servers. This is used for sending dummy events internally. Servers and clients can still explicitly fetch the event. - - Returns: - bool """ return self._dict.get("proactively_send", True) - def is_redacted(self): + def is_redacted(self) -> bool: """Whether the event has been redacted. This is used for efficiently checking whether an event has been marked as redacted without needing to make another database call. - - Returns: - bool """ return self._dict.get("redacted", False) @@ -241,29 +307,31 @@ class EventBase(metaclass=abc.ABCMeta): self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) - auth_events = DictProperty("auth_events") - depth = DictProperty("depth") - content = DictProperty("content") - hashes = DictProperty("hashes") - origin = DictProperty("origin") - origin_server_ts = DictProperty("origin_server_ts") - prev_events = DictProperty("prev_events") - redacts = DefaultDictProperty("redacts", None) - room_id = DictProperty("room_id") - sender = DictProperty("sender") - state_key = DictProperty("state_key") - type = DictProperty("type") - user_id = DictProperty("sender") + depth: DictProperty[int] = DictProperty("depth") + content: DictProperty[JsonDict] = DictProperty("content") + hashes: DictProperty[Dict[str, str]] = DictProperty("hashes") + origin: DictProperty[str] = DictProperty("origin") + origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts") + redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None) + room_id: DictProperty[str] = DictProperty("room_id") + sender: DictProperty[str] = DictProperty("sender") + # TODO state_key should be Optional[str], this is generally asserted in Synapse + # by calling is_state() first (which ensures this), but it is hard (not possible?) + # to properly annotate that calling is_state() asserts that state_key exists + # and is non-None. + state_key: DictProperty[str] = DictProperty("state_key") + type: DictProperty[str] = DictProperty("type") + user_id: DictProperty[str] = DictProperty("sender") @property def event_id(self) -> str: raise NotImplementedError() @property - def membership(self): + def membership(self) -> str: return self.content["membership"] - def is_state(self): + def is_state(self) -> bool: return hasattr(self, "state_key") and self.state_key is not None def get_dict(self) -> JsonDict: @@ -272,13 +340,13 @@ class EventBase(metaclass=abc.ABCMeta): return d - def get(self, key, default=None): + def get(self, key: str, default: Optional[Any] = None) -> Any: return self._dict.get(key, default) - def get_internal_metadata_dict(self): + def get_internal_metadata_dict(self) -> JsonDict: return self.internal_metadata.get_dict() - def get_pdu_json(self, time_now=None) -> JsonDict: + def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict: pdu_json = self.get_dict() if time_now is not None and "age_ts" in pdu_json["unsigned"]: @@ -305,49 +373,46 @@ class EventBase(metaclass=abc.ABCMeta): return template_json - def __set__(self, instance, value): - raise AttributeError("Unrecognized attribute %s" % (instance,)) - - def __getitem__(self, field): + def __getitem__(self, field: str) -> Optional[Any]: return self._dict[field] - def __contains__(self, field): + def __contains__(self, field: str) -> bool: return field in self._dict - def items(self): + def items(self) -> List[Tuple[str, Optional[Any]]]: return list(self._dict.items()) - def keys(self): + def keys(self) -> Iterable[str]: return self._dict.keys() - def prev_event_ids(self): + def prev_event_ids(self) -> Sequence[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. Returns: - list[str]: The list of event IDs of this event's prev_events + The list of event IDs of this event's prev_events """ - return [e for e, _ in self.prev_events] + return [e for e, _ in self._dict["prev_events"]] - def auth_event_ids(self): + def auth_event_ids(self) -> Sequence[str]: """Returns the list of auth event IDs. The order matches the order specified in the event, though there is no meaning to it. Returns: - list[str]: The list of event IDs of this event's auth_events + The list of event IDs of this event's auth_events """ - return [e for e, _ in self.auth_events] + return [e for e, _ in self._dict["auth_events"]] - def freeze(self): + def freeze(self) -> None: """'Freeze' the event dict, so it cannot be modified by accident""" # this will be a no-op if the event dict is already frozen. self._dict = freeze(self._dict) - def __str__(self): + def __str__(self) -> str: return self.__repr__() - def __repr__(self): + def __repr__(self) -> str: rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" return ( @@ -443,7 +508,7 @@ class FrozenEventV2(EventBase): else: frozen_dict = event_dict - self._event_id = None + self._event_id: Optional[str] = None super().__init__( frozen_dict, @@ -455,7 +520,7 @@ class FrozenEventV2(EventBase): ) @property - def event_id(self): + def event_id(self) -> str: # We have to import this here as otherwise we get an import loop which # is hard to break. from synapse.crypto.event_signing import compute_event_reference_hash @@ -465,23 +530,23 @@ class FrozenEventV2(EventBase): self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) return self._event_id - def prev_event_ids(self): + def prev_event_ids(self) -> Sequence[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. Returns: - list[str]: The list of event IDs of this event's prev_events + The list of event IDs of this event's prev_events """ - return self.prev_events + return self._dict["prev_events"] - def auth_event_ids(self): + def auth_event_ids(self) -> Sequence[str]: """Returns the list of auth event IDs. The order matches the order specified in the event, though there is no meaning to it. Returns: - list[str]: The list of event IDs of this event's auth_events + The list of event IDs of this event's auth_events """ - return self.auth_events + return self._dict["auth_events"] class FrozenEventV3(FrozenEventV2): @@ -490,7 +555,7 @@ class FrozenEventV3(FrozenEventV2): format_version = EventFormatVersions.V3 # All events of this type are V3 @property - def event_id(self): + def event_id(self) -> str: # We have to import this here as otherwise we get an import loop which # is hard to break. from synapse.crypto.event_signing import compute_event_reference_hash @@ -503,12 +568,14 @@ class FrozenEventV3(FrozenEventV2): return self._event_id -def _event_type_from_format_version(format_version: int) -> Type[EventBase]: +def _event_type_from_format_version( + format_version: int, +) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]: """Returns the python type to use to construct an Event object for the given event format version. Args: - format_version (int): The event format version + format_version: The event format version Returns: type: A type that can be initialized as per the initializer of diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 4d459c17f1..cf86934968 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -55,7 +55,7 @@ class EventValidator: ] for k in required: - if not hasattr(event, k): + if k not in event: raise SynapseError(400, "Event does not have key %s" % (k,)) # Check that the following keys have string values diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index e617db4c0d..1a1cd93b1a 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1643,7 +1643,7 @@ class FederationEventHandler: event: the event whose auth_events we want Returns: - all of the events in `event.auth_events`, after deduplication + all of the events listed in `event.auth_events_ids`, after deduplication Raises: AuthError if we were unable to fetch the auth_events for any reason. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4a0fccfcc6..b7bc187169 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1318,6 +1318,8 @@ class EventCreationHandler: # user is actually admin or not). is_admin_redaction = False if event.type == EventTypes.Redaction: + assert event.redacts is not None + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -1413,6 +1415,8 @@ class EventCreationHandler: ) if event.type == EventTypes.Redaction: + assert event.redacts is not None + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -1500,11 +1504,13 @@ class EventCreationHandler: next_batch_id = event.content.get( EventContentFields.MSC2716_NEXT_BATCH_ID ) - conflicting_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id( - event.room_id, next_batch_id + conflicting_insertion_event_id = None + if next_batch_id: + conflicting_insertion_event_id = ( + await self.store.get_insertion_event_by_batch_id( + event.room_id, next_batch_id + ) ) - ) if conflicting_insertion_event_id is not None: # The current insertion event that we're processing is invalid # because an insertion event already exists in the room with the diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 99e9b37344..969eb3b9b0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -525,7 +525,7 @@ class RoomCreationHandler: ): await self.room_member_handler.update_membership( requester, - UserID.from_string(old_event["state_key"]), + UserID.from_string(old_event.state_key), new_room_id, "ban", ratelimit=False, diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 2f5a3e4d19..0723286383 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -355,7 +355,7 @@ class RoomBatchHandler: for (event, context) in reversed(events_to_persist): await self.event_creation_handler.handle_new_client_event( await self.create_requester_for_user_id_from_app_service( - event["sender"], app_service_requester.app_service + event.sender, app_service_requester.app_service ), event=event, context=context, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 74e6c7eca6..08244b690d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1669,7 +1669,9 @@ class RoomMemberMasterHandler(RoomMemberHandler): # # the prev_events consist solely of the previous membership event. prev_event_ids = [previous_membership_event.event_id] - auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids + auth_event_ids = ( + list(previous_membership_event.auth_event_ids()) + prev_event_ids + ) event, context = await self.event_creation_handler.create_event( requester, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 0622a37ae8..009d8e77b0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -232,6 +232,8 @@ class BulkPushRuleEvaluator: # that user, as they might not be already joined. if event.type == EventTypes.Member and event.state_key == uid: display_name = event.content.get("displayname", None) + if not isinstance(display_name, str): + display_name = None if count_as_unread: # Add an element for the current user if the event needs to be marked as @@ -268,7 +270,7 @@ def _condition_checker( evaluator: PushRuleEvaluatorForEvent, conditions: List[dict], uid: str, - display_name: str, + display_name: Optional[str], cache: Dict[str, bool], ) -> bool: for cond in conditions: diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 7a8dc63976..7f68092ec5 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -18,7 +18,7 @@ import re from typing import Any, Dict, List, Optional, Pattern, Tuple, Union from synapse.events import EventBase -from synapse.types import UserID +from synapse.types import JsonDict, UserID from synapse.util import glob_to_regex, re_word_boundary from synapse.util.caches.lrucache import LruCache @@ -129,7 +129,7 @@ class PushRuleEvaluatorForEvent: self._value_cache = _flatten_dict(event) def matches( - self, condition: Dict[str, Any], user_id: str, display_name: str + self, condition: Dict[str, Any], user_id: str, display_name: Optional[str] ) -> bool: if condition["kind"] == "event_match": return self._event_match(condition, user_id) @@ -172,7 +172,7 @@ class PushRuleEvaluatorForEvent: return _glob_matches(pattern, haystack) - def _contains_display_name(self, display_name: str) -> bool: + def _contains_display_name(self, display_name: Optional[str]) -> bool: if not display_name: return False @@ -222,7 +222,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _flatten_dict( - d: Union[EventBase, dict], + d: Union[EventBase, JsonDict], prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -233,7 +233,7 @@ def _flatten_dict( for key, value in d.items(): if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() - elif hasattr(value, "items"): + elif isinstance(value, dict): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 99f8156ad0..ab9a743bba 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -191,7 +191,7 @@ class RoomBatchSendEventRestServlet(RestServlet): depth=inherited_depth, ) - batch_id_to_connect_to = base_insertion_event["content"][ + batch_id_to_connect_to = base_insertion_event.content[ EventContentFields.MSC2716_NEXT_BATCH_ID ] diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 98a0239759..1605411b00 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -247,7 +247,7 @@ class StateHandler: return await self.get_hosts_in_room_at_events(room_id, event_ids) async def get_hosts_in_room_at_events( - self, room_id: str, event_ids: List[str] + self, room_id: str, event_ids: Iterable[str] ) -> Set[str]: """Get the hosts that were in a room at the given event ids diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 8d9086ecf0..596275c23c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -24,6 +24,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, ) @@ -494,7 +495,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, List[str]], + event_to_auth_chain: Dict[str, Sequence[str]], ) -> None: """Calculate the chain cover index for the given events. @@ -786,7 +787,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, List[str]], + event_to_auth_chain: Dict[str, Sequence[str]], events_to_calc_chain_id_for: Set[str], chain_map: Dict[str, Tuple[int, int]], ) -> Dict[str, Tuple[int, int]]: @@ -1794,7 +1795,7 @@ class PersistEventsStore: ) # Insert an edge for every prev_event connection - for prev_event_id in event.prev_events: + for prev_event_id in event.prev_event_ids(): self.db_pool.simple_insert_txn( txn, table="insertion_event_edges", diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4b288bb2e7..033a9831d6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): async def get_joined_users_from_context( self, event: EventBase, context: EventContext - ): + ) -> Dict[str, ProfileInfo]: state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): event.room_id, state_group, current_state_ids, event=event, context=context ) - async def get_joined_users_from_state(self, room_id, state_entry): + async def get_joined_users_from_state( + self, room_id, state_entry + ) -> Dict[str, ProfileInfo]: state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cache_context, event=None, context=None, - ): + ) -> Dict[str, ProfileInfo]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. -- cgit 1.4.1 From af54167516c7211937efa5b800853f3088ef5178 Mon Sep 17 00:00:00 2001 From: Nick Barrett Date: Wed, 3 Nov 2021 14:25:47 +0000 Subject: Enable passing typing stream writers as a list. (#11237) This makes the typing stream writer config match the other stream writers that only currently support a single worker. --- changelog.d/11237.misc | 1 + synapse/config/workers.py | 18 +++++++++++++++--- synapse/federation/federation_server.py | 4 ---- synapse/handlers/typing.py | 6 +++--- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/_base.py | 3 +-- synapse/rest/client/room.py | 2 +- synapse/server.py | 4 ++-- 8 files changed, 24 insertions(+), 16 deletions(-) create mode 100644 changelog.d/11237.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11237.misc b/changelog.d/11237.misc new file mode 100644 index 0000000000..b90efc6535 --- /dev/null +++ b/changelog.d/11237.misc @@ -0,0 +1 @@ +Allow `stream_writers.typing` config to be a list of one worker. diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 462630201d..4507992031 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -63,7 +63,8 @@ class WriterLocations: Attributes: events: The instances that write to the event and backfill streams. - typing: The instance that writes to the typing stream. + typing: The instances that write to the typing stream. Currently + can only be a single instance. to_device: The instances that write to the to_device stream. Currently can only be a single instance. account_data: The instances that write to the account data streams. Currently @@ -75,9 +76,15 @@ class WriterLocations: """ events = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter + default=["master"], + type=List[str], + converter=_instance_to_list_converter, + ) + typing = attr.ib( + default=["master"], + type=List[str], + converter=_instance_to_list_converter, ) - typing = attr.ib(default="master", type=str) to_device = attr.ib( default=["master"], type=List[str], @@ -217,6 +224,11 @@ class WorkerConfig(Config): % (instance, stream) ) + if len(self.writers.typing) != 1: + raise ConfigError( + "Must only specify one instance to handle `typing` messages." + ) + if len(self.writers.to_device) != 1: raise ConfigError( "Must only specify one instance to handle `to_device` messages." diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 32a75993d9..42e3acecb4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1232,10 +1232,6 @@ class FederationHandlerRegistry: self.query_handlers[query_type] = handler - def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None: - """Register that the EDU handler is on a different instance than master.""" - self._edu_type_to_instance[edu_type] = [instance_name] - def register_instances_for_edu( self, edu_type: str, instance_names: List[str] ) -> None: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c411d69924..22c6174821 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -62,8 +62,8 @@ class FollowerTypingHandler: if hs.should_send_federation(): self.federation = hs.get_federation_sender() - if hs.config.worker.writers.typing != hs.get_instance_name(): - hs.get_federation_registry().register_instance_for_edu( + if hs.get_instance_name() not in hs.config.worker.writers.typing: + hs.get_federation_registry().register_instances_for_edu( "m.typing", hs.config.worker.writers.typing, ) @@ -205,7 +205,7 @@ class TypingWriterHandler(FollowerTypingHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - assert hs.config.worker.writers.typing == hs.get_instance_name() + assert hs.get_instance_name() in hs.config.worker.writers.typing self.auth = hs.get_auth() self.notifier = hs.get_notifier() diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 06fd06fdf3..21293038ef 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -138,7 +138,7 @@ class ReplicationCommandHandler: if isinstance(stream, TypingStream): # Only add TypingStream as a source on the instance in charge of # typing. - if hs.config.worker.writers.typing == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.typing: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index c8b188ae4e..743a01da08 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -328,8 +328,7 @@ class TypingStream(Stream): ROW_TYPE = TypingStreamRow def __init__(self, hs: "HomeServer"): - writer_instance = hs.config.worker.writers.typing - if writer_instance == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.typing: # On the writer, query the typing handler typing_writer_handler = hs.get_typing_writer_handler() update_function: Callable[ diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index ed95189b6d..6a876cfa2f 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -914,7 +914,7 @@ class RoomTypingRestServlet(RestServlet): # If we're not on the typing writer instance we should scream if we get # requests. self._is_typing_writer = ( - hs.config.worker.writers.typing == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.typing ) async def on_PUT( diff --git a/synapse/server.py b/synapse/server.py index 0fbf36ba99..013a7bacaa 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -463,7 +463,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_typing_writer_handler(self) -> TypingWriterHandler: - if self.config.worker.writers.typing == self.get_instance_name(): + if self.get_instance_name() in self.config.worker.writers.typing: return TypingWriterHandler(self) else: raise Exception("Workers cannot write typing") @@ -474,7 +474,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_typing_handler(self) -> FollowerTypingHandler: - if self.config.worker.writers.typing == self.get_instance_name(): + if self.get_instance_name() in self.config.worker.writers.typing: # Use get_typing_writer_handler to ensure that we use the same # cached version. return self.get_typing_writer_handler() -- cgit 1.4.1 From a271e233e9f846193c22b6d74f33ae7d7f2c1167 Mon Sep 17 00:00:00 2001 From: Nick Barrett Date: Wed, 3 Nov 2021 16:51:00 +0000 Subject: Add a linearizer on (appservice, stream) when handling ephemeral events. (#11207) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/11207.bugfix | 1 + synapse/handlers/appservice.py | 69 +++++++++++++++++++++++++++++---------- tests/handlers/test_appservice.py | 51 +++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 18 deletions(-) create mode 100644 changelog.d/11207.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11207.bugfix b/changelog.d/11207.bugfix new file mode 100644 index 0000000000..7e98d565a1 --- /dev/null +++ b/changelog.d/11207.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which could result in serialization errors and potentially duplicate transaction data when sending ephemeral events to application services. Contributed by @Fizzadar at Beeper. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 67f8ffcaff..ddc9105ee9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -34,6 +34,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -58,6 +59,10 @@ class ApplicationServicesHandler: self.current_max = 0 self.is_processing = False + self._ephemeral_events_linearizer = Linearizer( + name="appservice_ephemeral_events" + ) + def notify_interested_services(self, max_token: RoomStreamToken) -> None: """Notifies (pushes) all application services interested in this event. @@ -260,26 +265,37 @@ class ApplicationServicesHandler: events = await self._handle_typing(service, new_token) if events: self.scheduler.submit_ephemeral_events_for_as(service, events) + continue - elif stream_key == "receipt_key": - events = await self._handle_receipts(service) - if events: - self.scheduler.submit_ephemeral_events_for_as(service, events) - - # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( - service, "read_receipt", new_token + # Since we read/update the stream position for this AS/stream + with ( + await self._ephemeral_events_linearizer.queue( + (service.id, stream_key) ) + ): + if stream_key == "receipt_key": + events = await self._handle_receipts(service, new_token) + if events: + self.scheduler.submit_ephemeral_events_for_as( + service, events + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_type_stream_id_for_appservice( + service, "read_receipt", new_token + ) - elif stream_key == "presence_key": - events = await self._handle_presence(service, users) - if events: - self.scheduler.submit_ephemeral_events_for_as(service, events) + elif stream_key == "presence_key": + events = await self._handle_presence(service, users, new_token) + if events: + self.scheduler.submit_ephemeral_events_for_as( + service, events + ) - # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( - service, "presence", new_token - ) + # Persist the latest handled stream token for this appservice + await self.store.set_type_stream_id_for_appservice( + service, "presence", new_token + ) async def _handle_typing( self, service: ApplicationService, new_token: int @@ -316,7 +332,9 @@ class ApplicationServicesHandler: ) return typing - async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: + async def _handle_receipts( + self, service: ApplicationService, new_token: Optional[int] + ) -> List[JsonDict]: """ Return the latest read receipts that the given application service should receive. @@ -335,6 +353,12 @@ class ApplicationServicesHandler: from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) + if new_token is not None and new_token <= from_key: + logger.debug( + "Rejecting token lower than or equal to stored: %s" % (new_token,) + ) + return [] + receipts_source = self.event_sources.sources.receipt receipts, _ = await receipts_source.get_new_events_as( service=service, from_key=from_key @@ -342,7 +366,10 @@ class ApplicationServicesHandler: return receipts async def _handle_presence( - self, service: ApplicationService, users: Collection[Union[str, UserID]] + self, + service: ApplicationService, + users: Collection[Union[str, UserID]], + new_token: Optional[int], ) -> List[JsonDict]: """ Return the latest presence updates that the given application service should receive. @@ -365,6 +392,12 @@ class ApplicationServicesHandler: from_key = await self.store.get_type_stream_id_for_appservice( service, "presence" ) + if new_token is not None and new_token <= from_key: + logger.debug( + "Rejecting token lower than or equal to stored: %s" % (new_token,) + ) + return [] + for user in users: if isinstance(user, str): user = UserID.from_string(user) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 43998020b2..1f6a924452 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -40,6 +40,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() self.handler = ApplicationServicesHandler(hs) + self.event_source = hs.get_event_sources() def test_notify_interested_services(self): interested_service = self._mkservice(is_interested=True) @@ -252,6 +253,56 @@ class AppServiceHandlerTestCase(unittest.TestCase): }, ) + def test_notify_interested_services_ephemeral(self): + """ + Test sending ephemeral events to the appservice handler are scheduled + to be pushed out to interested appservices, and that the stream ID is + updated accordingly. + """ + interested_service = self._mkservice(is_interested=True) + services = [interested_service] + + self.mock_store.get_app_services.return_value = services + self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( + 579 + ) + + event = Mock(event_id="event_1") + self.event_source.sources.receipt.get_new_events_as.return_value = ( + make_awaitable(([event], None)) + ) + + self.handler.notify_interested_services_ephemeral("receipt_key", 580) + self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with( + interested_service, [event] + ) + self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with( + interested_service, + "read_receipt", + 580, + ) + + def test_notify_interested_services_ephemeral_out_of_order(self): + """ + Test sending out of order ephemeral events to the appservice handler + are ignored. + """ + interested_service = self._mkservice(is_interested=True) + services = [interested_service] + + self.mock_store.get_app_services.return_value = services + self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( + 580 + ) + + event = Mock(event_id="event_1") + self.event_source.sources.receipt.get_new_events_as.return_value = ( + make_awaitable(([event], None)) + ) + + self.handler.notify_interested_services_ephemeral("receipt_key", 579) + self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called() + def _mkservice(self, is_interested, protocols=None): service = Mock() service.is_interested.return_value = make_awaitable(is_interested) -- cgit 1.4.1 From 499c44d69685c1c1e347ff252ad08f5dfe089a83 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 4 Nov 2021 17:10:11 +0000 Subject: Make minor correction to type of auth_checkers callbacks (#11253) --- changelog.d/11253.misc | 1 + docs/modules/password_auth_provider_callbacks.md | 2 +- synapse/handlers/auth.py | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11253.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11253.misc b/changelog.d/11253.misc new file mode 100644 index 0000000000..71c55a2751 --- /dev/null +++ b/changelog.d/11253.misc @@ -0,0 +1 @@ +Make minor correction to the type of `auth_checkers` callbacks. diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index 0de60b128a..e53abf6409 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -11,7 +11,7 @@ registered by using the Module API's `register_password_auth_provider_callbacks` _First introduced in Synapse v1.46.0_ ```python - auth_checkers: Dict[Tuple[str,Tuple], Callable] +auth_checkers: Dict[Tuple[str, Tuple[str, ...]], Callable] ``` A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d508d7d32a..60e59d11a0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1989,7 +1989,9 @@ class PasswordAuthProvider: self, check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, - auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None, + auth_checkers: Optional[ + Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: -- cgit 1.4.1 From 86a497efaa60cf0e456103724c369e5172ea5485 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 8 Nov 2021 14:13:10 +0000 Subject: Default value for `public_baseurl` (#11210) We might as well use a default value for `public_baseurl` based on `server_name` - in many cases, it will be correct. --- changelog.d/11210.feature | 1 + docs/sample_config.yaml | 13 +++++------ synapse/api/urls.py | 3 --- synapse/config/account_validity.py | 4 ---- synapse/config/cas.py | 10 +++------ synapse/config/emailconfig.py | 8 ------- synapse/config/oidc.py | 2 -- synapse/config/registration.py | 15 +------------ synapse/config/saml2.py | 5 +---- synapse/config/server.py | 45 ++++++++++++++++++++++++++++++++++---- synapse/config/sso.py | 18 ++++++--------- synapse/handlers/identity.py | 4 ---- synapse/rest/well_known.py | 3 +-- tests/push/test_email.py | 2 +- tests/rest/client/test_consent.py | 1 - tests/rest/client/test_register.py | 1 - 16 files changed, 62 insertions(+), 73 deletions(-) create mode 100644 changelog.d/11210.feature (limited to 'synapse/handlers') diff --git a/changelog.d/11210.feature b/changelog.d/11210.feature new file mode 100644 index 0000000000..8f8e386415 --- /dev/null +++ b/changelog.d/11210.feature @@ -0,0 +1 @@ +Calculate a default value for `public_baseurl` based on `server_name`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c3a4148f74..d48c08f1d9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -91,6 +91,8 @@ pid_file: DATADIR/homeserver.pid # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see # 'listeners' below). # +# Defaults to 'https:///'. +# #public_baseurl: https://example.com/ # Uncomment the following to tell other servers to send federation traffic on @@ -1265,7 +1267,7 @@ oembed: # in on this server. # # (By default, no suggestion is made, so it is left up to the client. -# This setting is ignored unless public_baseurl is also set.) +# This setting is ignored unless public_baseurl is also explicitly set.) # #default_identity_server: https://matrix.org @@ -1290,8 +1292,6 @@ oembed: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # -# If a delegate is specified, the config option public_baseurl must also be filled out. -# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1981,11 +1981,10 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is whitelisted in addition to any URLs in this list. # - # By default, this list is empty. + # By default, this list contains only the login fallback page. # #client_whitelist: # - https://riot.im/develop diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 6e84b1524f..4486b3bc7d 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -38,9 +38,6 @@ class ConsentURIBuilder: def __init__(self, hs_config: HomeServerConfig): if hs_config.key.form_secret is None: raise ConfigError("form_secret not set in config") - if hs_config.server.public_baseurl is None: - raise ConfigError("public_baseurl not set in config") - self._hmac_secret = hs_config.key.form_secret.encode("utf-8") self._public_baseurl = hs_config.server.public_baseurl diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index b56c2a24df..c533452cab 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -75,10 +75,6 @@ class AccountValidityConfig(Config): self.account_validity_period * 10.0 / 100.0 ) - if self.account_validity_renew_by_email_enabled: - if not self.root.server.public_baseurl: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - # Load account validity templates. account_validity_template_dir = account_validity_config.get("template_dir") if account_validity_template_dir is not None: diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 9b58ecf3d8..3f81814043 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -16,7 +16,7 @@ from typing import Any, List from synapse.config.sso import SsoAttributeRequirement -from ._base import Config, ConfigError +from ._base import Config from ._util import validate_config @@ -35,14 +35,10 @@ class CasConfig(Config): if self.cas_enabled: self.cas_server_url = cas_config["server_url"] - # The public baseurl is required because it is used by the redirect - # template. - public_baseurl = self.root.server.public_baseurl - if not public_baseurl: - raise ConfigError("cas_config requires a public_baseurl to be set") - # TODO Update this to a _synapse URL. + public_baseurl = self.root.server.public_baseurl self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" + self.cas_displayname_attribute = cas_config.get("displayname_attribute") required_attributes = cas_config.get("required_attributes") or {} self.cas_required_attributes = _parsed_required_attributes_def( diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 8ff59aa2f8..afd65fecd3 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -186,11 +186,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - # public_baseurl is required to build password reset and validation links that - # will be emailed to users - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -296,9 +291,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 10f5796330..42f113cd24 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -59,8 +59,6 @@ class OIDCConfig(Config): ) public_baseurl = self.root.server.public_baseurl - if public_baseurl is None: - raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback" @property diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a3d2a38c4c..5379e80715 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -45,17 +45,6 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if ( - self.account_threepid_delegate_msisdn - and not self.root.server.public_baseurl - ): - raise ConfigError( - "The configuration option `public_baseurl` is required if " - "`account_threepid_delegate.msisdn` is set, such that " - "clients know where to submit validation tokens to. Please " - "configure `public_baseurl`." - ) - self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -240,7 +229,7 @@ class RegistrationConfig(Config): # in on this server. # # (By default, no suggestion is made, so it is left up to the client. - # This setting is ignored unless public_baseurl is also set.) + # This setting is ignored unless public_baseurl is also explicitly set.) # #default_identity_server: https://matrix.org @@ -265,8 +254,6 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # - # If a delegate is specified, the config option public_baseurl must also be filled out. - # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 9c51b6a25a..ba2b0905ff 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -199,14 +199,11 @@ class SAML2Config(Config): """ import saml2 - public_baseurl = self.root.server.public_baseurl - if public_baseurl is None: - raise ConfigError("saml2_config requires a public_baseurl to be set") - if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes + public_baseurl = self.root.server.public_baseurl metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml" response_url = public_baseurl + "_synapse/client/saml2/authn_response" return { diff --git a/synapse/config/server.py b/synapse/config/server.py index a387fd9310..7bc0030a9e 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -16,6 +16,7 @@ import itertools import logging import os.path import re +import urllib.parse from textwrap import indent from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union @@ -264,10 +265,44 @@ class ServerConfig(Config): self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.serve_server_wellknown = config.get("serve_server_wellknown", False) - self.public_baseurl = config.get("public_baseurl") - if self.public_baseurl is not None: - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" + # Whether we should serve a "client well-known": + # (a) at .well-known/matrix/client on our client HTTP listener + # (b) in the response to /login + # + # ... which together help ensure that clients use our public_baseurl instead of + # whatever they were told by the user. + # + # For the sake of backwards compatibility with existing installations, this is + # True if public_baseurl is specified explicitly, and otherwise False. (The + # reasoning here is that we have no way of knowing that the default + # public_baseurl is actually correct for existing installations - many things + # will not work correctly, but that's (probably?) better than sending clients + # to a completely broken URL. + self.serve_client_wellknown = False + + public_baseurl = config.get("public_baseurl") + if public_baseurl is None: + public_baseurl = f"https://{self.server_name}/" + logger.info("Using default public_baseurl %s", public_baseurl) + else: + self.serve_client_wellknown = True + if public_baseurl[-1] != "/": + public_baseurl += "/" + self.public_baseurl = public_baseurl + + # check that public_baseurl is valid + try: + splits = urllib.parse.urlsplit(self.public_baseurl) + except Exception as e: + raise ConfigError(f"Unable to parse URL: {e}", ("public_baseurl",)) + if splits.scheme not in ("https", "http"): + raise ConfigError( + f"Invalid scheme '{splits.scheme}': only https and http are supported" + ) + if splits.query or splits.fragment: + raise ConfigError( + "public_baseurl cannot contain query parameters or a #-fragment" + ) # Whether to enable user presence. presence_config = config.get("presence") or {} @@ -773,6 +808,8 @@ class ServerConfig(Config): # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see # 'listeners' below). # + # Defaults to 'https:///'. + # #public_baseurl: https://example.com/ # Uncomment the following to tell other servers to send federation traffic on diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 11a9b76aa0..60aacb13ea 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -101,13 +101,10 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - # public_baseurl is an optional setting, so we only add the fallback's URL to the - # list if it's provided (because we can't figure out what that URL is otherwise). - if self.root.server.public_baseurl: - login_fallback_url = ( - self.root.server.public_baseurl + "_matrix/static/client/login" - ) - self.sso_client_whitelist.append(login_fallback_url) + login_fallback_url = ( + self.root.server.public_baseurl + "_matrix/static/client/login" + ) + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -128,11 +125,10 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is whitelisted in addition to any URLs in this list. # - # By default, this list is empty. + # By default, this list contains only the login fallback page. # #client_whitelist: # - https://riot.im/develop diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 6a315117ba..3dbe611f95 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -537,10 +537,6 @@ class IdentityHandler: except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - # It is already checked that public_baseurl is configured since this code - # should only be used if account_threepid_delegate_msisdn is true. - assert self.hs.config.server.public_baseurl - # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index edbf5ce5d0..04b035a1b1 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,8 +34,7 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self) -> Optional[JsonDict]: - # if we don't have a public_baseurl, we can't help much here. - if self._config.server.public_baseurl is None: + if not self._config.server.serve_client_wellknown: return None result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} diff --git a/tests/push/test_email.py b/tests/push/test_email.py index fa8018e5a7..90f800e564 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -65,7 +65,7 @@ class EmailPusherTests(HomeserverTestCase): "notif_from": "test@example.com", "riot_base_url": None, } - config["public_baseurl"] = "aaa" + config["public_baseurl"] = "http://aaa" config["start_pushers"] = True hs = self.setup_test_homeserver(config=config) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 84d092ca82..fcdc565814 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -35,7 +35,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config["public_baseurl"] = "aaaa" config["form_secret"] = "123abc" # Make some temporary templates... diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 66dcfc9f88..6e7c0f11df 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -891,7 +891,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "smtp_pass": None, "notif_from": "test@example.com", } - config["public_baseurl"] = "aaa" self.hs = self.setup_test_homeserver(config=config) -- cgit 1.4.1 From 84f235aea47e2d2f9875f7334d8497660320f55e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 8 Nov 2021 21:21:10 -0600 Subject: Rename to more clear `get_insertion_event_id_by_batch_id` (MSC2716) (#11244) `get_insertion_event_by_batch_id` -> `get_insertion_event_id_by_batch_id` Split out from https://github.com/matrix-org/synapse/pull/11114 --- changelog.d/11244.misc | 1 + synapse/handlers/message.py | 2 +- synapse/rest/client/room_batch.py | 2 +- synapse/storage/databases/main/room_batch.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11244.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11244.misc b/changelog.d/11244.misc new file mode 100644 index 0000000000..c6e65df97f --- /dev/null +++ b/changelog.d/11244.misc @@ -0,0 +1 @@ +Fix [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical messages backfilling in random order on remote homeservers. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b7bc187169..d4c2a6ab7a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1507,7 +1507,7 @@ class EventCreationHandler: conflicting_insertion_event_id = None if next_batch_id: conflicting_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id( + await self.store.get_insertion_event_id_by_batch_id( event.room_id, next_batch_id ) ) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 46f033eee2..e4c9451ae0 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -112,7 +112,7 @@ class RoomBatchSendEventRestServlet(RestServlet): # and have the batch connected. if batch_id_from_query: corresponding_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id( + await self.store.get_insertion_event_id_by_batch_id( room_id, batch_id_from_query ) ) diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index dcbce8fdcf..97b2618437 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -18,7 +18,7 @@ from synapse.storage._base import SQLBaseStore class RoomBatchStore(SQLBaseStore): - async def get_insertion_event_by_batch_id( + async def get_insertion_event_id_by_batch_id( self, room_id: str, batch_id: str ) -> Optional[str]: """Retrieve a insertion event ID. -- cgit 1.4.1 From af784644c3380d0a2ea885abbe748fbe69d3a990 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Nov 2021 11:45:36 +0000 Subject: Include cross-signing signatures when syncing remote devices for the first time (#11234) When fetching remote devices for the first time, we did not correctly include the cross signing keys in the returned results. c.f. #11159 --- changelog.d/11234.bugfix | 1 + synapse/handlers/e2e_keys.py | 211 ++++++++++++++++++++++++---------------- tests/handlers/test_e2e_keys.py | 151 ++++++++++++++++++++++++++++ 3 files changed, 277 insertions(+), 86 deletions(-) create mode 100644 changelog.d/11234.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11234.bugfix b/changelog.d/11234.bugfix new file mode 100644 index 0000000000..c0c02a58f6 --- /dev/null +++ b/changelog.d/11234.bugfix @@ -0,0 +1 @@ +Fix long-standing bug where cross signing keys were not included in the response to `/r0/keys/query` the first time a remote user was queried. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d0fb2fc7dc..60c11e3d21 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -201,95 +201,19 @@ class E2eKeysHandler: r[user_id] = remote_queries[user_id] # Now fetch any devices that we don't have in our cache - @trace - async def do_remote_query(destination: str) -> None: - """This is called when we are querying the device list of a user on - a remote homeserver and their device list is not in the device list - cache. If we share a room with this user and we're not querying for - specific user we will update the cache with their device list. - """ - - destination_query = remote_queries_not_in_cache[destination] - - # We first consider whether we wish to update the device list cache with - # the users device list. We want to track a user's devices when the - # authenticated user shares a room with the queried user and the query - # has not specified a particular device. - # If we update the cache for the queried user we remove them from further - # queries. We use the more efficient batched query_client_keys for all - # remaining users - user_ids_updated = [] - for (user_id, device_list) in destination_query.items(): - if user_id in user_ids_updated: - continue - - if device_list: - continue - - room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: - continue - - # We've decided we're sharing a room with this user and should - # probably be tracking their device lists. However, we haven't - # done an initial sync on the device list so we do it now. - try: - if self._is_master: - user_devices = await self.device_handler.device_list_updater.user_device_resync( - user_id - ) - else: - user_devices = await self._user_device_resync_client( - user_id=user_id - ) - - user_devices = user_devices["devices"] - user_results = results.setdefault(user_id, {}) - for device in user_devices: - user_results[device["device_id"]] = device["keys"] - user_ids_updated.append(user_id) - except Exception as e: - failures[destination] = _exception_to_failure(e) - - if len(destination_query) == len(user_ids_updated): - # We've updated all the users in the query and we do not need to - # make any further remote calls. - return - - # Remove all the users from the query which we have updated - for user_id in user_ids_updated: - destination_query.pop(user_id) - - try: - remote_result = await self.federation.query_client_keys( - destination, {"device_keys": destination_query}, timeout=timeout - ) - - for user_id, keys in remote_result["device_keys"].items(): - if user_id in destination_query: - results[user_id] = keys - - if "master_keys" in remote_result: - for user_id, key in remote_result["master_keys"].items(): - if user_id in destination_query: - cross_signing_keys["master_keys"][user_id] = key - - if "self_signing_keys" in remote_result: - for user_id, key in remote_result["self_signing_keys"].items(): - if user_id in destination_query: - cross_signing_keys["self_signing_keys"][user_id] = key - - except Exception as e: - failure = _exception_to_failure(e) - failures[destination] = failure - set_tag("error", True) - set_tag("reason", failure) - await make_deferred_yieldable( defer.gatherResults( [ - run_in_background(do_remote_query, destination) - for destination in remote_queries_not_in_cache + run_in_background( + self._query_devices_for_destination, + results, + cross_signing_keys, + failures, + destination, + queries, + timeout, + ) + for destination, queries in remote_queries_not_in_cache.items() ], consumeErrors=True, ).addErrback(unwrapFirstError) @@ -301,6 +225,121 @@ class E2eKeysHandler: return ret + @trace + async def _query_devices_for_destination( + self, + results: JsonDict, + cross_signing_keys: JsonDict, + failures: Dict[str, JsonDict], + destination: str, + destination_query: Dict[str, Iterable[str]], + timeout: int, + ) -> None: + """This is called when we are querying the device list of a user on + a remote homeserver and their device list is not in the device list + cache. If we share a room with this user and we're not querying for + specific user we will update the cache with their device list. + + Args: + results: A map from user ID to their device keys, which gets + updated with the newly fetched keys. + cross_signing_keys: Map from user ID to their cross signing keys, + which gets updated with the newly fetched keys. + failures: Map of destinations to failures that have occurred while + attempting to fetch keys. + destination: The remote server to query + destination_query: The query dict of devices to query the remote + server for. + timeout: The timeout for remote HTTP requests. + """ + + # We first consider whether we wish to update the device list cache with + # the users device list. We want to track a user's devices when the + # authenticated user shares a room with the queried user and the query + # has not specified a particular device. + # If we update the cache for the queried user we remove them from further + # queries. We use the more efficient batched query_client_keys for all + # remaining users + user_ids_updated = [] + for (user_id, device_list) in destination_query.items(): + if user_id in user_ids_updated: + continue + + if device_list: + continue + + room_ids = await self.store.get_rooms_for_user(user_id) + if not room_ids: + continue + + # We've decided we're sharing a room with this user and should + # probably be tracking their device lists. However, we haven't + # done an initial sync on the device list so we do it now. + try: + if self._is_master: + resync_results = await self.device_handler.device_list_updater.user_device_resync( + user_id + ) + else: + resync_results = await self._user_device_resync_client( + user_id=user_id + ) + + # Add the device keys to the results. + user_devices = resync_results["devices"] + user_results = results.setdefault(user_id, {}) + for device in user_devices: + user_results[device["device_id"]] = device["keys"] + user_ids_updated.append(user_id) + + # Add any cross signing keys to the results. + master_key = resync_results.get("master_key") + self_signing_key = resync_results.get("self_signing_key") + + if master_key: + cross_signing_keys["master_keys"][user_id] = master_key + + if self_signing_key: + cross_signing_keys["self_signing_keys"][user_id] = self_signing_key + except Exception as e: + failures[destination] = _exception_to_failure(e) + + if len(destination_query) == len(user_ids_updated): + # We've updated all the users in the query and we do not need to + # make any further remote calls. + return + + # Remove all the users from the query which we have updated + for user_id in user_ids_updated: + destination_query.pop(user_id) + + try: + remote_result = await self.federation.query_client_keys( + destination, {"device_keys": destination_query}, timeout=timeout + ) + + for user_id, keys in remote_result["device_keys"].items(): + if user_id in destination_query: + results[user_id] = keys + + if "master_keys" in remote_result: + for user_id, key in remote_result["master_keys"].items(): + if user_id in destination_query: + cross_signing_keys["master_keys"][user_id] = key + + if "self_signing_keys" in remote_result: + for user_id, key in remote_result["self_signing_keys"].items(): + if user_id in destination_query: + cross_signing_keys["self_signing_keys"][user_id] = key + + except Exception as e: + failure = _exception_to_failure(e) + failures[destination] = failure + set_tag("error", True) + set_tag("reason", failure) + + return + async def get_cross_signing_keys_from_cache( self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 39e7b1ab25..0c3b86fda9 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -17,6 +17,8 @@ from unittest import mock from signedjson import key as key, sign as sign +from twisted.internet import defer + from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError @@ -630,3 +632,152 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) + + def test_query_devices_remote_no_sync(self): + """Tests that querying keys for a remote user that we don't share a room + with returns the cross signing keys correctly. + """ + + remote_user_id = "@test:other" + local_user_id = "@test:test" + + remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" + remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" + + self.hs.get_federation_client().query_client_keys = mock.Mock( + return_value=defer.succeed( + { + "device_keys": {remote_user_id: {}}, + "master_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + }, + "self_signing_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + } + }, + } + ) + ) + + e2e_handler = self.hs.get_e2e_keys_handler() + + query_result = self.get_success( + e2e_handler.query_devices( + { + "device_keys": {remote_user_id: []}, + }, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + + self.assertEqual(query_result["failures"], {}) + self.assertEqual( + query_result["master_keys"], + { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + }, + ) + self.assertEqual( + query_result["self_signing_keys"], + { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key + }, + } + }, + ) + + def test_query_devices_remote_sync(self): + """Tests that querying keys for a remote user that we share a room with, + but haven't yet fetched the keys for, returns the cross signing keys + correctly. + """ + + remote_user_id = "@test:other" + local_user_id = "@test:test" + + self.store.get_rooms_for_user = mock.Mock( + return_value=defer.succeed({"some_room_id"}) + ) + + remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" + remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" + + self.hs.get_federation_client().query_user_devices = mock.Mock( + return_value=defer.succeed( + { + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) + ) + + e2e_handler = self.hs.get_e2e_keys_handler() + + query_result = self.get_success( + e2e_handler.query_devices( + { + "device_keys": {remote_user_id: []}, + }, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + + self.assertEqual(query_result["failures"], {}) + self.assertEqual( + query_result["master_keys"], + { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + } + }, + ) + self.assertEqual( + query_result["self_signing_keys"], + { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key + }, + } + }, + ) -- cgit 1.4.1