diff options
author | Richard van der Hoff <richard@matrix.org> | 2020-05-06 11:43:11 +0100 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2020-05-06 11:43:11 +0100 |
commit | 30a19daa02cf61d2a811efb0fd619ff48ce6bcf6 (patch) | |
tree | c430aa536e76a2329831d04ae7d8d761ed2ec188 /synapse | |
parent | use an upsert to update device_lists_outbound_last_success (diff) | |
parent | Fix typing annotations in synapse/federation (#7382) (diff) | |
download | synapse-30a19daa02cf61d2a811efb0fd619ff48ce6bcf6.tar.xz |
Merge branch 'develop' into rav/upsert_for_device_list
Diffstat (limited to 'synapse')
46 files changed, 794 insertions, 651 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index c1ade1333b..c5d1eb952b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -537,8 +537,7 @@ class Auth(object): return defer.succeed(auth_ids) - @defer.inlineCallbacks - def check_can_change_room_list(self, room_id: str, user: UserID): + async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the published room list. @@ -547,17 +546,17 @@ class Auth(object): user """ - is_admin = yield self.is_server_admin(user) + is_admin = await self.is_server_admin(user) if is_admin: return True user_id = user.to_string() - yield self.check_user_in_room(room_id, user_id) + await self.check_user_in_room(room_id, user_id) # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the # m.room.canonical_alias events - power_level_event = yield self.state.get_current_state( + power_level_event = await self.state.get_current_state( room_id, EventTypes.PowerLevels, "" ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 0ace7b787d..667ad20428 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -413,12 +413,6 @@ class GenericWorkerTyping(object): # map room IDs to sets of users currently typing self._room_typing = {} - def stream_positions(self): - # We must update this typing token from the response of the previous - # sync. In particular, the stream id may "reset" back to zero/a low - # value which we *must* use for the next replication request. - return {"typing": self._latest_room_serial} - def process_replication_rows(self, token, rows): if self._latest_room_serial > token: # The master has gone backwards. To prevent inconsistent data, just @@ -652,20 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): else: self.send_handler = None - async def on_rdata(self, stream_name, token, rows): - await super(GenericWorkerReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - await self.process_and_notify(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + await self._process_and_notify(stream_name, instance_name, token, rows) - def get_streams_to_replicate(self): - args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate() - args.update(self.typing_handler.stream_positions()) - if self.send_handler: - args.update(self.send_handler.stream_positions()) - return args - - async def process_and_notify(self, stream_name, token, rows): + async def _process_and_notify(self, stream_name, instance_name, token, rows): try: if self.send_handler: await self.send_handler.process_replication_rows( @@ -799,9 +784,6 @@ class FederationSenderHandler(object): def wake_destination(self, server: str): self.federation_sender.wake_destination(server) - def stream_positions(self): - return {"federation": self.federation_position} - async def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a0071fec94..687cd841ac 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -883,18 +883,37 @@ class FederationClient(FederationBase): def get_public_rooms( self, - destination, - limit=None, - since_token=None, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, + remote_server: str, + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + include_all_networks: bool = False, + third_party_instance_id: Optional[str] = None, ): - if destination == self.server_name: - return + """Get the list of public rooms from a remote homeserver + + Args: + remote_server: The name of the remote server + limit: Maximum amount of rooms to return + since_token: Used for result pagination + search_filter: A filter dictionary to send the remote homeserver + and filter the result set + include_all_networks: Whether to include results from all third party instances + third_party_instance_id: Whether to only include results from a specific third + party instance + + Returns: + Deferred[Dict[str, Any]]: The response from the remote server, or None if + `remote_server` is the same as the local server_name + Raises: + HttpResponseException: There was an exception returned from the remote server + SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom + requests over federation + + """ return self.transport_layer.get_public_rooms( - destination, + remote_server, limit, since_token, search_filter, @@ -957,14 +976,13 @@ class FederationClient(FederationBase): return signed_events - @defer.inlineCallbacks - def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite(self, destinations, room_id, event_dict): for destination in destinations: if destination == self.server_name: continue try: - yield self.transport_layer.exchange_third_party_invite( + await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) return None diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index e1700ca8aa..52f4f54215 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,6 +31,7 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple +from typing import Dict, List, Tuple, Type from six import iteritems @@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.presence_map = {} # Pending presence map user_id -> UserPresenceState - self.presence_changed = SortedDict() # Stream position -> list[user_id] + # Pending presence map user_id -> UserPresenceState + self.presence_map = {} # type: Dict[str, UserPresenceState] + + # Stream position -> list[user_id] + self.presence_changed = SortedDict() # type: SortedDict[int, List[str]] # Stores the destinations we need to explicitly send presence to about a # given user. # Stream position -> (user_id, destinations) - self.presence_destinations = SortedDict() + self.presence_destinations = ( + SortedDict() + ) # type: SortedDict[int, Tuple[str, List[str]]] + + # (destination, key) -> EDU + self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] - self.keyed_edu = {} # (destination, key) -> EDU - self.keyed_edu_changed = SortedDict() # stream position -> (destination, key) + # stream position -> (destination, key) + self.keyed_edu_changed = ( + SortedDict() + ) # type: SortedDict[int, Tuple[str, tuple]] - self.edus = SortedDict() # stream position -> Edu + self.edus = SortedDict() # type: SortedDict[int, Edu] + # stream ID for the next entry into presence_changed/keyed_edu_changed/edus. self.pos = 1 - self.pos_time = SortedDict() + + # map from stream ID to the time that stream entry was generated, so that we + # can clear out entries after a while + self.pos_time = SortedDict() # type: SortedDict[int, int] # EVERYTHING IS SAD. In particular, python only makes new scopes when # we make a new function, so we need to make a new function so the inner @@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object): for edu_key in self.keyed_edu_changed.values(): live_keys.add(edu_key) - to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys] - for edu_key in to_del: + keys_to_del = [ + edu_key for edu_key in self.keyed_edu if edu_key not in live_keys + ] + for edu_key in keys_to_del: del self.keyed_edu[edu_key] # Delete things out of edu map @@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object): self._clear_queue_before_pos(token) async def get_replication_rows( - self, from_token, to_token, limit, federation_ack=None - ): + self, instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: """Get rows to be sent over federation between the two tokens Args: - from_token (int) - to_token(int) - limit (int) - federation_ack (int): Optional. The position where the worker is - explicitly acknowledged it has handled. Allows us to drop - data from before that point + instance_name: the name of the current process + from_token: the previous stream token: the starting point for fetching the + updates + to_token: the new stream token: the point to get updates up to + target_row_count: a target for the number of rows to be returned. + + Returns: a triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of `(token, row)` entries. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. """ - # TODO: Handle limit. + # TODO: Handle target_row_count. # To handle restarts where we wrap around if from_token > self.pos: @@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object): # list of tuple(int, BaseFederationRow), where the first is the position # of the federation stream. - rows = [] - - # There should be only one reader, so lets delete everything its - # acknowledged its seen. - if federation_ack: - self._clear_queue_before_pos(federation_ack) + rows = [] # type: List[Tuple[int, BaseFederationRow]] # Fetch changed presence i = self.presence_changed.bisect_right(from_token) @@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object): # Sort rows based on pos rows.sort() - return [(pos, row.TypeId, row.to_data()) for pos, row in rows] + return ( + [(pos, (row.TypeId, row.to_data())) for pos, row in rows], + to_token, + False, + ) class BaseFederationRow(object): @@ -341,7 +361,7 @@ class BaseFederationRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + TypeId = "" # Unique string that ids the type. Must be overriden in sub classes. @staticmethod def from_data(data): @@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu buff.edus.setdefault(self.edu.destination, []).append(self.edu) -TypeToRow = { - Row.TypeId: Row - for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,) -} +_rowtypes = ( + PresenceRow, + PresenceDestinationsRow, + KeyedEduRow, + EduRow, +) # type: Tuple[Type[BaseFederationRow], ...] + +TypeToRow = {Row.TypeId: Row for Row in _rowtypes} ParsedFederationStreamData = namedtuple( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index a477578e44..d473576902 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Hashable, Iterable, List, Optional, Set +from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple from six import itervalues @@ -498,14 +498,16 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() - def get_current_token(self) -> int: + @staticmethod + def get_current_token() -> int: # Dummy implementation for case where federation sender isn't offloaded # to a worker. return 0 + @staticmethod async def get_replication_rows( - self, from_token, to_token, limit, federation_ack=None - ): + instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: # Dummy implementation for case where federation sender isn't offloaded # to a worker. - return [] + return [], 0, False diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index e13cd20ffa..276a2b596f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,11 +15,10 @@ # limitations under the License. import datetime import logging -from typing import Dict, Hashable, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple from prometheus_client import Counter -import synapse.server from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter +if TYPE_CHECKING: + import synapse.server + # This is defined in the Matrix spec and enforced by the receiver. MAX_EDUS_PER_TRANSACTION = 100 diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 3c2a02a3b3..a2752a54a5 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import TYPE_CHECKING, List from canonicaljson import json -import synapse.server from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -31,6 +30,9 @@ from synapse.logging.opentracing import ( ) from synapse.util.metrics import measure_func +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 383e3fdc8b..060bf07197 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,13 +15,14 @@ # limitations under the License. import logging -from typing import Any, Dict +from typing import Any, Dict, Optional from six.moves import urllib from twisted.internet import defer from synapse.api.constants import Membership +from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.urls import ( FEDERATION_UNSTABLE_PREFIX, FEDERATION_V1_PREFIX, @@ -326,18 +327,25 @@ class TransportLayerClient(object): @log_function def get_public_rooms( self, - remote_server, - limit, - since_token, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, + remote_server: str, + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + include_all_networks: bool = False, + third_party_instance_id: Optional[str] = None, ): + """Get the list of public rooms from a remote homeserver + + See synapse.federation.federation_client.FederationClient.get_public_rooms for + more information. + """ if search_filter: # this uses MSC2197 (Search Filtering over Federation) path = _create_v1_path("/publicRooms") - data = {"include_all_networks": "true" if include_all_networks else "false"} + data = { + "include_all_networks": "true" if include_all_networks else "false" + } # type: Dict[str, Any] if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -347,9 +355,19 @@ class TransportLayerClient(object): data["filter"] = search_filter - response = yield self.client.post_json( - destination=remote_server, path=path, data=data, ignore_backoff=True - ) + try: + response = yield self.client.post_json( + destination=remote_server, path=path, data=data, ignore_backoff=True + ) + except HttpResponseException as e: + if e.code == 403: + raise SynapseError( + 403, + "You are not allowed to view the public rooms list of %s" + % (remote_server,), + errcode=Codes.FORBIDDEN, + ) + raise else: path = _create_v1_path("/publicRooms") @@ -363,9 +381,19 @@ class TransportLayerClient(object): if since_token: args["since"] = [since_token] - response = yield self.client.get_json( - destination=remote_server, path=path, args=args, ignore_backoff=True - ) + try: + response = yield self.client.get_json( + destination=remote_server, path=path, args=args, ignore_backoff=True + ) + except HttpResponseException as e: + if e.code == 403: + raise SynapseError( + 403, + "You are not allowed to view the public rooms list of %s" + % (remote_server,), + errcode=Codes.FORBIDDEN, + ) + raise return response diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 4f0dc0a209..4acb4fa489 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise NotImplementedError() - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from the group; either a user is leaving or an admin kicked them. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_kick = False if requester_user_id != user_id: - is_admin = yield self.store.is_user_admin_in_group( + is_admin = await self.store.is_user_admin_in_group( group_id, requester_user_id ) if not is_admin: @@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler): is_kick = True - yield self.store.remove_user_from_group(group_id, user_id) + await self.store.remove_user_from_group(group_id, user_id) if is_kick: if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) if not self.hs.is_mine_id(user_id): - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # Delete group if the last user has left - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) if not users: - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) return {} - @defer.inlineCallbacks - def create_group(self, group_id, requester_user_id, content): - group = yield self.check_group_is_ours(group_id, requester_user_id) + async def create_group(self, group_id, requester_user_id, content): + group = await self.check_group_is_ours(group_id, requester_user_id) logger.info("Attempting to create group with ID: %r", group_id) @@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) if not is_admin: @@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): long_description = profile.get("long_description") user_profile = content.get("user_profile", {}) - yield self.store.create_group( + await self.store.create_group( group_id, requester_user_id, name=name, @@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if not self.hs.is_mine_id(requester_user_id): remote_attestation = content["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=requester_user_id, group_id=group_id ) @@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): local_attestation = None remote_attestation = None - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, requester_user_id, is_admin=True, @@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): ) if not self.hs.is_mine_id(requester_user_id): - yield self.store.add_remote_profile_cache( + await self.store.add_remote_profile_cache( requester_user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), @@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {"group_id": group_id} - @defer.inlineCallbacks - def delete_group(self, group_id, requester_user_id): + async def delete_group(self, group_id, requester_user_id): """Deletes a group, kicking out all current members. Only group admins or server admins can call this request @@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler): Deferred """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) # Only server admins or group admins can delete groups. - is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) + is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) if not is_admin: - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) @@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise SynapseError(403, "User is not an admin") # Before deleting the group lets kick everyone out of it - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) - @defer.inlineCallbacks - def _kick_user_from_group(user_id): + async def _kick_user_from_group(user_id): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # We kick users out in the order of: # 1. Non-admins @@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler): else: non_admins.append(u["user_id"]) - yield concurrently_execute(_kick_user_from_group, non_admins, 10) - yield concurrently_execute(_kick_user_from_group, admins, 10) - yield _kick_user_from_group(requester_user_id) + await concurrently_execute(_kick_user_from_group, non_admins, 10) + await concurrently_execute(_kick_user_from_group, admins, 10) + await _kick_user_from_group(requester_user_id) - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) def _parse_join_policy_from_contents(content): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 51413d910e..3b781d9836 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -126,30 +126,28 @@ class BaseHandler(object): retry_after_ms=int(1000 * (time_allowed - time_now)) ) - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, context=None): + async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids() - current_state = yield self.store.get_events( + current_state_ids = await context.get_current_state_ids() + current_state = await self.store.get_events( list(current_state_ids.values()) ) else: - current_state = yield self.state_handler.get_current_state( + current_state = await self.state_handler.get_current_state( event.room_id ) current_state = list(current_state.values()) logger.info("maybe_kick_guest_users %r", current_state) - yield self.kick_guest_users(current_state) + await self.kick_guest_users(current_state) - @defer.inlineCallbacks - def kick_guest_users(self, current_state): + async def kick_guest_users(self, current_state): for member_event in current_state: try: if member_event.type != EventTypes.Member: @@ -180,7 +178,7 @@ class BaseHandler(object): # homeserver. requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() - yield handler.update_membership( + await handler.update_membership( requester, target_user, member_event.room_id, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 53e5f585d9..f2f16b1e43 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler): room_alias, room_id, servers, creator=creator ) - @defer.inlineCallbacks - def create_association( + async def create_association( self, requester: Requester, room_alias: RoomAlias, @@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler): else: # Server admins are not subject to the same constraints as normal # users when creating an alias (e.g. being in the room). - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) if (self.require_membership and check_membership) and not is_admin: - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + rooms_for_user = await self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( 403, "You must be in the room to create an alias for it" @@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias(room_alias, user_id=user_id) + can_create = await self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( 400, @@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - yield self._create_association(room_alias, room_id, servers, creator=user_id) + await self._create_association(room_alias, room_id, servers, creator=user_id) - @defer.inlineCallbacks - def delete_association(self, requester: Requester, room_alias: RoomAlias): + async def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call @@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler): user_id = requester.user.to_string() try: - can_delete = yield self._user_can_delete_alias(room_alias, user_id) + can_delete = await self._user_can_delete_alias(room_alias, user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown room alias") @@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler): if not can_delete: raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) + can_delete = await self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( 400, @@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - room_id = yield self._delete_association(room_alias) + room_id = await self._delete_association(room_alias) try: - yield self._update_canonical_alias(requester, user_id, room_id, room_alias) + await self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - @defer.inlineCallbacks - def _update_canonical_alias( + async def _update_canonical_alias( self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias ): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. """ - alias_event = yield self.state.get_current_state( + alias_event = await self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" ) @@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler): del content["alt_aliases"] if send_update: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler): # either no interested services, or no service with an exclusive lock return defer.succeed(True) - @defer.inlineCallbacks - def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): """Determine whether a user can delete an alias. One of the following must be true: @@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler): for the current room. """ - creator = yield self.store.get_room_alias_creator(alias.to_string()) + creator = await self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True # Resolve the alias to the corresponding room. - room_mapping = yield self.get_association(alias) + room_mapping = await self.get_association(alias) room_id = room_mapping["room_id"] if not room_id: return False - res = yield self.auth.check_can_change_room_list( + res = await self.auth.check_can_change_room_list( room_id, UserID.from_string(user_id) ) return res - @defer.inlineCallbacks - def edit_published_room_list( + async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str ): """Edit the entry of the room in the published room list. @@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler): 403, "This user is not permitted to publish rooms to the room list" ) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Unknown room") - can_change_room_list = yield self.auth.check_can_change_room_list( + can_change_room_list = await self.auth.check_can_change_room_list( room_id, requester.user ) if not can_change_room_list: @@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler): making_public = visibility == "public" if making_public: - room_aliases = yield self.store.get_aliases_for_room(room_id) - canonical_alias = yield self.store.get_canonical_alias_for_room(room_id) + room_aliases = await self.store.get_aliases_for_room(room_id) + canonical_alias = await self.store.get_canonical_alias_for_room(room_id) if canonical_alias: room_aliases.append(canonical_alias) @@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") - yield self.store.set_room_is_public(room_id, making_public) + await self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks def edit_published_appservice_room_list( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 41b96c0a73..4e5c645525 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler): "missing": [e.event_id for e in missing_locals], } - @defer.inlineCallbacks @log_function - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): third_party_invite = {"signed": signed} @@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler): "state_key": target_user_id, } - if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): - room_version = yield self.store.get_room_version_id(room_id) + if await self.auth.check_host_in_room(room_id, self.hs.hostname): + room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler): 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - event, context = yield self.add_display_name_to_third_party_invite( + event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) @@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e - yield self._check_signature(event, context) + await self._check_signature(event, context) # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() - yield member_handler.send_membership_event(None, event, context) + await member_handler.send_membership_event(None, event, context) else: destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} - yield self.federation_client.forward_third_party_invite( + await self.federation_client.forward_third_party_invite( destinations, room_id, event_dict ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ad22415782..ca5c83811a 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - @defer.inlineCallbacks - def create_group(self, group_id, user_id, content): + async def create_group(self, group_id, user_id, content): """Create a group """ logger.info("Asking to create group with ID: %r", group_id) if self.is_mine_id(group_id): - res = yield self.groups_server_handler.create_group( + res = await self.groups_server_handler.create_group( group_id, user_id, content ) local_attestation = None @@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = self.attestations.create_attestation(group_id, user_id) content["attestation"] = local_attestation - content["user_profile"] = yield self.profile_handler.get_profile(user_id) + content["user_profile"] = await self.profile_handler.get_profile(user_id) try: - res = yield self.transport_client.create_group( + res = await self.transport_client.create_group( get_domain_from_id(group_id), group_id, user_id, content ) except HttpResponseException as e: @@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): raise SynapseError(502, "Failed to contact group server") remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): ) is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from a group """ if user_id == requester_user_id: - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" ) self.notifier.on_new_event("groups_key", token, users=[user_id]) @@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): # retry if the group server is currently down. if self.is_mine_id(group_id): - res = yield self.groups_server_handler.remove_user_from_group( + res = await self.groups_server_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id try: - res = yield self.transport_client.remove_user_from_group( + res = await self.transport_client.remove_user_from_group( get_domain_from_id(group_id), group_id, requester_user_id, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 522271eed1..a324f09340 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -626,8 +626,7 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - @defer.inlineCallbacks - def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event(self, requester, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -647,7 +646,7 @@ class EventCreationHandler(object): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = yield self.deduplicate_state_event(event, context) + prev_state = await self.deduplicate_state_event(event, context) if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", @@ -656,7 +655,7 @@ class EventCreationHandler(object): ) return prev_state - yield self.handle_new_client_event( + await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -683,8 +682,7 @@ class EventCreationHandler(object): return prev_event return - @defer.inlineCallbacks - def create_and_send_nonmember_event( + async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None ): """ @@ -698,8 +696,8 @@ class EventCreationHandler(object): # a situation where event persistence can't keep up, causing # extremities to pile up, which in turn leads to state resolution # taking longer. - with (yield self.limiter.queue(event_dict["room_id"])): - event, context = yield self.create_event( + with (await self.limiter.queue(event_dict["room_id"])): + event, context = await self.create_event( requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) @@ -709,7 +707,7 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - yield self.send_nonmember_event( + await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) return event @@ -770,8 +768,7 @@ class EventCreationHandler(object): return (event, context) @measure_func("handle_new_client_event") - @defer.inlineCallbacks - def handle_new_client_event( + async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): """Processes a new event. This includes checking auth, persisting it, @@ -794,9 +791,9 @@ class EventCreationHandler(object): ): room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: - room_version = yield self.store.get_room_version_id(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -805,7 +802,7 @@ class EventCreationHandler(object): ) try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) raise err @@ -818,7 +815,7 @@ class EventCreationHandler(object): logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event(event, context) + await self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. @@ -826,7 +823,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - yield self.send_event_to_master( + await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -838,7 +835,7 @@ class EventCreationHandler(object): success = True return - yield self.persist_and_notify_client_event( + await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) @@ -883,8 +880,7 @@ class EventCreationHandler(object): Codes.BAD_ALIAS, ) - @defer.inlineCallbacks - def persist_and_notify_client_event( + async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): """Called when we have fully built the event, have already @@ -901,7 +897,7 @@ class EventCreationHandler(object): # user is actually admin or not). is_admin_redaction = False if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -913,11 +909,11 @@ class EventCreationHandler(object): original_event and event.sender != original_event.sender ) - yield self.base_handler.ratelimit( + await self.base_handler.ratelimit( requester, is_admin_redaction=is_admin_redaction ) - yield self.base_handler.maybe_kick_guest_users(event, context) + await self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Validate a newly added alias or newly added alt_aliases. @@ -927,7 +923,7 @@ class EventCreationHandler(object): original_event_id = event.unsigned.get("replaces_state") if original_event_id: - original_event = yield self.store.get_event(original_event_id) + original_event = await self.store.get_event(original_event_id) if original_event: original_alias = original_event.content.get("alias", None) @@ -937,7 +933,7 @@ class EventCreationHandler(object): room_alias_str = event.content.get("alias", None) directory_handler = self.hs.get_handlers().directory_handler if room_alias_str and room_alias_str != original_alias: - yield self._validate_canonical_alias( + await self._validate_canonical_alias( directory_handler, room_alias_str, event.room_id ) @@ -957,7 +953,7 @@ class EventCreationHandler(object): new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) if new_alt_aliases: for alias_str in new_alt_aliases: - yield self._validate_canonical_alias( + await self._validate_canonical_alias( directory_handler, alias_str, event.room_id ) @@ -969,7 +965,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids() + current_state_ids = await context.get_current_state_ids() state_to_include_ids = [ e_id @@ -978,7 +974,7 @@ class EventCreationHandler(object): or k == (EventTypes.Member, event.sender) ] - state_to_include = yield self.store.get_events(state_to_include_ids) + state_to_include = await self.store.get_events(state_to_include_ids) event.unsigned["invite_room_state"] = [ { @@ -996,8 +992,8 @@ class EventCreationHandler(object): # way? If we have been invited by a remote server, we need # to get them to sign the event. - returned_invite = yield defer.ensureDeferred( - federation_handler.send_invite(invitee.domain, event) + returned_invite = await federation_handler.send_invite( + invitee.domain, event ) event.unsigned.pop("room_state", None) @@ -1005,7 +1001,7 @@ class EventCreationHandler(object): event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -1021,14 +1017,14 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids() - auth_events_ids = yield self.auth.compute_auth_events( + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} - room_version = yield self.store.get_room_version_id(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if event_auth.check_redaction( @@ -1047,11 +1043,11 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( + event_stream_id, max_stream_id = await self.storage.persistence.persist_event( event, context=context ) @@ -1059,7 +1055,7 @@ class EventCreationHandler(object): # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: @@ -1083,13 +1079,12 @@ class EventCreationHandler(object): except Exception: logger.exception("Error bumping presence active time") - @defer.inlineCallbacks - def _send_dummy_events_to_fill_extremities(self): + async def _send_dummy_events_to_fill_extremities(self): """Background task to send dummy events into rooms that have a large number of extremities """ self._expire_rooms_to_exclude_from_dummy_event_insertion() - room_ids = yield self.store.get_rooms_with_many_extremities( + room_ids = await self.store.get_rooms_with_many_extremities( min_count=10, limit=5, room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), @@ -1099,9 +1094,9 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - latest_event_ids = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - members = yield self.state.get_current_users_in_room( + members = await self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids ) dummy_event_sent = False @@ -1110,7 +1105,7 @@ class EventCreationHandler(object): continue requester = create_requester(user_id) try: - event, context = yield self.create_event( + event, context = await self.create_event( requester, { "type": "org.matrix.dummy_event", @@ -1123,7 +1118,7 @@ class EventCreationHandler(object): event.internal_metadata.proactively_send = False - yield self.send_nonmember_event( + await self.send_nonmember_event( requester, event, context, ratelimit=False ) dummy_event_sent = True diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6aa1c0f5e0..302efc1b9a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler): return result["displayname"] - @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname, by_admin=False): + async def set_displayname( + self, target_user, requester, new_displayname, by_admin=False + ): """Set the displayname of a user Args: @@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's displayname") if not by_admin and not self.hs.config.enable_set_displayname: - profile = yield self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( 400, @@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - yield self.store.set_profile_displayname(target_user.localpart, new_displayname) + await self.store.set_profile_displayname(target_user.localpart, new_displayname) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler): return result["avatar_url"] - @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): + async def set_avatar_url( + self, target_user, requester, new_avatar_url, by_admin=False + ): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): @@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's avatar_url") if not by_admin and not self.hs.config.enable_set_avatar_url: - profile = yield self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN @@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def on_profile_query(self, args): @@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler): return response - @defer.inlineCallbacks - def _update_join_states(self, requester, target_user): + async def _update_join_states(self, requester, target_user): if not self.hs.is_mine(target_user): return - yield self.ratelimit(requester) + await self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) + room_ids = await self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() try: # Assume the target_user isn't a guest, # because we don't let guests set profile or avatar data. - yield handler.update_membership( + await handler.update_membership( requester, target_user, room_id, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3a65b46ecd..1e6bdac0ad 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler): """Registers a new client on the server. Args: - localpart : The local part of the user ID to register. If None, + localpart: The local part of the user ID to register. If None, one will be generated. - password (unicode) : The password to assign to this user so they can + password (unicode): The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). user_type (str|None): type of user. One of the values from @@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler): fail_count += 1 if not self.hs.config.user_consent_at_registration: - yield self._auto_join_rooms(user_id) + yield defer.ensureDeferred(self._auto_join_rooms(user_id)) else: logger.info( "Skipping auto-join for %s because consent is required at registration", @@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler): return user_id - @defer.inlineCallbacks - def _auto_join_rooms(self, user_id): + async def _auto_join_rooms(self, user_id): """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler): # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_real_user = yield self.store.is_real_user(user_id) + is_real_user = await self.store.is_real_user(user_id) if self.hs.config.autocreate_auto_join_rooms and is_real_user: - count = yield self.store.count_real_users() + count = await self.store.count_real_users() should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) @@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler): # getting the RoomCreationHandler during init gives a dependency # loop - yield self.hs.get_room_creation_handler().create_room( + await self.hs.get_room_creation_handler().create_room( fake_requester, config={ "preset": "public_chat", @@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) else: - yield self._join_user_to_room(fake_requester, r) + await self._join_user_to_room(fake_requester, r) except ConsentNotGivenError as e: # Technically not necessary to pull out this error though # moving away from bare excepts is a good thing to do. @@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - @defer.inlineCallbacks - def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id): """A series of registration actions that can only be carried out once consent has been granted Args: user_id (str): The user to join """ - yield self._auto_join_rooms(user_id) + await self._auto_join_rooms(user_id) @defer.inlineCallbacks def appservice_register(self, user_localpart, as_token): @@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id += 1 return str(id) - @defer.inlineCallbacks - def _join_user_to_room(self, requester, room_identifier): + async def _join_user_to_room(self, requester, room_identifier): room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( + room_id, remote_room_hosts = await room_member_handler.lookup_room_alias( room_alias ) room_id = room_id.to_string() @@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler): 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield room_member_handler.update_membership( + await room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler): return (device_id, access_token) - @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions(self, user_id, auth_result, access_token): """A user has completed registration Args: @@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler): device, or None if `inhibit_login` enabled. """ if self.hs.config.worker_app: - yield self._post_registration_client( + await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token ) return @@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(user_id) + await self.store.upsert_monthly_active_user(user_id) - yield self._register_email_threepid(user_id, threepid, access_token) + await self._register_email_threepid(user_id, threepid, access_token) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid(user_id, threepid) + await self._register_msisdn_threepid(user_id, threepid) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented(user_id, self.hs.config.user_consent_version) + await self._on_user_consented(user_id, self.hs.config.user_consent_version) - @defer.inlineCallbacks - def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id, consent_version): """A user consented to the terms on registration Args: @@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version(user_id, consent_version) - yield self.post_consent_actions(user_id) + await self.store.user_set_consent_version(user_id, consent_version) + await self.post_consent_actions(user_id) @defer.inlineCallbacks def _register_email_threepid(self, user_id, threepid, token): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d10e4b2d9..73f9eeb399 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,8 +25,6 @@ from collections import OrderedDict from six import iteritems, string_types -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() - @defer.inlineCallbacks - def upgrade_room( + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): """Replace a room with a new room with a different version @@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler): Returns: Deferred[unicode]: the new room id """ - yield self.ratelimit(requester) + await self.ratelimit(requester) user_id = requester.user.to_string() @@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler): # If this user has sent multiple upgrade requests for the same room # and one of them is not complete yet, cache the response and # return it to all subsequent requests - ret = yield self._upgrade_response_cache.wrap( + ret = await self._upgrade_response_cache.wrap( (old_room_id, user_id), self._upgrade_room, requester, @@ -148,17 +145,16 @@ class RoomCreationHandler(BaseHandler): return ret - @defer.inlineCallbacks - def _upgrade_room( + async def _upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): user_id = requester.user.to_string() # start by allocating a new room id - r = yield self.store.get_room(old_room_id) + r = await self.store.get_room(old_room_id) if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) - new_room_id = yield self._generate_room_id( + new_room_id = await self._generate_room_id( creator_id=user_id, is_public=r["is_public"], room_version=new_version, ) @@ -169,7 +165,7 @@ class RoomCreationHandler(BaseHandler): ( tombstone_event, tombstone_context, - ) = yield self.event_creation_handler.create_event( + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Tombstone, @@ -183,12 +179,12 @@ class RoomCreationHandler(BaseHandler): }, token_id=requester.access_token_id, ) - old_room_version = yield self.store.get_room_version_id(old_room_id) - yield self.auth.check_from_context( + old_room_version = await self.store.get_room_version_id(old_room_id) + await self.auth.check_from_context( old_room_version, tombstone_event, tombstone_context ) - yield self.clone_existing_room( + await self.clone_existing_room( requester, old_room_id=old_room_id, new_room_id=new_room_id, @@ -197,32 +193,31 @@ class RoomCreationHandler(BaseHandler): ) # now send the tombstone - yield self.event_creation_handler.send_nonmember_event( + await self.event_creation_handler.send_nonmember_event( requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids() + old_room_state = await tombstone_context.get_current_state_ids() # update any aliases - yield self._move_aliases_to_new_room( + await self._move_aliases_to_new_room( requester, old_room_id, new_room_id, old_room_state ) # Copy over user push rules, tags and migrate room directory state - yield self.room_member_handler.transfer_room_state_on_room_upgrade( + await self.room_member_handler.transfer_room_state_on_room_upgrade( old_room_id, new_room_id ) # finally, shut down the PLs in the old room, and update them in the new # room. - yield self._update_upgraded_room_pls( + await self._update_upgraded_room_pls( requester, old_room_id, new_room_id, old_room_state, ) return new_room_id - @defer.inlineCallbacks - def _update_upgraded_room_pls( + async def _update_upgraded_room_pls( self, requester: Requester, old_room_id: str, @@ -249,7 +244,7 @@ class RoomCreationHandler(BaseHandler): ) return - old_room_pl_state = yield self.store.get_event(old_room_pl_event_id) + old_room_pl_state = await self.store.get_event(old_room_pl_event_id) # we try to stop regular users from speaking by setting the PL required # to send regular events and invites to 'Moderator' level. That's normally @@ -278,7 +273,7 @@ class RoomCreationHandler(BaseHandler): if updated: try: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -292,7 +287,7 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -304,8 +299,7 @@ class RoomCreationHandler(BaseHandler): ratelimit=False, ) - @defer.inlineCallbacks - def clone_existing_room( + async def clone_existing_room( self, requester: Requester, old_room_id: str, @@ -338,7 +332,7 @@ class RoomCreationHandler(BaseHandler): # Check if old room was non-federatable # Get old room's create event - old_room_create_event = yield self.store.get_create_event_for_room(old_room_id) + old_room_create_event = await self.store.get_create_event_for_room(old_room_id) # Check if the create event specified a non-federatable room if not old_room_create_event.content.get("m.federate", True): @@ -361,11 +355,11 @@ class RoomCreationHandler(BaseHandler): (EventTypes.PowerLevels, ""), ) - old_room_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent - old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) + old_room_state_events = await self.store.get_events(old_room_state_ids.values()) for k, old_event_id in iteritems(old_room_state_ids): old_event = old_room_state_events.get(old_event_id) @@ -400,7 +394,7 @@ class RoomCreationHandler(BaseHandler): if current_power_level < needed_power_level: power_levels["users"][user_id] = needed_power_level - yield self._send_events_for_new_room( + await self._send_events_for_new_room( requester, new_room_id, # we expect to override all the presets with initial_state, so this is @@ -412,12 +406,12 @@ class RoomCreationHandler(BaseHandler): ) # Transfer membership events - old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_member_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent - old_room_member_state_events = yield self.store.get_events( + old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): @@ -426,7 +420,7 @@ class RoomCreationHandler(BaseHandler): "membership" in old_event.content and old_event.content["membership"] == "ban" ): - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(old_event["state_key"]), new_room_id, @@ -438,8 +432,7 @@ class RoomCreationHandler(BaseHandler): # XXX invites/joins # XXX 3pid invites - @defer.inlineCallbacks - def _move_aliases_to_new_room( + async def _move_aliases_to_new_room( self, requester: Requester, old_room_id: str, @@ -448,13 +441,13 @@ class RoomCreationHandler(BaseHandler): ): directory_handler = self.hs.get_handlers().directory_handler - aliases = yield self.store.get_aliases_for_room(old_room_id) + aliases = await self.store.get_aliases_for_room(old_room_id) # check to see if we have a canonical alias. canonical_alias_event = None canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event_id: - canonical_alias_event = yield self.store.get_event(canonical_alias_event_id) + canonical_alias_event = await self.store.get_event(canonical_alias_event_id) # first we try to remove the aliases from the old room (we suppress sending # the room_aliases event until the end). @@ -472,7 +465,7 @@ class RoomCreationHandler(BaseHandler): for alias_str in aliases: alias = RoomAlias.from_string(alias_str) try: - yield directory_handler.delete_association(requester, alias) + await directory_handler.delete_association(requester, alias) removed_aliases.append(alias_str) except SynapseError as e: logger.warning("Unable to remove alias %s from old room: %s", alias, e) @@ -485,7 +478,7 @@ class RoomCreationHandler(BaseHandler): # we can now add any aliases we successfully removed to the new room. for alias in removed_aliases: try: - yield directory_handler.create_association( + await directory_handler.create_association( requester, RoomAlias.from_string(alias), new_room_id, @@ -502,7 +495,7 @@ class RoomCreationHandler(BaseHandler): # alias event for the new room with a copy of the information. try: if canonical_alias_event: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -518,8 +511,9 @@ class RoomCreationHandler(BaseHandler): # we returned the new room to the client at this point. logger.error("Unable to send updated alias events in new room: %s", e) - @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): + async def create_room( + self, requester, config, ratelimit=True, creator_join_profile=None + ): """ Creates a new room. Args: @@ -547,7 +541,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) if ( self._server_notices_mxid is not None @@ -556,11 +550,11 @@ class RoomCreationHandler(BaseHandler): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. - event_allowed = yield self.third_party_event_rules.on_create_room( + event_allowed = await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) if not event_allowed: @@ -574,7 +568,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: - yield self.ratelimit(requester) + await self.ratelimit(requester) room_version_id = config.get( "room_version", self.config.default_room_version.identifier @@ -597,7 +591,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(400, "Invalid characters in room alias") room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) - mapping = yield self.store.get_association_from_room_alias(room_alias) + mapping = await self.store.get_association_from_room_alias(room_alias) if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) @@ -612,7 +606,7 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy(requester) + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") if ( @@ -631,13 +625,13 @@ class RoomCreationHandler(BaseHandler): visibility = config.get("visibility", None) is_public = visibility == "public" - room_id = yield self._generate_room_id( + room_id = await self._generate_room_id( creator_id=user_id, is_public=is_public, room_version=room_version, ) directory_handler = self.hs.get_handlers().directory_handler if room_alias: - yield directory_handler.create_association( + await directory_handler.create_association( requester=requester, room_id=room_id, room_alias=room_alias, @@ -670,7 +664,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - yield self._send_events_for_new_room( + await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -684,7 +678,7 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -698,7 +692,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -716,7 +710,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -730,7 +724,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - yield self.hs.get_room_member_handler().do_3pid_invite( + await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -748,8 +742,7 @@ class RoomCreationHandler(BaseHandler): return result - @defer.inlineCallbacks - def _send_events_for_new_room( + async def _send_events_for_new_room( self, creator, # A Requester object. room_id, @@ -769,11 +762,10 @@ class RoomCreationHandler(BaseHandler): return e - @defer.inlineCallbacks - def send(etype, content, **kwargs): + async def send(etype, content, **kwargs): event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) @@ -784,10 +776,10 @@ class RoomCreationHandler(BaseHandler): event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send(etype=EventTypes.Create, content=creation_content) + await send(etype=EventTypes.Create, content=creation_content) logger.debug("Sending %s in new room", EventTypes.Member) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( creator, creator.user, room_id, @@ -800,7 +792,7 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send(etype=EventTypes.PowerLevels, content=pl_content) + await send(etype=EventTypes.PowerLevels, content=pl_content) else: power_level_content = { "users": {creator_id: 100}, @@ -833,36 +825,35 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - yield send(etype=EventTypes.PowerLevels, content=power_level_content) + await send(etype=EventTypes.PowerLevels, content=power_level_content) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - yield send( + await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - yield send( + await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - yield send( + await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - yield send( + await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send(etype=etype, state_key=state_key, content=content) + await send(etype=etype, state_key=state_key, content=content) - @defer.inlineCallbacks - def _generate_room_id( + async def _generate_room_id( self, creator_id: str, is_public: str, room_version: RoomVersion, ): # autogen room IDs and try to create it. We may clash, so just @@ -874,7 +865,7 @@ class RoomCreationHandler(BaseHandler): gen_room_id = RoomID(random_string, self.hs.hostname).to_string() if isinstance(gen_room_id, bytes): gen_room_id = gen_room_id.decode("utf-8") - yield self.store.store_room( + await self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, is_public=is_public, @@ -893,8 +884,7 @@ class RoomContextHandler(object): self.storage = hs.get_storage() self.state_store = self.storage.state - @defer.inlineCallbacks - def get_event_context(self, user, room_id, event_id, limit, event_filter): + async def get_event_context(self, user, room_id, event_id, limit, event_filter): """Retrieves events, pagination tokens and state around a given event in a room. @@ -913,7 +903,7 @@ class RoomContextHandler(object): before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit - users = yield self.store.get_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) is_peeking = user.to_string() not in users def filter_evts(events): @@ -921,17 +911,17 @@ class RoomContextHandler(object): self.storage, user.to_string(), events, is_peeking=is_peeking ) - event = yield self.store.get_event( + event = await self.store.get_event( event_id, get_prev_content=True, allow_none=True ) if not event: return None - filtered = yield (filter_evts([event])) + filtered = await filter_evts([event]) if not filtered: raise AuthError(403, "You don't have permission to access that event.") - results = yield self.store.get_events_around( + results = await self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter ) @@ -939,8 +929,8 @@ class RoomContextHandler(object): results["events_before"] = event_filter.filter(results["events_before"]) results["events_after"] = event_filter.filter(results["events_after"]) - results["events_before"] = yield filter_evts(results["events_before"]) - results["events_after"] = yield filter_evts(results["events_after"]) + results["events_before"] = await filter_evts(results["events_before"]) + results["events_after"] = await filter_evts(results["events_after"]) # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. @@ -967,7 +957,7 @@ class RoomContextHandler(object): # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = yield self.state_store.get_state_for_events( + state = await self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) @@ -975,7 +965,7 @@ class RoomContextHandler(object): if event_filter: state_events = event_filter.filter(state_events) - results["state"] = yield filter_evts(state_events) + results["state"] = await filter_evts(state_events) # We use a dummy token here as we only care about the room portion of # the token, which we replace. @@ -994,13 +984,12 @@ class RoomEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_new_events( + async def get_new_events( self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None ): # We just ignore the key for now. - to_key = yield self.get_current_key() + to_key = await self.get_current_key() from_token = RoomStreamToken.parse(from_key) if from_token.topological: @@ -1013,11 +1002,11 @@ class RoomEventSource(object): # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() else: - room_events = yield self.store.get_membership_changes_for_user( + room_events = await self.store.get_membership_changes_for_user( user.to_string(), from_key, to_key ) - room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=room_ids, from_key=from_key, to_key=to_key, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c3ee8db4f0..53b49bc15f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -142,8 +142,7 @@ class RoomMemberHandler(object): """ raise NotImplementedError() - @defer.inlineCallbacks - def _local_membership_update( + async def _local_membership_update( self, requester, target, @@ -164,7 +163,7 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" - event, context = yield self.event_creation_handler.create_event( + event, context = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -182,18 +181,18 @@ class RoomMemberHandler(object): ) # Check if this event matches the previous membership event for the user. - duplicate = yield self.event_creation_handler.deduplicate_state_event( + duplicate = await self.event_creation_handler.deduplicate_state_event( event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. return duplicate - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -203,15 +202,15 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target, room_id) + await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target, room_id) + await self._user_left_room(target, room_id) return event @@ -253,8 +252,7 @@ class RoomMemberHandler(object): for tag, tag_content in room_tags.items(): yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) - @defer.inlineCallbacks - def update_membership( + async def update_membership( self, requester, target, @@ -269,8 +267,8 @@ class RoomMemberHandler(object): ): key = (room_id,) - with (yield self.member_linearizer.queue(key)): - result = yield self._update_membership( + with (await self.member_linearizer.queue(key)): + result = await self._update_membership( requester, target, room_id, @@ -285,8 +283,7 @@ class RoomMemberHandler(object): return result - @defer.inlineCallbacks - def _update_membership( + async def _update_membership( self, requester, target, @@ -321,7 +318,7 @@ class RoomMemberHandler(object): # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: - yield self.federation_handler.exchange_third_party_invite( + await self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -332,7 +329,7 @@ class RoomMemberHandler(object): remote_room_hosts = [] if effective_membership_state not in ("leave", "ban"): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -351,7 +348,7 @@ class RoomMemberHandler(object): is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -370,9 +367,9 @@ class RoomMemberHandler(object): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - latest_event_ids = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - current_state_ids = yield self.state_handler.get_current_state_ids( + current_state_ids = await self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids ) @@ -381,7 +378,7 @@ class RoomMemberHandler(object): # transitions and generic otherwise old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) if old_state_id: - old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( @@ -413,7 +410,7 @@ class RoomMemberHandler(object): old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): - is_blocked = yield self._is_server_notice_room(room_id) + is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( http_client.FORBIDDEN, @@ -424,18 +421,18 @@ class RoomMemberHandler(object): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = yield self._is_host_in_room(current_state_ids) + is_host_in_room = await self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(current_state_ids) + guest_can_join = await self._can_guest_join(current_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) @@ -443,13 +440,13 @@ class RoomMemberHandler(object): profile = self.profile_handler if not content_specified: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + content["displayname"] = await profile.get_displayname(target) + content["avatar_url"] = await profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" - remote_join_response = yield self._remote_join( + remote_join_response = await self._remote_join( requester, remote_room_hosts, room_id, target, content ) @@ -458,7 +455,7 @@ class RoomMemberHandler(object): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") @@ -472,12 +469,12 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - res = yield self._remote_reject_invite( + res = await self._remote_reject_invite( requester, remote_room_hosts, room_id, target, content, ) return res - res = yield self._local_membership_update( + res = await self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -572,8 +569,7 @@ class RoomMemberHandler(object): ) continue - @defer.inlineCallbacks - def send_membership_event(self, requester, event, context, ratelimit=True): + async def send_membership_event(self, requester, event, context, ratelimit=True): """ Change the membership status of a user in a room. @@ -599,27 +595,27 @@ class RoomMemberHandler(object): else: requester = types.create_requester(target_user) - prev_event = yield self.event_creation_handler.deduplicate_state_event( + prev_event = await self.event_creation_handler.deduplicate_state_event( event, context ) if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(prev_state_ids) + guest_can_join = await self._can_guest_join(prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if event.membership not in (Membership.LEAVE, Membership.BAN): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) @@ -633,15 +629,15 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target_user, room_id) + await self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target_user, room_id) + await self._user_left_room(target_user, room_id) @defer.inlineCallbacks def _can_guest_join(self, current_state_ids): @@ -699,8 +695,7 @@ class RoomMemberHandler(object): if invite: return UserID.from_string(invite.sender) - @defer.inlineCallbacks - def do_3pid_invite( + async def do_3pid_invite( self, room_id, inviter, @@ -712,7 +707,7 @@ class RoomMemberHandler(object): id_access_token=None, ): if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN @@ -720,9 +715,9 @@ class RoomMemberHandler(object): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - yield self.base_handler.ratelimit(requester) + await self.base_handler.ratelimit(requester) - can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( + can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id ) if not can_invite: @@ -737,16 +732,16 @@ class RoomMemberHandler(object): 403, "Looking up third-party identifiers is denied from this server" ) - invitee = yield self.identity_handler.lookup_3pid( + invitee = await self.identity_handler.lookup_3pid( id_server, medium, address, id_access_token ) if invitee: - yield self.update_membership( + await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - yield self._make_and_store_3pid_invite( + await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -757,8 +752,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - @defer.inlineCallbacks - def _make_and_store_3pid_invite( + async def _make_and_store_3pid_invite( self, requester, id_server, @@ -769,7 +763,7 @@ class RoomMemberHandler(object): txn_id, id_access_token=None, ): - room_state = yield self.state_handler.get_current_state(room_id) + room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" @@ -807,7 +801,7 @@ class RoomMemberHandler(object): public_keys, fallback_public_key, display_name, - ) = yield self.identity_handler.ask_id_server_for_third_party_invite( + ) = await self.identity_handler.ask_id_server_for_third_party_invite( requester=requester, id_server=id_server, medium=medium, @@ -823,7 +817,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity - @defer.inlineCallbacks - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: # Fetch the room complexity - too_complex = yield self._is_remote_room_too_complex( + too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts ) if too_complex is True: @@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - yield defer.ensureDeferred( - self.federation_handler.do_invite_join( - remote_room_hosts, room_id, user.to_string(), content - ) + await self.federation_handler.do_invite_join( + remote_room_hosts, room_id, user.to_string(), content ) - yield self._user_joined_room(user, room_id) + await self._user_joined_room(user, room_id) # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. @@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return # Check again, but with the local state events - too_complex = yield self._is_local_room_too_complex(room_id) + too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. @@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # The room is too large. Leave. requester = types.create_requester(user, None, False, None) - yield self.update_membership( + await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) raise SynapseError( @@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler): def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room """ - return user_joined_room(self.distributor, target, room_id) + return defer.succeed(user_joined_room(self.distributor, target, room_id)) def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ - return user_left_room(self.distributor, target, room_id) + return defer.succeed(user_left_room(self.distributor, target, room_id)) @defer.inlineCallbacks def forget(self, user, room_id): diff --git a/synapse/notifier.py b/synapse/notifier.py index 88a5a97caf..71d9ed62b0 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -273,10 +273,9 @@ class Notifier(object): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) - @defer.inlineCallbacks - def _notify_app_services(self, room_stream_id): + async def _notify_app_services(self, room_stream_id): try: - yield self.appservice_handler.notify_interested_services(room_stream_id) + await self.appservice_handler.notify_interested_services(room_stream_id) except Exception: logger.exception("Error notifying application services of event") @@ -475,20 +474,18 @@ class Notifier(object): return result - @defer.inlineCallbacks - def _get_room_ids(self, user, explicit_room_id): - joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) + async def _get_room_ids(self, user, explicit_room_id): + joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: return [explicit_room_id], True - if (yield self._is_world_readable(explicit_room_id)): + if await self._is_world_readable(explicit_room_id): return [explicit_room_id], False raise AuthError(403, "Non-joined access not allowed") return joined_room_ids, True - @defer.inlineCallbacks - def _is_world_readable(self, room_id): - state = yield self.state_handler.get_current_state( + async def _is_world_readable(self, room_id): + state = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 1be1ccbdf3..f88c80ae84 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,7 @@ import abc import logging import re +from inspect import signature from typing import Dict, List, Tuple from six import raise_from @@ -60,6 +61,8 @@ class ReplicationEndpoint(object): must call `register` to register the path with the HTTP server. Requests can be sent by calling the client returned by `make_client`. + Requests are sent to master process by default, but can be sent to other + named processes by specifying an `instance_name` keyword argument. Attributes: NAME (str): A name for the endpoint, added to the path as well as used @@ -91,6 +94,16 @@ class ReplicationEndpoint(object): hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) + # We reserve `instance_name` as a parameter to sending requests, so we + # assert here that sub classes don't try and use the name. + assert ( + "instance_name" not in self.PATH_ARGS + ), "`instance_name` is a reserved paramater name" + assert ( + "instance_name" + not in signature(self.__class__._serialize_payload).parameters + ), "`instance_name` is a reserved paramater name" + assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod @@ -135,7 +148,11 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks - def send_request(**kwargs): + def send_request(instance_name="master", **kwargs): + # Currently we only support sending requests to master process. + if instance_name != "master": + raise Exception("Unknown instance") + data = yield cls._serialize_payload(**kwargs) url_args = [ diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index f35cebc710..0459f582bf 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) + self._instance_name = hs.get_instance_name() + # We pull the streams from the replication steamer (if we try and make # them ourselves we end up in an import loop). self.streams = hs.get_replication_streamer().get_streams() @@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): upto_token = parse_integer(request, "upto_token", required=True) updates, upto_token, limited = await stream.get_updates_since( - from_token, upto_token + self._instance_name, from_token, upto_token ) return ( diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 751c799d94..5d7c8871a4 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Optional +from typing import Optional import six @@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): self.hs = hs - def stream_positions(self) -> Dict[str, int]: - """ - Get the current positions of all the streams this store wants to subscribe to - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - pos = {} - if self._cache_id_gen: - pos["caches"] = self._cache_id_gen.get_current_token() - return pos - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index ebe94909cb..65e54b1c71 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedAccountDataStore, self).stream_positions() - position = self._account_data_id_gen.get_current_token() - result["user_account_data"] = position - result["room_account_data"] = position - result["tag_account_data"] = position - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 0c237c6e0f..c923751e50 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def stream_positions(self): - result = super(SlavedDeviceInboxStore, self).stream_positions() - result["to_device"] = self._device_inbox_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 23b1650e41..58fb0eaae3 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) - def stream_positions(self): - result = super(SlavedDeviceStore, self).stream_positions() - # The user signature stream uses the same stream ID generator as the - # device list stream, so set them both to the device list ID - # generator's current token. - current_token = self._device_list_id_gen.get_current_token() - result[DeviceListsStream.NAME] = current_token - result[UserSignatureStream.NAME] = current_token - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index e73342c657..15011259df 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,12 +93,6 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedEventStore, self).stream_positions() - result["events"] = self._stream_id_gen.get_current_token() - result["backfill"] = -self._backfill_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "events": self._stream_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 2d4fd08cf5..01bcf0e882 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedGroupServerStore, self).stream_positions() - result["groups"] = self._group_updates_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index ad8f0c15a9..fae3125072 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPresenceStore, self).stream_positions() - - if self.hs.config.use_presence: - position = self._presence_id_gen.get_current_token() - result["presence"] = position - - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index eebd5a1fb6..6138796da4 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPushRuleStore, self).stream_positions() - result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index bce8a3d115..67be337945 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def stream_positions(self): - result = super(SlavedPusherStore, self).stream_positions() - result["pushers"] = self._pushers_id_gen.get_current_token() - return result - def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index d40dc6e1f5..993432edcb 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedReceiptsStore, self).stream_positions() - result["receipts"] = self._receipts_id_gen.get_current_token() - return result - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate_many((room_id,)) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 3a20f45316..10dda8708f 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def stream_positions(self): - result = super(RoomStore, self).stream_positions() - result["public_rooms"] = self._public_room_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2d07b8b2d0..3bbf3c3569 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,7 +16,7 @@ """ import logging -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from twisted.internet.protocol import ReconnectingClientFactory @@ -86,37 +86,22 @@ class ReplicationDataHandler: def __init__(self, store: BaseSlavedStore): self.store = store - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to handle more. Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. + stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, token, rows) - def get_streams_to_replicate(self) -> Dict[str, int]: - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - args = self.store.stream_positions() - user_account_data = args.pop("user_account_data", None) - room_account_data = args.pop("room_account_data", None) - if user_account_data: - args["account_data"] = user_account_data - elif room_account_data: - args["account_data"] = room_account_data - return args - async def on_position(self, stream_name: str, token: int): self.store.process_replication_rows(stream_name, token, []) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6f7054d5af..2d1d119c7c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -278,19 +278,24 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self._replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: @@ -314,15 +319,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(cmd.stream_name, []) # Find where we previously streamed up to. - current_token = self._replication_data_handler.get_streams_to_replicate().get( - cmd.stream_name - ) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", - cmd.stream_name, - ) - return + current_token = stream.current_token() # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates @@ -333,7 +330,9 @@ class ReplicationCommandHandler: updates, current_token, missing_updates, - ) = await stream.get_updates_since(current_token, cmd.token) + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) # TODO: add some tests for this @@ -342,7 +341,10 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + cmd.stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 617e860f95..41c623d737 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -61,6 +61,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): outbound_redis_connection = None # type: txredisapi.RedisProtocol def connectionMade(self): + super().connectionMade() logger.info("Connected to redis instance") self.subscribe(self.stream_name) self.send_command(ReplicateCommand()) @@ -119,6 +120,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): logger.warning("Unhandled command: %r", cmd) def connectionLost(self, reason): + super().connectionLost(reason) logger.info("Lost connection to redis instance") self.handler.lost_connection(self) @@ -189,5 +191,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): p.handler = self.handler p.outbound_redis_connection = self.outbound_redis_connection p.stream_name = self.stream_name + p.password = self.password return p diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 33d2f589ac..b690abedad 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -80,7 +80,7 @@ class ReplicationStreamer(object): for stream in STREAMS_MAP.values(): if stream == FederationStream and hs.config.send_federation: # We only support federation stream if federation sending - # hase been disabled on the master. + # has been disabled on the master. continue self.streams.append(stream(hs)) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4af1afd119..084604e8b0 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # # The arguments are: # +# * instance_name: the writer of the stream # * from_token: the previous stream token: the starting point for fetching the # updates # * to_token: the new stream token: the point to get updates up to @@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # If there are more updates available, it should set `limited` in the result, and # it will be called again to get the next batch. # -UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): @@ -93,6 +94,7 @@ class Stream(object): def __init__( self, + local_instance_name: str, current_token_function: Callable[[], Token], update_function: UpdateFunction, ): @@ -102,15 +104,18 @@ class Stream(object): implemented by subclasses. current_token_function is called to get the current token of the underlying - stream. + stream. It is only meaningful on the process that is the source of the + replication stream (ie, usually the master). update_function is called to get updates for this stream between a pair of stream tokens. See the UpdateFunction type definition for more info. Args: + local_instance_name: The instance name of the current process current_token_function: callback to get the current token, as above update_function: callback go get stream updates, as above """ + self.local_instance_name = local_instance_name self.current_token = current_token_function self.update_function = update_function @@ -135,14 +140,14 @@ class Stream(object): """ current_token = self.current_token() updates, current_token, limited = await self.get_updates_since( - self.last_token, current_token + self.local_instance_name, self.last_token, current_token ) self.last_token = current_token return updates, current_token, limited async def get_updates_since( - self, from_token: Token, upto_token: Token + self, instance_name: str, from_token: Token, upto_token: Token ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -160,19 +165,19 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ - async def update_function(from_token, upto_token, limit): + async def update_function(instance_name, from_token, upto_token, limit): rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] limited = False @@ -193,10 +198,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: client = ReplicationGetStreamUpdates.make_client(hs) async def update_function( - from_token: int, upto_token: int, limit: int + instance_name: str, from_token: int, upto_token: int, limit: int ) -> StreamUpdateResult: result = await client( - stream_name=stream_name, from_token=from_token, upto_token=upto_token, + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, ) return result["updates"], result["upto_token"], result["limited"] @@ -226,6 +234,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_backfill_token, db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -261,7 +270,9 @@ class PresenceStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(store.get_current_presence_token, update_function) + super().__init__( + hs.get_instance_name(), store.get_current_presence_token, update_function + ) class TypingStream(Stream): @@ -284,7 +295,9 @@ class TypingStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(typing_handler.get_current_token, update_function) + super().__init__( + hs.get_instance_name(), typing_handler.get_current_token, update_function + ) class ReceiptsStream(Stream): @@ -305,6 +318,7 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_receipt_stream_id, db_query_to_update_function(store.get_all_updated_receipts), ) @@ -322,14 +336,16 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super(PushRulesStream, self).__init__( - self._current_token, self._update_function + hs.get_instance_name(), self._current_token, self._update_function ) def _current_token(self) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function(self, from_token: Token, to_token: Token, limit: int): + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) limited = False @@ -356,6 +372,7 @@ class PushersStream(Stream): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_pushers_stream_token, db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -387,6 +404,7 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_cache_stream_token, db_query_to_update_function(store.get_all_updated_caches), ) @@ -412,6 +430,7 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_public_room_stream_id, db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -432,6 +451,7 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -449,6 +469,7 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_to_device_stream_token, db_query_to_update_function(store.get_all_new_device_messages), ) @@ -468,6 +489,7 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_account_data_stream_id, db_query_to_update_function(store.get_all_updated_tags), ) @@ -487,6 +509,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super().__init__( + hs.get_instance_name(), self.store.get_max_account_data_stream_id, db_query_to_update_function(self._update_function), ) @@ -517,6 +540,7 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_group_stream_token, db_query_to_update_function(store.get_all_groups_changes), ) @@ -534,6 +558,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function( store.get_all_user_signature_changes_for_remotes diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 52df81b1bd..890e75d827 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -118,11 +118,17 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() super().__init__( - self._store.get_current_events_token, self._update_function, + hs.get_instance_name(), + self._store.get_current_events_token, + self._update_function, ) async def _update_function( - self, from_token: Token, current_token: Token, target_row_count: int + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, ) -> StreamUpdateResult: # the events stream merges together three separate sources: diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 75133d7e40..b0505b8a2c 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,7 @@ # limitations under the License. from collections import namedtuple -from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function +from synapse.replication.tcp.streams._base import Stream, make_http_update_function class FederationStream(Stream): @@ -35,21 +35,33 @@ class FederationStream(Stream): ROW_TYPE = FederationStreamRow def __init__(self, hs): - # Not all synapse instances will have a federation sender instance, - # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, - # so we stub the stream out when that is the case. - if hs.config.worker_app is None or hs.should_send_federation(): + if hs.config.worker_app is None: + # master process: get updates from the FederationRemoteSendQueue. + # (if the master is configured to send federation itself, federation_sender + # will be a real FederationSender, which has stubs for current_token and + # get_replication_rows.) federation_sender = hs.get_federation_sender() current_token = federation_sender.get_current_token - update_function = db_query_to_update_function( - federation_sender.get_replication_rows - ) + update_function = federation_sender.get_replication_rows + + elif hs.should_send_federation(): + # federation sender: Query master process + update_function = make_http_update_function(hs, self.NAME) + current_token = self._stub_current_token + else: - current_token = lambda: 0 + # other worker: stub out the update function (we're not interested in + # any updates so when we get a POSITION we do nothing) update_function = self._stub_update_function + current_token = self._stub_current_token - super().__init__(current_token, update_function) + super().__init__(hs.get_instance_name(), current_token, update_function) + + @staticmethod + def _stub_current_token(): + # dummy current-token method for use on workers + return 0 @staticmethod - async def _stub_update_function(from_token, upto_token, limit): + async def _stub_update_function(instance_name, from_token, upto_token, limit): return [], upto_token, False diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 5736c56032..3bf330da49 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -16,8 +16,6 @@ import logging from six import iteritems, string_types -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError @@ -59,8 +57,7 @@ class ConsentServerNotices(object): self._consent_uri_builder = ConsentURIBuilder(hs.config) - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, and does so if so Args: @@ -78,7 +75,7 @@ class ConsentServerNotices(object): return self._users_in_progress.add(user_id) try: - u = yield self._store.get_user_by_id(user_id) + u = await self._store.get_user_by_id(user_id) if u["is_guest"] and not self._send_to_guests: # don't send to guests @@ -100,8 +97,8 @@ class ConsentServerNotices(object): content = copy_with_str_subst( self._server_notice_content, {"consent_uri": consent_uri} ) - yield self._server_notices_manager.send_notice(user_id, content) - yield self._store.user_set_consent_server_notice_sent( + await self._server_notices_manager.send_notice(user_id, content) + await self._store.user_set_consent_server_notice_sent( user_id, self._current_consent_version ) except SynapseError as e: diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index ce4a828894..d97166351e 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -16,8 +16,6 @@ import logging from six import iteritems -from twisted.internet import defer - from synapse.api.constants import ( EventTypes, LimitBlockingTypes, @@ -50,8 +48,7 @@ class ResourceLimitsServerNotices(object): self._notifier = hs.get_notifier() - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, this will be true in two cases. 1. The server has reached its limit does not reflect this @@ -74,13 +71,13 @@ class ResourceLimitsServerNotices(object): # Don't try and send server notices unless they've been enabled return - timestamp = yield self._store.user_last_seen_monthly_active(user_id) + timestamp = await self._store.user_last_seen_monthly_active(user_id) if timestamp is None: # This user will be blocked from receiving the notice anyway. # In practice, not sure we can ever get here return - room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user( + room_id = await self._server_notices_manager.get_or_create_notice_room_for_user( user_id ) @@ -88,10 +85,10 @@ class ResourceLimitsServerNotices(object): logger.warning("Failed to get server notices room") return - yield self._check_and_set_tags(user_id, room_id) + await self._check_and_set_tags(user_id, room_id) # Determine current state of room - currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + currently_blocked, ref_events = await self._is_room_currently_blocked(room_id) limit_msg = None limit_type = None @@ -99,7 +96,7 @@ class ResourceLimitsServerNotices(object): # Normally should always pass in user_id to check_auth_blocking # if you have it, but in this case are checking what would happen # to other users if they were to arrive. - yield self._auth.check_auth_blocking() + await self._auth.check_auth_blocking() except ResourceLimitError as e: limit_msg = e.msg limit_type = e.limit_type @@ -112,22 +109,21 @@ class ResourceLimitsServerNotices(object): # We have hit the MAU limit, but MAU alerting is disabled: # reset room if necessary and return if currently_blocked: - self._remove_limit_block_notification(user_id, ref_events) + await self._remove_limit_block_notification(user_id, ref_events) return if currently_blocked and not limit_msg: # Room is notifying of a block, when it ought not to be. - yield self._remove_limit_block_notification(user_id, ref_events) + await self._remove_limit_block_notification(user_id, ref_events) elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. - yield self._apply_limit_block_notification( + await self._apply_limit_block_notification( user_id, limit_msg, limit_type ) except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) - @defer.inlineCallbacks - def _remove_limit_block_notification(self, user_id, ref_events): + async def _remove_limit_block_notification(self, user_id, ref_events): """Utility method to remove limit block notifications from the server notices room. @@ -137,12 +133,13 @@ class ResourceLimitsServerNotices(object): limit blocking and need to be preserved. """ content = {"pinned": ref_events} - yield self._server_notices_manager.send_notice( + await self._server_notices_manager.send_notice( user_id, content, EventTypes.Pinned, "" ) - @defer.inlineCallbacks - def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): + async def _apply_limit_block_notification( + self, user_id, event_body, event_limit_type + ): """Utility method to apply limit block notifications in the server notices room. @@ -159,17 +156,16 @@ class ResourceLimitsServerNotices(object): "admin_contact": self._config.admin_contact, "limit_type": event_limit_type, } - event = yield self._server_notices_manager.send_notice( + event = await self._server_notices_manager.send_notice( user_id, content, EventTypes.Message ) content = {"pinned": [event.event_id]} - yield self._server_notices_manager.send_notice( + await self._server_notices_manager.send_notice( user_id, content, EventTypes.Pinned, "" ) - @defer.inlineCallbacks - def _check_and_set_tags(self, user_id, room_id): + async def _check_and_set_tags(self, user_id, room_id): """ Since server notices rooms were originally not with tags, important to check that tags have been set correctly @@ -177,20 +173,19 @@ class ResourceLimitsServerNotices(object): user_id(str): the user in question room_id(str): the server notices room for that user """ - tags = yield self._store.get_tags_for_room(user_id, room_id) + tags = await self._store.get_tags_for_room(user_id, room_id) need_to_set_tag = True if tags: if SERVER_NOTICE_ROOM_TAG in tags: # tag already present, nothing to do here need_to_set_tag = False if need_to_set_tag: - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) - @defer.inlineCallbacks - def _is_room_currently_blocked(self, room_id): + async def _is_room_currently_blocked(self, room_id): """ Determines if the room is currently blocked @@ -198,7 +193,7 @@ class ResourceLimitsServerNotices(object): room_id(str): The room id of the server notices room Returns: - + Deferred[Tuple[bool, List]]: bool: Is the room currently blocked list: The list of pinned events that are unrelated to limit blocking This list can be used as a convenience in the case where the block @@ -208,7 +203,7 @@ class ResourceLimitsServerNotices(object): currently_blocked = False pinned_state_event = None try: - pinned_state_event = yield self._state.get_current_state( + pinned_state_event = await self._state.get_current_state( room_id, event_type=EventTypes.Pinned ) except AuthError: @@ -219,7 +214,7 @@ class ResourceLimitsServerNotices(object): if pinned_state_event is not None: referenced_events = list(pinned_state_event.content.get("pinned", [])) - events = yield self._store.get_events(referenced_events) + events = await self._store.get_events(referenced_events) for event_id, event in iteritems(events): if event.type != EventTypes.Message: continue diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index bf0943f265..999c621b92 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -14,11 +14,9 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.types import UserID, create_requester -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -51,8 +49,7 @@ class ServerNoticesManager(object): """ return self._config.server_notices_mxid is not None - @defer.inlineCallbacks - def send_notice( + async def send_notice( self, user_id, event_content, type=EventTypes.Message, state_key=None ): """Send a notice to the given user @@ -68,8 +65,8 @@ class ServerNoticesManager(object): Returns: Deferred[FrozenEvent] """ - room_id = yield self.get_or_create_notice_room_for_user(user_id) - yield self.maybe_invite_user_to_room(user_id, room_id) + room_id = await self.get_or_create_notice_room_for_user(user_id) + await self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid requester = create_requester(system_mxid) @@ -86,13 +83,13 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = yield self._event_creation_handler.create_and_send_nonmember_event( + res = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) return res - @cachedInlineCallbacks() - def get_or_create_notice_room_for_user(self, user_id): + @cached() + async def get_or_create_notice_room_for_user(self, user_id): """Get the room for notices for a given user If we have not yet created a notice room for this user, create it, but don't @@ -109,7 +106,7 @@ class ServerNoticesManager(object): assert self._is_mine_id(user_id), "Cannot send server notices to remote users" - rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) for room in rooms: @@ -118,7 +115,7 @@ class ServerNoticesManager(object): # be joined. This is kinda deliberate, in that if somebody somehow # manages to invite the system user to a room, that doesn't make it # the server notices room. - user_ids = yield self._store.get_users_in_room(room.room_id) + user_ids = await self._store.get_users_in_room(room.room_id) if self.server_notices_mxid in user_ids: # we found a room which our user shares with the system notice # user @@ -146,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, @@ -158,7 +155,7 @@ class ServerNoticesManager(object): ) room_id = info["room_id"] - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) @@ -166,8 +163,7 @@ class ServerNoticesManager(object): logger.info("Created server notices room %s for %s", room_id, user_id) return room_id - @defer.inlineCallbacks - def maybe_invite_user_to_room(self, user_id: str, room_id: str): + async def maybe_invite_user_to_room(self, user_id: str, room_id: str): """Invite the given user to the given server room, unless the user has already joined or been invited to it. @@ -179,14 +175,14 @@ class ServerNoticesManager(object): # Check whether the user has already joined or been invited to this room. If # that's the case, there is no need to re-invite them. - joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) for room in joined_rooms: if room.room_id == room_id: return - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( requester=requester, target=UserID.from_string(user_id), room_id=room_id, diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 652bab58e3..be74e86641 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, @@ -36,18 +34,16 @@ class ServerNoticesSender(object): ResourceLimitsServerNotices(hs), ) - @defer.inlineCallbacks - def on_user_syncing(self, user_id): + async def on_user_syncing(self, user_id): """Called when the user performs a sync operation. Args: user_id (str): mxid of user who synced """ for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) - @defer.inlineCallbacks - def on_user_ip(self, user_id): + async def on_user_ip(self, user_id): """Called on the master when a worker process saw a client request. Args: @@ -57,4 +53,4 @@ class ServerNoticesSender(object): # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 3e53c8568a..efcdd2100b 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="delete_account_validity_for_user", ) - @defer.inlineCallbacks - def is_server_admin(self, user): + async def is_server_admin(self, user): """Determines if a user is an admin of this homeserver. Args: @@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self.db.simple_select_one_onecol( + res = await self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql index 163529c071..bbdde121e8 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql @@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN ( 'populate_stats_cleanup' ); +-- this relies on current_state_events.membership having been populated, so add +-- a dependency on current_state_events_membership. INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_rooms', '{}', ''); + ('populate_stats_process_rooms', '{}', 'current_state_events_membership'); +-- this also relies on current_state_events.membership having been populated, but +-- we get that as a side-effect of depending on populate_stats_process_rooms. INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9d851beaa5..86d04ea9ac 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -16,6 +16,11 @@ import contextlib import threading from collections import deque +from typing import Dict, Set, Tuple + +from typing_extensions import Deque + +from synapse.storage.database import Database, LoggingTransaction class IdGenerator(object): @@ -87,7 +92,7 @@ class StreamIdGenerator(object): self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[int] def get_next(self): """ @@ -163,7 +168,7 @@ class ChainedIdGenerator(object): self.chained_generator = chained_generator self._lock = threading.Lock() self._current_max = _load_current_id(db_conn, table, column) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] def get_next(self): """ @@ -198,3 +203,163 @@ class ChainedIdGenerator(object): return stream_id - 1, chained_id return self._current_max, self.chained_generator.get_current_token() + + +class MultiWriterIdGenerator: + """An ID generator that tracks a stream that can have multiple writers. + + Uses a Postgres sequence to coordinate ID assignment, but positions of other + writers will only get updated when `advance` is called (by replication). + + Note: Only works with Postgres. + + Args: + db_conn + db + instance_name: The name of this instance. + table: Database table associated with stream. + instance_column: Column that stores the row's writer's instance name + id_column: Column that stores the stream ID. + sequence_name: The name of the postgres sequence used to generate new + IDs. + """ + + def __init__( + self, + db_conn, + db: Database, + instance_name: str, + table: str, + instance_column: str, + id_column: str, + sequence_name: str, + ): + self._db = db + self._instance_name = instance_name + self._sequence_name = sequence_name + + # We lock as some functions may be called from DB threads. + self._lock = threading.Lock() + + self._current_positions = self._load_current_ids( + db_conn, table, instance_column, id_column + ) + + # Set of local IDs that we're still processing. The current position + # should be less than the minimum of this set (if not empty). + self._unfinished_ids = set() # type: Set[int] + + def _load_current_ids( + self, db_conn, table: str, instance_column: str, id_column: str + ) -> Dict[str, int]: + sql = """ + SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + GROUP BY %(instance)s + """ % { + "instance": instance_column, + "id": id_column, + "table": table, + } + + cur = db_conn.cursor() + cur.execute(sql) + + # `cur` is an iterable over returned rows, which are 2-tuples. + current_positions = dict(cur) + + cur.close() + + return current_positions + + def _load_next_id_txn(self, txn): + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + (next_id,) = txn.fetchone() + return next_id + + async def get_next(self): + """ + Usage: + with await stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) + + # Assert the fetched ID is actually greater than what we currently + # believe the ID to be. If not, then the sequence and table have got + # out of sync somehow. + assert self.get_current_token() < next_id + + with self._lock: + self._unfinished_ids.add(next_id) + + @contextlib.contextmanager + def manager(): + try: + yield next_id + finally: + self._mark_id_as_finished(next_id) + + return manager() + + def get_next_txn(self, txn: LoggingTransaction): + """ + Usage: + + stream_id = stream_id_gen.get_next(txn) + # ... persist event ... + """ + + next_id = self._load_next_id_txn(txn) + + with self._lock: + self._unfinished_ids.add(next_id) + + txn.call_after(self._mark_id_as_finished, next_id) + txn.call_on_exception(self._mark_id_as_finished, next_id) + + return next_id + + def _mark_id_as_finished(self, next_id: int): + """The ID has finished being processed so we should advance the + current poistion if possible. + """ + + with self._lock: + self._unfinished_ids.discard(next_id) + + # Figure out if its safe to advance the position by checking there + # aren't any lower allocated IDs that are yet to finish. + if all(c > next_id for c in self._unfinished_ids): + curr = self._current_positions.get(self._instance_name, 0) + self._current_positions[self._instance_name] = max(curr, next_id) + + def get_current_token(self, instance_name: str = None) -> int: + """Gets the current position of a named writer (defaults to current + instance). + + Returns 0 if we don't have a position for the named writer (likely due + to it being a new writer). + """ + + if instance_name is None: + instance_name = self._instance_name + + with self._lock: + return self._current_positions.get(instance_name, 0) + + def get_positions(self) -> Dict[str, int]: + """Get a copy of the current positon map. + """ + + with self._lock: + return dict(self._current_positions) + + def advance(self, instance_name: str, new_id: int): + """Advance the postion of the named writer to the given ID, if greater + than existing entry. + """ + + with self._lock: + self._current_positions[instance_name] = max( + new_id, self._current_positions.get(instance_name, 0) + ) |