diff options
author | Dan Callahan <danc@element.io> | 2021-10-27 20:04:00 +0100 |
---|---|---|
committer | Dan Callahan <danc@element.io> | 2021-10-27 20:04:00 +0100 |
commit | 0dffa9d0e096e5ff04768b2e06ce4acf92120486 (patch) | |
tree | aa501e65702a3e3fb179c3a0c37dc15543637f72 /synapse | |
parent | Changelog (diff) | |
parent | Annotate `log_function` decorator (#10943) (diff) | |
download | synapse-0dffa9d0e096e5ff04768b2e06ce4acf92120486.tar.xz |
Merge remote-tracking branch 'origin/develop' into shellcheck
Fixes a merge conflict with debian/changelog Signed-off-by: Dan Callahan <danc@element.io>
Diffstat (limited to 'synapse')
34 files changed, 864 insertions, 267 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 2687d932ea..355b36fc63 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.45.1" +__version__ = "1.46.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py index 9de913db88..8e66a38421 100644 --- a/synapse/_scripts/review_recent_signups.py +++ b/synapse/_scripts/review_recent_signups.py @@ -20,7 +20,12 @@ from typing import List import attr -from synapse.config._base import RootConfig, find_config_files, read_config_files +from synapse.config._base import ( + Config, + RootConfig, + find_config_files, + read_config_files, +) from synapse.config.database import DatabaseConfig from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.engines import create_engine @@ -126,7 +131,7 @@ def main(): config_dict, ) - since_ms = time.time() * 1000 - config.parse_duration(config_args.since) + since_ms = time.time() * 1000 - Config.parse_duration(config_args.since) exclude_users_with_email = config_args.exclude_emails include_context = not config_args.only_users diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bc550ae646..4b0a9b2974 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -18,7 +18,8 @@ import json from typing import ( TYPE_CHECKING, Awaitable, - Container, + Callable, + Dict, Iterable, List, Optional, @@ -217,19 +218,19 @@ class FilterCollection: return self._filter_json def timeline_limit(self) -> int: - return self._room_timeline_filter.limit() + return self._room_timeline_filter.limit def presence_limit(self) -> int: - return self._presence_filter.limit() + return self._presence_filter.limit def ephemeral_limit(self) -> int: - return self._room_ephemeral_filter.limit() + return self._room_ephemeral_filter.limit def lazy_load_members(self) -> bool: - return self._room_state_filter.lazy_load_members() + return self._room_state_filter.lazy_load_members def include_redundant_members(self) -> bool: - return self._room_state_filter.include_redundant_members() + return self._room_state_filter.include_redundant_members def filter_presence( self, events: Iterable[UserPresenceState] @@ -276,19 +277,25 @@ class Filter: def __init__(self, filter_json: JsonDict): self.filter_json = filter_json - self.types = self.filter_json.get("types", None) - self.not_types = self.filter_json.get("not_types", []) + self.limit = filter_json.get("limit", 10) + self.lazy_load_members = filter_json.get("lazy_load_members", False) + self.include_redundant_members = filter_json.get( + "include_redundant_members", False + ) + + self.types = filter_json.get("types", None) + self.not_types = filter_json.get("not_types", []) - self.rooms = self.filter_json.get("rooms", None) - self.not_rooms = self.filter_json.get("not_rooms", []) + self.rooms = filter_json.get("rooms", None) + self.not_rooms = filter_json.get("not_rooms", []) - self.senders = self.filter_json.get("senders", None) - self.not_senders = self.filter_json.get("not_senders", []) + self.senders = filter_json.get("senders", None) + self.not_senders = filter_json.get("not_senders", []) - self.contains_url = self.filter_json.get("contains_url", None) + self.contains_url = filter_json.get("contains_url", None) - self.labels = self.filter_json.get("org.matrix.labels", None) - self.not_labels = self.filter_json.get("org.matrix.not_labels", []) + self.labels = filter_json.get("org.matrix.labels", None) + self.not_labels = filter_json.get("org.matrix.not_labels", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -302,76 +309,95 @@ class Filter: def check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. + Args: + event: The event, account data, or presence to check against this + filter. + Returns: - True if the event matches + True if the event matches the filter. """ # We usually get the full "events" as dictionaries coming through, # except for presence which actually gets passed around as its own # namedtuple type. if isinstance(event, UserPresenceState): - sender: Optional[str] = event.user_id - room_id = None - ev_type = "m.presence" - contains_url = False - labels: List[str] = [] + user_id = event.user_id + field_matchers = { + "senders": lambda v: user_id == v, + "types": lambda v: "m.presence" == v, + } + return self._check_fields(field_matchers) else: + content = event.get("content") + # Content is assumed to be a dict below, so ensure it is. This should + # always be true for events, but account_data has been allowed to + # have non-dict content. + if not isinstance(content, dict): + content = {} + sender = event.get("sender", None) if not sender: # Presence events had their 'sender' in content.user_id, but are # now handled above. We don't know if anything else uses this # form. TODO: Check this and probably remove it. - content = event.get("content") - # account_data has been allowed to have non-dict content, so - # check type first - if isinstance(content, dict): - sender = content.get("user_id") + sender = content.get("user_id") room_id = event.get("room_id", None) ev_type = event.get("type", None) - content = event.get("content") or {} # check if there is a string url field in the content for filtering purposes - contains_url = isinstance(content.get("url"), str) labels = content.get(EventContentFields.LABELS, []) - return self.check_fields(room_id, sender, ev_type, labels, contains_url) + field_matchers = { + "rooms": lambda v: room_id == v, + "senders": lambda v: sender == v, + "types": lambda v: _matches_wildcard(ev_type, v), + "labels": lambda v: v in labels, + } + + result = self._check_fields(field_matchers) + if not result: + return result + + contains_url_filter = self.contains_url + if contains_url_filter is not None: + contains_url = isinstance(content.get("url"), str) + if contains_url_filter != contains_url: + return False + + return True - def check_fields( - self, - room_id: Optional[str], - sender: Optional[str], - event_type: Optional[str], - labels: Container[str], - contains_url: bool, - ) -> bool: + def _check_fields(self, field_matchers: Dict[str, Callable[[str], bool]]) -> bool: """Checks whether the filter matches the given event fields. + Args: + field_matchers: A map of attribute name to callable to use for checking + particular fields. + + The attribute name and an inverse (not_<attribute name>) must + exist on the Filter. + + The callable should return true if the event's value matches the + filter's value. + Returns: True if the event fields match """ - literal_keys = { - "rooms": lambda v: room_id == v, - "senders": lambda v: sender == v, - "types": lambda v: _matches_wildcard(event_type, v), - "labels": lambda v: v in labels, - } - - for name, match_func in literal_keys.items(): + + for name, match_func in field_matchers.items(): + # If the event matches one of the disallowed values, reject it. not_name = "not_%s" % (name,) disallowed_values = getattr(self, not_name) if any(map(match_func, disallowed_values)): return False + # Other the event does not match at least one of the allowed values, + # reject it. allowed_values = getattr(self, name) if allowed_values is not None: if not any(map(match_func, allowed_values)): return False - contains_url_filter = self.filter_json.get("contains_url") - if contains_url_filter is not None: - if contains_url_filter != contains_url: - return False - + # Otherwise, accept it. return True def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]: @@ -385,10 +411,10 @@ class Filter: """ room_ids = set(room_ids) - disallowed_rooms = set(self.filter_json.get("not_rooms", [])) + disallowed_rooms = set(self.not_rooms) room_ids -= disallowed_rooms - allowed_rooms = self.filter_json.get("rooms", None) + allowed_rooms = self.rooms if allowed_rooms is not None: room_ids &= set(allowed_rooms) @@ -397,15 +423,6 @@ class Filter: def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: return list(filter(self.check, events)) - def limit(self) -> int: - return self.filter_json.get("limit", 10) - - def lazy_load_members(self) -> bool: - return self.filter_json.get("lazy_load_members", False) - - def include_redundant_members(self) -> bool: - return self.filter_json.get("include_redundant_members", False) - def with_room_ids(self, room_ids: Iterable[str]) -> "Filter": """Returns a new filter with the given room IDs appended. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2ca2e051e4..f4c3f867a8 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -31,6 +31,7 @@ import twisted from twisted.internet import defer, error, reactor from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.python.threadpool import ThreadPool import synapse from synapse.api.constants import MAX_PDU_SIZE @@ -48,6 +49,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.metrics.jemalloc import setup_jemalloc_stats from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process +from synapse.util.gai_resolver import GAIResolver from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string @@ -338,9 +340,19 @@ async def start(hs: "HomeServer"): Args: hs: homeserver instance """ + reactor = hs.get_reactor() + + # We want to use a separate thread pool for the resolver so that large + # numbers of DNS requests don't starve out other users of the threadpool. + resolver_threadpool = ThreadPool(name="gai_resolver") + resolver_threadpool.start() + reactor.addSystemEventTrigger("during", "shutdown", resolver_threadpool.stop) + reactor.installNameResolver( + GAIResolver(reactor, getThreadPool=lambda: resolver_threadpool) + ) + # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): - reactor = hs.get_reactor() @wrap_as_background_process("sighup") def handle_sighup(*args, **kwargs): diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 2a6dabdab6..8816ef4b76 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -36,6 +36,7 @@ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] +ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -152,6 +153,7 @@ class ThirdPartyEventRules: self._check_visibility_can_be_modified_callbacks: List[ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] + self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = [] def register_third_party_rules_callbacks( self, @@ -163,6 +165,7 @@ class ThirdPartyEventRules: check_visibility_can_be_modified: Optional[ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, + on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -181,6 +184,9 @@ class ThirdPartyEventRules: check_visibility_can_be_modified, ) + if on_new_event is not None: + self._on_new_event_callbacks.append(on_new_event) + async def check_event_allowed( self, event: EventBase, context: EventContext ) -> Tuple[bool, Optional[dict]]: @@ -321,6 +327,31 @@ class ThirdPartyEventRules: return True + async def on_new_event(self, event_id: str) -> None: + """Let modules act on events after they've been sent (e.g. auto-accepting + invites, etc.) + + Args: + event_id: The ID of the event. + + Raises: + ModuleFailureError if a callback raised any exception. + """ + # Bail out early without hitting the store if we don't have any callbacks + if len(self._on_new_event_callbacks) == 0: + return + + event = await self.store.get_event(event_id) + state_events = await self._get_state_map_for_room(event.room_id) + + for callback in self._on_new_event_callbacks: + try: + await callback(event, state_events) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: """Given a room ID, return the state events of that room. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 2ab4dec88f..670186f548 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -227,7 +227,7 @@ class FederationClient(FederationBase): ) async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Iterable[str] + self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> Optional[List[EventBase]]: """Requests some more historic PDUs for the given room from the given destination server. @@ -237,6 +237,8 @@ class FederationClient(FederationBase): room_id: The room_id to backfill. limit: The maximum number of events to return. extremities: our current backwards extremities, to backfill from + Must be a Collection that is falsy when empty. + (Iterable is not enough here!) """ logger.debug("backfill extrem=%s", extremities) @@ -250,11 +252,22 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%r", transaction_data) + if not isinstance(transaction_data, dict): + # TODO we probably want an exception type specific to federation + # client validation. + raise TypeError("Backfill transaction_data is not a dict.") + + transaction_data_pdus = transaction_data.get("pdus") + if not isinstance(transaction_data_pdus, list): + # TODO we probably want an exception type specific to federation + # client validation. + raise TypeError("transaction_data.pdus is not a list.") + room_version = await self.store.get_room_version(room_id) pdus = [ event_from_pdu_json(p, room_version, outlier=False) - for p in transaction_data["pdus"] + for p in transaction_data_pdus ] # Check signatures and hash of pdus, removing any from the list that fail checks diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 0d66034f44..32a75993d9 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -295,14 +295,16 @@ class FederationServer(FederationBase): Returns: HTTP response code and body """ - response = await self.transaction_actions.have_responded(origin, transaction) + existing_response = await self.transaction_actions.have_responded( + origin, transaction + ) - if response: + if existing_response: logger.debug( "[%s] We've already responded to this request", transaction.transaction_id, ) - return response + return existing_response logger.debug("[%s] Transaction is new", transaction.transaction_id) @@ -632,7 +634,7 @@ class FederationServer(FederationBase): async def on_make_knock_request( self, origin: str, room_id: str, user_id: str, supported_versions: List[str] - ) -> Dict[str, Union[EventBase, str]]: + ) -> JsonDict: """We've received a /make_knock/ request, so we create a partial knock event for the room and hand that back, along with the room version, to the knocking homeserver. We do *not* persist or process this event until the other server has diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index dc555cca0b..ab935e5a7e 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -149,7 +149,6 @@ class TransactionManager: ) except HttpResponseException as e: code = e.code - response = e.response set_tag(tags.ERROR, True) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8b247fe206..d963178838 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,7 +15,19 @@ import logging import urllib -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import attr import ijson @@ -100,7 +112,7 @@ class TransportLayerClient: @log_function async def backfill( - self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int + self, destination: str, room_id: str, event_tuples: Collection[str], limit: int ) -> Optional[JsonDict]: """Requests `limit` previous PDUs in a given context before list of PDUs. @@ -108,7 +120,9 @@ class TransportLayerClient: Args: destination room_id - event_tuples + event_tuples: + Must be a Collection that is falsy when empty. + (Iterable is not enough here!) limit Returns: @@ -786,7 +800,7 @@ class TransportLayerClient: @log_function def join_group( self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: + ) -> Awaitable[JsonDict]: """Attempts to join a group""" path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ebe75a9e9b..d508d7d32a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -62,7 +62,6 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.module_api import ModuleApi from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils @@ -73,6 +72,7 @@ from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: + from synapse.module_api import ModuleApi from synapse.rest.client.login import LoginResponse from synapse.server import HomeServer @@ -1818,7 +1818,9 @@ def load_legacy_password_auth_providers(hs: "HomeServer") -> None: def load_single_legacy_password_auth_provider( - module: Type, config: JsonDict, api: ModuleApi + module: Type, + config: JsonDict, + api: "ModuleApi", ) -> None: try: provider = module(config=config, account_handler=api) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 14ed7d9879..8ca5f60b1c 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -145,7 +145,7 @@ class DirectoryHandler: if not self.config.roomdirectory.is_alias_creation_allowed( user_id, room_id, room_alias_str ): - # Lets just return a generic message, as there may be all sorts of + # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") @@ -245,7 +245,7 @@ class DirectoryHandler: servers = result.servers else: try: - fed_result = await self.federation.make_query( + fed_result: Optional[JsonDict] = await self.federation.make_query( destination=room_alias.domain, query_type="directory", args={"room_alias": room_alias.to_string()}, @@ -461,7 +461,7 @@ class DirectoryHandler: if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases ): - # Lets just return a generic message, as there may be all sorts of + # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3431a80ab4..e617db4c0d 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -361,6 +361,7 @@ class FederationEventHandler: # need to. await self._event_creation_handler.cache_joined_hosts_for_event(event, context) + await self._check_for_soft_fail(event, None, origin=origin) await self._run_push_actions_and_persist_event(event, context) return event, context @@ -402,29 +403,28 @@ class FederationEventHandler: """Persists the events returned by a send_join Checks the auth chain is valid (and passes auth checks) for the - state and event. Then persists the auth chain and state atomically. - Persists the event separately. Notifies about the persisted events - where appropriate. - - Will attempt to fetch missing auth events. + state and event. Then persists all of the events. + Notifies about the persisted events where appropriate. Args: origin: Where the events came from - room_id, + room_id: auth_events state event room_version: The room version we expect this room to have, and will raise if it doesn't match the version in the create event. + + Returns: + The stream ID after which all events have been persisted. + + Raises: + SynapseError if the response is in some way invalid. """ - events_to_context = {} for e in itertools.chain(auth_events, state): e.internal_metadata.outlier = True - events_to_context[e.event_id] = EventContext.for_outlier() - event_map = { - e.event_id: e for e in itertools.chain(auth_events, state, [event]) - } + event_map = {e.event_id: e for e in itertools.chain(auth_events, state)} create_event = None for e in auth_events: @@ -444,68 +444,40 @@ class FederationEventHandler: if room_version.identifier != room_version_id: raise SynapseError(400, "Room version mismatch") - missing_auth_events = set() - for e in itertools.chain(auth_events, state, [event]): - for e_id in e.auth_event_ids(): - if e_id not in event_map: - missing_auth_events.add(e_id) - - for e_id in missing_auth_events: - m_ev = await self._federation_client.get_pdu( - [origin], - e_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - if m_ev and m_ev.event_id == e_id: - event_map[e_id] = m_ev - else: - logger.info("Failed to find auth event %r", e_id) - - for e in itertools.chain(auth_events, state, [event]): - auth_for_e = [ - event_map[e_id] for e_id in e.auth_event_ids() if e_id in event_map - ] - if create_event: - auth_for_e.append(create_event) - - try: - validate_event_for_room_version(room_version, e) - check_auth_rules_for_event(room_version, e, auth_for_e) - except SynapseError as err: - # we may get SynapseErrors here as well as AuthErrors. For - # instance, there are a couple of (ancient) events in some - # rooms whose senders do not have the correct sigil; these - # cause SynapseErrors in auth.check. We don't want to give up - # the attempt to federate altogether in such cases. - - logger.warning("Rejecting %s because %s", e.event_id, err.msg) - - if e == event: - raise - events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR - - if auth_events or state: - await self.persist_events_and_notify( - room_id, - [ - (e, events_to_context[e.event_id]) - for e in itertools.chain(auth_events, state) - ], + # filter out any events we have already seen + seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) + for s in seen_remotes: + event_map.pop(s, None) + + # persist the auth chain and state events. + # + # any invalid events here will be marked as rejected, and we'll carry on. + # + # any events whose auth events are missing (ie, not in the send_join response, + # and not already in our db) will just be ignored. This is correct behaviour, + # because the reason that auth_events are missing might be due to us being + # unable to validate their signatures. The fact that we can't validate their + # signatures right now doesn't mean that we will *never* be able to, so it + # is premature to reject them. + # + await self._auth_and_persist_outliers(room_id, event_map.values()) + + # and now persist the join event itself. + logger.info("Peristing join-via-remote %s", event) + with nested_logging_context(suffix=event.event_id): + context = await self._state_handler.compute_event_context( + event, old_state=state ) - new_event_context = await self._state_handler.compute_event_context( - event, old_state=state - ) + context = await self._check_event_auth(origin, event, context) + if context.rejected: + raise SynapseError(400, "Join event was rejected") - return await self.persist_events_and_notify( - room_id, [(event, new_event_context)] - ) + return await self.persist_events_and_notify(room_id, [(event, context)]) @log_function async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Iterable[str] + self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> None: """Trigger a backfill request to `dest` for the given `room_id` @@ -974,9 +946,15 @@ class FederationEventHandler: ) -> None: """Called when we have a new non-outlier event. - This is called when we have a new event to add to the room DAG - either directly - via a /send request, retrieved via get_missing_events after a /send request, or - backfilled after a client request. + This is called when we have a new event to add to the room DAG. This can be + due to: + * events received directly via a /send request + * events retrieved via get_missing_events after a /send request + * events backfilled after a client request. + + It's not currently used for events received from incoming send_{join,knock,leave} + requests (which go via on_send_membership_event), nor for joins created by a + remote join dance (which go via process_remote_join). We need to do auth checks and put it through the StateHandler. @@ -1012,11 +990,19 @@ class FederationEventHandler: logger.exception("Unexpected AuthError from _check_event_auth") raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + if not backfilled and not context.rejected: + # For new (non-backfilled and non-outlier) events we check if the event + # passes auth based on the current state. If it doesn't then we + # "soft-fail" the event. + await self._check_for_soft_fail(event, state, origin=origin) + await self._run_push_actions_and_persist_event(event, context, backfilled) - if backfilled: + if backfilled or context.rejected: return + await self._maybe_kick_guest_users(event) + # For encrypted messages we check that we know about the sending device, # if we don't then we mark the device cache for that user as stale. if event.type == EventTypes.Encrypted: @@ -1317,14 +1303,14 @@ class FederationEventHandler: for auth_event_id in event.auth_event_ids(): ae = persisted_events.get(auth_event_id) if not ae: + # the fact we can't find the auth event doesn't mean it doesn't + # exist, which means it is premature to reject `event`. Instead we + # just ignore it for now. logger.warning( - "Event %s relies on auth_event %s, which could not be found.", + "Dropping event %s, which relies on auth_event %s, which could not be found", event, auth_event_id, ) - # the fact we can't find the auth event doesn't mean it doesn't - # exist, which means it is premature to reject `event`. Instead we - # just ignore it for now. return None auth.append(ae) @@ -1447,10 +1433,6 @@ class FederationEventHandler: except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR - return context - - await self._check_for_soft_fail(event, state, backfilled, origin=origin) - await self._maybe_kick_guest_users(event) return context @@ -1470,7 +1452,6 @@ class FederationEventHandler: self, event: EventBase, state: Optional[Iterable[EventBase]], - backfilled: bool, origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1479,15 +1460,8 @@ class FederationEventHandler: Args: event state: The state at the event if we don't have all the event's prev events - backfilled: Whether the event is from backfill origin: The host the event originates from. """ - # For new (non-backfilled and non-outlier) events we check if the event - # passes auth based on the current state. If it doesn't then we - # "soft-fail" the event. - if backfilled or event.internal_metadata.is_outlier(): - return - extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) @@ -1942,7 +1916,7 @@ class FederationEventHandler: event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) - self._notifier.on_new_room_event( + await self._notifier.on_new_room_event( event, event_pos, max_stream_token, extra_users=extra_users ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 7ef8698a5e..6a315117ba 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -879,6 +879,8 @@ class IdentityHandler: } if room_type is not None: + invite_config["room_type"] = room_type + # TODO The unstable field is deprecated and should be removed in the future. invite_config["org.matrix.msc3288.room_type"] = room_type # If a custom web client location is available, include it in the request. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 2e024b551f..4a0fccfcc6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1537,13 +1537,16 @@ class EventCreationHandler: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - def _notify() -> None: + async def _notify() -> None: try: - self.notifier.on_new_room_event( + await self.notifier.on_new_room_event( event, event_pos, max_stream_token, extra_users=extra_users ) except Exception: - logger.exception("Error notifying about new room event") + logger.exception( + "Error notifying about new room event %s", + event.event_id, + ) run_in_background(_notify) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 60ff896386..abfe7be0e3 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -438,7 +438,7 @@ class PaginationHandler: } state = None - if event_filter and event_filter.lazy_load_members() and len(events) > 0: + if event_filter and event_filter.lazy_load_members and len(events) > 0: # TODO: remove redundant members # FIXME: we also care about invite targets etc. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fdab50da37..3df872c578 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -52,6 +52,7 @@ import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState +from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.logging.utils import log_function @@ -1551,6 +1552,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, + service: Optional[ApplicationService] = None, ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index e6c3cf585b..6b5a6ded8b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -456,7 +456,11 @@ class ProfileHandler: continue new_name = profile.get("displayname") + if not isinstance(new_name, str): + new_name = None new_avatar = profile.get("avatar_url") + if not isinstance(new_avatar, str): + new_avatar = None # We always hit update to update the last_check timestamp await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6f39e9446f..99e9b37344 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -773,6 +773,15 @@ class RoomCreationHandler: if not allowed_by_third_party_rules: raise SynapseError(403, "Room visibility value not allowed.") + if is_public: + if not self.config.roomdirectory.is_publishing_room_allowed( + user_id, room_id, room_alias + ): + # Let's just return a generic message, as there may be all sorts of + # reasons why we said no. TODO: Allow configurable error messages + # per alias creation rule? + raise SynapseError(403, "Not allowed to publish room") + directory_handler = self.hs.get_directory_handler() if room_alias: await directory_handler.create_association( @@ -783,15 +792,6 @@ class RoomCreationHandler: check_membership=False, ) - if is_public: - if not self.config.roomdirectory.is_publishing_room_allowed( - user_id, room_id, room_alias - ): - # Lets just return a generic message, as there may be all sorts of - # reasons why we said no. TODO: Allow configurable error messages - # per alias creation rule? - raise SynapseError(403, "Not allowed to publish room") - preset_config = config.get( "preset", RoomCreationPreset.PRIVATE_CHAT @@ -1173,7 +1173,7 @@ class RoomContextHandler: else: last_event_id = event_id - if event_filter and event_filter.lazy_load_members(): + if event_filter and event_filter.lazy_load_members: state_filter = StateFilter.from_lazy_load_member_list( ev.sender for ev in itertools.chain( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index a3ffa26be8..6e4dff8056 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -249,7 +249,7 @@ class SearchHandler: ) events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[: search_filter.limit()] + allowed_events = events[: search_filter.limit] for e in allowed_events: rm = room_groups.setdefault( @@ -271,13 +271,13 @@ class SearchHandler: # We keep looping and we keep filtering until we reach the limit # or we run out of things. # But only go around 5 times since otherwise synapse will be sad. - while len(room_events) < search_filter.limit() and i < 5: + while len(room_events) < search_filter.limit and i < 5: i += 1 search_result = await self.store.search_rooms( room_ids, search_term, keys, - search_filter.limit() * 2, + search_filter.limit * 2, pagination_token=pagination_token, ) @@ -299,9 +299,9 @@ class SearchHandler: ) room_events.extend(events) - room_events = room_events[: search_filter.limit()] + room_events = room_events[: search_filter.limit] - if len(results) < search_filter.limit() * 2: + if len(results) < search_filter.limit * 2: pagination_token = None break else: @@ -311,7 +311,7 @@ class SearchHandler: group = room_groups.setdefault(event.room_id, {"results": []}) group["results"].append(event.event_id) - if room_events and len(room_events) >= search_filter.limit(): + if room_events and len(room_events) >= search_filter.limit: last_event_id = room_events[-1].event_id pagination_token = results_map[last_event_id]["pagination_token"] diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 08895e72ee..4a01b902c2 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -16,6 +16,7 @@ import logging from functools import wraps from inspect import getcallargs +from typing import Callable, TypeVar, cast _TIME_FUNC_ID = 0 @@ -41,7 +42,10 @@ def _log_debug_as_f(f, msg, msg_args): logger.handle(record) -def log_function(f): +F = TypeVar("F", bound=Callable) + + +def log_function(f: F) -> F: """Function decorator that logs every call to that function.""" func_name = f.__name__ @@ -69,4 +73,4 @@ def log_function(f): return f(*args, **kwargs) wrapped.__name__ = func_name - return wrapped + return cast(F, wrapped) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ab7ef8f950..d707a9325d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -46,6 +46,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client.login import LoginResponse +from synapse.storage import DataStore from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -61,6 +62,7 @@ from synapse.util import Clock from synapse.util.caches.descriptors import cached if TYPE_CHECKING: + from synapse.app.generic_worker import GenericWorkerSlavedStore from synapse.server import HomeServer """ @@ -111,7 +113,9 @@ class ModuleApi: def __init__(self, hs: "HomeServer", auth_handler): self._hs = hs - self._store = hs.get_datastore() + # TODO: Fix this type hint once the types for the data stores have been ironed + # out. + self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -150,27 +154,42 @@ class ModuleApi: @property def register_spam_checker_callbacks(self): - """Registers callbacks for spam checking capabilities.""" + """Registers callbacks for spam checking capabilities. + + Added in Synapse v1.37.0. + """ return self._spam_checker.register_callbacks @property def register_account_validity_callbacks(self): - """Registers callbacks for account validity capabilities.""" + """Registers callbacks for account validity capabilities. + + Added in Synapse v1.39.0. + """ return self._account_validity_handler.register_account_validity_callbacks @property def register_third_party_rules_callbacks(self): - """Registers callbacks for third party event rules capabilities.""" + """Registers callbacks for third party event rules capabilities. + + Added in Synapse v1.39.0. + """ return self._third_party_event_rules.register_third_party_rules_callbacks @property def register_presence_router_callbacks(self): - """Registers callbacks for presence router capabilities.""" + """Registers callbacks for presence router capabilities. + + Added in Synapse v1.42.0. + """ return self._presence_router.register_presence_router_callbacks @property def register_password_auth_provider_callbacks(self): - """Registers callbacks for password auth provider capabilities.""" + """Registers callbacks for password auth provider capabilities. + + Added in Synapse v1.46.0. + """ return self._password_auth_provider.register_password_auth_provider_callbacks def register_web_resource(self, path: str, resource: IResource): @@ -181,6 +200,8 @@ class ModuleApi: If multiple modules register a resource for the same path, the module that appears the highest in the configuration file takes priority. + Added in Synapse v1.37.0. + Args: path: The path to register the resource for. resource: The resource to attach to this path. @@ -195,6 +216,8 @@ class ModuleApi: """Allows making outbound HTTP requests to remote resources. An instance of synapse.http.client.SimpleHttpClient + + Added in Synapse v1.22.0. """ return self._http_client @@ -204,22 +227,32 @@ class ModuleApi: public room list. An instance of synapse.module_api.PublicRoomListManager + + Added in Synapse v1.22.0. """ return self._public_room_list_manager @property def public_baseurl(self) -> str: - """The configured public base URL for this homeserver.""" + """The configured public base URL for this homeserver. + + Added in Synapse v1.39.0. + """ return self._hs.config.server.public_baseurl @property def email_app_name(self) -> str: - """The application name configured in the homeserver's configuration.""" + """The application name configured in the homeserver's configuration. + + Added in Synapse v1.39.0. + """ return self._hs.config.email.email_app_name async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: """Get user info by user_id + Added in Synapse v1.41.0. + Args: user_id: Fully qualified user id. Returns: @@ -235,6 +268,8 @@ class ModuleApi: ) -> Requester: """Check the access_token provided for a request + Added in Synapse v1.39.0. + Args: req: Incoming HTTP request allow_guest: True if guest users should be allowed. If this @@ -260,6 +295,8 @@ class ModuleApi: async def is_user_admin(self, user_id: str) -> bool: """Checks if a user is a server admin. + Added in Synapse v1.39.0. + Args: user_id: The Matrix ID of the user to check. @@ -274,6 +311,8 @@ class ModuleApi: Takes a user id provided by the user and adds the @ and :domain to qualify it, if necessary + Added in Synapse v0.25.0. + Args: username (str): provided user id @@ -287,6 +326,8 @@ class ModuleApi: async def get_profile_for_user(self, localpart: str) -> ProfileInfo: """Look up the profile info for the user with the given localpart. + Added in Synapse v1.39.0. + Args: localpart: The localpart to look up profile information for. @@ -299,6 +340,8 @@ class ModuleApi: """Look up the threepids (email addresses and phone numbers) associated with the given Matrix user ID. + Added in Synapse v1.39.0. + Args: user_id: The Matrix user ID to look up threepids for. @@ -313,6 +356,8 @@ class ModuleApi: def check_user_exists(self, user_id): """Check if user exists. + Added in Synapse v0.25.0. + Args: user_id (str): Complete @user:id @@ -332,6 +377,8 @@ class ModuleApi: return that device to the user. Prefer separate calls to register_user and register_device. + Added in Synapse v0.25.0. + Args: localpart (str): The localpart of the new user. displayname (str|None): The displayname of the new user. @@ -352,6 +399,8 @@ class ModuleApi: ): """Registers a new user with given localpart and optional displayname, emails. + Added in Synapse v1.2.0. + Args: localpart (str): The localpart of the new user. displayname (str|None): The displayname of the new user. @@ -375,6 +424,8 @@ class ModuleApi: def register_device(self, user_id, device_id=None, initial_display_name=None): """Register a device for a user and generate an access token. + Added in Synapse v1.2.0. + Args: user_id (str): full canonical @user:id device_id (str|None): The device ID to check, or None to generate @@ -398,6 +449,8 @@ class ModuleApi: ) -> defer.Deferred: """Record a mapping from an external user id to a mxid + Added in Synapse v1.9.0. + Args: auth_provider: identifier for the remote auth provider external_id: id on that system @@ -417,6 +470,8 @@ class ModuleApi: ) -> str: """Generate a login token suitable for m.login.token authentication + Added in Synapse v1.9.0. + Args: user_id: gives the ID of the user that the token is for @@ -436,6 +491,8 @@ class ModuleApi: def invalidate_access_token(self, access_token): """Invalidate an access token for a user + Added in Synapse v0.25.0. + Args: access_token(str): access token @@ -466,6 +523,8 @@ class ModuleApi: def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection + Added in Synapse v0.25.0. + Args: desc (str): description for the transaction, for metrics etc func (func): function to be run. Passed a database cursor object @@ -489,6 +548,8 @@ class ModuleApi: This is deprecated in favor of complete_sso_login_async. + Added in Synapse v1.11.1. + Args: registered_user_id: The MXID that has been registered as a previous step of of this SSO login. @@ -515,6 +576,8 @@ class ModuleApi: want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. + Added in Synapse v1.13.0. + Args: registered_user_id: The MXID that has been registered as a previous step of of this SSO login. @@ -543,6 +606,8 @@ class ModuleApi: (This is exposed for compatibility with the old SpamCheckerApi. We should probably deprecate it and replace it with an async method in a subclass.) + Added in Synapse v1.22.0. + Args: room_id: The room ID to get state events in. types: The event type and state key (using None @@ -563,6 +628,8 @@ class ModuleApi: async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase: """Create and send an event into a room. Membership events are currently not supported. + Added in Synapse v1.22.0. + Args: event_dict: A dictionary representing the event to send. Required keys are `type`, `room_id`, `sender` and `content`. @@ -603,6 +670,8 @@ class ModuleApi: Note that this method can only be run on the process that is configured to write to the presence stream. By default this is the main process. + + Added in Synapse v1.32.0. """ if self._hs._instance_name not in self._hs.config.worker.writers.presence: raise Exception( @@ -657,6 +726,8 @@ class ModuleApi: Waits `msec` initially before calling `f` for the first time. + Added in Synapse v1.39.0. + Args: f: The function to call repeatedly. f can be either synchronous or asynchronous, and must follow Synapse's logcontext rules. @@ -696,6 +767,8 @@ class ModuleApi: ): """Send an email on behalf of the homeserver. + Added in Synapse v1.39.0. + Args: recipient: The email address for the recipient. subject: The email's subject. @@ -719,6 +792,8 @@ class ModuleApi: By default, Synapse will look for these templates in its configured template directory, but another directory to search in can be provided. + Added in Synapse v1.39.0. + Args: filenames: The name of the template files to look for. custom_template_directory: An additional directory to look for the files in. @@ -736,13 +811,13 @@ class ModuleApi: """ Checks whether an ID (user id, room, ...) comes from this homeserver. + Added in Synapse v1.44.0. + Args: id: any Matrix id (e.g. user id, room id, ...), either as a raw id, e.g. string "@user:example.com" or as a parsed UserID, RoomID, ... Returns: True if id comes from this homeserver, False otherwise. - - Added in Synapse v1.44.0. """ if isinstance(id, DomainSpecificString): return self._hs.is_mine(id) @@ -755,6 +830,8 @@ class ModuleApi: """ Return the list of user IPs and agents for a user. + Added in Synapse v1.44.0. + Args: user_id: the id of a user, local or remote since_ts: a timestamp in seconds since the epoch, @@ -763,8 +840,6 @@ class ModuleApi: The list of all UserIpAndAgent that the user has used to connect to this homeserver since `since_ts`. If the user is remote, this list is empty. - - Added in Synapse v1.44.0. """ # Don't hit the db if this is not a local user. is_mine = False @@ -803,6 +878,8 @@ class PublicRoomListManager: async def room_is_in_public_room_list(self, room_id: str) -> bool: """Checks whether a room is in the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. @@ -819,6 +896,8 @@ class PublicRoomListManager: async def add_room_to_public_room_list(self, room_id: str) -> None: """Publishes a room to the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. """ @@ -827,6 +906,8 @@ class PublicRoomListManager: async def remove_room_from_public_room_list(self, room_id: str) -> None: """Removes a room from the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. """ diff --git a/synapse/notifier.py b/synapse/notifier.py index 1acd899fab..1882fffd2a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -220,6 +220,8 @@ class Notifier: # down. self.remote_server_up_callbacks: List[Callable[[str], None]] = [] + self._third_party_rules = hs.get_third_party_event_rules() + self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() self._pusher_pool = hs.get_pusherpool() @@ -267,7 +269,7 @@ class Notifier: """ self.replication_callbacks.append(cb) - def on_new_room_event( + async def on_new_room_event( self, event: EventBase, event_pos: PersistedEventPosition, @@ -275,9 +277,10 @@ class Notifier: extra_users: Optional[Collection[UserID]] = None, ): """Unwraps event and calls `on_new_room_event_args`.""" - self.on_new_room_event_args( + await self.on_new_room_event_args( event_pos=event_pos, room_id=event.room_id, + event_id=event.event_id, event_type=event.type, state_key=event.get("state_key"), membership=event.content.get("membership"), @@ -285,9 +288,10 @@ class Notifier: extra_users=extra_users or [], ) - def on_new_room_event_args( + async def on_new_room_event_args( self, room_id: str, + event_id: str, event_type: str, state_key: Optional[str], membership: Optional[str], @@ -302,7 +306,10 @@ class Notifier: listening to the room, and any listeners for the users in the `extra_users` param. - The events can be peristed out of order. The notifier will wait + This also notifies modules listening on new events via the + `on_new_event` callback. + + The events can be persisted out of order. The notifier will wait until all previous events have been persisted before notifying the client streams. """ @@ -318,6 +325,8 @@ class Notifier: ) self._notify_pending_new_room_events(max_room_stream_token) + await self._third_party_rules.on_new_event(event_id) + self.notify_replication() def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 961c17762e..e29ae1e375 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -207,11 +207,12 @@ class ReplicationDataHandler: max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) - self.notifier.on_new_room_event_args( + await self.notifier.on_new_room_event_args( event_pos=event_pos, max_room_stream_token=max_token, extra_users=extra_users, room_id=row.data.room_id, + event_id=row.data.event_id, event_type=row.data.type, state_key=row.data.state_key, membership=row.data.membership, diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c0bebc3cf0..d14fafbbc9 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -326,6 +326,9 @@ class UserRestServletV2(RestServlet): target_user.to_string() ) + if "user_type" in body: + await self.store.set_user_type(target_user, user_type) + user = await self.admin_handler.get_user(target_user) assert user is not None diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5cf2e12575..98a0239759 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -26,6 +26,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -519,7 +520,7 @@ class StateResolutionHandler: self, room_id: str, room_version: str, - state_groups_ids: Dict[int, StateMap[str]], + state_groups_ids: Mapping[int, StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ) -> _StateCacheEntry: @@ -703,7 +704,7 @@ class StateResolutionHandler: def _make_state_cache_entry( - new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] + new_state: StateMap[str], state_groups_ids: Mapping[int, StateMap[str]] ) -> _StateCacheEntry: """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index b81d9218ce..1dc7f0ebe3 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): return {(d["user_id"], d["device_id"]): d for d in res} + async def get_user_ip_and_agents( + self, user: UserID, since_ts: int = 0 + ) -> List[LastConnectionInfo]: + """Fetch the IPs and user agents for a user since the given timestamp. + + The result might be slightly out of date as client IPs are inserted in batches. + + Args: + user: The user for which to fetch IP addresses and user agents. + since_ts: The timestamp after which to fetch IP addresses and user agents, + in milliseconds. + + Returns: + A list of dictionaries, each containing: + * `access_token`: The access token used. + * `ip`: The IP address used. + * `user_agent`: The last user agent seen for this access token and IP + address combination. + * `last_seen`: The timestamp at which this access token and IP address + combination was last seen, in milliseconds. + + Only the latest user agent for each access token and IP address combination + is available. + """ + user_id = user.to_string() + + def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: + txn.execute( + """ + SELECT access_token, ip, user_agent, last_seen FROM user_ips + WHERE last_seen >= ? AND user_id = ? + ORDER BY last_seen + DESC + """, + (since_ts, user_id), + ) + return cast(List[Tuple[str, str, str, int]], txn.fetchall()) + + rows = await self.db_pool.runInteraction( + desc="get_user_ip_and_agents", func=get_recent + ) + + return [ + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for access_token, ip, user_agent, last_seen in rows + ] + class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): @@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): async def get_user_ip_and_agents( self, user: UserID, since_ts: int = 0 ) -> List[LastConnectionInfo]: + """Fetch the IPs and user agents for a user since the given timestamp. + + Args: + user: The user for which to fetch IP addresses and user agents. + since_ts: The timestamp after which to fetch IP addresses and user agents, + in milliseconds. + + Returns: + A list of dictionaries, each containing: + * `access_token`: The access token used. + * `ip`: The IP address used. + * `user_agent`: The last user agent seen for this access token and IP + address combination. + * `last_seen`: The timestamp at which this access token and IP address + combination was last seen, in milliseconds. + + Only the latest user agent for each access token and IP address combination + is available. """ - Fetch IP/User Agent connection since a given timestamp. - """ - user_id = user.to_string() - results: Dict[Tuple[str, str], Tuple[str, int]] = {} + results: Dict[Tuple[str, str], LastConnectionInfo] = { + (connection["access_token"], connection["ip"]): connection + for connection in await super().get_user_ip_and_agents(user, since_ts) + } + # Overlay data that is pending insertion on top of the results from the + # database. + user_id = user.to_string() for key in self._batch_row_update: - ( - uid, - access_token, - ip, - ) = key + uid, access_token, ip = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] if last_seen >= since_ts: - results[(access_token, ip)] = (user_agent, last_seen) - - def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: - txn.execute( - """ - SELECT access_token, ip, user_agent, last_seen FROM user_ips - WHERE last_seen >= ? AND user_id = ? - ORDER BY last_seen - DESC - """, - (since_ts, user_id), - ) - return cast(List[Tuple[str, str, str, int]], txn.fetchall()) - - rows = await self.db_pool.runInteraction( - desc="get_user_ip_and_agents", func=get_recent - ) + results[(access_token, ip)] = { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } - results.update( - ((access_token, ip), (user_agent, last_seen)) - for access_token, ip, user_agent, last_seen in rows - ) - return [ - { - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "last_seen": last_seen, - } - for (access_token, ip), (user_agent, last_seen) in results.items() - ] + return list(results.values()) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 8143168107..b0ccab0c9b 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -19,9 +19,10 @@ from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -555,6 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox" def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) @@ -570,6 +572,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) + self.db_pool.updates.register_background_update_handler( + self.REMOVE_DELETED_DEVICES, + self._remove_deleted_devices_from_device_inbox, + ) + async def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() @@ -582,6 +589,89 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): return 1 + async def _remove_deleted_devices_from_device_inbox( + self, progress: JsonDict, batch_size: int + ) -> int: + """A background update that deletes all device_inboxes for deleted devices. + + This should only need to be run once (when users upgrade to v1.46.0) + + Args: + progress: JsonDict used to store progress of this background update + batch_size: the maximum number of rows to retrieve in a single select query + + Returns: + The number of deleted rows + """ + + def _remove_deleted_devices_from_device_inbox_txn( + txn: LoggingTransaction, + ) -> int: + """stream_id is not unique + we need to use an inclusive `stream_id >= ?` clause, + since we might not have deleted all dead device messages for the stream_id + returned from the previous query + + Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, + to avoid problems of deleting a large number of rows all at once + due to a single device having lots of device messages. + """ + + last_stream_id = progress.get("stream_id", 0) + + sql = """ + SELECT device_id, user_id, stream_id + FROM device_inbox + WHERE + stream_id >= ? + AND (device_id, user_id) NOT IN ( + SELECT device_id, user_id FROM devices + ) + ORDER BY stream_id + LIMIT ? + """ + + txn.execute(sql, (last_stream_id, batch_size)) + rows = txn.fetchall() + + num_deleted = 0 + for row in rows: + num_deleted += self.db_pool.simple_delete_txn( + txn, + "device_inbox", + {"device_id": row[0], "user_id": row[1], "stream_id": row[2]}, + ) + + if rows: + # send more than stream_id to progress + # otherwise it can happen in large deployments that + # no change of status is visible in the log file + # it may be that the stream_id does not change in several runs + self.db_pool.updates._background_update_progress_txn( + txn, + self.REMOVE_DELETED_DEVICES, + { + "device_id": rows[-1][0], + "user_id": rows[-1][1], + "stream_id": rows[-1][2], + }, + ) + + return num_deleted + + number_deleted = await self.db_pool.runInteraction( + "_remove_deleted_devices_from_device_inbox", + _remove_deleted_devices_from_device_inbox_txn, + ) + + # The task is finished when no more lines are deleted. + if not number_deleted: + await self.db_pool.updates._end_background_update( + self.REMOVE_DELETED_DEVICES + ) + + return number_deleted + class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): pass diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a01bf2c5b7..b15cd030e0 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1134,19 +1134,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): raise StoreError(500, "Problem storing device.") async def delete_device(self, user_id: str, device_id: str) -> None: - """Delete a device. + """Delete a device and its device_inbox. Args: user_id: The ID of the user which owns the device device_id: The ID of the device to delete """ - await self.db_pool.simple_delete_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - desc="delete_device", - ) - self.device_id_exists_cache.invalidate((user_id, device_id)) + await self.delete_devices(user_id, [device_id]) async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. @@ -1155,13 +1150,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: The ID of the user which owns the devices device_ids: The IDs of the devices to delete """ - await self.db_pool.simple_delete_many( - table="devices", - column="device_id", - iterable=device_ids, - keyvalues={"user_id": user_id, "hidden": False}, - desc="delete_devices", - ) + + def _delete_devices_txn(txn: LoggingTransaction) -> None: + self.db_pool.simple_delete_many_txn( + txn, + table="devices", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id, "hidden": False}, + ) + + self.db_pool.simple_delete_many_txn( + txn, + table="device_inbox", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id}, + ) + + await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index fc49112063..ae3a8a63e4 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -17,11 +17,15 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr -from synapse.api.constants import EventContentFields +from synapse.api.constants import EventContentFields, RelationTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, make_tuple_comparison_clause +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.types import Cursor from synapse.types import JsonDict @@ -167,6 +171,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._purged_chain_cover_index, ) + self.db_pool.updates.register_background_update_handler( + "event_thread_relation", self._event_thread_relation + ) + ################################################################################ # bg updates for replacing stream_ordering with a BIGINT @@ -1091,6 +1099,79 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result + async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int: + """Background update handler which will store thread relations for existing events.""" + last_event_id = progress.get("last_event_id", "") + + def _event_thread_relation_txn(txn: LoggingTransaction) -> int: + txn.execute( + """ + SELECT event_id, json FROM event_json + LEFT JOIN event_relations USING (event_id) + WHERE event_id > ? AND event_relations.event_id IS NULL + ORDER BY event_id LIMIT ? + """, + (last_event_id, batch_size), + ) + + results = list(txn) + missing_thread_relations = [] + for (event_id, event_json_raw) in results: + try: + event_json = db_to_json(event_json_raw) + except Exception as e: + logger.warning( + "Unable to load event %s (no relations will be updated): %s", + event_id, + e, + ) + continue + + # If there's no relation (or it is not a thread), skip! + relates_to = event_json["content"].get("m.relates_to") + if not relates_to or not isinstance(relates_to, dict): + continue + if relates_to.get("rel_type") != RelationTypes.THREAD: + continue + + # Get the parent ID. + parent_id = relates_to.get("event_id") + if not isinstance(parent_id, str): + continue + + missing_thread_relations.append((event_id, parent_id)) + + # Insert the missing data. + self.db_pool.simple_insert_many_txn( + txn=txn, + table="event_relations", + values=[ + { + "event_id": event_id, + "relates_to_Id": parent_id, + "relation_type": RelationTypes.THREAD, + } + for event_id, parent_id in missing_thread_relations + ], + ) + + if results: + latest_event_id = results[-1][0] + self.db_pool.updates._background_update_progress_txn( + txn, "event_thread_relation", {"last_event_id": latest_event_id} + ) + + return len(results) + + num_rows = await self.db_pool.runInteraction( + desc="event_thread_relation", func=_event_thread_relation_txn + ) + + if not num_rows: + await self.db_pool.updates._end_background_update("event_thread_relation") + + return num_rows + async def _background_populate_stream_ordering2( self, progress: JsonDict, batch_size: int ) -> int: diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index ba7075caa5..dd8e27e226 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def update_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str + self, user_id: str, displayname: Optional[str], avatar_url: Optional[str] ) -> int: return await self.db_pool.simple_update( table="remote_profile_cache", diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 37d47aa823..6c7d6ba508 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -499,6 +499,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None: + """Sets the user type. + + Args: + user: user ID of the user. + user_type: type of the user or None for a user without a type. + """ + + def set_user_type_txn(txn): + self.db_pool.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"user_type": user_type} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + await self.db_pool.runInteraction("set_user_type", set_user_type_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, diff --git a/synapse/storage/schema/main/delta/64/02remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/64/02remove_deleted_devices_from_device_inbox.sql new file mode 100644 index 0000000000..efe702f621 --- /dev/null +++ b/synapse/storage/schema/main/delta/64/02remove_deleted_devices_from_device_inbox.sql @@ -0,0 +1,22 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Remove messages from the device_inbox table which were orphaned +-- when a device was deleted using Synapse earlier than 1.46.0. +-- This runs as background task, but may take a bit to finish. + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6402, 'remove_deleted_devices_from_device_inbox', '{}'); diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/02_thread_relations.sql new file mode 100644 index 0000000000..d60517f7b4 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/02_thread_relations.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Check old events for thread relations. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6502, 'event_thread_relation', '{}'); diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py new file mode 100644 index 0000000000..a447ce4e55 --- /dev/null +++ b/synapse/util/gai_resolver.py @@ -0,0 +1,136 @@ +# This is a direct lift from +# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/_resolver.py. +# We copy it here as we need to instantiate `GAIResolver` manually, but it is a +# private class. + + +from socket import ( + AF_INET, + AF_INET6, + AF_UNSPEC, + SOCK_DGRAM, + SOCK_STREAM, + gaierror, + getaddrinfo, +) + +from zope.interface import implementer + +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.interfaces import IHostnameResolver, IHostResolution +from twisted.internet.threads import deferToThreadPool + + +@implementer(IHostResolution) +class HostResolution: + """ + The in-progress resolution of a given hostname. + """ + + def __init__(self, name): + """ + Create a L{HostResolution} with the given name. + """ + self.name = name + + def cancel(self): + # IHostResolution.cancel + raise NotImplementedError() + + +_any = frozenset([IPv4Address, IPv6Address]) + +_typesToAF = { + frozenset([IPv4Address]): AF_INET, + frozenset([IPv6Address]): AF_INET6, + _any: AF_UNSPEC, +} + +_afToType = { + AF_INET: IPv4Address, + AF_INET6: IPv6Address, +} + +_transportToSocket = { + "TCP": SOCK_STREAM, + "UDP": SOCK_DGRAM, +} + +_socktypeToType = { + SOCK_STREAM: "TCP", + SOCK_DGRAM: "UDP", +} + + +@implementer(IHostnameResolver) +class GAIResolver: + """ + L{IHostnameResolver} implementation that resolves hostnames by calling + L{getaddrinfo} in a thread. + """ + + def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo): + """ + Create a L{GAIResolver}. + @param reactor: the reactor to schedule result-delivery on + @type reactor: L{IReactorThreads} + @param getThreadPool: a function to retrieve the thread pool to use for + scheduling name resolutions. If not supplied, the use the given + C{reactor}'s thread pool. + @type getThreadPool: 0-argument callable returning a + L{twisted.python.threadpool.ThreadPool} + @param getaddrinfo: a reference to the L{getaddrinfo} to use - mainly + parameterized for testing. + @type getaddrinfo: callable with the same signature as L{getaddrinfo} + """ + self._reactor = reactor + self._getThreadPool = ( + reactor.getThreadPool if getThreadPool is None else getThreadPool + ) + self._getaddrinfo = getaddrinfo + + def resolveHostName( + self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics="TCP", + ): + """ + See L{IHostnameResolver.resolveHostName} + @param resolutionReceiver: see interface + @param hostName: see interface + @param portNumber: see interface + @param addressTypes: see interface + @param transportSemantics: see interface + @return: see interface + """ + pool = self._getThreadPool() + addressFamily = _typesToAF[ + _any if addressTypes is None else frozenset(addressTypes) + ] + socketType = _transportToSocket[transportSemantics] + + def get(): + try: + return self._getaddrinfo( + hostName, portNumber, addressFamily, socketType + ) + except gaierror: + return [] + + d = deferToThreadPool(self._reactor, pool, get) + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + + @d.addCallback + def deliverResults(result): + for family, socktype, _proto, _cannoname, sockaddr in result: + addrType = _afToType[family] + resolutionReceiver.addressResolved( + addrType(_socktypeToType.get(socktype, "TCP"), *sockaddr) + ) + resolutionReceiver.resolutionComplete() + + return resolution |