diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/account_data.py | 52 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 148 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 3 | ||||
-rw-r--r-- | synapse/handlers/device.py | 132 | ||||
-rw-r--r-- | synapse/handlers/federation_event.py | 49 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 2 | ||||
-rw-r--r-- | synapse/handlers/relations.py | 41 | ||||
-rw-r--r-- | synapse/handlers/room.py | 4 | ||||
-rw-r--r-- | synapse/handlers/room_batch.py | 38 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 63 |
10 files changed, 428 insertions, 104 deletions
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 177b4f8991..4af9fbc5d1 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import random -from typing import TYPE_CHECKING, Collection, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple from synapse.replication.http.account_data import ( ReplicationAddTagRestServlet, @@ -27,6 +28,12 @@ from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + +ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[ + [str, Optional[str], str, JsonDict], Awaitable +] + class AccountDataHandler: def __init__(self, hs: "HomeServer"): @@ -40,6 +47,44 @@ class AccountDataHandler: self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) self._account_data_writers = hs.config.worker.writers.account_data + self._on_account_data_updated_callbacks: List[ + ON_ACCOUNT_DATA_UPDATED_CALLBACK + ] = [] + + def register_module_callbacks( + self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None + ) -> None: + """Register callbacks from modules.""" + if on_account_data_updated is not None: + self._on_account_data_updated_callbacks.append(on_account_data_updated) + + async def _notify_modules( + self, + user_id: str, + room_id: Optional[str], + account_data_type: str, + content: JsonDict, + ) -> None: + """Notifies modules about new account data changes. + + A change can be either a new account data type being added, or the content + associated with a type being changed. Account data for a given type is removed by + changing the associated content to an empty dictionary. + + Note that this is not called when the tags associated with a room change. + + Args: + user_id: The user whose account data is changing. + room_id: The ID of the room the account data change concerns, if any. + account_data_type: The type of the account data. + content: The content that is now associated with this type. + """ + for callback in self._on_account_data_updated_callbacks: + try: + await callback(user_id, room_id, account_data_type, content) + except Exception as e: + logger.exception("Failed to run module callback %s: %s", callback, e) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: @@ -63,6 +108,8 @@ class AccountDataHandler: "account_data_key", max_stream_id, users=[user_id] ) + await self._notify_modules(user_id, room_id, account_data_type, content) + return max_stream_id else: response = await self._room_data_client( @@ -96,6 +143,9 @@ class AccountDataHandler: self._notifier.on_new_event( "account_data_key", max_stream_id, users=[user_id] ) + + await self._notify_modules(user_id, None, account_data_type, content) + return max_stream_id else: response = await self._user_data_client( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index bd913e524e..316c4b677c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.types import ( + DeviceListUpdates, + JsonDict, + RoomAlias, + RoomStreamToken, + UserID, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure @@ -58,6 +64,9 @@ class ApplicationServicesHandler: self._msc2409_to_device_messages_enabled = ( hs.config.experimental.msc2409_to_device_messages_enabled ) + self._msc3202_transaction_extensions_enabled = ( + hs.config.experimental.msc3202_transaction_extensions + ) self.current_max = 0 self.is_processing = False @@ -204,9 +213,9 @@ class ApplicationServicesHandler: Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key", "presence_key" or - "to_device_key". Any other value for `stream_key` will cause this function - to return early. + `stream_key` can be "typing_key", "receipt_key", "presence_key", + "to_device_key" or "device_list_key". Any other value for `stream_key` + will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into receiving them by setting `push_ephemeral` to true in their registration @@ -230,6 +239,7 @@ class ApplicationServicesHandler: "receipt_key", "presence_key", "to_device_key", + "device_list_key", ): return @@ -253,15 +263,37 @@ class ApplicationServicesHandler: ): return + # Ignore device lists if the feature flag is not enabled + if ( + stream_key == "device_list_key" + and not self._msc3202_transaction_extensions_enabled + ): + return + # Check whether there are any appservices which have registered to receive # ephemeral events. # # Note that whether these events are actually relevant to these appservices # is decided later on. + services = self.store.get_app_services() services = [ service - for service in self.store.get_app_services() - if service.supports_ephemeral + for service in services + # Different stream keys require different support booleans + if ( + stream_key + in ( + "typing_key", + "receipt_key", + "presence_key", + "to_device_key", + ) + and service.supports_ephemeral + ) + or ( + stream_key == "device_list_key" + and service.msc3202_transaction_extensions + ) ] if not services: # Bail out early if none of the target appservices have explicitly registered @@ -336,6 +368,20 @@ class ApplicationServicesHandler: service, "to_device", new_token ) + elif stream_key == "device_list_key": + device_list_summary = await self._get_device_list_summary( + service, new_token + ) + if device_list_summary: + self.scheduler.enqueue_for_appservice( + service, device_list_summary=device_list_summary + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "device_list", new_token + ) + async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: @@ -542,6 +588,96 @@ class ApplicationServicesHandler: return message_payload + async def _get_device_list_summary( + self, + appservice: ApplicationService, + new_key: int, + ) -> DeviceListUpdates: + """ + Retrieve a list of users who have changed their device lists. + + Args: + appservice: The application service to retrieve device list changes for. + new_key: The stream key of the device list change that triggered this method call. + + Returns: + A set of device list updates, comprised of users that the appservices needs to: + * resync the device list of, and + * stop tracking the device list of. + """ + # Fetch the last successfully processed device list update stream ID + # for this appservice. + from_key = await self.store.get_type_stream_id_for_appservice( + appservice, "device_list" + ) + + # Fetch the users who have modified their device list since then. + users_with_changed_device_lists = ( + await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + ) + + # Filter out any users the application service is not interested in + # + # For each user who changed their device list, we want to check whether this + # appservice would be interested in the change. + filtered_users_with_changed_device_lists = { + user_id + for user_id in users_with_changed_device_lists + if await self._is_appservice_interested_in_device_lists_of_user( + appservice, user_id + ) + } + + # Create a summary of "changed" and "left" users. + # TODO: Calculate "left" users. + device_list_summary = DeviceListUpdates( + changed=filtered_users_with_changed_device_lists + ) + + return device_list_summary + + async def _is_appservice_interested_in_device_lists_of_user( + self, + appservice: ApplicationService, + user_id: str, + ) -> bool: + """ + Returns whether a given application service is interested in the device list + updates of a given user. + + The application service is interested in the user's device list updates if any + of the following are true: + * The user is the appservice's sender localpart user. + * The user is in the appservice's user namespace. + * At least one member of one room that the user is a part of is in the + appservice's user namespace. + * The appservice is explicitly (via room ID or alias) interested in at + least one room that the user is in. + + Args: + appservice: The application service to gauge interest of. + user_id: The ID of the user whose device list interest is in question. + + Returns: + True if the application service is interested in the user's device lists, False + otherwise. + """ + # This method checks against both the sender localpart user as well as if the + # user is in the appservice's user namespace. + if appservice.is_interested_in_user(user_id): + return True + + # Determine whether any of the rooms the user is in justifies sending this + # device list update to the application service. + room_ids = await self.store.get_rooms_for_user(user_id) + for room_id in room_ids: + # This method covers checking room members for appservice interest as well as + # room ID and alias checks. + if await appservice.is_interested_in_room(room_id, self.store): + return True + + return False + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3e29c96a49..86991d26ce 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -211,6 +211,7 @@ class AuthHandler: self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.auth.password_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled + self._third_party_rules = hs.get_third_party_event_rules() # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -1505,6 +1506,8 @@ class AuthHandler: user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) + await self._third_party_rules.on_threepid_bind(user_id, medium, address) + async def delete_threepid( self, user_id: str, medium: str, address: str, id_server: Optional[str] = None ) -> bool: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d5ccaa0c37..c710c02cf9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -37,7 +37,10 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import ( JsonDict, StreamToken, @@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) + # Whether `_handle_new_device_update_async` is currently processing. + self._handle_new_device_update_is_processing = False + + # If a new device update may have happened while the loop was + # processing. + self._handle_new_device_update_new_data = False + + # On start up check if there are any updates pending. + hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + + # Used to decide if we calculate outbound pokes up front or not. By + # default we do to allow safely downgrading Synapse. + self.use_new_device_lists_changes_in_room = ( + hs.config.server.use_new_device_lists_changes_in_room + ) + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler): # No changes to notify about, so this is a no-op. return - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id - ) + room_ids = await self.store.get_rooms_for_user(user_id) + + hosts: Optional[Set[str]] = None + if not self.use_new_device_lists_changes_in_room: + hosts = set() - hosts: Set[str] = set() - if self.hs.is_mine_id(user_id): - hosts.update(get_domain_from_id(u) for u in users_who_share_room) - hosts.discard(self.server_name) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + joined_users = await self.store.get_users_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in joined_users) - set_tag("target_hosts", hosts) + set_tag("target_hosts", hosts) + + hosts.discard(self.server_name) position = await self.store.add_device_change_to_streams( - user_id, device_ids, list(hosts) + user_id, + device_ids, + hosts=hosts, + room_ids=room_ids, ) if not position: @@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler): # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - users_to_notify = users_who_share_room.union({user_id}) + self.notifier.on_new_event( + "device_list_key", position, users={user_id}, rooms=room_ids + ) - self.notifier.on_new_event("device_list_key", position, users=users_to_notify) + # We may need to do some processing asynchronously. + self._handle_new_device_update_async() if hosts: logger.info( @@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler): return {"success": True} + @wrap_as_background_process("_handle_new_device_update_async") + async def _handle_new_device_update_async(self) -> None: + """Called when we have a new local device list update that we need to + send out over federation. + + This happens in the background so as not to block the original request + that generated the device update. + """ + if self._handle_new_device_update_is_processing: + self._handle_new_device_update_new_data = True + return + + self._handle_new_device_update_is_processing = True + + # The stream ID we processed previous iteration (if any), and the set of + # hosts we've already poked about for this update. This is so that we + # don't poke the same remote server about the same update repeatedly. + current_stream_id = None + hosts_already_sent_to: Set[str] = set() + + try: + while True: + self._handle_new_device_update_new_data = False + rows = await self.store.get_uncoverted_outbound_room_pokes() + if not rows: + # If the DB returned nothing then there is nothing left to + # do, *unless* a new device list update happened during the + # DB query. + if self._handle_new_device_update_new_data: + continue + else: + return + + for user_id, device_id, room_id, stream_id, opentracing_context in rows: + joined_user_ids = await self.store.get_users_in_room(room_id) + hosts = {get_domain_from_id(u) for u in joined_user_ids} + hosts.discard(self.server_name) + + # Check if we've already sent this update to some hosts + if current_stream_id == stream_id: + hosts -= hosts_already_sent_to + + await self.store.add_device_list_outbound_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + stream_id=stream_id, + hosts=hosts, + context=opentracing_context, + ) + + # Notify replication that we've updated the device list stream. + self.notifier.notify_replication() + + if hosts: + logger.info( + "Sending device list update notif for %r to: %r", + user_id, + hosts, + ) + for host in hosts: + self.federation_sender.send_device_messages( + host, immediate=False + ) + log_kv( + {"message": "sent device update to host", "host": host} + ) + + if current_stream_id != stream_id: + # Clear the set of hosts we've already sent to as we're + # processing a new update. + hosts_already_sent_to.clear() + + hosts_already_sent_to.update(hosts) + current_stream_id = stream_id + + finally: + self._handle_new_device_update_is_processing = False + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 4bd87709f3..e7b9f15e13 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -469,6 +469,12 @@ class FederationEventHandler: if context.rejected: raise SynapseError(400, "Join event was rejected") + # the remote server is responsible for sending our join event to the rest + # of the federation. Indeed, attempting to do so will result in problems + # when we try to look up the state before the join (to get the server list) + # and discover that we do not have it. + event.internal_metadata.proactively_send = False + return await self.persist_events_and_notify(room_id, [(event, context)]) async def backfill( @@ -891,10 +897,24 @@ class FederationEventHandler: logger.debug("We are also missing %i auth events", len(missing_auth_events)) missing_events = missing_desired_events | missing_auth_events - logger.debug("Fetching %i events from remote", len(missing_events)) - await self._get_events_and_persist( - destination=destination, room_id=room_id, event_ids=missing_events - ) + + # Making an individual request for each of 1000s of events has a lot of + # overhead. On the other hand, we don't really want to fetch all of the events + # if we already have most of them. + # + # As an arbitrary heuristic, if we are missing more than 10% of the events, then + # we fetch the whole state. + # + # TODO: might it be better to have an API which lets us do an aggregate event + # request + if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids): + logger.debug("Requesting complete state from remote") + await self._get_state_and_persist(destination, room_id, event_id) + else: + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, event_ids=missing_events + ) # we need to make sure we re-load from the database to get the rejected # state correct. @@ -953,6 +973,27 @@ class FederationEventHandler: return remote_state + async def _get_state_and_persist( + self, destination: str, room_id: str, event_id: str + ) -> None: + """Get the complete room state at a given event, and persist any new events + as outliers""" + room_version = await self._store.get_room_version(room_id) + auth_events, state_events = await self._federation_client.get_room_state( + destination, room_id, event_id=event_id, room_version=room_version + ) + logger.info("/state returned %i events", len(auth_events) + len(state_events)) + + await self._auth_and_persist_outliers( + room_id, itertools.chain(auth_events, state_events) + ) + + # we also need the event itself. + if not await self._store.have_seen_event(room_id, event_id): + await self._get_events_and_persist( + destination=destination, room_id=room_id, event_ids=(event_id,) + ) + async def _process_received_pdu( self, origin: str, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 34d9411bbf..dace31d87e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1625,7 +1625,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # We'll actually pull the presence updates for these users at the end. interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set() - if from_key: + if from_key is not None: # First get all users that have had a presence update updated_users = stream_change_cache.get_all_entities_changed(from_key) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 73217d135d..a36936b520 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +from typing import TYPE_CHECKING, Dict, Iterable, Optional import attr from frozendict import frozendict @@ -25,7 +25,6 @@ from synapse.visibility import filter_events_for_client if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -116,7 +115,7 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - pagination_chunk = await self._main_store.get_relations_for_event( + related_events, next_token = await self._main_store.get_relations_for_event( event_id=event_id, event=event, room_id=room_id, @@ -129,9 +128,7 @@ class RelationsHandler: to_token=to_token, ) - events = await self._main_store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) + events = await self._main_store.get_events_as_list(related_events) events = await filter_events_for_client( self._storage, user_id, events, is_peeking=(member_event_id is None) @@ -152,9 +149,16 @@ class RelationsHandler: events, now, bundle_aggregations=aggregations ) - return_value = await pagination_chunk.to_dict(self._main_store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event + return_value = { + "chunk": serialized_events, + "original_event": original_event, + } + + if next_token: + return_value["next_batch"] = await next_token.to_string(self._main_store) + + if from_token: + return_value["prev_batch"] = await from_token.to_string(self._main_store) return return_value @@ -193,16 +197,21 @@ class RelationsHandler: annotations = await self._main_store.get_aggregation_groups_for_event( event_id, room_id ) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) + if annotations: + aggregations.annotations = {"chunk": annotations} - references = await self._main_store.get_relations_for_event( + references, next_token = await self._main_store.get_relations_for_event( event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) + if references: + aggregations.references = { + "chunk": [{"event_id": event_id} for event_id in references] + } + + if next_token: + aggregations.references["next_batch"] = await next_token.to_string( + self._main_store + ) # Store the bundled aggregations in the event metadata for later use. return aggregations diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 092e185c99..51a08fd2c0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -771,7 +771,9 @@ class RoomCreationHandler: % (user_id,), ) - visibility = config.get("visibility", None) + # The spec says rooms should default to private visibility if + # `visibility` is not specified. + visibility = config.get("visibility", "private") is_public = visibility == "public" room_id = await self._generate_room_id( diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index a0255bd143..78e299d3a5 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -156,8 +156,8 @@ class RoomBatchHandler: ) -> List[str]: """Takes all `state_events_at_start` event dictionaries and creates/persists them in a floating state event chain which don't resolve into the current room - state. They are floating because they reference no prev_events and are marked - as outliers which disconnects them from the normal DAG. + state. They are floating because they reference no prev_events which disconnects + them from the normal DAG. Args: state_events_at_start: @@ -213,31 +213,23 @@ class RoomBatchHandler: room_id=room_id, action=membership, content=event_dict["content"], - # Mark as an outlier to disconnect it from the normal DAG - # and not show up between batches of history. - outlier=True, historical=True, # Only the first event in the state chain should be floating. # The rest should hang off each other in a chain. allow_no_prev_events=index == 0, prev_event_ids=prev_event_ids_for_state_chain, - # Since each state event is marked as an outlier, the - # `EventContext.for_outlier()` won't have any `state_ids` - # set and therefore can't derive any state even though the - # prev_events are set. Also since the first event in the - # state chain is floating with no `prev_events`, it can't - # derive state from anywhere automatically. So we need to - # set some state explicitly. + # The first event in the state chain is floating with no + # `prev_events` which means it can't derive state from + # anywhere automatically. So we need to set some state + # explicitly. # # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. + # reference and also update in the event when we append + # later. state_event_ids=state_event_ids.copy(), ) else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? ( event, _, @@ -246,21 +238,15 @@ class RoomBatchHandler: state_event["sender"], app_service_requester.app_service ), event_dict, - # Mark as an outlier to disconnect it from the normal DAG - # and not show up between batches of history. - outlier=True, historical=True, # Only the first event in the state chain should be floating. # The rest should hang off each other in a chain. allow_no_prev_events=index == 0, prev_event_ids=prev_event_ids_for_state_chain, - # Since each state event is marked as an outlier, the - # `EventContext.for_outlier()` won't have any `state_ids` - # set and therefore can't derive any state even though the - # prev_events are set. Also since the first event in the - # state chain is floating with no `prev_events`, it can't - # derive state from anywhere automatically. So we need to - # set some state explicitly. + # The first event in the state chain is floating with no + # `prev_events` which means it can't derive state from + # anywhere automatically. So we need to set some state + # explicitly. # # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6c569cfb1c..303c38c746 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,17 +13,7 @@ # limitations under the License. import itertools import logging -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Dict, - FrozenSet, - List, - Optional, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( + DeviceListUpdates, JsonDict, MutableStateMap, Requester, @@ -184,21 +175,6 @@ class GroupsSyncResult: return bool(self.join or self.invite or self.leave) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceLists: - """ - Attributes: - changed: List of user_ids whose devices may have changed - left: List of user_ids whose devices we no longer track - """ - - changed: Collection[str] - left: Collection[str] - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - @attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined @@ -240,7 +216,7 @@ class SyncResult: knocked: List[KnockedSyncResult] archived: List[ArchivedSyncResult] to_device: List[JsonDict] - device_lists: DeviceLists + device_lists: DeviceListUpdates device_one_time_keys_count: JsonDict device_unused_fallback_key_types: List[str] groups: Optional[GroupsSyncResult] @@ -298,6 +274,8 @@ class SyncHandler: expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) + self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + async def wait_for_sync_for_user( self, requester: Requester, @@ -1262,8 +1240,8 @@ class SyncHandler: newly_joined_or_invited_or_knocked_users: Set[str], newly_left_rooms: Set[str], newly_left_users: Set[str], - ) -> DeviceLists: - """Generate the DeviceLists section of sync + ) -> DeviceListUpdates: + """Generate the DeviceListUpdates section of sync Args: sync_result_builder @@ -1381,9 +1359,11 @@ class SyncHandler: if any(e.room_id in joined_rooms for e in entries): newly_left_users.discard(user_id) - return DeviceLists(changed=users_that_have_changed, left=newly_left_users) + return DeviceListUpdates( + changed=users_that_have_changed, left=newly_left_users + ) else: - return DeviceLists(changed=[], left=[]) + return DeviceListUpdates() async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" @@ -1607,13 +1587,15 @@ class SyncHandler: ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( - sync_result_builder, ignored_users + sync_result_builder, ignored_users, self.rooms_to_exclude ) tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) + room_changes = await self._get_all_rooms( + sync_result_builder, ignored_users, self.rooms_to_exclude + ) tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1689,7 +1671,10 @@ class SyncHandler: return False async def _get_rooms_changed( - self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] + self, + sync_result_builder: "SyncResultBuilder", + ignored_users: FrozenSet[str], + excluded_rooms: List[str], ) -> _RoomChanges: """Determine the changes in rooms to report to the user. @@ -1721,7 +1706,7 @@ class SyncHandler: # _have_rooms_changed. We could keep the results in memory to avoid a # second query, at the cost of more complicated source code. membership_change_events = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key + user_id, since_token.room_key, now_token.room_key, excluded_rooms ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} @@ -1922,7 +1907,10 @@ class SyncHandler: ) async def _get_all_rooms( - self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] + self, + sync_result_builder: "SyncResultBuilder", + ignored_users: FrozenSet[str], + ignored_rooms: List[str], ) -> _RoomChanges: """Returns entries for all rooms for the user. @@ -1933,7 +1921,7 @@ class SyncHandler: Args: sync_result_builder ignored_users: Set of users ignored by user. - + ignored_rooms: List of rooms to ignore. """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1944,6 +1932,7 @@ class SyncHandler: room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=Membership.LIST, + excluded_rooms=ignored_rooms, ) room_entries = [] |