From 2173785f0d9124037ca841b568349ad0424b39cd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Nov 2019 11:31:56 +0000 Subject: Propagate reason in remotely rejected invites --- synapse/handlers/room_member_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'synapse/handlers/room_member_worker.py') diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 75e96ae1a2..69be86893b 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -55,7 +55,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Implements RoomMemberHandler._remote_reject_invite """ return self._remote_reject_client( @@ -63,6 +65,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), + content=content, ) def _user_joined_room(self, target, room_id): -- cgit 1.4.1 From e9f3de0baba9be63b77fdaff996274e0abed8ec4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 May 2020 09:32:13 -0400 Subject: Update the room member handler to use async/await. (#7507) --- changelog.d/7507.misc | 1 + synapse/handlers/room_member.py | 111 +++++++++++++++------------------ synapse/handlers/room_member_worker.py | 21 +++---- 3 files changed, 59 insertions(+), 74 deletions(-) create mode 100644 changelog.d/7507.misc (limited to 'synapse/handlers/room_member_worker.py') diff --git a/changelog.d/7507.misc b/changelog.d/7507.misc new file mode 100644 index 0000000000..afc7a730b3 --- /dev/null +++ b/changelog.d/7507.misc @@ -0,0 +1 @@ +Convert the room member handler to async/await. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ccc9659454..4ddeba4c97 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -20,8 +20,6 @@ import logging from six.moves import http_client -from twisted.internet import defer - from synapse import types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError @@ -76,7 +74,7 @@ class RoomMemberHandler(object): self.base_handler = BaseHandler(hs) @abc.abstractmethod - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Try and join a room that this server is not in Args: @@ -94,7 +92,7 @@ class RoomMemberHandler(object): raise NotImplementedError() @abc.abstractmethod - def _remote_reject_invite( + async def _remote_reject_invite( self, requester, remote_room_hosts, room_id, target, content ): """Attempt to reject an invite for a room this server is not in. If we @@ -115,7 +113,7 @@ class RoomMemberHandler(object): raise NotImplementedError() @abc.abstractmethod - def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target, room_id): """Notifies distributor on master process that the user has joined the room. @@ -124,12 +122,12 @@ class RoomMemberHandler(object): room_id (str) Returns: - Deferred|None + None """ raise NotImplementedError() @abc.abstractmethod - def _user_left_room(self, target, room_id): + async def _user_left_room(self, target, room_id): """Notifies distributor on master process that the user has left the room. @@ -138,7 +136,7 @@ class RoomMemberHandler(object): room_id (str) Returns: - Deferred|None + None """ raise NotImplementedError() @@ -214,8 +212,9 @@ class RoomMemberHandler(object): return event - @defer.inlineCallbacks - def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id): + async def copy_room_tags_and_direct_to_room( + self, old_room_id, new_room_id, user_id + ): """Copies the tags and direct room state from one room to another. Args: @@ -227,7 +226,7 @@ class RoomMemberHandler(object): Deferred[None] """ # Retrieve user account data for predecessor room - user_account_data, _ = yield self.store.get_account_data_for_user(user_id) + user_account_data, _ = await self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable direct_rooms = user_account_data.get("m.direct", {}) @@ -240,17 +239,17 @@ class RoomMemberHandler(object): direct_rooms[key].append(new_room_id) # Save back to user's m.direct account data - yield self.store.add_account_data_for_user( + await self.store.add_account_data_for_user( user_id, "m.direct", direct_rooms ) break # Copy room tags if applicable - room_tags = yield self.store.get_tags_for_room(user_id, old_room_id) + room_tags = await self.store.get_tags_for_room(user_id, old_room_id) # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) + await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) async def update_membership( self, @@ -487,8 +486,7 @@ class RoomMemberHandler(object): ) return res - @defer.inlineCallbacks - def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): + async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): """Upon our server becoming aware of an upgraded room, either by upgrading a room ourselves or joining one, we can transfer over information from the previous room. @@ -506,30 +504,29 @@ class RoomMemberHandler(object): logger.info("Transferring room state from %s to %s", old_room_id, room_id) # Find all local users that were in the old room and copy over each user's state - users = yield self.store.get_users_in_room(old_room_id) - yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) + users = await self.store.get_users_in_room(old_room_id) + await self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) # Add new room to the room directory if the old room was there # Remove old room from the room directory - old_room = yield self.store.get_room(old_room_id) + old_room = await self.store.get_room(old_room_id) if old_room and old_room["is_public"]: - yield self.store.set_room_is_public(old_room_id, False) - yield self.store.set_room_is_public(room_id, True) + await self.store.set_room_is_public(old_room_id, False) + await self.store.set_room_is_public(room_id, True) # Transfer alias mappings in the room directory - yield self.store.update_aliases_for_room(old_room_id, room_id) + await self.store.update_aliases_for_room(old_room_id, room_id) # Check if any groups we own contain the predecessor room - local_group_ids = yield self.store.get_local_groups_for_room(old_room_id) + local_group_ids = await self.store.get_local_groups_for_room(old_room_id) for group_id in local_group_ids: # Add new the new room to those groups - yield self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) + await self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) # Remove the old room from those groups - yield self.store.remove_room_from_group(group_id, old_room_id) + await self.store.remove_room_from_group(group_id, old_room_id) - @defer.inlineCallbacks - def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): + async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): """Copy user-specific information when they join a new room when that new room is the result of a room upgrade @@ -552,11 +549,11 @@ class RoomMemberHandler(object): for user_id in user_ids: try: # It is an upgraded room. Copy over old tags - yield self.copy_room_tags_and_direct_to_room( + await self.copy_room_tags_and_direct_to_room( old_room_id, new_room_id, user_id ) # Copy over push rules - yield self.store.copy_push_rules_from_room_to_room_for_user( + await self.store.copy_push_rules_from_room_to_room_for_user( old_room_id, new_room_id, user_id ) except Exception: @@ -639,8 +636,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - @defer.inlineCallbacks - def _can_guest_join(self, current_state_ids): + async def _can_guest_join(self, current_state_ids): """ Returns whether a guest can join a room based on its current state. """ @@ -648,7 +644,7 @@ class RoomMemberHandler(object): if not guest_access_id: return False - guest_access = yield self.store.get_event(guest_access_id) + guest_access = await self.store.get_event(guest_access_id) return ( guest_access @@ -657,8 +653,7 @@ class RoomMemberHandler(object): and guest_access.content["guest_access"] == "can_join" ) - @defer.inlineCallbacks - def lookup_room_alias(self, room_alias): + async def lookup_room_alias(self, room_alias): """ Get the room ID associated with a room alias. @@ -672,7 +667,7 @@ class RoomMemberHandler(object): SynapseError if room alias could not be found. """ directory_handler = self.directory_handler - mapping = yield directory_handler.get_association(room_alias) + mapping = await directory_handler.get_association(room_alias) if not mapping: raise SynapseError(404, "No such room alias") @@ -687,9 +682,8 @@ class RoomMemberHandler(object): return RoomID.from_string(room_id), servers - @defer.inlineCallbacks - def _get_inviter(self, user_id, room_id): - invite = yield self.store.get_invite_for_local_user_in_room( + async def _get_inviter(self, user_id, room_id): + invite = await self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) if invite: @@ -836,8 +830,7 @@ class RoomMemberHandler(object): txn_id=txn_id, ) - @defer.inlineCallbacks - def _is_host_in_room(self, current_state_ids): + async def _is_host_in_room(self, current_state_ids): # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) @@ -850,7 +843,7 @@ class RoomMemberHandler(object): continue event_id = current_state_ids[(etype, state_key)] - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event: continue @@ -859,11 +852,10 @@ class RoomMemberHandler(object): return False - @defer.inlineCallbacks - def _is_server_notice_room(self, room_id): + async def _is_server_notice_room(self, room_id): if self._server_notices_mxid is None: return False - user_ids = yield self.store.get_users_in_room(room_id) + user_ids = await self.store.get_users_in_room(room_id) return self._server_notices_mxid in user_ids @@ -895,8 +887,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity return None - @defer.inlineCallbacks - def _is_local_room_too_complex(self, room_id): + async def _is_local_room_too_complex(self, room_id): """ Check if the complexity of a local room is too great. @@ -906,7 +897,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Returns: bool """ max_complexity = self.hs.config.limit_remote_rooms.complexity - complexity = yield self.store.get_room_complexity(room_id) + complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity @@ -969,18 +960,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) - @defer.inlineCallbacks - def _remote_reject_invite( + async def _remote_reject_invite( self, requester, remote_room_hosts, room_id, target, content ): """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: - ret = yield defer.ensureDeferred( - fed_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, target.to_string(), content=content, - ) + ret = await fed_handler.do_remotely_reject_invite( + remote_room_hosts, room_id, target.to_string(), content=content, ) return ret except Exception as e: @@ -992,24 +980,23 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite(target.to_string(), room_id) + await self.store.locally_reject_invite(target.to_string(), room_id) return {} - def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room """ - return defer.succeed(user_joined_room(self.distributor, target, room_id)) + return user_joined_room(self.distributor, target, room_id) - def _user_left_room(self, target, room_id): + async def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ - return defer.succeed(user_left_room(self.distributor, target, room_id)) + return user_left_room(self.distributor, target, room_id) - @defer.inlineCallbacks - def forget(self, user, room_id): + async def forget(self, user, room_id): user_id = user.to_string() - member = yield self.state_handler.get_current_state( + member = await self.state_handler.get_current_state( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None @@ -1021,4 +1008,4 @@ class RoomMemberMasterHandler(RoomMemberHandler): raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) if membership: - yield self.store.forget(user_id, room_id) + await self.store.forget(user_id, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 69be86893b..0fc54349ab 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler from synapse.replication.http.membership import ( @@ -36,14 +34,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler): self._remote_reject_client = ReplRejectInvite.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs) - @defer.inlineCallbacks - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - ret = yield self._remote_join_client( + ret = await self._remote_join_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, @@ -51,16 +48,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler): content=content, ) - yield self._user_joined_room(user, room_id) + await self._user_joined_room(user, room_id) return ret - def _remote_reject_invite( + async def _remote_reject_invite( self, requester, remote_room_hosts, room_id, target, content ): """Implements RoomMemberHandler._remote_reject_invite """ - return self._remote_reject_client( + return await self._remote_reject_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, @@ -68,16 +65,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler): content=content, ) - def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room """ - return self._notify_change_client( + return await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="joined" ) - def _user_left_room(self, target, room_id): + async def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ - return self._notify_change_client( + return await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) -- cgit 1.4.1 From c29915bd05513a329e099d7e2970768113595830 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 May 2020 15:05:25 -0400 Subject: Add type hints to room member handlers (#7513) --- changelog.d/7513.misc | 1 + synapse/handlers/room_member.py | 284 ++++++++++++++++++--------------- synapse/handlers/room_member_worker.py | 28 +++- tox.ini | 2 + 4 files changed, 176 insertions(+), 139 deletions(-) create mode 100644 changelog.d/7513.misc (limited to 'synapse/handlers/room_member_worker.py') diff --git a/changelog.d/7513.misc b/changelog.d/7513.misc new file mode 100644 index 0000000000..2ea7373e29 --- /dev/null +++ b/changelog.d/7513.misc @@ -0,0 +1 @@ +Add type hints to room member handler. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4ddeba4c97..e51e1c32fe 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,13 +17,16 @@ import abc import logging +from typing import Dict, Iterable, List, Optional, Tuple, Union from six.moves import http_client from synapse import types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.types import Collection, RoomID, UserID +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -74,84 +77,84 @@ class RoomMemberHandler(object): self.base_handler = BaseHandler(hs) @abc.abstractmethod - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Optional[dict]: """Try and join a room that this server is not in Args: - requester (Requester) - remote_room_hosts (list[str]): List of servers that can be used - to join via. - room_id (str): Room that we are trying to join - user (UserID): User who is trying to join - content (dict): A dict that should be used as the content of the - join event. - - Returns: - Deferred + requester + remote_room_hosts: List of servers that can be used to join via. + room_id: Room that we are trying to join + user: User who is trying to join + content: A dict that should be used as the content of the join event. """ raise NotImplementedError() @abc.abstractmethod async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. Args: - requester (Requester) - remote_room_hosts (list[str]): List of servers to use to try and - reject invite - room_id (str) - target (UserID): The user rejecting the invite - content (dict): The content for the rejection event + requester + remote_room_hosts: List of servers to use to try and reject invite + room_id + target: The user rejecting the invite + content: The content for the rejection event Returns: - Deferred[dict]: A dictionary to be returned to the client, may + A dictionary to be returned to the client, may include event_id etc, or nothing if we locally rejected """ raise NotImplementedError() @abc.abstractmethod - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has joined the room. Args: - target (UserID) - room_id (str) - - Returns: - None + target + room_id """ raise NotImplementedError() @abc.abstractmethod - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the room. Args: - target (UserID) - room_id (str) - - Returns: - None + target + room_id """ raise NotImplementedError() async def _local_membership_update( self, - requester, - target, - room_id, - membership, + requester: Requester, + target: UserID, + room_id: str, + membership: str, prev_event_ids: Collection[str], - txn_id=None, - ratelimit=True, - content=None, - require_consent=True, - ): + txn_id: Optional[str] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> EventBase: user_id = target.to_string() if content is None: @@ -214,16 +217,13 @@ class RoomMemberHandler(object): async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id - ): + ) -> None: """Copies the tags and direct room state from one room to another. Args: - old_room_id (str) - new_room_id (str) - user_id (str) - - Returns: - Deferred[None] + old_room_id: The room ID of the old room. + new_room_id: The room ID of the new room. + user_id: The user's ID. """ # Retrieve user account data for predecessor room user_account_data, _ = await self.store.get_account_data_for_user(user_id) @@ -253,17 +253,17 @@ class RoomMemberHandler(object): async def update_membership( self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + action: str, + txn_id: Optional[str] = None, + remote_room_hosts: Optional[List[str]] = None, + third_party_signed: Optional[dict] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Union[EventBase, Optional[dict]]: key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -284,17 +284,17 @@ class RoomMemberHandler(object): async def _update_membership( self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + action: str, + txn_id: Optional[str] = None, + remote_room_hosts: Optional[List[str]] = None, + third_party_signed: Optional[dict] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Union[EventBase, Optional[dict]]: content_specified = bool(content) if content is None: content = {} @@ -468,12 +468,11 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - res = await self._remote_reject_invite( + return await self._remote_reject_invite( requester, remote_room_hosts, room_id, target, content, ) - return res - res = await self._local_membership_update( + return await self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -484,9 +483,10 @@ class RoomMemberHandler(object): content=content, require_consent=require_consent, ) - return res - async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): + async def transfer_room_state_on_room_upgrade( + self, old_room_id: str, room_id: str + ) -> None: """Upon our server becoming aware of an upgraded room, either by upgrading a room ourselves or joining one, we can transfer over information from the previous room. @@ -494,12 +494,8 @@ class RoomMemberHandler(object): well as migrating the room directory state. Args: - old_room_id (str): The ID of the old room - - room_id (str): The ID of the new room - - Returns: - Deferred + old_room_id: The ID of the old room + room_id: The ID of the new room """ logger.info("Transferring room state from %s to %s", old_room_id, room_id) @@ -526,17 +522,16 @@ class RoomMemberHandler(object): # Remove the old room from those groups await self.store.remove_room_from_group(group_id, old_room_id) - async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): + async def copy_user_state_on_room_upgrade( + self, old_room_id: str, new_room_id: str, user_ids: Iterable[str] + ) -> None: """Copy user-specific information when they join a new room when that new room is the result of a room upgrade Args: - old_room_id (str): The ID of upgraded room - new_room_id (str): The ID of the new room - user_ids (Iterable[str]): User IDs to copy state for - - Returns: - Deferred + old_room_id: The ID of upgraded room + new_room_id: The ID of the new room + user_ids: User IDs to copy state for """ logger.debug( @@ -566,17 +561,23 @@ class RoomMemberHandler(object): ) continue - async def send_membership_event(self, requester, event, context, ratelimit=True): + async def send_membership_event( + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + ): """ Change the membership status of a user in a room. Args: - requester (Requester): The local user who requested the membership + requester: The local user who requested the membership event. If None, certain checks, like whether this homeserver can act as the sender, will be skipped. - event (SynapseEvent): The membership event. + event: The membership event. context: The context of the event. - ratelimit (bool): Whether to rate limit this request. + ratelimit: Whether to rate limit this request. Raises: SynapseError if there was a problem changing the membership. """ @@ -636,7 +637,9 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join(self, current_state_ids): + async def _can_guest_join( + self, current_state_ids: Dict[Tuple[str, str], str] + ) -> bool: """ Returns whether a guest can join a room based on its current state. """ @@ -653,12 +656,14 @@ class RoomMemberHandler(object): and guest_access.content["guest_access"] == "can_join" ) - async def lookup_room_alias(self, room_alias): + async def lookup_room_alias( + self, room_alias: RoomAlias + ) -> Tuple[RoomID, List[str]]: """ Get the room ID associated with a room alias. Args: - room_alias (RoomAlias): The alias to look up. + room_alias: The alias to look up. Returns: A tuple of: The room ID as a RoomID object. @@ -682,24 +687,25 @@ class RoomMemberHandler(object): return RoomID.from_string(room_id), servers - async def _get_inviter(self, user_id, room_id): + async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]: invite = await self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) if invite: return UserID.from_string(invite.sender) + return None async def do_3pid_invite( self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id, - id_access_token=None, - ): + room_id: str, + inviter: UserID, + medium: str, + address: str, + id_server: str, + requester: Requester, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> None: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -748,15 +754,15 @@ class RoomMemberHandler(object): async def _make_and_store_3pid_invite( self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id, - id_access_token=None, - ): + requester: Requester, + id_server: str, + medium: str, + address: str, + room_id: str, + user: UserID, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> None: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -830,7 +836,9 @@ class RoomMemberHandler(object): txn_id=txn_id, ) - async def _is_host_in_room(self, current_state_ids): + async def _is_host_in_room( + self, current_state_ids: Dict[Tuple[str, str], str] + ) -> bool: # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) @@ -852,7 +860,7 @@ class RoomMemberHandler(object): return False - async def _is_server_notice_room(self, room_id): + async def _is_server_notice_room(self, room_id: str) -> bool: if self._server_notices_mxid is None: return False user_ids = await self.store.get_users_in_room(room_id) @@ -867,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") - async def _is_remote_room_too_complex(self, room_id, remote_room_hosts): + async def _is_remote_room_too_complex( + self, room_id: str, remote_room_hosts: List[str] + ) -> Optional[bool]: """ Check if complexity of a remote room is too great. Args: - room_id (str) - remote_room_hosts (list[str]) + room_id + remote_room_hosts Returns: bool of whether the complexity is too great, or None if unable to be fetched @@ -887,21 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity return None - async def _is_local_room_too_complex(self, room_id): + async def _is_local_room_too_complex(self, room_id: str) -> bool: """ Check if the complexity of a local room is too great. Args: - room_id (str) - - Returns: bool + room_id: The room ID to check for complexity. """ max_complexity = self.hs.config.limit_remote_rooms.complexity complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> None: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -961,8 +976,13 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler @@ -983,17 +1003,17 @@ class RoomMemberMasterHandler(RoomMemberHandler): await self.store.locally_reject_invite(target.to_string(), room_id) return {} - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room """ - return user_joined_room(self.distributor, target, room_id) + user_joined_room(self.distributor, target, room_id) - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ - return user_left_room(self.distributor, target, room_id) + user_left_room(self.distributor, target, room_id) - async def forget(self, user, room_id): + async def forget(self, user: UserID, room_id: str) -> None: user_id = user.to_string() member = await self.state_handler.get_current_state( diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 0fc54349ab..5c776cc0be 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import List, Optional from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -22,6 +23,7 @@ from synapse.replication.http.membership import ( ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ) +from synapse.types import Requester, UserID logger = logging.getLogger(__name__) @@ -34,7 +36,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler): self._remote_reject_client = ReplRejectInvite.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs) - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Optional[dict]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -53,8 +62,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Implements RoomMemberHandler._remote_reject_invite """ return await self._remote_reject_client( @@ -65,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler): content=content, ) - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room """ - return await self._notify_change_client( + await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="joined" ) - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ - return await self._notify_change_client( + await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) diff --git a/tox.ini b/tox.ini index a69bc04334..5a1fa610b6 100644 --- a/tox.ini +++ b/tox.ini @@ -188,6 +188,8 @@ commands = mypy \ synapse/handlers/directory.py \ synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ + synapse/handlers/room_member.py \ + synapse/handlers/room_member_worker.py \ synapse/handlers/saml_handler.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ -- cgit 1.4.1 From 1531b214fc57714c14046a8f66c7b5fe5ec5dcdd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 14:21:54 +0100 Subject: Add ability to wait for replication streams (#7542) The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room). Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on. People probably want to look at this commit by commit. --- changelog.d/7542.misc | 1 + synapse/handlers/federation.py | 33 ++++++--- synapse/handlers/message.py | 36 ++++++--- synapse/handlers/room.py | 65 +++++++++++----- synapse/handlers/room_member.py | 65 ++++++++++------ synapse/handlers/room_member_worker.py | 11 +-- synapse/replication/http/federation.py | 13 +++- synapse/replication/http/membership.py | 14 ++-- synapse/replication/http/send_event.py | 4 +- synapse/replication/http/streams.py | 5 +- synapse/replication/tcp/client.py | 90 ++++++++++++++++++++++- synapse/rest/admin/rooms.py | 10 ++- synapse/rest/client/v1/room.py | 20 +++-- synapse/rest/client/v2_alpha/relations.py | 2 +- synapse/server.py | 5 ++ synapse/server.pyi | 5 ++ synapse/server_notices/server_notices_manager.py | 6 +- synapse/storage/data_stores/main/events_worker.py | 6 +- synapse/storage/data_stores/main/roommember.py | 2 + tests/federation/test_complexity.py | 8 +- tests/handlers/test_typing.py | 5 +- tests/storage/test_cleanup_extrems.py | 4 +- tests/storage/test_event_metrics.py | 2 +- tests/test_federation.py | 4 +- 24 files changed, 304 insertions(+), 112 deletions(-) create mode 100644 changelog.d/7542.misc (limited to 'synapse/handlers/room_member_worker.py') diff --git a/changelog.d/7542.misc b/changelog.d/7542.misc new file mode 100644 index 0000000000..7dd9b4823b --- /dev/null +++ b/changelog.d/7542.misc @@ -0,0 +1 @@ +Add ability to wait for replication streams. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bb03cc9add..e354c803db 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,6 +126,7 @@ class FederationHandler(BaseHandler): self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._replication = hs.get_replication_data_handler() self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( hs @@ -1221,7 +1222,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict - ) -> None: + ) -> Tuple[str, int]: """ Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. @@ -1304,15 +1305,23 @@ class FederationHandler(BaseHandler): room_id=room_id, room_version=room_version_obj, ) - await self._persist_auth_tree( + max_stream_id = await self._persist_auth_tree( origin, auth_chain, state, event, room_version_obj ) + # We wait here until this instance has seen the events come down + # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + "master", "events", max_stream_id + ) + # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = await self.store.get_room_predecessor(room_id) if not predecessor or not isinstance(predecessor.get("room_id"), str): - return + return event.event_id, max_stream_id old_room_id = predecessor["room_id"] logger.debug( "Found predecessor for %s during remote join: %s", room_id, old_room_id @@ -1325,6 +1334,7 @@ class FederationHandler(BaseHandler): ) logger.debug("Finished joining %s to %s", joinee, room_id) + return event.event_id, max_stream_id finally: room_queue = self.room_queues[room_id] del self.room_queues[room_id] @@ -1554,7 +1564,7 @@ class FederationHandler(BaseHandler): async def do_remotely_reject_invite( self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict - ) -> EventBase: + ) -> Tuple[EventBase, int]: origin, event, room_version = await self._make_and_verify_event( target_hosts, room_id, user_id, "leave", content=content ) @@ -1574,9 +1584,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(target_hosts, event) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify([(event, context)]) - return event + return event, stream_id async def _make_and_verify_event( self, @@ -1888,7 +1898,7 @@ class FederationHandler(BaseHandler): state: List[EventBase], event: EventBase, room_version: RoomVersion, - ) -> None: + ) -> int: """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event separately. Notifies about the persisted events @@ -1982,7 +1992,7 @@ class FederationHandler(BaseHandler): event, old_state=state ) - await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify([(event, new_event_context)]) async def _prep_event( self, @@ -2835,7 +2845,7 @@ class FederationHandler(BaseHandler): self, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> None: + ) -> int: """Persists events and tells the notifier/pushers about them, if necessary. @@ -2845,11 +2855,12 @@ class FederationHandler(BaseHandler): backfilling or not """ if self.config.worker_app: - await self._send_events_to_master( + result = await self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, ) + return result["max_stream_id"] else: max_stream_id = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled @@ -2864,6 +2875,8 @@ class FederationHandler(BaseHandler): for event, _ in event_and_contexts: await self._notify_persisted_event(event, max_stream_id) + return max_stream_id + async def _notify_persisted_event( self, event: EventBase, max_stream_id: int ) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f362896a2..f445e2aa2a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import Optional, Tuple from six import iteritems, itervalues, string_types @@ -42,6 +42,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -630,7 +631,9 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - async def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event( + self, requester, event, context, ratelimit=True + ) -> int: """ Persists and notifies local clients and federation of an event. @@ -639,6 +642,9 @@ class EventCreationHandler(object): context (Context) the context of the event. ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. + + Return: + The stream_id of the persisted event. """ if event.type == EventTypes.Member: raise SynapseError( @@ -659,7 +665,7 @@ class EventCreationHandler(object): ) return prev_state - await self.handle_new_client_event( + return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -688,7 +694,7 @@ class EventCreationHandler(object): async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None - ): + ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -711,10 +717,10 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - await self.send_nonmember_event( + stream_id = await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) - return event + return event, stream_id @measure_func("create_new_client_event") @defer.inlineCallbacks @@ -774,7 +780,7 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -787,6 +793,9 @@ class EventCreationHandler(object): context (EventContext) ratelimit (bool) extra_users (list(UserID)): Any extra users to notify about event + + Return: + The stream_id of the persisted event. """ if event.is_state() and (event.type, event.state_key) == ( @@ -827,7 +836,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - await self.send_event_to_master( + result = await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -836,14 +845,17 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + stream_id = result["stream_id"] + event.internal_metadata.stream_ordering = stream_id success = True - return + return stream_id - await self.persist_and_notify_client_event( + stream_id = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True + return stream_id finally: if not success: # Ensure that we actually remove the entries in the push actions @@ -886,7 +898,7 @@ class EventCreationHandler(object): async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1076,6 +1088,8 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) + return event_stream_id + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 13850ba672..2698a129ca 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,6 +22,7 @@ import logging import math import string from collections import OrderedDict +from typing import Tuple from six import iteritems, string_types @@ -518,7 +519,7 @@ class RoomCreationHandler(BaseHandler): async def create_room( self, requester, config, ratelimit=True, creator_join_profile=None - ): + ) -> Tuple[dict, int]: """ Creates a new room. Args: @@ -535,9 +536,9 @@ class RoomCreationHandler(BaseHandler): `avatar_url` and/or `displayname`. Returns: - Deferred[dict]: - a dict containing the keys `room_id` and, if an alias was - requested, `room_alias`. + First, a dict containing the keys `room_id` and, if an alias + was, requested, `room_alias`. Secondly, the stream_id of the + last persisted event. Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. @@ -669,7 +670,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - await self._send_events_for_new_room( + last_stream_id = await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -683,7 +684,10 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -697,7 +701,10 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -715,7 +722,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - await self.room_member_handler.update_membership( + _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -729,7 +736,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - await self.hs.get_room_member_handler().do_3pid_invite( + last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -745,7 +752,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - return result + return result, last_stream_id async def _send_events_for_new_room( self, @@ -758,7 +765,13 @@ class RoomCreationHandler(BaseHandler): room_alias=None, power_level_content_override=None, # Doesn't apply when initial state has power level state event content creator_join_profile=None, - ): + ) -> int: + """Sends the initial events into a new room. + + Returns: + The stream_id of the last event persisted. + """ + def create(etype, content, **kwargs): e = {"type": etype, "content": content} @@ -767,12 +780,16 @@ class RoomCreationHandler(BaseHandler): return e - async def send(etype, content, **kwargs): + async def send(etype, content, **kwargs) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) + return last_stream_id config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -797,7 +814,9 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - await send(etype=EventTypes.PowerLevels, content=pl_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=pl_content + ) else: power_level_content = { "users": {creator_id: 100}, @@ -830,33 +849,39 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - await send(etype=EventTypes.PowerLevels, content=power_level_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=power_level_content + ) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - await send(etype=etype, state_key=state_key, content=content) + last_sent_stream_id = await send( + etype=etype, state_key=state_key, content=content + ) + + return last_sent_stream_id async def _generate_room_id( self, creator_id: str, is_public: str, room_version: RoomVersion, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e51e1c32fe..691b6705b2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,7 +17,7 @@ import abc import logging -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple from six.moves import http_client @@ -84,7 +84,7 @@ class RoomMemberHandler(object): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Try and join a room that this server is not in Args: @@ -104,7 +104,7 @@ class RoomMemberHandler(object): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. @@ -154,7 +154,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> EventBase: + ) -> Tuple[str, int]: user_id = target.to_string() if content is None: @@ -187,9 +187,10 @@ class RoomMemberHandler(object): ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - return duplicate + _, stream_id = await self.store.get_event_ordering(duplicate.event_id) + return duplicate.event_id, stream_id - await self.event_creation_handler.handle_new_client_event( + stream_id = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) @@ -213,7 +214,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) - return event + return event.event_id, stream_id async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id @@ -263,7 +264,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -294,7 +295,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: content_specified = bool(content) if content is None: content = {} @@ -398,7 +399,13 @@ class RoomMemberHandler(object): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - return old_state + _, stream_id = await self.store.get_event_ordering( + old_state.event_id + ) + return ( + old_state.event_id, + stream_id, + ) if old_membership in ["ban", "leave"] and action == "kick": raise AuthError(403, "The target user is not in the room") @@ -705,7 +712,7 @@ class RoomMemberHandler(object): requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -737,11 +744,11 @@ class RoomMemberHandler(object): ) if invitee: - await self.update_membership( + _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - await self._make_and_store_3pid_invite( + stream_id = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -752,6 +759,8 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) + return stream_id + async def _make_and_store_3pid_invite( self, requester: Requester, @@ -762,7 +771,7 @@ class RoomMemberHandler(object): user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -817,7 +826,10 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -835,6 +847,7 @@ class RoomMemberHandler(object): ratelimit=False, txn_id=txn_id, ) + return stream_id async def _is_host_in_room( self, current_state_ids: Dict[Tuple[str, str], str] @@ -916,7 +929,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> None: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -945,7 +958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content ) await self._user_joined_room(user, room_id) @@ -955,14 +968,14 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: if too_complex is False: # We checked, and we're under the limit. - return + return event_id, stream_id # Check again, but with the local state events too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. - return + return event_id, stream_id # The room is too large. Leave. requester = types.create_requester(user, None, False, None) @@ -975,6 +988,8 @@ class RoomMemberMasterHandler(RoomMemberHandler): errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) + return event_id, stream_id + async def _remote_reject_invite( self, requester: Requester, @@ -982,15 +997,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: - ret = await fed_handler.do_remotely_reject_invite( + event, stream_id = await fed_handler.do_remotely_reject_invite( remote_room_hosts, room_id, target.to_string(), content=content, ) - return ret + return event.event_id, stream_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -1000,8 +1015,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(target.to_string(), room_id) - return {} + stream_id = await self.store.locally_reject_invite( + target.to_string(), room_id + ) + return None, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 5c776cc0be..02e0c4103d 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Optional +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): await self._user_joined_room(user, room_id) - return ret + return ret["event_id"], ret["stream_id"] async def _remote_reject_invite( self, @@ -68,16 +68,17 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ - return await self._remote_reject_client( + ret = await self._remote_reject_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), content=content, ) + return ret["event_id"], ret["stream_id"] async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7e23b565b9..c287c4e269 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """Handles events newly received from federation, including persisting and - notifying. + notifying. Returns the maximum stream ID of the persisted events. The API looks like: @@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): "context": { .. serialized event context .. }, }], "backfilled": false + } + + 200 OK + + { + "max_stream_id": 32443, + } """ NAME = "fed_send_events" @@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) - return 200, {} + return 200, {"max_stream_id": max_stream_id} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 3577611fd7..050fd34562 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) - return 200, {} + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = await self.federation_handler.do_remotely_reject_invite( + event, stream_id = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id, event_content, ) - ret = event.get_pdu_json() + event_id = event.event_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(user_id, room_id) - ret = {} + stream_id = await self.store.locally_reject_invite(user_id, room_id) + event_id = None - return 200, ret + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index b74b088ff4..c981723c1a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - await self.event_creation_handler.persist_and_notify_client_event( + stream_id = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {} + return 200, {"stream_id": stream_id} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index b705a8e16c..bde97eef32 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): super().__init__(hs) self._instance_name = hs.get_instance_name() - - # We pull the streams from the replication handler (if we try and make - # them ourselves we end up in an import loop). - self.streams = hs.get_tcp_replication().get_streams() + self.streams = hs.get_replication_streams() @staticmethod def _serialize_payload(stream_name, from_token, upto_token): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 28826302f5..508ad1b720 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,19 +14,23 @@ # limitations under the License. """A replication client for use by synapse workers. """ - +import heapq import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple +from twisted.internet.defer import Deferred from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, EventsStreamRow, ) +from synapse.util.async_helpers import timeout_deferred +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,6 +39,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# How long we allow callers to wait for replication updates before timing out. +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 + + class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. @@ -92,6 +100,16 @@ class ReplicationDataHandler: self.store = hs.get_datastore() self.pusher_pool = hs.get_pusherpool() self.notifier = hs.get_notifier() + self._reactor = hs.get_reactor() + self._clock = hs.get_clock() + self._streams = hs.get_replication_streams() + self._instance_name = hs.get_instance_name() + + # Map from stream to list of deferreds waiting for the stream to + # arrive at a particular position. The lists are sorted by stream position. + self._streams_to_waiters = ( + {} + ) # type: Dict[str, List[Tuple[int, Deferred[None]]]] async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -131,8 +149,76 @@ class ReplicationDataHandler: await self.pusher_pool.on_new_notifications(token, token) + # Notify any waiting deferreds. The list is ordered by position so we + # just iterate through the list until we reach a position that is + # greater than the received row position. + waiting_list = self._streams_to_waiters.get(stream_name, []) + + # Index of first item with a position after the current token, i.e we + # have called all deferreds before this index. If not overwritten by + # loop below means either a) no items in list so no-op or b) all items + # in list were called and so the list should be cleared. Setting it to + # `len(list)` works for both cases. + index_of_first_deferred_not_called = len(waiting_list) + + for idx, (position, deferred) in enumerate(waiting_list): + if position <= token: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + else: + # The list is sorted by position so we don't need to continue + # checking any futher entries in the list. + index_of_first_deferred_not_called = idx + break + + # Drop all entries in the waiting list that were called in the above + # loop. (This maintains the order so no need to resort) + waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" + + async def wait_for_stream_position( + self, instance_name: str, stream_name: str, position: int + ): + """Wait until this instance has received updates up to and including + the given stream position. + """ + + if instance_name == self._instance_name: + # We don't get told about updates written by this process, and + # anyway in that case we don't need to wait. + return + + current_position = self._streams[stream_name].current_token(self._instance_name) + if position <= current_position: + # We're already past the position + return + + # Create a new deferred that times out after N seconds, as we don't want + # to wedge here forever. + deferred = Deferred() + deferred = timeout_deferred( + deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + ) + + waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + + # We insert into the list using heapq as it is more efficient than + # pushing then resorting each time. + heapq.heappush(waiting_list, (position, deferred)) + + # We measure here to get in flight counts and average waiting time. + with Measure(self._clock, "repl.wait_for_stream_position"): + logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + await make_deferred_yieldable(deferred) + logger.info( + "Finished waiting for repl stream %r to reach %s", stream_name, position + ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 7d40001988..0a13e1ed34 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -59,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() + self._replication = hs.get_replication_data_handler() async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -73,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet): message = content.get("message", self.DEFAULT_MESSAGE) room_name = content.get("room_name", "Content Violation Notification") - info = await self._room_creation_handler.create_room( + info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": "public_chat", @@ -94,6 +95,13 @@ class ShutdownRoomRestServlet(RestServlet): # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position("master", "events", stream_id) + users = await self.state.get_current_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6b5830cc3f..105e0cf4d2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request): requester = await self.auth.get_user_by_req(request) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = await self.room_member_handler.update_membership( + event_id, _ = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) + event_id = event.event_id ret = {} # type: dict - if event: - set_tag("event_id", event.event_id) - ret = {"event_id": event.event_id} + if event_id: + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 63f07b63da..89002ffbff 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) diff --git a/synapse/server.py b/synapse/server.py index c530f1aa1a..ca2deb49bb 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -210,6 +211,7 @@ class HomeServer(object): "storage", "replication_streamer", "replication_data_handler", + "replication_streams", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -583,6 +585,9 @@ class HomeServer(object): def build_replication_data_handler(self): return ReplicationDataHandler(self) + def build_replication_streams(self): + return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9e7fad7e6e..fe8024d2d4 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +from typing import Dict + import twisted.internet import synapse.api.auth @@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage from synapse.events.builder import EventBuilderFactory +from synapse.replication.tcp.streams import Stream class HomeServer(object): @property @@ -136,3 +139,5 @@ class HomeServer(object): pass def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: pass + def get_replication_streams(self) -> Dict[str, Stream]: + pass diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 999c621b92..bf2454c01c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -83,10 +83,10 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = await self._event_creation_handler.create_and_send_nonmember_event( + event, _ = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) - return res + return event @cached() async def get_or_create_notice_room_for_user(self, user_id): @@ -143,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9130b74eb5..b880a71782 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -1289,12 +1289,12 @@ class EventsWorkerStore(SQLBaseStore): async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = await self._get_event_ordering(event_id1) - to_2, so_2 = await self._get_event_ordering(event_id2) + to_1, so_1 = await self.get_event_ordering(event_id1) + to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): + def get_event_ordering(self, event_id): res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 1e9c850152..7c5ca81ae0 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1069,6 +1069,8 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) + return stream_ordering + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 94980733c4..0c9987be54 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) d = handler._remote_join( None, @@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) # Artificially raise the complexity self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 51e2b37218..2fa8d4739b 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): reactor.pump((1000,)) hs = self.setup_test_homeserver( - notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring + notifier=Mock(), + http_client=mock_federation_client, + keyring=mock_keyring, + replication_streams={}, ) hs.datastores = datastores diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 0e04b2cf92..43425c969a 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] def run_background_update(self): @@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b7fd36d3..a7b85004e5 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info = self.get_success(room_creator.create_room(requester, {})) + info, _ = self.get_success(room_creator.create_room(requester, {})) room_id = info["room_id"] last_event = None diff --git a/tests/test_federation.py b/tests/test_federation.py index 13ff14863e..c5099dd039 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -28,13 +28,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = ensureDeferred( + room_deferred = ensureDeferred( room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False ) ) self.reactor.advance(0.1) - self.room_id = self.successResultOf(room)["room_id"] + self.room_id = self.successResultOf(room_deferred)[0]["room_id"] self.store = self.homeserver.get_datastore() -- cgit 1.4.1