diff options
Diffstat (limited to 'synapse')
57 files changed, 1608 insertions, 901 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index c38666da18..6324df883b 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -97,6 +97,7 @@ BOOLEAN_COLUMNS = { "users": ["shadow_banned"], "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], + "device_lists_changes_in_room": ["converted_to_destinations"], } diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 6f8e33a156..2b0d92cbae 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -33,7 +33,6 @@ from synapse.handlers.admin import ExfiltrationWriter from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.events import SlavedEventStore @@ -61,7 +60,6 @@ class AdminCmdSlavedStore( SlavedDeviceStore, SlavedPushRuleStore, SlavedEventStore, - SlavedClientIpStore, BaseSlavedStore, RoomWorkerStore, ): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index b6f510ed30..1865c671f4 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -53,7 +53,6 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.directory import DirectoryStore @@ -247,7 +246,6 @@ class GenericWorkerSlavedStore( SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedProfileStore, - SlavedClientIpStore, SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 07ec95f1d6..d23d9221bc 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +23,13 @@ from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id +from synapse.types import ( + DeviceListUpdates, + GroupID, + JsonDict, + UserID, + get_domain_from_id, +) from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -400,6 +407,7 @@ class AppServiceTransaction: to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ): self.service = service self.id = id @@ -408,6 +416,7 @@ class AppServiceTransaction: self.to_device_messages = to_device_messages self.one_time_key_counts = one_time_key_counts self.unused_fallback_keys = unused_fallback_keys + self.device_list_summary = device_list_summary async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -424,6 +433,7 @@ class AppServiceTransaction: to_device_messages=self.to_device_messages, one_time_key_counts=self.one_time_key_counts, unused_fallback_keys=self.unused_fallback_keys, + device_list_summary=self.device_list_summary, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 98fe354014..0cdbb04bfb 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,7 @@ from synapse.appservice import ( from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient -from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient): to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, txn_id: Optional[int] = None, ) -> bool: """ @@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: if one_time_key_counts: body[ @@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient): body[ "org.matrix.msc3202.device_unused_fallback_keys" ] = unused_fallback_keys + if device_list_summary: + body["org.matrix.msc3202.device_lists"] = { + "changed": list(device_list_summary.changed), + "left": list(device_list_summary.left), + } try: await self.put_json( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index a6084b9c35..3b49e60716 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -72,7 +72,7 @@ from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore -from synapse.types import JsonDict +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import Clock if TYPE_CHECKING: @@ -122,6 +122,7 @@ class ApplicationServiceScheduler: events: Optional[Collection[EventBase]] = None, ephemeral: Optional[Collection[JsonDict]] = None, to_device_messages: Optional[Collection[JsonDict]] = None, + device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ Enqueue some data to be sent off to an application service. @@ -133,10 +134,18 @@ class ApplicationServiceScheduler: to_device_messages: The to-device messages to send. These differ from normal to-device messages sent to clients, as they have 'to_device_id' and 'to_user_id' fields. + device_list_summary: A summary of users that the application service either needs + to refresh the device lists of, or those that the application service need no + longer track the device lists of. """ # We purposefully allow this method to run with empty events/ephemeral # collections, so that callers do not need to check iterable size themselves. - if not events and not ephemeral and not to_device_messages: + if ( + not events + and not ephemeral + and not to_device_messages + and not device_list_summary + ): return if events: @@ -147,6 +156,10 @@ class ApplicationServiceScheduler: self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend( to_device_messages ) + if device_list_summary: + self.queuer.queued_device_list_summaries.setdefault( + appservice.id, [] + ).append(device_list_summary) # Kick off a new application service transaction self.queuer.start_background_request(appservice) @@ -169,6 +182,8 @@ class _ServiceQueuer: self.queued_ephemeral: Dict[str, List[JsonDict]] = {} # dict of {service_id: [to_device_message_json]} self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} + # dict of {service_id: [device_list_summary]} + self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} # the appservices which currently have a transaction in flight self.requests_in_flight: Set[str] = set() @@ -212,7 +227,35 @@ class _ServiceQueuer: ] del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION] - if not events and not ephemeral and not to_device_messages_to_send: + # Consolidate any pending device list summaries into a single, up-to-date + # summary. + # Note: this code assumes that in a single DeviceListUpdates, a user will + # never be in both "changed" and "left" sets. + device_list_summary = DeviceListUpdates() + for summary in self.queued_device_list_summaries.get(service.id, []): + # For every user in the incoming "changed" set: + # * Remove them from the existing "left" set if necessary + # (as we need to start tracking them again) + # * Add them to the existing "changed" set if necessary. + device_list_summary.left.difference_update(summary.changed) + device_list_summary.changed.update(summary.changed) + + # For every user in the incoming "left" set: + # * Remove them from the existing "changed" set if necessary + # (we no longer need to track them) + # * Add them to the existing "left" set if necessary. + device_list_summary.changed.difference_update(summary.left) + device_list_summary.left.update(summary.left) + self.queued_device_list_summaries.clear() + + if ( + not events + and not ephemeral + and not to_device_messages_to_send + # DeviceListUpdates is True if either the 'changed' or 'left' sets have + # at least one entry, otherwise False + and not device_list_summary + ): return one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None @@ -240,6 +283,7 @@ class _ServiceQueuer: to_device_messages_to_send, one_time_key_counts, unused_fallback_keys, + device_list_summary, ) except Exception: logger.exception("AS request failed") @@ -322,6 +366,7 @@ class _TransactionController: to_device_messages: Optional[List[JsonDict]] = None, one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, + device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -336,6 +381,7 @@ class _TransactionController: appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -345,6 +391,7 @@ class _TransactionController: to_device_messages=to_device_messages or [], one_time_key_counts=one_time_key_counts or {}, unused_fallback_keys=unused_fallback_keys or {}, + device_list_summary=device_list_summary or DeviceListUpdates(), ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 439bfe1526..ada165f238 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -170,6 +170,7 @@ def _load_appservice( # When enabled, appservice transactions contain the following information: # - device One-Time Key counts # - device unused fallback key usage states + # - device list changes msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) if not isinstance(msc3202_transaction_extensions, bool): raise ValueError( diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 064db4487c..43db5fcdd9 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -59,8 +59,9 @@ class ExperimentalConfig(Config): "msc3202_device_masquerading", False ) - # Portion of MSC3202 related to transaction extensions: - # sending one-time key counts and fallback key usage to application services. + # The portion of MSC3202 related to transaction extensions: + # sending device list changes, one-time key counts and fallback key + # usage to application services. self.msc3202_transaction_extensions: bool = experimental.get( "msc3202_transaction_extensions", False ) @@ -77,3 +78,6 @@ class ExperimentalConfig(Config): # The deprecated groups feature. self.groups_enabled: bool = experimental.get("groups_enabled", True) + + # MSC2654: Unread counts + self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False) diff --git a/synapse/config/key.py b/synapse/config/key.py index ee83c6c06b..f5377e7d9c 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,7 +16,7 @@ import hashlib import logging import os -from typing import Any, Dict, Iterator, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional import attr import jsonschema @@ -38,6 +38,9 @@ from synapse.util.stringutils import random_string, random_string_with_symbols from ._base import Config, ConfigError +if TYPE_CHECKING: + from signedjson.key import VerifyKeyWithExpiry + INSECURE_NOTARY_ERROR = """\ Your server is configured to accept key server responses without signature validation or TLS certificate validation. This is likely to be very insecure. If @@ -300,7 +303,7 @@ class KeyConfig(Config): def read_old_signing_keys( self, old_signing_keys: Optional[JsonDict] - ) -> Dict[str, VerifyKey]: + ) -> Dict[str, "VerifyKeyWithExpiry"]: if old_signing_keys is None: return {} keys = {} @@ -308,8 +311,8 @@ class KeyConfig(Config): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) - verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_key.expired_ts = key_data["expired_ts"] + verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment] + verify_key.expired = key_data["expired_ts"] keys[key_id] = verify_key else: raise ConfigError( @@ -422,7 +425,7 @@ def _parse_key_servers( server_name = server["server_name"] result = TrustedKeyServer(server_name=server_name) - verify_keys = server.get("verify_keys") + verify_keys: Optional[Dict[str, str]] = server.get("verify_keys") if verify_keys is not None: result.verify_keys = {} for key_id, key_base64 in verify_keys.items(): diff --git a/synapse/config/server.py b/synapse/config/server.py index 38de4b8000..b3a9e50752 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -680,6 +680,18 @@ class ServerConfig(Config): config.get("use_account_validity_in_account_status") or False ) + # This is a temporary option that enables fully using the new + # `device_lists_changes_in_room` without the backwards compat code. This + # is primarily for testing. If enabled the server should *not* be + # downgraded, as it may lead to missing device list updates. + self.use_new_device_lists_changes_in_room = ( + config.get("use_new_device_lists_changes_in_room") or False + ) + + self.rooms_to_exclude_from_sync: List[str] = ( + config.get("exclude_rooms_from_sync") or [] + ) + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -1234,6 +1246,15 @@ class ServerConfig(Config): # information about using custom templates. # #custom_template_directory: /path/to/custom/templates/ + + # List of rooms to exclude from sync responses. This is useful for server + # administrators wishing to group users into a room without these users being able + # to see it from their client. + # + # By default, no room is excluded. + # + #exclude_rooms_from_sync: + # - !foo:example.com """ % locals() ) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6cf384f6a1..c88afb2986 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -176,7 +176,7 @@ class Keyring: self._local_verify_keys: Dict[str, FetchKeyResult] = {} for key_id, key in hs.config.key.old_signing_keys.items(): self._local_verify_keys[key_id] = FetchKeyResult( - verify_key=key, valid_until_ts=key.expired_ts + verify_key=key, valid_until_ts=key.expired ) vk = get_verify_key(hs.signing_key) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 1ea1bb7d37..98c203ada0 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr -from nacl.signing import SigningKey +from signedjson.types import SigningKey from synapse.api.constants import MAX_DEPTH from synapse.api.room_versions import ( diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index bfca454f51..ef68e20282 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -42,6 +42,7 @@ CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] +ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -169,6 +170,7 @@ class ThirdPartyEventRules: self._on_user_deactivation_status_changed_callbacks: List[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = [] + self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = [] def register_third_party_rules_callbacks( self, @@ -187,6 +189,7 @@ class ThirdPartyEventRules: on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, + on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -221,6 +224,9 @@ class ThirdPartyEventRules: on_user_deactivation_status_changed, ) + if on_threepid_bind is not None: + self._on_threepid_bind_callbacks.append(on_threepid_bind) + async def check_event_allowed( self, event: EventBase, context: EventContext ) -> Tuple[bool, Optional[dict]]: @@ -479,3 +485,23 @@ class ThirdPartyEventRules: logger.exception( "Failed to run module API callback %s: %s", callback, e ) + + async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> None: + """Called after a threepid association has been verified and stored. + + Note that this callback is called when an association is created on the + local homeserver, not when it's created on an identity server (and then kept track + of so that it can be unbound on the same IS later on). + + Args: + user_id: the user being associated with the threepid. + medium: the threepid's medium. + address: the threepid's address. + """ + for callback in self._on_threepid_bind_callbacks: + try: + await callback(user_id, medium, address) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index de6e5f44fe..d4b3cb3e98 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -481,7 +481,7 @@ class TransportLayerClient: if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: - data["limit"] = str(limit) + data["limit"] = limit if since_token: data["since"] = since_token @@ -509,7 +509,7 @@ class TransportLayerClient: if third_party_instance_id: args["third_party_instance_id"] = (third_party_instance_id,) if limit: - args["limit"] = [str(limit)] + args["limit"] = [limit] if since_token: args["since"] = [since_token] diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 177b4f8991..4af9fbc5d1 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import random -from typing import TYPE_CHECKING, Collection, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple from synapse.replication.http.account_data import ( ReplicationAddTagRestServlet, @@ -27,6 +28,12 @@ from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + +ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[ + [str, Optional[str], str, JsonDict], Awaitable +] + class AccountDataHandler: def __init__(self, hs: "HomeServer"): @@ -40,6 +47,44 @@ class AccountDataHandler: self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) self._account_data_writers = hs.config.worker.writers.account_data + self._on_account_data_updated_callbacks: List[ + ON_ACCOUNT_DATA_UPDATED_CALLBACK + ] = [] + + def register_module_callbacks( + self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None + ) -> None: + """Register callbacks from modules.""" + if on_account_data_updated is not None: + self._on_account_data_updated_callbacks.append(on_account_data_updated) + + async def _notify_modules( + self, + user_id: str, + room_id: Optional[str], + account_data_type: str, + content: JsonDict, + ) -> None: + """Notifies modules about new account data changes. + + A change can be either a new account data type being added, or the content + associated with a type being changed. Account data for a given type is removed by + changing the associated content to an empty dictionary. + + Note that this is not called when the tags associated with a room change. + + Args: + user_id: The user whose account data is changing. + room_id: The ID of the room the account data change concerns, if any. + account_data_type: The type of the account data. + content: The content that is now associated with this type. + """ + for callback in self._on_account_data_updated_callbacks: + try: + await callback(user_id, room_id, account_data_type, content) + except Exception as e: + logger.exception("Failed to run module callback %s: %s", callback, e) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: @@ -63,6 +108,8 @@ class AccountDataHandler: "account_data_key", max_stream_id, users=[user_id] ) + await self._notify_modules(user_id, room_id, account_data_type, content) + return max_stream_id else: response = await self._room_data_client( @@ -96,6 +143,9 @@ class AccountDataHandler: self._notifier.on_new_event( "account_data_key", max_stream_id, users=[user_id] ) + + await self._notify_modules(user_id, None, account_data_type, content) + return max_stream_id else: response = await self._user_data_client( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index bd913e524e..316c4b677c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.types import ( + DeviceListUpdates, + JsonDict, + RoomAlias, + RoomStreamToken, + UserID, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure @@ -58,6 +64,9 @@ class ApplicationServicesHandler: self._msc2409_to_device_messages_enabled = ( hs.config.experimental.msc2409_to_device_messages_enabled ) + self._msc3202_transaction_extensions_enabled = ( + hs.config.experimental.msc3202_transaction_extensions + ) self.current_max = 0 self.is_processing = False @@ -204,9 +213,9 @@ class ApplicationServicesHandler: Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key", "presence_key" or - "to_device_key". Any other value for `stream_key` will cause this function - to return early. + `stream_key` can be "typing_key", "receipt_key", "presence_key", + "to_device_key" or "device_list_key". Any other value for `stream_key` + will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into receiving them by setting `push_ephemeral` to true in their registration @@ -230,6 +239,7 @@ class ApplicationServicesHandler: "receipt_key", "presence_key", "to_device_key", + "device_list_key", ): return @@ -253,15 +263,37 @@ class ApplicationServicesHandler: ): return + # Ignore device lists if the feature flag is not enabled + if ( + stream_key == "device_list_key" + and not self._msc3202_transaction_extensions_enabled + ): + return + # Check whether there are any appservices which have registered to receive # ephemeral events. # # Note that whether these events are actually relevant to these appservices # is decided later on. + services = self.store.get_app_services() services = [ service - for service in self.store.get_app_services() - if service.supports_ephemeral + for service in services + # Different stream keys require different support booleans + if ( + stream_key + in ( + "typing_key", + "receipt_key", + "presence_key", + "to_device_key", + ) + and service.supports_ephemeral + ) + or ( + stream_key == "device_list_key" + and service.msc3202_transaction_extensions + ) ] if not services: # Bail out early if none of the target appservices have explicitly registered @@ -336,6 +368,20 @@ class ApplicationServicesHandler: service, "to_device", new_token ) + elif stream_key == "device_list_key": + device_list_summary = await self._get_device_list_summary( + service, new_token + ) + if device_list_summary: + self.scheduler.enqueue_for_appservice( + service, device_list_summary=device_list_summary + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "device_list", new_token + ) + async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: @@ -542,6 +588,96 @@ class ApplicationServicesHandler: return message_payload + async def _get_device_list_summary( + self, + appservice: ApplicationService, + new_key: int, + ) -> DeviceListUpdates: + """ + Retrieve a list of users who have changed their device lists. + + Args: + appservice: The application service to retrieve device list changes for. + new_key: The stream key of the device list change that triggered this method call. + + Returns: + A set of device list updates, comprised of users that the appservices needs to: + * resync the device list of, and + * stop tracking the device list of. + """ + # Fetch the last successfully processed device list update stream ID + # for this appservice. + from_key = await self.store.get_type_stream_id_for_appservice( + appservice, "device_list" + ) + + # Fetch the users who have modified their device list since then. + users_with_changed_device_lists = ( + await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + ) + + # Filter out any users the application service is not interested in + # + # For each user who changed their device list, we want to check whether this + # appservice would be interested in the change. + filtered_users_with_changed_device_lists = { + user_id + for user_id in users_with_changed_device_lists + if await self._is_appservice_interested_in_device_lists_of_user( + appservice, user_id + ) + } + + # Create a summary of "changed" and "left" users. + # TODO: Calculate "left" users. + device_list_summary = DeviceListUpdates( + changed=filtered_users_with_changed_device_lists + ) + + return device_list_summary + + async def _is_appservice_interested_in_device_lists_of_user( + self, + appservice: ApplicationService, + user_id: str, + ) -> bool: + """ + Returns whether a given application service is interested in the device list + updates of a given user. + + The application service is interested in the user's device list updates if any + of the following are true: + * The user is the appservice's sender localpart user. + * The user is in the appservice's user namespace. + * At least one member of one room that the user is a part of is in the + appservice's user namespace. + * The appservice is explicitly (via room ID or alias) interested in at + least one room that the user is in. + + Args: + appservice: The application service to gauge interest of. + user_id: The ID of the user whose device list interest is in question. + + Returns: + True if the application service is interested in the user's device lists, False + otherwise. + """ + # This method checks against both the sender localpart user as well as if the + # user is in the appservice's user namespace. + if appservice.is_interested_in_user(user_id): + return True + + # Determine whether any of the rooms the user is in justifies sending this + # device list update to the application service. + room_ids = await self.store.get_rooms_for_user(user_id) + for room_id in room_ids: + # This method covers checking room members for appservice interest as well as + # room ID and alias checks. + if await appservice.is_interested_in_room(room_id, self.store): + return True + + return False + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3e29c96a49..86991d26ce 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -211,6 +211,7 @@ class AuthHandler: self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.auth.password_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled + self._third_party_rules = hs.get_third_party_event_rules() # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -1505,6 +1506,8 @@ class AuthHandler: user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) + await self._third_party_rules.on_threepid_bind(user_id, medium, address) + async def delete_threepid( self, user_id: str, medium: str, address: str, id_server: Optional[str] = None ) -> bool: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d5ccaa0c37..c710c02cf9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -37,7 +37,10 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import ( JsonDict, StreamToken, @@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) + # Whether `_handle_new_device_update_async` is currently processing. + self._handle_new_device_update_is_processing = False + + # If a new device update may have happened while the loop was + # processing. + self._handle_new_device_update_new_data = False + + # On start up check if there are any updates pending. + hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + + # Used to decide if we calculate outbound pokes up front or not. By + # default we do to allow safely downgrading Synapse. + self.use_new_device_lists_changes_in_room = ( + hs.config.server.use_new_device_lists_changes_in_room + ) + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler): # No changes to notify about, so this is a no-op. return - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id - ) + room_ids = await self.store.get_rooms_for_user(user_id) + + hosts: Optional[Set[str]] = None + if not self.use_new_device_lists_changes_in_room: + hosts = set() - hosts: Set[str] = set() - if self.hs.is_mine_id(user_id): - hosts.update(get_domain_from_id(u) for u in users_who_share_room) - hosts.discard(self.server_name) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + joined_users = await self.store.get_users_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in joined_users) - set_tag("target_hosts", hosts) + set_tag("target_hosts", hosts) + + hosts.discard(self.server_name) position = await self.store.add_device_change_to_streams( - user_id, device_ids, list(hosts) + user_id, + device_ids, + hosts=hosts, + room_ids=room_ids, ) if not position: @@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler): # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - users_to_notify = users_who_share_room.union({user_id}) + self.notifier.on_new_event( + "device_list_key", position, users={user_id}, rooms=room_ids + ) - self.notifier.on_new_event("device_list_key", position, users=users_to_notify) + # We may need to do some processing asynchronously. + self._handle_new_device_update_async() if hosts: logger.info( @@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler): return {"success": True} + @wrap_as_background_process("_handle_new_device_update_async") + async def _handle_new_device_update_async(self) -> None: + """Called when we have a new local device list update that we need to + send out over federation. + + This happens in the background so as not to block the original request + that generated the device update. + """ + if self._handle_new_device_update_is_processing: + self._handle_new_device_update_new_data = True + return + + self._handle_new_device_update_is_processing = True + + # The stream ID we processed previous iteration (if any), and the set of + # hosts we've already poked about for this update. This is so that we + # don't poke the same remote server about the same update repeatedly. + current_stream_id = None + hosts_already_sent_to: Set[str] = set() + + try: + while True: + self._handle_new_device_update_new_data = False + rows = await self.store.get_uncoverted_outbound_room_pokes() + if not rows: + # If the DB returned nothing then there is nothing left to + # do, *unless* a new device list update happened during the + # DB query. + if self._handle_new_device_update_new_data: + continue + else: + return + + for user_id, device_id, room_id, stream_id, opentracing_context in rows: + joined_user_ids = await self.store.get_users_in_room(room_id) + hosts = {get_domain_from_id(u) for u in joined_user_ids} + hosts.discard(self.server_name) + + # Check if we've already sent this update to some hosts + if current_stream_id == stream_id: + hosts -= hosts_already_sent_to + + await self.store.add_device_list_outbound_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + stream_id=stream_id, + hosts=hosts, + context=opentracing_context, + ) + + # Notify replication that we've updated the device list stream. + self.notifier.notify_replication() + + if hosts: + logger.info( + "Sending device list update notif for %r to: %r", + user_id, + hosts, + ) + for host in hosts: + self.federation_sender.send_device_messages( + host, immediate=False + ) + log_kv( + {"message": "sent device update to host", "host": host} + ) + + if current_stream_id != stream_id: + # Clear the set of hosts we've already sent to as we're + # processing a new update. + hosts_already_sent_to.clear() + + hosts_already_sent_to.update(hosts) + current_stream_id = stream_id + + finally: + self._handle_new_device_update_is_processing = False + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 4bd87709f3..e7b9f15e13 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -469,6 +469,12 @@ class FederationEventHandler: if context.rejected: raise SynapseError(400, "Join event was rejected") + # the remote server is responsible for sending our join event to the rest + # of the federation. Indeed, attempting to do so will result in problems + # when we try to look up the state before the join (to get the server list) + # and discover that we do not have it. + event.internal_metadata.proactively_send = False + return await self.persist_events_and_notify(room_id, [(event, context)]) async def backfill( @@ -891,10 +897,24 @@ class FederationEventHandler: logger.debug("We are also missing %i auth events", len(missing_auth_events)) missing_events = missing_desired_events | missing_auth_events - logger.debug("Fetching %i events from remote", len(missing_events)) - await self._get_events_and_persist( - destination=destination, room_id=room_id, event_ids=missing_events - ) + + # Making an individual request for each of 1000s of events has a lot of + # overhead. On the other hand, we don't really want to fetch all of the events + # if we already have most of them. + # + # As an arbitrary heuristic, if we are missing more than 10% of the events, then + # we fetch the whole state. + # + # TODO: might it be better to have an API which lets us do an aggregate event + # request + if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids): + logger.debug("Requesting complete state from remote") + await self._get_state_and_persist(destination, room_id, event_id) + else: + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, event_ids=missing_events + ) # we need to make sure we re-load from the database to get the rejected # state correct. @@ -953,6 +973,27 @@ class FederationEventHandler: return remote_state + async def _get_state_and_persist( + self, destination: str, room_id: str, event_id: str + ) -> None: + """Get the complete room state at a given event, and persist any new events + as outliers""" + room_version = await self._store.get_room_version(room_id) + auth_events, state_events = await self._federation_client.get_room_state( + destination, room_id, event_id=event_id, room_version=room_version + ) + logger.info("/state returned %i events", len(auth_events) + len(state_events)) + + await self._auth_and_persist_outliers( + room_id, itertools.chain(auth_events, state_events) + ) + + # we also need the event itself. + if not await self._store.have_seen_event(room_id, event_id): + await self._get_events_and_persist( + destination=destination, room_id=room_id, event_ids=(event_id,) + ) + async def _process_received_pdu( self, origin: str, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 34d9411bbf..dace31d87e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1625,7 +1625,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # We'll actually pull the presence updates for these users at the end. interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set() - if from_key: + if from_key is not None: # First get all users that have had a presence update updated_users = stream_change_cache.get_all_entities_changed(from_key) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 73217d135d..a36936b520 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +from typing import TYPE_CHECKING, Dict, Iterable, Optional import attr from frozendict import frozendict @@ -25,7 +25,6 @@ from synapse.visibility import filter_events_for_client if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -116,7 +115,7 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - pagination_chunk = await self._main_store.get_relations_for_event( + related_events, next_token = await self._main_store.get_relations_for_event( event_id=event_id, event=event, room_id=room_id, @@ -129,9 +128,7 @@ class RelationsHandler: to_token=to_token, ) - events = await self._main_store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) + events = await self._main_store.get_events_as_list(related_events) events = await filter_events_for_client( self._storage, user_id, events, is_peeking=(member_event_id is None) @@ -152,9 +149,16 @@ class RelationsHandler: events, now, bundle_aggregations=aggregations ) - return_value = await pagination_chunk.to_dict(self._main_store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event + return_value = { + "chunk": serialized_events, + "original_event": original_event, + } + + if next_token: + return_value["next_batch"] = await next_token.to_string(self._main_store) + + if from_token: + return_value["prev_batch"] = await from_token.to_string(self._main_store) return return_value @@ -193,16 +197,21 @@ class RelationsHandler: annotations = await self._main_store.get_aggregation_groups_for_event( event_id, room_id ) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) + if annotations: + aggregations.annotations = {"chunk": annotations} - references = await self._main_store.get_relations_for_event( + references, next_token = await self._main_store.get_relations_for_event( event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) + if references: + aggregations.references = { + "chunk": [{"event_id": event_id} for event_id in references] + } + + if next_token: + aggregations.references["next_batch"] = await next_token.to_string( + self._main_store + ) # Store the bundled aggregations in the event metadata for later use. return aggregations diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 092e185c99..51a08fd2c0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -771,7 +771,9 @@ class RoomCreationHandler: % (user_id,), ) - visibility = config.get("visibility", None) + # The spec says rooms should default to private visibility if + # `visibility` is not specified. + visibility = config.get("visibility", "private") is_public = visibility == "public" room_id = await self._generate_room_id( diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index a0255bd143..78e299d3a5 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -156,8 +156,8 @@ class RoomBatchHandler: ) -> List[str]: """Takes all `state_events_at_start` event dictionaries and creates/persists them in a floating state event chain which don't resolve into the current room - state. They are floating because they reference no prev_events and are marked - as outliers which disconnects them from the normal DAG. + state. They are floating because they reference no prev_events which disconnects + them from the normal DAG. Args: state_events_at_start: @@ -213,31 +213,23 @@ class RoomBatchHandler: room_id=room_id, action=membership, content=event_dict["content"], - # Mark as an outlier to disconnect it from the normal DAG - # and not show up between batches of history. - outlier=True, historical=True, # Only the first event in the state chain should be floating. # The rest should hang off each other in a chain. allow_no_prev_events=index == 0, prev_event_ids=prev_event_ids_for_state_chain, - # Since each state event is marked as an outlier, the - # `EventContext.for_outlier()` won't have any `state_ids` - # set and therefore can't derive any state even though the - # prev_events are set. Also since the first event in the - # state chain is floating with no `prev_events`, it can't - # derive state from anywhere automatically. So we need to - # set some state explicitly. + # The first event in the state chain is floating with no + # `prev_events` which means it can't derive state from + # anywhere automatically. So we need to set some state + # explicitly. # # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. + # reference and also update in the event when we append + # later. state_event_ids=state_event_ids.copy(), ) else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? ( event, _, @@ -246,21 +238,15 @@ class RoomBatchHandler: state_event["sender"], app_service_requester.app_service ), event_dict, - # Mark as an outlier to disconnect it from the normal DAG - # and not show up between batches of history. - outlier=True, historical=True, # Only the first event in the state chain should be floating. # The rest should hang off each other in a chain. allow_no_prev_events=index == 0, prev_event_ids=prev_event_ids_for_state_chain, - # Since each state event is marked as an outlier, the - # `EventContext.for_outlier()` won't have any `state_ids` - # set and therefore can't derive any state even though the - # prev_events are set. Also since the first event in the - # state chain is floating with no `prev_events`, it can't - # derive state from anywhere automatically. So we need to - # set some state explicitly. + # The first event in the state chain is floating with no + # `prev_events` which means it can't derive state from + # anywhere automatically. So we need to set some state + # explicitly. # # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6c569cfb1c..303c38c746 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,17 +13,7 @@ # limitations under the License. import itertools import logging -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Dict, - FrozenSet, - List, - Optional, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( + DeviceListUpdates, JsonDict, MutableStateMap, Requester, @@ -184,21 +175,6 @@ class GroupsSyncResult: return bool(self.join or self.invite or self.leave) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceLists: - """ - Attributes: - changed: List of user_ids whose devices may have changed - left: List of user_ids whose devices we no longer track - """ - - changed: Collection[str] - left: Collection[str] - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - @attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined @@ -240,7 +216,7 @@ class SyncResult: knocked: List[KnockedSyncResult] archived: List[ArchivedSyncResult] to_device: List[JsonDict] - device_lists: DeviceLists + device_lists: DeviceListUpdates device_one_time_keys_count: JsonDict device_unused_fallback_key_types: List[str] groups: Optional[GroupsSyncResult] @@ -298,6 +274,8 @@ class SyncHandler: expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) + self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + async def wait_for_sync_for_user( self, requester: Requester, @@ -1262,8 +1240,8 @@ class SyncHandler: newly_joined_or_invited_or_knocked_users: Set[str], newly_left_rooms: Set[str], newly_left_users: Set[str], - ) -> DeviceLists: - """Generate the DeviceLists section of sync + ) -> DeviceListUpdates: + """Generate the DeviceListUpdates section of sync Args: sync_result_builder @@ -1381,9 +1359,11 @@ class SyncHandler: if any(e.room_id in joined_rooms for e in entries): newly_left_users.discard(user_id) - return DeviceLists(changed=users_that_have_changed, left=newly_left_users) + return DeviceListUpdates( + changed=users_that_have_changed, left=newly_left_users + ) else: - return DeviceLists(changed=[], left=[]) + return DeviceListUpdates() async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" @@ -1607,13 +1587,15 @@ class SyncHandler: ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( - sync_result_builder, ignored_users + sync_result_builder, ignored_users, self.rooms_to_exclude ) tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) + room_changes = await self._get_all_rooms( + sync_result_builder, ignored_users, self.rooms_to_exclude + ) tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1689,7 +1671,10 @@ class SyncHandler: return False async def _get_rooms_changed( - self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] + self, + sync_result_builder: "SyncResultBuilder", + ignored_users: FrozenSet[str], + excluded_rooms: List[str], ) -> _RoomChanges: """Determine the changes in rooms to report to the user. @@ -1721,7 +1706,7 @@ class SyncHandler: # _have_rooms_changed. We could keep the results in memory to avoid a # second query, at the cost of more complicated source code. membership_change_events = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key + user_id, since_token.room_key, now_token.room_key, excluded_rooms ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} @@ -1922,7 +1907,10 @@ class SyncHandler: ) async def _get_all_rooms( - self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] + self, + sync_result_builder: "SyncResultBuilder", + ignored_users: FrozenSet[str], + ignored_rooms: List[str], ) -> _RoomChanges: """Returns entries for all rooms for the user. @@ -1933,7 +1921,7 @@ class SyncHandler: Args: sync_result_builder ignored_users: Set of users ignored by user. - + ignored_rooms: List of rooms to ignore. """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1944,6 +1932,7 @@ class SyncHandler: room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=Membership.LIST, + excluded_rooms=ignored_rooms, ) room_entries = [] diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ba9755f08b..9a61593ff5 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -62,8 +62,10 @@ from synapse.events.third_party_rules import ( ON_CREATE_ROOM_CALLBACK, ON_NEW_EVENT_CALLBACK, ON_PROFILE_UPDATE_CALLBACK, + ON_THREEPID_BIND_CALLBACK, ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) +from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK from synapse.handlers.account_validity import ( IS_USER_EXPIRED_CALLBACK, ON_LEGACY_ADMIN_REQUEST, @@ -215,6 +217,7 @@ class ModuleApi: self._third_party_event_rules = hs.get_third_party_event_rules() self._password_auth_provider = hs.get_password_auth_provider() self._presence_router = hs.get_presence_router() + self._account_data_handler = hs.get_account_data_handler() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -293,6 +296,7 @@ class ModuleApi: on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, + on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -308,6 +312,7 @@ class ModuleApi: check_can_deactivate_user=check_can_deactivate_user, on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, + on_threepid_bind=on_threepid_bind, ) def register_presence_router_callbacks( @@ -373,6 +378,19 @@ class ModuleApi: min_batch_size=min_batch_size, ) + def register_account_data_callbacks( + self, + *, + on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None, + ) -> None: + """Registers account data callbacks. + + Added in Synapse 1.57.0. + """ + return self._account_data_handler.register_module_callbacks( + on_account_data_updated=on_account_data_updated, + ) + def register_web_resource(self, path: str, resource: Resource) -> None: """Registers a web resource to be served at the given path. @@ -512,6 +530,17 @@ class ModuleApi: """ return await self._store.is_server_admin(UserID.from_string(user_id)) + async def set_user_admin(self, user_id: str, admin: bool) -> None: + """Sets if a user is a server admin. + + Added in Synapse v1.56.0. + + Args: + user_id: The Matrix ID of the user to set admin status for. + admin: True iff the user is to be a server admin, false otherwise. + """ + await self._store.set_server_admin(UserID.from_string(user_id), admin) + def get_qualified_user_id(self, username: str) -> str: """Qualify a user id, if necessary diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py deleted file mode 100644 index 14706a0817..0000000000 --- a/synapse/replication/slave/storage/client_ips.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.lrucache import LruCache - -from ._base import BaseSlavedStore - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SlavedClientIpStore(BaseSlavedStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self.client_ip_last_seen: LruCache[tuple, int] = LruCache( - cache_name="client_ip_last_seen", max_size=50000 - ) - - async def insert_client_ip( - self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str - ) -> None: - now = int(self._clock.time_msec()) - key = (user_id, access_token, ip) - - try: - last_seen = self.client_ip_last_seen.get(key) - except KeyError: - last_seen = None - - # Rate-limited inserts - if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: - return - - self.client_ip_last_seen.set(key, now) - - self.hs.get_replication_command_handler().send_user_ip( - user_id, access_token, ip, user_agent, device_id, now - ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 0ffd34f1da..30717c2bd0 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -20,7 +20,6 @@ from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatu from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore -from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: from synapse.server import HomeServer @@ -33,8 +32,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): - super().__init__(database, db_conn, hs) - self.hs = hs self._device_list_id_gen = SlavedIdTracker( @@ -44,18 +41,11 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) - device_list_max = self._device_list_id_gen.get_current_token() - self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max - ) - self._user_signature_stream_cache = StreamChangeCache( - "UserSignatureStreamChangeCache", device_list_max - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max - ) + + super().__init__(database, db_conn, hs) def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 3654f6c03c..fe34948168 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -356,7 +356,7 @@ class UserIpCommand(Command): access_token: str, ip: str, user_agent: str, - device_id: str, + device_id: Optional[str], last_seen: int, ): self.user_id = user_id @@ -389,6 +389,12 @@ class UserIpCommand(Command): ) ) + def __repr__(self) -> str: + return ( + f"UserIpCommand({self.user_id!r}, .., {self.ip!r}, " + f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})" + ) + class RemoteServerUpCommand(_SimpleCommand): """Sent when a worker has detected that a remote server is no longer diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b217c35f99..615f1828dd 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -235,6 +235,14 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + if hs.config.redis.redis_enabled: + # If we're using Redis, it's the background worker that should + # receive USER_IP commands and store the relevant client IPs. + self._should_insert_client_ips = hs.config.worker.run_background_tasks + else: + # If we're NOT using Redis, this must be handled by the master + self._should_insert_client_ips = hs.get_instance_name() == "master" + def _add_command_to_stream_queue( self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] ) -> None: @@ -401,23 +409,37 @@ class ReplicationCommandHandler: ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() - if self._is_master: + if self._is_master or self._should_insert_client_ips: + # We make a point of only returning an awaitable if there's actually + # something to do; on_USER_IP is not an async function, but + # _handle_user_ip is. + # If on_USER_IP returns an awaitable, it gets scheduled as a + # background process (see `BaseReplicationStreamProtocol.handle_command`). return self._handle_user_ip(cmd) else: + # Returning None when this process definitely has nothing to do + # reduces the overhead of handling the USER_IP command, which is + # currently broadcast to all workers regardless of utility. return None async def _handle_user_ip(self, cmd: UserIpCommand) -> None: - await self._store.insert_client_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - assert self._server_notices_sender is not None - await self._server_notices_sender.on_user_ip(cmd.user_id) + """ + Handles a User IP, branching depending on whether we are the main process + and/or the background worker. + """ + if self._is_master: + assert self._server_notices_sender is not None + await self._server_notices_sender.on_user_ip(cmd.user_id) + + if self._should_insert_client_ips: + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None: if cmd.instance_name == self._instance_name: @@ -698,7 +720,7 @@ class ReplicationCommandHandler: access_token: str, ip: str, user_agent: str, - device_id: str, + device_id: Optional[str], last_seen: int, ) -> None: """Tell the master that the user made a request.""" diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index c16078b187..55c96a2af3 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -12,22 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This class implements the proposed relation APIs from MSC 1849. - -Since the MSC has not been approved all APIs here are unstable and may change at -any time to reflect changes in the MSC. -""" - import logging from typing import TYPE_CHECKING, Optional, Tuple -from synapse.api.constants import RelationTypes -from synapse.api.errors import SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import AggregationPaginationToken from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -93,166 +84,5 @@ class RelationPaginationServlet(RestServlet): return 200, result -class RelationAggregationPaginationServlet(RestServlet): - """API to paginate aggregation groups of relations, e.g. paginate the - types and counts of the reactions on the events. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id} - - { - chunk: [ - { - "type": "m.reaction", - "key": "👍", - "count": 3 - } - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" - "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", - releases=(), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.event_handler = hs.get_event_handler() - - async def on_GET( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: Optional[str] = None, - event_type: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - - if relation_type not in (RelationTypes.ANNOTATION, None): - raise SynapseError( - 400, f"Relation type must be '{RelationTypes.ANNOTATION}'" - ) - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - room_id=room_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - return 200, await pagination_chunk.to_dict(self.store) - - -class RelationAggregationGroupPaginationServlet(RestServlet): - """API to paginate within an aggregation group of relations, e.g. paginate - all the 👍 reactions on an event. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 - - { - chunk: [ - { - "type": "m.reaction", - "content": { - "m.relates_to": { - "rel_type": "m.annotation", - "key": "👍" - } - } - }, - ... - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" - "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$", - releases=(), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self._relations_handler = hs.get_relations_handler() - - async def on_GET( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - key: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if relation_type != RelationTypes.ANNOTATION: - raise SynapseError(400, "Relation type must be 'annotation'") - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) - - result = await self._relations_handler.get_relations( - requester=requester, - event_id=parent_id, - room_id=room_id, - relation_type=relation_type, - event_type=event_type, - aggregation_key=key, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - return 200, result - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) - RelationAggregationPaginationServlet(hs).register(http_server) - RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 0780485322..dd91dabedd 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -123,6 +123,19 @@ class RoomBatchSendEventRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) + # Make sure that the prev_event_ids exist and aren't outliers - ie, they are + # regular parts of the room DAG where we know the state. + non_outlier_prev_events = await self.store.have_events_in_timeline( + prev_event_ids_from_query + ) + for prev_event_id in prev_event_ids_from_query: + if prev_event_id not in non_outlier_prev_events: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "prev_event %s does not exist, or is an outlier" % (prev_event_id,), + errcode=Codes.INVALID_PARAM, + ) + # For the event we are inserting next to (`prev_event_ids_from_query`), # find the most recent state events that allowed that message to be # sent. We will use that as a base to auth our historical messages @@ -131,14 +144,6 @@ class RoomBatchSendEventRestServlet(RestServlet): prev_event_ids_from_query ) - if not state_event_ids: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist." - % prev_event_ids_from_query, - errcode=Codes.INVALID_PARAM, - ) - state_event_ids_at_start = [] # Create and persist all of the state events that float off on their own # before the batch. These will most likely be all of the invite/member diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 53c385a86c..0bf32f873b 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -99,6 +99,7 @@ class SyncRestServlet(RestServlet): self.presence_handler = hs.get_presence_handler() self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() + self._msc2654_enabled = hs.config.experimental.msc2654_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: # This will always be set by the time Twisted calls us. @@ -521,7 +522,8 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count + if self._msc2654_enabled: + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index b9bfbea21b..0c9f042c84 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -76,17 +76,17 @@ class LocalKey(Resource): def response_json_object(self) -> JsonDict: verify_keys = {} - for key in self.config.key.signing_key: - verify_key_bytes = key.verify_key.encode() - key_id = "%s:%s" % (key.alg, key.version) + for signing_key in self.config.key.signing_key: + verify_key_bytes = signing_key.verify_key.encode() + key_id = "%s:%s" % (signing_key.alg, signing_key.version) verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)} old_verify_keys = {} - for key_id, key in self.config.key.old_signing_keys.items(): - verify_key_bytes = key.encode() + for key_id, old_signing_key in self.config.key.old_signing_keys.items(): + verify_key_bytes = old_signing_key.encode() old_verify_keys[key_id] = { "key": encode_base64(verify_key_bytes), - "expired_ts": key.expired_ts, + "expired_ts": old_signing_key.expired, } json_object = { diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 3525d6ae54..f597157581 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Set from signedjson.sign import sign_json @@ -149,7 +149,7 @@ class RemoteKey(DirectServeJsonResource): cached = await self.store.get_server_keys_json(store_queries) - json_results = set() + json_results: Set[bytes] = set() time_now_ms = self.clock.time_msec() @@ -234,8 +234,8 @@ class RemoteKey(DirectServeJsonResource): await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] - for key_json in json_results: - key_json = json_decoder.decode(key_json.decode("utf-8")) + for key_json_raw in json_results: + key_json = json_decoder.decode(key_json_raw.decode("utf-8")) for signing_key in self.config.key.key_server_signing_keys: key_json = sign_json( key_json, self.config.server.server_name, signing_key diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index d47af8ead6..50383bdbd1 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -200,12 +200,17 @@ class PreviewUrlResource(DirectServeJsonResource): match = False continue + # Some attributes might not be parsed as strings by urlsplit (such as the + # port, which is parsed as an int). Because we use match functions that + # expect strings, we want to make sure that's what we give them. + value_str = str(value) + if pattern.startswith("^"): - if not re.match(pattern, getattr(url_tuple, attrib)): + if not re.match(pattern, value_str): match = False continue else: - if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern): + if not fnmatch.fnmatch(value_str, pattern): match = False continue if match: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 3ef2bdd74b..12750d9b89 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -241,9 +241,17 @@ class LoggingTransaction: self.exception_callbacks = exception_callbacks def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): - """Call the given callback on the main twisted thread after the - transaction has finished. Used to invalidate the caches on the - correct thread. + """Call the given callback on the main twisted thread after the transaction has + finished. + + Mostly used to invalidate the caches on the correct thread. + + Note that transactions may be retried a few times if they encounter database + errors such as serialization failures. Callbacks given to `call_after` + will accumulate across transaction attempts and will _all_ be called once a + transaction attempt succeeds, regardless of whether previous transaction + attempts failed. Otherwise, if all transaction attempts fail, all + `call_on_exception` callbacks will be run instead. """ # if self.after_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that @@ -254,6 +262,15 @@ class LoggingTransaction: def call_on_exception( self, callback: Callable[..., object], *args: Any, **kwargs: Any ): + """Call the given callback on the main twisted thread after the transaction has + failed. + + Note that transactions may be retried a few times if they encounter database + errors such as serialization failures. Callbacks given to `call_on_exception` + will accumulate across transaction attempts and will _all_ be called once the + final transaction attempt fails. No `call_on_exception` callbacks will be run + if any transaction attempt succeeds. + """ # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. @@ -2013,29 +2030,40 @@ class DatabasePool: max_value: int, limit: int = 100000, ) -> Tuple[Dict[Any, int], int]: - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - %(limit)s" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - "limit": limit, - } + """Gets roughly the last N changes in the given stream table as a + map from entity to the stream ID of the most recent change. + + Also returns the minimum stream ID. + """ + + # This may return many rows for the same entity, but the `limit` is only + # a suggestion so we don't care that much. + # + # Note: Some stream tables can have multiple rows with the same stream + # ID. Instead of handling this with complicated SQL, we instead simply + # add one to the returned minimum stream ID to ensure correctness. + sql = f""" + SELECT {entity_column}, {stream_column} + FROM {table} + ORDER BY {stream_column} DESC + LIMIT ? + """ txn = db_conn.cursor(txn_name="get_cache_dict") - txn.execute(sql, (int(max_value),)) + txn.execute(sql, (limit,)) - cache = {row[0]: int(row[1]) for row in txn} + # The rows come out in reverse stream ID order, so we want to keep the + # stream ID of the first row for each entity. + cache: Dict[Any, int] = {} + for row in txn: + cache.setdefault(row[0], int(row[1])) txn.close() if cache: - min_val = min(cache.values()) + # We add one here as we don't know if we have all rows for the + # minimum stream ID. + min_val = min(cache.values()) + 1 else: min_val = max_value diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index f024761ba7..951031af50 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -33,7 +33,7 @@ from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore from .cache import CacheInvalidationWorkerStore from .censor_events import CensorEventsStore -from .client_ips import ClientIpStore +from .client_ips import ClientIpWorkerStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore from .directory import DirectoryStore @@ -49,7 +49,7 @@ from .keys import KeyStore from .lock import LockStore from .media_repository import MediaRepositoryStore from .metrics import ServerMetricsStore -from .monthly_active_users import MonthlyActiveUsersStore +from .monthly_active_users import MonthlyActiveUsersWorkerStore from .openid import OpenIdStore from .presence import PresenceStore from .profile import ProfileStore @@ -112,13 +112,13 @@ class DataStore( AccountDataStore, EventPushActionsStore, OpenIdStore, - ClientIpStore, + ClientIpWorkerStore, DeviceStore, DeviceInboxStore, UserDirectoryStore, GroupServerStore, UserErasureStore, - MonthlyActiveUsersStore, + MonthlyActiveUsersWorkerStore, StatsStore, RelationsStore, CensorEventsStore, @@ -146,6 +146,7 @@ class DataStore( extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) @@ -182,17 +183,6 @@ class DataStore( super().__init__(database, db_conn, hs) - device_list_max = self._device_list_id_gen.get_current_token() - self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max - ) - self._user_signature_stream_cache = StreamChangeCache( - "UserSignatureStreamChangeCache", device_list_max - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max - ) - events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 0694446558..eb32c34a85 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -29,7 +29,9 @@ from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -72,6 +74,22 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + def get_max_as_txn_id(txn: Cursor) -> int: + logger.warning("Falling back to slow query, you should port to postgres") + txn.execute( + "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns" + ) + return txn.fetchone()[0] # type: ignore + + self._as_txn_seq_gen = build_sequence_generator( + db_conn, + database.engine, + get_max_as_txn_id, + "application_services_txn_id_seq", + table="application_services_txns", + id_column="txn_id", + ) + super().__init__(database, db_conn, hs) def get_app_services(self): @@ -217,6 +235,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -231,27 +250,14 @@ class ApplicationServiceTransactionWorkerStore( appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. Returns: A new transaction. """ def _create_appservice_txn(txn): - # work out new txn id (highest txn id for this service += 1) - # The highest id may be the last one sent (in which case it is last_txn) - # or it may be the highest in the txns list (which are waiting to be/are - # being sent) - last_txn_id = self._get_last_txn(txn, service.id) - - txn.execute( - "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", - (service.id,), - ) - highest_txn_id = txn.fetchone()[0] - if highest_txn_id is None: - highest_txn_id = 0 - - new_txn_id = max(highest_txn_id, last_txn_id) + 1 + new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn) # Insert new txn into txn table event_ids = json_encoder.encode([e.event_id for e in events]) @@ -268,6 +274,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=to_device_messages, one_time_key_counts=one_time_key_counts, unused_fallback_keys=unused_fallback_keys, + device_list_summary=device_list_summary, ) return await self.db_pool.runInteraction( @@ -283,25 +290,8 @@ class ApplicationServiceTransactionWorkerStore( txn_id: The transaction ID being completed. service: The application service which was sent this transaction. """ - txn_id = int(txn_id) def _complete_appservice_txn(txn): - # Debugging query: Make sure the txn being completed is EXACTLY +1 from - # what was there before. If it isn't, we've got problems (e.g. the AS - # has probably missed some events), so whine loudly but still continue, - # since it shouldn't fail completion of the transaction. - last_txn_id = self._get_last_txn(txn, service.id) - if (last_txn_id + 1) != txn_id: - logger.error( - "appservice: Completing a transaction which has an ID > 1 from " - "the last ID sent to this AS. We've either dropped events or " - "sent it to the AS out of order. FIX ME. last_txn=%s " - "completing_txn=%s service_id=%s", - last_txn_id, - txn_id, - service.id, - ) - # Set current txn_id for AS to 'txn_id' self.db_pool.simple_upsert_txn( txn, @@ -359,8 +349,8 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - # TODO: to-device messages, one-time key counts and unused fallback keys - # are not yet populated for catch-up transactions. + # TODO: to-device messages, one-time key counts, device list summaries and unused + # fallback keys are not yet populated for catch-up transactions. # We likely want to populate those for reliability. return AppServiceTransaction( service=service, @@ -370,19 +360,9 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) - def _get_last_txn(self, txn, service_id: Optional[str]) -> int: - txn.execute( - "SELECT last_txn FROM application_services_state WHERE as_id=?", - (service_id,), - ) - last_txn_id = txn.fetchone() - if last_txn_id is None or last_txn_id[0] is None: # no row exists - return 0 - else: - return int(last_txn_id[0]) # select 'last_txn' col - async def set_appservice_last_pos(self, pos: int) -> None: def set_appservice_last_pos_txn(txn): txn.execute( @@ -430,7 +410,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence", "to_device"): + if type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -446,7 +426,8 @@ class ApplicationServiceTransactionWorkerStore( ) last_stream_id = txn.fetchone() if last_stream_id is None or last_stream_id[0] is None: # no row exists - return 0 + # Stream tokens always start from 1, to avoid foot guns around `0` being falsey. + return 1 else: return int(last_stream_id[0]) @@ -457,7 +438,7 @@ class ApplicationServiceTransactionWorkerStore( async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence", "to_device"): + if stream_type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 8b0c614ece..8480ea4e1c 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -25,7 +25,9 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore +from synapse.storage.databases.main.monthly_active_users import ( + MonthlyActiveUsersWorkerStore, +) from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache @@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return updated -class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): +class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore): def __init__( self, database: DatabasePool, @@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): ): super().__init__(database, db_conn, hs) + if hs.config.redis.redis_enabled: + # If we're using Redis, we can shift this update process off to + # the background worker + self._update_on_this_worker = hs.config.worker.run_background_tasks + else: + # If we're NOT using Redis, this must be handled by the master + self._update_on_this_worker = hs.get_instance_name() == "master" + self.user_ips_max_age = hs.config.server.user_ips_max_age + # (user_id, access_token, ip,) -> last_seen + self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( + cache_name="client_ip_last_seen", max_size=50000 + ) + if hs.config.worker.run_background_tasks and self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + if self._update_on_this_worker: + # This is the designated worker that can write to the client IP + # tables. + + # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) + self._batch_row_update: Dict[ + Tuple[str, str, str], Tuple[str, Optional[str], int] + ] = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + self.hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", self._update_client_ips_batch + ) + @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self) -> None: """Removes entries in user IPs older than the configured period.""" @@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): "_prune_old_user_ips", _prune_old_user_ips_txn ) - async def get_last_client_ip_by_device( + async def _get_last_client_ip_by_device_from_database( self, user_id: str, device_id: Optional[str] ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on. @@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): return {(d["user_id"], d["device_id"]): d for d in res} - async def get_user_ip_and_agents( + async def _get_user_ip_and_agents_from_database( self, user: UserID, since_ts: int = 0 ) -> List[LastConnectionInfo]: """Fetch the IPs and user agents for a user since the given timestamp. @@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): for access_token, ip, user_agent, last_seen in rows ] - -class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - - # (user_id, access_token, ip,) -> last_seen - self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( - cache_name="client_ip_last_seen", max_size=50000 - ) - - super().__init__(database, db_conn, hs) - - # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update: Dict[ - Tuple[str, str, str], Tuple[str, Optional[str], int] - ] = {} - - self._client_ip_looper = self._clock.looping_call( - self._update_client_ips_batch, 5 * 1000 - ) - self.hs.get_reactor().addSystemEventTrigger( - "before", "shutdown", self._update_client_ips_batch - ) - async def insert_client_ip( self, user_id: str, @@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - await self.populate_monthly_active_users(user_id) + # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return self.client_ip_last_seen.set(key, now) - self._batch_row_update[key] = (user_agent, device_id, now) + if self._update_on_this_worker: + await self.populate_monthly_active_users(user_id) + self._batch_row_update[key] = (user_agent, device_id, now) + else: + # We are not the designated writer-worker, so stream over replication + self.hs.get_replication_command_handler().send_user_ip( + user_id, access_token, ip, user_agent, device_id, now + ) @wrap_as_background_process("update_client_ips") async def _update_client_ips_batch(self) -> None: + assert ( + self._update_on_this_worker + ), "This worker is not designated to update client IPs" # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -612,6 +625,10 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): txn: LoggingTransaction, to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], ) -> None: + assert ( + self._update_on_this_worker + ), "This worker is not designated to update client IPs" + if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): @@ -662,7 +679,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): A dictionary mapping a tuple of (user_id, device_id) to dicts, with keys giving the column names from the devices table. """ - ret = await super().get_last_client_ip_by_device(user_id, device_id) + ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id) + + if not self._update_on_this_worker: + # Only the writing-worker has additional in-memory data to enhance + # the result + return ret # Update what is retrieved from the database with data which is pending # insertion, as if it has already been stored in the database. @@ -707,9 +729,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): Only the latest user agent for each access token and IP address combination is available. """ + rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts) + + if not self._update_on_this_worker: + # Only the writing-worker has additional in-memory data to enhance + # the result + return rows_from_db + results: Dict[Tuple[str, str], LastConnectionInfo] = { (connection["access_token"], connection["ip"]): connection - for connection in await super().get_user_ip_and_agents(user, since_ts) + for connection in rows_from_db } # Overlay data that is pending insertion on top of the results from the diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3b3a089b76..dc8009b23d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) + device_list_max = self._device_list_id_gen.get_current_token() + device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict( + db_conn, + "device_lists_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", + min_device_list_id, + prefilled_cache=device_list_prefill, + ) + + ( + user_signature_stream_prefill, + user_signature_stream_list_id, + ) = self.db_pool.get_cache_dict( + db_conn, + "user_signature_stream", + entity_column="from_user_id", + stream_column="stream_id", + max_value=device_list_max, + limit=1000, + ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", + user_signature_stream_list_id, + prefilled_cache=user_signature_stream_prefill, + ) + + ( + device_list_federation_prefill, + device_list_federation_list_id, + ) = self.db_pool.get_cache_dict( + db_conn, + "device_lists_outbound_pokes", + entity_column="destination", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", + device_list_federation_list_id, + prefilled_cache=device_list_federation_prefill, + ) + if hs.config.worker.run_background_tasks: self._clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 @@ -681,42 +731,64 @@ class DeviceWorkerStore(SQLBaseStore): return self._device_list_stream_cache.get_all_entities_changed(from_key) async def get_users_whose_devices_changed( - self, from_key: int, user_ids: Iterable[str] + self, + from_key: int, + user_ids: Optional[Iterable[str]] = None, + to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key: The device lists stream token - user_ids: The user IDs to query for devices. + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + user_ids: If provided, only check if these users have changed their device lists. + Otherwise changes from all users are returned. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. Returns: - The set of user_ids whose devices have changed since `from_key` + The set of user_ids whose devices have changed since `from_key` (exclusive) + until `to_key` (inclusive). """ - # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + if user_ids is None: + # Get set of all users that have had device list changes since 'from_key' + user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( + from_key + ) + else: + # The same as above, but filter results to only those users in 'user_ids' + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) - if not to_check: + if not user_ids_to_check: return set() def _get_users_whose_devices_changed_txn(txn): changes = set() - sql = """ + stream_id_where_clause = "stream_id > ?" + sql_args = [from_key] + + if to_key: + stream_id_where_clause += " AND stream_id <= ?" + sql_args.append(to_key) + + sql = f""" SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? + WHERE {stream_id_where_clause} AND """ - for chunk in batch_iter(to_check, 100): + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk ) - txn.execute(sql + clause, (from_key,) + tuple(args)) + txn.execute(sql + clause, sql_args + args) changes.update(user_id for user_id, in txn) return changes @@ -788,6 +860,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? """ @@ -1506,7 +1579,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + self, + user_id: str, + device_ids: Collection[str], + hosts: Optional[Collection[str]], + room_ids: Collection[str], ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -1515,7 +1592,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: The ID of the user whose device changed. device_ids: The IDs of any changed devices. If empty, this function will return None. - hosts: The remote destinations that should be notified of the change. + hosts: The remote destinations that should be notified of the change. If + None then the set of hosts have *not* been calculated, and will be + calculated later by a background task. + room_ids: The rooms that the user is in Returns: The maximum stream ID of device list updates that were added to the database, or @@ -1524,34 +1604,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return None - async with self._device_list_id_gen.get_next_mult( - len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_change_to_stream", - self._add_device_change_to_stream_txn, + context = get_active_span_text_map() + + def add_device_changes_txn( + txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes + ): + self._add_device_change_to_stream_txn( + txn, user_id, device_ids, - stream_ids, + stream_ids_for_device_change, ) - if not hosts: - return stream_ids[-1] + self._add_device_outbound_room_poke_txn( + txn, + user_id, + device_ids, + room_ids, + stream_ids_for_device_change, + context, + hosts_have_been_calculated=hosts is not None, + ) - context = get_active_span_text_map() - async with self._device_list_id_gen.get_next_mult( - len(hosts) * len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_outbound_poke_to_stream", - self._add_device_outbound_poke_to_stream_txn, + # If the set of hosts to send to has not been calculated yet (and so + # `hosts` is None) or there are no `hosts` to send to, then skip + # trying to persist them to the DB. + if not hosts: + return + + self._add_device_outbound_poke_to_stream_txn( + txn, user_id, device_ids, hosts, - stream_ids, + stream_ids_for_outbound_pokes, context, ) + # `device_lists_stream` wants a stream ID per device update. + num_stream_ids = len(device_ids) + + if hosts: + # `device_lists_outbound_pokes` wants a different stream ID for + # each row, which is a row per host per device update. + num_stream_ids += len(hosts) * len(device_ids) + + async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids: + stream_ids_for_device_change = stream_ids[: len(device_ids)] + stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :] + + await self.db_pool.runInteraction( + "add_device_change_to_stream", + add_device_changes_txn, + stream_ids_for_device_change, + stream_ids_for_outbound_pokes, + ) + return stream_ids[-1] def _add_device_change_to_stream_txn( @@ -1595,7 +1703,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_ids: Iterable[str], hosts: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: for host in hosts: @@ -1606,8 +1714,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) now = self._clock.time_msec() - next_stream_id = iter(stream_ids) + stream_id_iterator = iter(stream_ids) + encoded_context = json_encoder.encode(context) self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1623,16 +1732,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ ( destination, - next(next_stream_id), + next(stream_id_iterator), user_id, device_id, False, now, - json_encoder.encode(context) - if whitelisted_homeserver(destination) - else "{}", + encoded_context if whitelisted_homeserver(destination) else "{}", ) for destination in hosts for device_id in device_ids ], ) + + def _add_device_outbound_room_poke_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Iterable[str], + room_ids: Collection[str], + stream_ids: List[str], + context: Dict[str, str], + hosts_have_been_calculated: bool, + ) -> None: + """Record the user in the room has updated their device. + + Args: + hosts_have_been_calculated: True if `device_lists_outbound_pokes` + has been updated already with the updates. + """ + + # We only need to convert to outbound pokes if they are our user. + converted_to_destinations = ( + hosts_have_been_calculated or not self.hs.is_mine_id(user_id) + ) + + encoded_context = json_encoder.encode(context) + + # The `device_lists_changes_in_room.stream_id` column matches the + # corresponding `stream_id` of the update in the `device_lists_stream` + # table, i.e. all rows persisted for the same device update will have + # the same `stream_id` (but different room IDs). + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_changes_in_room", + keys=( + "user_id", + "device_id", + "room_id", + "stream_id", + "converted_to_destinations", + "opentracing_context", + ), + values=[ + ( + user_id, + device_id, + room_id, + stream_id, + converted_to_destinations, + encoded_context, + ) + for room_id in room_ids + for device_id, stream_id in zip(device_ids, stream_ids) + ], + ) + + async def get_uncoverted_outbound_room_pokes( + self, limit: int = 10 + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: + """Get device list changes by room that have not yet been handled and + written to `device_lists_outbound_pokes`. + + Returns: + A list of user ID, device ID, room ID, stream ID and optional opentracing context. + """ + + sql = """ + SELECT user_id, device_id, room_id, stream_id, opentracing_context + FROM device_lists_changes_in_room + WHERE NOT converted_to_destinations + ORDER BY stream_id + LIMIT ? + """ + + def get_uncoverted_outbound_room_pokes_txn(txn): + txn.execute(sql, (limit,)) + return txn.fetchall() + + return await self.db_pool.runInteraction( + "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn + ) + + async def add_device_list_outbound_pokes( + self, + user_id: str, + device_id: str, + room_id: str, + stream_id: int, + hosts: Collection[str], + context: Optional[Dict[str, str]], + ) -> None: + """Queue the device update to be sent to the given set of hosts, + calculated from the room ID. + + Marks the associated row in `device_lists_changes_in_room` as handled. + """ + + def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + if hosts: + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_ids=[device_id], + hosts=hosts, + stream_ids=stream_ids, + context=context, + ) + + self.db_pool.simple_update_txn( + txn, + table="device_lists_changes_in_room", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "room_id": room_id, + }, + updatevalues={"converted_to_destinations": True}, + ) + + if not hosts: + # If there are no hosts then we don't try and generate stream IDs. + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + [], + ) + + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + stream_ids, + ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 59454a47df..a60e3f4fdd 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -22,7 +22,6 @@ from typing import ( Dict, Iterable, List, - NoReturn, Optional, Set, Tuple, @@ -1330,10 +1329,9 @@ class EventsWorkerStore(SQLBaseStore): return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: - # this only exists for the benefit of the @cachedList descriptor on - # _have_seen_events_dict - raise NotImplementedError() + async def have_seen_event(self, room_id: str, event_id: str) -> bool: + res = await self._have_seen_events_dict(((room_id, event_id),)) + return res[(room_id, event_id)] def _get_current_state_event_counts_txn( self, txn: LoggingTransaction, room_id: str diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 216622964a..4f1c22c71b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -15,7 +15,6 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -36,7 +35,7 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 60 * 60 * 1000 -class MonthlyActiveUsersWorkerStore(SQLBaseStore): +class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): def __init__( self, database: DatabasePool, @@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs + if hs.config.redis.redis_enabled: + # If we're using Redis, we can shift this update process off to + # the background worker + self._update_on_this_worker = hs.config.worker.run_background_tasks + else: + # If we're NOT using Redis, this must be handled by the master + self._update_on_this_worker = hs.get_instance_name() == "master" + self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau self._max_mau_value = hs.config.server.max_mau_value + self._mau_stats_only = hs.config.server.mau_stats_only + + if self._update_on_this_worker: + # Do not add more reserved users than the total allowable number + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], + ) + @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): "reap_monthly_active_users", _reap_users, reserved_users ) - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self._mau_stats_only = hs.config.server.mau_stats_only - - # Do not add more reserved users than the total allowable number - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], - ) - def _initialise_reserved_users( self, txn: LoggingTransaction, threepids: List[dict] ) -> None: @@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS txn: threepids: List of threepid dicts to reserve """ + assert ( + self._update_on_this_worker + ), "This worker is not designated to update MAUs" # XXX what is this function trying to achieve? It upserts into # monthly_active_users for each *registered* reserved mau user, but why? @@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS Args: user_id: user to add/update """ + assert ( + self._update_on_this_worker + ), "This worker is not designated to update MAUs" + # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of # is_support_user which is complicated because I want to cache the result. @@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS txn (cursor): user_id (str): user to add/update """ + assert ( + self._update_on_this_worker + ), "This worker is not designated to update MAUs" # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple @@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS Args: user_id(str): the user_id to query """ + assert ( + self._update_on_this_worker + ), "This worker is not designated to update MAUs" + if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e6f97aeece..332e901dda 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore): super().__init__(database, db_conn, hs) + max_receipts_stream_id = self.get_max_receipt_stream_id() + receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict( + db_conn, + "receipts_linearized", + entity_column="room_id", + stream_column="stream_id", + max_value=max_receipts_stream_id, + limit=10000, + ) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() + "ReceiptsRoomChangeCache", + min_receipts_stream_id, + prefilled_cache=receipts_stream_prefill, ) def get_max_receipt_stream_id(self) -> int: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7f3d190e94..c7634c92fd 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1745,6 +1745,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "replace_refresh_token", _replace_refresh_token_txn ) + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1887,18 +1899,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) - @cached() - async def is_guest(self, user_id: str) -> bool: - res = await self.db_pool.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b2295fd51f..64a7808140 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -26,8 +26,6 @@ from typing import ( cast, ) -import attr - from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -39,8 +37,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import RoomStreamToken, StreamToken +from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -73,7 +70,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> PaginationChunk: + ) -> Tuple[List[str], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -90,8 +87,10 @@ class RelationsWorkerStore(SQLBaseStore): to_token: Fetch rows up to the given token, or up to the end if None. Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. + A tuple of: + A list of related event IDs + + The next stream token, if one exists. """ # We don't use `event_id`, it's there so that we can cache based on # it. The `event_id` must match the `event.event_id`. @@ -146,7 +145,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_recent_references_for_event_txn( txn: LoggingTransaction, - ) -> PaginationChunk: + ) -> Tuple[List[str], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None @@ -156,7 +155,7 @@ class RelationsWorkerStore(SQLBaseStore): # Do not include edits for redacted events as they leak event # content. if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append({"event_id": row[0]}) + events.append(row[0]) last_topo_id = row[2] last_stream_id = row[3] @@ -179,9 +178,7 @@ class RelationsWorkerStore(SQLBaseStore): groups_key=0, ) - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token - ) + return events[:limit], next_token return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn @@ -252,15 +249,8 @@ class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) async def get_aggregation_groups_for_event( - self, - event_id: str, - room_id: str, - event_type: Optional[str] = None, - limit: int = 5, - direction: str = "b", - from_token: Optional[AggregationPaginationToken] = None, - to_token: Optional[AggregationPaginationToken] = None, - ) -> PaginationChunk: + self, event_id: str, room_id: str, limit: int = 5 + ) -> List[JsonDict]: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -270,79 +260,36 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. - event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. - direction: Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. Returns: List of groups of annotations that match. Each row is a dict with `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [ + where_args = [ event_id, room_id, RelationTypes.ANNOTATION, + limit, ] - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) - - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] - to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] - engine=self.database_engine, - ) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + SELECT type, aggregation_key, COUNT(DISTINCT sender) FROM event_relations INNER JOIN events USING (event_id) - WHERE {where_clause} + WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + ORDER BY COUNT(*) DESC LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) + """ def _get_aggregation_groups_for_event_txn( txn: LoggingTransaction, - ) -> PaginationChunk: - txn.execute(sql, where_args + [limit + 1]) - - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) + ) -> List[JsonDict]: + txn.execute(sql, where_args) - if len(events) <= limit: - next_batch = None - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) + return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 3248da5356..98d09b3736 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): return None async def get_rooms_for_local_user_where_membership_is( - self, user_id: str, membership_list: Collection[str] + self, + user_id: str, + membership_list: Collection[str], + excluded_rooms: Optional[List[str]] = None, ) -> List[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id: The user ID. membership_list: A list of synapse.api.constants.Membership values which the user must be in. + excluded_rooms: A list of rooms to ignore. Returns: The RoomsForUser that the user matches the membership types. @@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership_list, ) - # Now we filter out forgotten rooms - forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id) - return [room for room in rooms if room.room_id not in forgotten_rooms] + # Now we filter out forgotten and excluded rooms + rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id) + + if excluded_rooms is not None: + rooms_to_exclude.update(set(excluded_rooms)) + + return [room for room in rooms if room.room_id not in rooms_to_exclude] def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id: str, membership_list: List[str] + self, + txn, + user_id: str, + membership_list: List[str], ) -> List[RoomsForUser]: # Paranoia check. if not self.hs.is_mine_id(user_id): diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 28460fd364..4a461a0abb 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections.abc import logging -from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple + +from frozendict import frozendict from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -29,7 +30,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter -from synapse.types import JsonDict, StateMap +from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return room_version - async def get_room_predecessor(self, room_id: str) -> Optional[dict]: + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. @@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): predecessor = create_event.content.get("predecessor", None) # Ensure the key is a dictionary - if not isinstance(predecessor, collections.abc.Mapping): + if not isinstance(predecessor, (dict, frozendict)): return None + # The keys must be strings since the data is JSON. return predecessor async def get_create_event_for_room(self, room_id: str) -> EventBase: @@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): list_name="event_ids", num_args=1, ) - async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict: - """Returns mapping event_id -> state_group""" + async def _get_state_group_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, int]: + """Returns mapping event_id -> state_group. + + Raises: + RuntimeError if the state is unknown at any of the given events + """ rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", @@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="_get_state_group_for_events", ) - return {row["event_id"]: row["state_group"] for row in rows} + res = {row["event_id"]: row["state_group"] for row in rows} + for e in event_ids: + if e not in res: + raise RuntimeError("No state group for unknown or outlier event %s" % e) + return res async def get_referenced_state_groups( self, state_groups: Iterable[int] @@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): ) for user_id in potentially_left_users - joined_users: - await self.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined] return batch_size diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 39e1efe373..8e764790db 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,7 +36,7 @@ what sort order was used: """ import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple import attr from frozendict import frozendict @@ -585,7 +585,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key async def get_membership_changes_for_user( - self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken + self, + user_id: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, + excluded_rooms: Optional[List[str]] = None, ) -> List[EventBase]: """Fetch membership events for a given user. @@ -610,23 +614,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): min_from_id = from_key.stream max_to_id = to_key.get_max_stream_pos() + args: List[Any] = [user_id, min_from_id, max_to_id] + + ignore_room_clause = "" + if excluded_rooms is not None and len(excluded_rooms) > 0: + ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join( + "?" for _ in excluded_rooms + ) + args = args + excluded_rooms + sql = """ SELECT m.event_id, instance_name, topological_ordering, stream_ordering FROM events AS e, room_memberships AS m WHERE e.event_id = m.event_id AND m.user_id = ? AND e.stream_ordering > ? AND e.stream_ordering <= ? + %s ORDER BY e.stream_ordering ASC - """ - txn.execute( - sql, - ( - user_id, - min_from_id, - max_to_id, - ), + """ % ( + ignore_room_clause, ) + txn.execute(sql, args) + rows = [ _EventDictReturn(event_id, None, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py deleted file mode 100644 index fba270150b..0000000000 --- a/synapse/storage/relations.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple - -import attr - -from synapse.api.errors import SynapseError -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.storage.databases.main import DataStore - -logger = logging.getLogger(__name__) - - -@attr.s(slots=True, auto_attribs=True) -class PaginationChunk: - """Returned by relation pagination APIs. - - Attributes: - chunk: The rows returned by pagination - next_batch: Token to fetch next set of results with, if - None then there are no more results. - prev_batch: Token to fetch previous set of results with, if - None then there are no previous results. - """ - - chunk: List[JsonDict] - next_batch: Optional[Any] = None - prev_batch: Optional[Any] = None - - async def to_dict(self, store: "DataStore") -> Dict[str, Any]: - d = {"chunk": self.chunk} - - if self.next_batch: - d["next_batch"] = await self.next_batch.to_string(store) - - if self.prev_batch: - d["prev_batch"] = await self.prev_batch.to_string(store) - - return d - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class AggregationPaginationToken: - """Pagination token for relation aggregation pagination API. - - As the results are order by count and then MAX(stream_ordering) of the - aggregation groups, we can just use them as our pagination token. - - Attributes: - count: The count of relations in the boundary group. - stream: The MAX stream ordering in the boundary group. - """ - - count: int - stream: int - - @staticmethod - def from_string(string: str) -> "AggregationPaginationToken": - try: - c, s = string.split("-") - return AggregationPaginationToken(int(c), int(s)) - except ValueError: - raise SynapseError(400, "Invalid aggregation pagination token") - - async def to_string(self, store: "DataStore") -> str: - return "%d-%d" % (self.count, self.stream) - - def as_tuple(self) -> Tuple[Any, ...]: - return attr.astuple(self) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 7b21c1b96d..151f2aa9bb 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 68 # remember to update the list below when updating +SCHEMA_VERSION = 69 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -58,6 +58,10 @@ Changes in SCHEMA_VERSION = 68: - event_reference_hashes is no longer read. - `events` has `state_key` and `rejection_reason` columns, which are populated for new events. + +Changes in SCHEMA_VERSION = 69: + - We now write to `device_lists_changes_in_room` table. + - Use sequence to generate future `application_services_txns.txn_id`s """ diff --git a/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql new file mode 100644 index 0000000000..7590e34b94 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to track what device list changes stream id that this application +-- service has been caught up to. + +-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible +-- state is useful for determining if we've ever sent traffic for a stream type +-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for +-- one way this can be used. +ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/69/01as_txn_seq.py b/synapse/storage/schema/main/delta/69/01as_txn_seq.py new file mode 100644 index 0000000000..24bd4b391e --- /dev/null +++ b/synapse/storage/schema/main/delta/69/01as_txn_seq.py @@ -0,0 +1,44 @@ +# Copyright 2022 Beeper +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Adds a postgres SEQUENCE for generating application service transaction IDs. +""" + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # If we already have some AS TXNs we want to start from the current + # maximum value. There are two potential places this is stored - the + # actual TXNs themselves *and* the AS state table. At time of migration + # it is possible the TXNs table is empty so we must include the AS state + # last_txn as a potential option, and pick the maximum. + + cur.execute("SELECT COALESCE(max(txn_id), 0) FROM application_services_txns") + row = cur.fetchone() + txn_max = row[0] + + cur.execute("SELECT COALESCE(max(last_txn), 0) FROM application_services_state") + row = cur.fetchone() + last_txn_max = row[0] + + start_val = max(last_txn_max, txn_max) + 1 + + cur.execute( + "CREATE SEQUENCE application_services_txn_id_seq START WITH %s", + (start_val,), + ) diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql new file mode 100644 index 0000000000..b5b1782b2a --- /dev/null +++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql @@ -0,0 +1,38 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_lists_changes_in_room ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + room_id TEXT NOT NULL, + + -- This initially matches `device_lists_stream.stream_id`. Note that we + -- delete older values from `device_lists_stream`, so we can't use a foreign + -- constraint here. + -- + -- The table will contain rows with the same `stream_id` but different + -- `room_id`, as for each device update we store a row per room the user is + -- joined to. Therefore `(stream_id, room_id)` gives a unique index. + stream_id BIGINT NOT NULL, + + -- We have a background process which goes through this table and converts + -- entries into rows in `device_lists_outbound_pokes`. Once we have processed + -- a row, we mark it as such by setting `converted_to_destinations=TRUE`. + converted_to_destinations BOOLEAN NOT NULL, + opentracing_context TEXT +); + +CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id); +CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations; diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 86f1a5373b..cda194e8c8 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -571,6 +571,10 @@ class StateGroupStorage: Returns: dict of state_group_id -> (dict of (type, state_key) -> event id) + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) """ if not event_ids: return {} @@ -659,6 +663,10 @@ class StateGroupStorage: Returns: A dict of (event_id) -> (type, state_key) -> [state_events] + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) """ event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) @@ -696,6 +704,10 @@ class StateGroupStorage: Returns: A dict from event_id -> (type, state_key) -> event_id + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) """ event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) @@ -723,6 +735,10 @@ class StateGroupStorage: Returns: A dict from (type, state_key) -> state_event + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) """ state_map = await self.get_state_for_events( [event_id], state_filter or StateFilter.all() @@ -741,6 +757,10 @@ class StateGroupStorage: Returns: A dict from (type, state_key) -> state_event_id + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) """ state_map = await self.get_state_ids_for_events( [event_id], state_filter or StateFilter.all() diff --git a/synapse/types.py b/synapse/types.py index 5ce2a5b0a5..6bbefb6faa 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -25,6 +25,7 @@ from typing import ( Match, MutableMapping, Optional, + Set, Tuple, Type, TypeVar, @@ -421,22 +422,44 @@ class RoomStreamToken: s0 s1 | | - [0] V [1] V [2] + [0] ▼ [1] ▼ [2] Tokens can either be a point in the live event stream or a cursor going through historic events. - When traversing the live event stream events are ordered by when they - arrived at the homeserver. + When traversing the live event stream, events are ordered by + `stream_ordering` (when they arrived at the homeserver). - When traversing historic events the events are ordered by their depth in - the event graph "topological_ordering" and then by when they arrived at the - homeserver "stream_ordering". + When traversing historic events, events are first ordered by their `depth` + (`topological_ordering` in the event graph) and tie-broken by + `stream_ordering` (when the event arrived at the homeserver). - Live tokens start with an "s" followed by the "stream_ordering" id of the - event it comes after. Historic tokens start with a "t" followed by the - "topological_ordering" id of the event it comes after, followed by "-", - followed by the "stream_ordering" id of the event it comes after. + If you're looking for more info about what a token with all of the + underscores means, ex. + `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring + for `StreamToken` below. + + --- + + Live tokens start with an "s" followed by the `stream_ordering` of the event + that comes before the position of the token. Said another way: + `stream_ordering` uniquely identifies a persisted event. The live token + means "the position just after the event identified by `stream_ordering`". + An example token is: + + s2633508 + + --- + + Historic tokens start with a "t" followed by the `depth` + (`topological_ordering` in the event graph) of the event that comes before + the position of the token, followed by "-", followed by the + `stream_ordering` of the event that comes before the position of the token. + An example token is: + + t426-2633508 + + --- There is also a third mode for live tokens where the token starts with "m", which is sometimes used when using sharded event persisters. In this case @@ -463,6 +486,8 @@ class RoomStreamToken: Note: The `RoomStreamToken` cannot have both a topological part and an instance map. + --- + For caching purposes, `RoomStreamToken`s and by extension, all their attributes, must be hashable. """ @@ -599,7 +624,57 @@ class RoomStreamToken: @attr.s(slots=True, frozen=True, auto_attribs=True) class StreamToken: - """A collection of positions within multiple streams. + """A collection of keys joined together by underscores in the following + order and which represent the position in their respective streams. + + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` + 1. `room_key`: `s2633508` which is a `RoomStreamToken` + - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` + - See the docstring for `RoomStreamToken` for more details. + 2. `presence_key`: `17` + 3. `typing_key`: `338` + 4. `receipt_key`: `6732159` + 5. `account_data_key`: `1082514` + 6. `push_rules_key`: `541479` + 7. `to_device_key`: `274711` + 8. `device_list_key`: `265584` + 9. `groups_key`: `1` + + You can see how many of these keys correspond to the various + fields in a "/sync" response: + ```json + { + "next_batch": "s12_4_0_1_1_1_1_4_1", + "presence": { + "events": [] + }, + "device_lists": { + "changed": [] + }, + "rooms": { + "join": { + "!QrZlfIDQLNLdZHqTnt:hs1": { + "timeline": { + "events": [], + "prev_batch": "s10_4_0_1_1_1_1_4_1", + "limited": false + }, + "state": { + "events": [] + }, + "account_data": { + "events": [] + }, + "ephemeral": { + "events": [] + } + } + } + } + } + ``` + + --- For caching purposes, `StreamToken`s and by extension, all their attributes, must be hashable. @@ -748,6 +823,30 @@ class ReadReceipt: data: JsonDict +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceListUpdates: + """ + An object containing a diff of information regarding other users' device lists, intended for + a recipient to carry out device list tracking. + + Attributes: + changed: A set of users whose device lists have changed recently. + left: A set of users who the recipient no longer needs to track the device lists of. + Typically when those users no longer share any end-to-end encryption enabled rooms. + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceListUpdates instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + def get_verify_key_from_cross_signing_key(key_info): """Get the key ID and signedjson verify key from a cross-signing key dict diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 1cbc180eda..42f6abb5e1 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -17,7 +17,7 @@ import logging import typing from enum import Enum, auto from sys import intern -from typing import Any, Callable, Dict, List, Optional, Sized +from typing import Any, Callable, Dict, List, Optional, Sized, TypeVar import attr from prometheus_client.core import Gauge @@ -195,8 +195,10 @@ KNOWN_KEYS = { ) } +T = TypeVar("T", Optional[str], str) -def intern_string(string: Optional[str]) -> Optional[str]: + +def intern_string(string: T) -> T: """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None diff --git a/synapse/visibility.py b/synapse/visibility.py index 49519eb8f5..250f073597 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -1,4 +1,5 @@ # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright (C) The Matrix.org Foundation C.I.C. 2022 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, FrozenSet, List, Optional +from typing import Collection, Dict, FrozenSet, List, Optional, Tuple + +from typing_extensions import Final from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase @@ -40,6 +43,8 @@ MEMBERSHIP_PRIORITY = ( Membership.BAN, ) +_HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "") + async def filter_events_for_client( storage: Storage, @@ -74,7 +79,7 @@ async def filter_events_for_client( # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] - types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) # we exclude outliers at this point, and then handle them separately later event_id_to_state = await storage.state.get_state_for_events( @@ -157,7 +162,7 @@ async def filter_events_for_client( state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. - visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) + visibility_event = state.get(_HISTORY_VIS_KEY, None) if visibility_event: visibility = visibility_event.content.get( "history_visibility", HistoryVisibility.SHARED @@ -293,67 +298,28 @@ async def filter_events_for_server( return True return False - def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool: - history = state.get((EventTypes.RoomHistoryVisibility, ""), None) - if history: - visibility = history.content.get( - "history_visibility", HistoryVisibility.SHARED - ) - if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]: - # We now loop through all state events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in state.values(): - if ev.type != EventTypes.Member: - continue - try: - domain = get_domain_from_id(ev.state_key) - except Exception: - continue - - if domain != server_name: - continue - - memtype = ev.membership - if memtype == Membership.JOIN: - return True - elif memtype == Membership.INVITE: - if visibility == HistoryVisibility.INVITED: - return True - else: - # server has no users in the room: redact - return False - - return True - - # Lets check to see if all the events have a history visibility - # of "shared" or "world_readable". If that's the case then we don't - # need to check membership (as we know the server is in the room). - event_to_state_ids = await storage.state.get_state_ids_for_events( - frozenset(e.event_id for e in events), - state_filter=StateFilter.from_types( - types=((EventTypes.RoomHistoryVisibility, ""),) - ), - ) - - visibility_ids = set() - for sids in event_to_state_ids.values(): - hist = sids.get((EventTypes.RoomHistoryVisibility, "")) - if hist: - visibility_ids.add(hist) + def check_event_is_visible( + visibility: str, memberships: StateMap[EventBase] + ) -> bool: + if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED): + return True - # If we failed to find any history visibility events then the default - # is "shared" visibility. - if not visibility_ids: - all_open = True - else: - event_map = await storage.main.get_events(visibility_ids) - all_open = all( - e.content.get("history_visibility") - in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) - for e in event_map.values() - ) + # We now loop through all membership events looking for + # membership states for the requesting server to determine + # if the server is either in the room or has been invited + # into the room. + for ev in memberships.values(): + assert get_domain_from_id(ev.state_key) == server_name + + memtype = ev.membership + if memtype == Membership.JOIN: + return True + elif memtype == Membership.INVITE: + if visibility == HistoryVisibility.INVITED: + return True + + # server has no users in the room: redact + return False if not check_history_visibility_only: erased_senders = await storage.main.are_users_erased(e.sender for e in events) @@ -362,34 +328,100 @@ async def filter_events_for_server( # to no users having been erased. erased_senders = {} - if all_open: - # all the history_visibility state affecting these events is open, so - # we don't need to filter by membership state. We *do* need to check - # for user erasure, though. - if erased_senders: - to_return = [] - for e in events: - if not is_sender_erased(e, erased_senders): - to_return.append(e) - elif redact: - to_return.append(prune_event(e)) - - return to_return - - # If there are no erased users then we can just return the given list - # of events without having to copy it. - return events - - # Ok, so we're dealing with events that have non-trivial visibility - # rules, so we need to also get the memberships of the room. - - # first, for each event we're wanting to return, get the event_ids - # of the history vis and membership state at those events. + # Let's check to see if all the events have a history visibility + # of "shared" or "world_readable". If that's the case then we don't + # need to check membership (as we know the server is in the room). + event_to_history_vis = await _event_to_history_vis(storage, events) + + # for any with restricted vis, we also need the memberships + event_to_memberships = await _event_to_memberships( + storage, + [ + e + for e in events + if event_to_history_vis[e.event_id] + not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) + ], + server_name, + ) + + to_return = [] + for e in events: + erased = is_sender_erased(e, erased_senders) + visible = check_event_is_visible( + event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) + ) + if visible and not erased: + to_return.append(e) + elif redact: + to_return.append(prune_event(e)) + + return to_return + + +async def _event_to_history_vis( + storage: Storage, events: Collection[EventBase] +) -> Dict[str, str]: + """Get the history visibility at each of the given events + + Returns a map from event id to history_visibility setting + """ + + # outliers get special treatment here. We don't have the state at that point in the + # room (and attempting to look it up will raise an exception), so all we can really + # do is assume that the requesting server is allowed to see the event. That's + # equivalent to there not being a history_visibility event, so we just exclude + # any outliers from the query. + event_to_state_ids = await storage.state.get_state_ids_for_events( + frozenset(e.event_id for e in events if not e.internal_metadata.is_outlier()), + state_filter=StateFilter.from_types(types=(_HISTORY_VIS_KEY,)), + ) + + visibility_ids = { + vis_event_id + for vis_event_id in ( + state_ids.get(_HISTORY_VIS_KEY) for state_ids in event_to_state_ids.values() + ) + if vis_event_id + } + vis_events = await storage.main.get_events(visibility_ids) + + result: Dict[str, str] = {} + for event in events: + vis = HistoryVisibility.SHARED + state_ids = event_to_state_ids.get(event.event_id) + + # if we didn't find any state for this event, it's an outlier, and we assume + # it's open + visibility_id = None + if state_ids: + visibility_id = state_ids.get(_HISTORY_VIS_KEY) + + if visibility_id: + vis_event = vis_events[visibility_id] + vis = vis_event.content.get("history_visibility", HistoryVisibility.SHARED) + assert isinstance(vis, str) + + result[event.event_id] = vis + return result + + +async def _event_to_memberships( + storage: Storage, events: Collection[EventBase], server_name: str +) -> Dict[str, StateMap[EventBase]]: + """Get the remote membership list at each of the given events + + Returns a map from event id to state map, which will contain only membership events + for the given server. + """ + + if not events: + return {} + + # for each event, get the event_ids of the membership state at those events. event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), - state_filter=StateFilter.from_types( - types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) - ), + state_filter=StateFilter.from_types(types=((EventTypes.Member, None),)), ) # We only want to pull out member events that correspond to the @@ -405,10 +437,7 @@ async def filter_events_for_server( for key, event_id in key_to_eid.items() } - def include(typ, state_key): - if typ != EventTypes.Member: - return True - + def include(state_key: str) -> bool: # we avoid using get_domain_from_id here for efficiency. idx = state_key.find(":") if idx == -1: @@ -416,10 +445,14 @@ async def filter_events_for_server( return state_key[idx + 1 :] == server_name event_map = await storage.main.get_events( - [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] + [ + e_id + for e_id, (_, state_key) in event_id_to_state_key.items() + if include(state_key) + ] ) - event_to_state = { + return { e_id: { key: event_map[inner_e_id] for key, inner_e_id in key_to_eid.items() @@ -427,14 +460,3 @@ async def filter_events_for_server( } for e_id, key_to_eid in event_to_state_ids.items() } - - to_return = [] - for e in events: - erased = is_sender_erased(e, erased_senders) - visible = check_event_is_visible(e, event_to_state[e.event_id]) - if visible and not erased: - to_return.append(e) - elif redact: - to_return.append(prune_event(e)) - - return to_return |