diff options
author | Richard van der Hoff <richard@matrix.org> | 2020-05-06 11:43:11 +0100 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2020-05-06 11:43:11 +0100 |
commit | 30a19daa02cf61d2a811efb0fd619ff48ce6bcf6 (patch) | |
tree | c430aa536e76a2329831d04ae7d8d761ed2ec188 | |
parent | use an upsert to update device_lists_outbound_last_success (diff) | |
parent | Fix typing annotations in synapse/federation (#7382) (diff) | |
download | synapse-30a19daa02cf61d2a811efb0fd619ff48ce6bcf6.tar.xz |
Merge branch 'develop' into rav/upsert_for_device_list
74 files changed, 1242 insertions, 777 deletions
diff --git a/MANIFEST.in b/MANIFEST.in index 156d6f04f7..120ce5b776 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -30,23 +30,24 @@ recursive-include synapse/static *.gif recursive-include synapse/static *.html recursive-include synapse/static *.js -exclude Dockerfile +exclude .codecov.yml +exclude .coveragerc exclude .dockerignore -exclude test_postgresql.sh exclude .editorconfig +exclude Dockerfile +exclude mypy.ini exclude sytest-blacklist +exclude test_postgresql.sh include pyproject.toml recursive-include changelog.d * prune .buildkite prune .circleci -prune .codecov.yml -prune .coveragerc prune .github +prune contrib prune debian prune demo/etc prune docker -prune mypy.ini prune snap prune stubs diff --git a/UPGRADE.rst b/UPGRADE.rst index 768d94a393..d1408be2af 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,37 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.13.0 +==================== + +Incorrect database migration in old synapse versions +---------------------------------------------------- + +A bug was introduced in Synapse 1.4.0 which could cause the room directory to +be incomplete or empty if Synapse was upgraded directly from v1.2.1 or earlier, +to versions between v1.4.0 and v1.12.x. + +This will *not* be a problem for Synapse installations which were: + * created at v1.4.0 or later, + * upgraded via v1.3.x, or + * upgraded straight from v1.2.1 or earlier to v1.13.0 or later. + +If completeness of the room directory is a concern, installations which are +affected can be repaired as follows: + +1. Run the following sql from a `psql` or `sqlite3` console: + + .. code:: sql + + INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'current_state_events_membership'); + + INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); + +2. Restart synapse. + + Upgrading to v1.12.0 ==================== diff --git a/changelog.d/7172.misc b/changelog.d/7172.misc new file mode 100644 index 0000000000..ffecdf97fe --- /dev/null +++ b/changelog.d/7172.misc @@ -0,0 +1 @@ +Use `stream.current_token()` and remove `stream_positions()`. diff --git a/changelog.d/7219.misc b/changelog.d/7219.misc index 4af5da8646..dbf7a530be 100644 --- a/changelog.d/7219.misc +++ b/changelog.d/7219.misc @@ -1 +1 @@ -Add typing information to federation server code. +Add typing annotations in `synapse.federation`. diff --git a/changelog.d/7281.misc b/changelog.d/7281.misc new file mode 100644 index 0000000000..86ad511e19 --- /dev/null +++ b/changelog.d/7281.misc @@ -0,0 +1 @@ +Add MultiWriterIdGenerator to support multiple concurrent writers of streams. diff --git a/changelog.d/7363.misc b/changelog.d/7363.misc new file mode 100644 index 0000000000..1e3cddde79 --- /dev/null +++ b/changelog.d/7363.misc @@ -0,0 +1 @@ +Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await. \ No newline at end of file diff --git a/changelog.d/7368.bugfix b/changelog.d/7368.bugfix new file mode 100644 index 0000000000..efa8a40b1f --- /dev/null +++ b/changelog.d/7368.bugfix @@ -0,0 +1 @@ +Improve error responses when accessing remote public room lists. \ No newline at end of file diff --git a/changelog.d/7369.misc b/changelog.d/7369.misc new file mode 100644 index 0000000000..060b09c888 --- /dev/null +++ b/changelog.d/7369.misc @@ -0,0 +1 @@ +Thread through instance name to replication client. diff --git a/changelog.d/7374.misc b/changelog.d/7374.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7374.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/changelog.d/7382.misc b/changelog.d/7382.misc new file mode 100644 index 0000000000..dbf7a530be --- /dev/null +++ b/changelog.d/7382.misc @@ -0,0 +1 @@ +Add typing annotations in `synapse.federation`. diff --git a/changelog.d/7387.bugfix b/changelog.d/7387.bugfix new file mode 100644 index 0000000000..a250517b49 --- /dev/null +++ b/changelog.d/7387.bugfix @@ -0,0 +1 @@ +Fix a bug which would cause the room durectory to be incorrectly populated if Synapse was upgraded directly from v1.2.1 or earlier to v1.4.0 or later. Note that this fix does not apply retrospectively; see the [upgrade notes](UPGRADE.rst#upgrading-to-v1130) for more information. diff --git a/changelog.d/7394.misc b/changelog.d/7394.misc new file mode 100644 index 0000000000..f1390308b3 --- /dev/null +++ b/changelog.d/7394.misc @@ -0,0 +1 @@ +Convert synapse.server_notices to async/await. diff --git a/changelog.d/7395.misc b/changelog.d/7395.misc new file mode 100644 index 0000000000..bc0ad59e2c --- /dev/null +++ b/changelog.d/7395.misc @@ -0,0 +1 @@ +Convert synapse.notifier to async/await. diff --git a/changelog.d/7396.misc b/changelog.d/7396.misc new file mode 100644 index 0000000000..290d2befc7 --- /dev/null +++ b/changelog.d/7396.misc @@ -0,0 +1 @@ +Convert the room handler to async/await. diff --git a/changelog.d/7401.feature b/changelog.d/7401.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7401.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/changelog.d/7404.misc b/changelog.d/7404.misc new file mode 100644 index 0000000000..9ac17958cc --- /dev/null +++ b/changelog.d/7404.misc @@ -0,0 +1 @@ +Fix issues with the Python package manifest. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 763d3fb404..cac689d4f3 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -22,7 +22,10 @@ class RedisProtocol: def publish(self, channel: str, message: bytes): ... class SubscriberProtocol: + password: Optional[str] def subscribe(self, channels: Union[str, List[str]]): ... + def connectionMade(self): ... + def connectionLost(self, reason): ... def lazyConnection( host: str = ..., diff --git a/synapse/api/auth.py b/synapse/api/auth.py index c1ade1333b..c5d1eb952b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -537,8 +537,7 @@ class Auth(object): return defer.succeed(auth_ids) - @defer.inlineCallbacks - def check_can_change_room_list(self, room_id: str, user: UserID): + async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the published room list. @@ -547,17 +546,17 @@ class Auth(object): user """ - is_admin = yield self.is_server_admin(user) + is_admin = await self.is_server_admin(user) if is_admin: return True user_id = user.to_string() - yield self.check_user_in_room(room_id, user_id) + await self.check_user_in_room(room_id, user_id) # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the # m.room.canonical_alias events - power_level_event = yield self.state.get_current_state( + power_level_event = await self.state.get_current_state( room_id, EventTypes.PowerLevels, "" ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 0ace7b787d..667ad20428 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -413,12 +413,6 @@ class GenericWorkerTyping(object): # map room IDs to sets of users currently typing self._room_typing = {} - def stream_positions(self): - # We must update this typing token from the response of the previous - # sync. In particular, the stream id may "reset" back to zero/a low - # value which we *must* use for the next replication request. - return {"typing": self._latest_room_serial} - def process_replication_rows(self, token, rows): if self._latest_room_serial > token: # The master has gone backwards. To prevent inconsistent data, just @@ -652,20 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): else: self.send_handler = None - async def on_rdata(self, stream_name, token, rows): - await super(GenericWorkerReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - await self.process_and_notify(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + await self._process_and_notify(stream_name, instance_name, token, rows) - def get_streams_to_replicate(self): - args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate() - args.update(self.typing_handler.stream_positions()) - if self.send_handler: - args.update(self.send_handler.stream_positions()) - return args - - async def process_and_notify(self, stream_name, token, rows): + async def _process_and_notify(self, stream_name, instance_name, token, rows): try: if self.send_handler: await self.send_handler.process_replication_rows( @@ -799,9 +784,6 @@ class FederationSenderHandler(object): def wake_destination(self, server: str): self.federation_sender.wake_destination(server) - def stream_positions(self): - return {"federation": self.federation_position} - async def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a0071fec94..687cd841ac 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -883,18 +883,37 @@ class FederationClient(FederationBase): def get_public_rooms( self, - destination, - limit=None, - since_token=None, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, + remote_server: str, + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + include_all_networks: bool = False, + third_party_instance_id: Optional[str] = None, ): - if destination == self.server_name: - return + """Get the list of public rooms from a remote homeserver + + Args: + remote_server: The name of the remote server + limit: Maximum amount of rooms to return + since_token: Used for result pagination + search_filter: A filter dictionary to send the remote homeserver + and filter the result set + include_all_networks: Whether to include results from all third party instances + third_party_instance_id: Whether to only include results from a specific third + party instance + + Returns: + Deferred[Dict[str, Any]]: The response from the remote server, or None if + `remote_server` is the same as the local server_name + Raises: + HttpResponseException: There was an exception returned from the remote server + SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom + requests over federation + + """ return self.transport_layer.get_public_rooms( - destination, + remote_server, limit, since_token, search_filter, @@ -957,14 +976,13 @@ class FederationClient(FederationBase): return signed_events - @defer.inlineCallbacks - def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite(self, destinations, room_id, event_dict): for destination in destinations: if destination == self.server_name: continue try: - yield self.transport_layer.exchange_third_party_invite( + await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) return None diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index e1700ca8aa..52f4f54215 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,6 +31,7 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple +from typing import Dict, List, Tuple, Type from six import iteritems @@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.presence_map = {} # Pending presence map user_id -> UserPresenceState - self.presence_changed = SortedDict() # Stream position -> list[user_id] + # Pending presence map user_id -> UserPresenceState + self.presence_map = {} # type: Dict[str, UserPresenceState] + + # Stream position -> list[user_id] + self.presence_changed = SortedDict() # type: SortedDict[int, List[str]] # Stores the destinations we need to explicitly send presence to about a # given user. # Stream position -> (user_id, destinations) - self.presence_destinations = SortedDict() + self.presence_destinations = ( + SortedDict() + ) # type: SortedDict[int, Tuple[str, List[str]]] + + # (destination, key) -> EDU + self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] - self.keyed_edu = {} # (destination, key) -> EDU - self.keyed_edu_changed = SortedDict() # stream position -> (destination, key) + # stream position -> (destination, key) + self.keyed_edu_changed = ( + SortedDict() + ) # type: SortedDict[int, Tuple[str, tuple]] - self.edus = SortedDict() # stream position -> Edu + self.edus = SortedDict() # type: SortedDict[int, Edu] + # stream ID for the next entry into presence_changed/keyed_edu_changed/edus. self.pos = 1 - self.pos_time = SortedDict() + + # map from stream ID to the time that stream entry was generated, so that we + # can clear out entries after a while + self.pos_time = SortedDict() # type: SortedDict[int, int] # EVERYTHING IS SAD. In particular, python only makes new scopes when # we make a new function, so we need to make a new function so the inner @@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object): for edu_key in self.keyed_edu_changed.values(): live_keys.add(edu_key) - to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys] - for edu_key in to_del: + keys_to_del = [ + edu_key for edu_key in self.keyed_edu if edu_key not in live_keys + ] + for edu_key in keys_to_del: del self.keyed_edu[edu_key] # Delete things out of edu map @@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object): self._clear_queue_before_pos(token) async def get_replication_rows( - self, from_token, to_token, limit, federation_ack=None - ): + self, instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: """Get rows to be sent over federation between the two tokens Args: - from_token (int) - to_token(int) - limit (int) - federation_ack (int): Optional. The position where the worker is - explicitly acknowledged it has handled. Allows us to drop - data from before that point + instance_name: the name of the current process + from_token: the previous stream token: the starting point for fetching the + updates + to_token: the new stream token: the point to get updates up to + target_row_count: a target for the number of rows to be returned. + + Returns: a triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of `(token, row)` entries. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. """ - # TODO: Handle limit. + # TODO: Handle target_row_count. # To handle restarts where we wrap around if from_token > self.pos: @@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object): # list of tuple(int, BaseFederationRow), where the first is the position # of the federation stream. - rows = [] - - # There should be only one reader, so lets delete everything its - # acknowledged its seen. - if federation_ack: - self._clear_queue_before_pos(federation_ack) + rows = [] # type: List[Tuple[int, BaseFederationRow]] # Fetch changed presence i = self.presence_changed.bisect_right(from_token) @@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object): # Sort rows based on pos rows.sort() - return [(pos, row.TypeId, row.to_data()) for pos, row in rows] + return ( + [(pos, (row.TypeId, row.to_data())) for pos, row in rows], + to_token, + False, + ) class BaseFederationRow(object): @@ -341,7 +361,7 @@ class BaseFederationRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + TypeId = "" # Unique string that ids the type. Must be overriden in sub classes. @staticmethod def from_data(data): @@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu buff.edus.setdefault(self.edu.destination, []).append(self.edu) -TypeToRow = { - Row.TypeId: Row - for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,) -} +_rowtypes = ( + PresenceRow, + PresenceDestinationsRow, + KeyedEduRow, + EduRow, +) # type: Tuple[Type[BaseFederationRow], ...] + +TypeToRow = {Row.TypeId: Row for Row in _rowtypes} ParsedFederationStreamData = namedtuple( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index a477578e44..d473576902 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Hashable, Iterable, List, Optional, Set +from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple from six import itervalues @@ -498,14 +498,16 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() - def get_current_token(self) -> int: + @staticmethod + def get_current_token() -> int: # Dummy implementation for case where federation sender isn't offloaded # to a worker. return 0 + @staticmethod async def get_replication_rows( - self, from_token, to_token, limit, federation_ack=None - ): + instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: # Dummy implementation for case where federation sender isn't offloaded # to a worker. - return [] + return [], 0, False diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index e13cd20ffa..276a2b596f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,11 +15,10 @@ # limitations under the License. import datetime import logging -from typing import Dict, Hashable, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple from prometheus_client import Counter -import synapse.server from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter +if TYPE_CHECKING: + import synapse.server + # This is defined in the Matrix spec and enforced by the receiver. MAX_EDUS_PER_TRANSACTION = 100 diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 3c2a02a3b3..a2752a54a5 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import TYPE_CHECKING, List from canonicaljson import json -import synapse.server from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -31,6 +30,9 @@ from synapse.logging.opentracing import ( ) from synapse.util.metrics import measure_func +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 383e3fdc8b..060bf07197 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,13 +15,14 @@ # limitations under the License. import logging -from typing import Any, Dict +from typing import Any, Dict, Optional from six.moves import urllib from twisted.internet import defer from synapse.api.constants import Membership +from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.urls import ( FEDERATION_UNSTABLE_PREFIX, FEDERATION_V1_PREFIX, @@ -326,18 +327,25 @@ class TransportLayerClient(object): @log_function def get_public_rooms( self, - remote_server, - limit, - since_token, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, + remote_server: str, + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + include_all_networks: bool = False, + third_party_instance_id: Optional[str] = None, ): + """Get the list of public rooms from a remote homeserver + + See synapse.federation.federation_client.FederationClient.get_public_rooms for + more information. + """ if search_filter: # this uses MSC2197 (Search Filtering over Federation) path = _create_v1_path("/publicRooms") - data = {"include_all_networks": "true" if include_all_networks else "false"} + data = { + "include_all_networks": "true" if include_all_networks else "false" + } # type: Dict[str, Any] if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -347,9 +355,19 @@ class TransportLayerClient(object): data["filter"] = search_filter - response = yield self.client.post_json( - destination=remote_server, path=path, data=data, ignore_backoff=True - ) + try: + response = yield self.client.post_json( + destination=remote_server, path=path, data=data, ignore_backoff=True + ) + except HttpResponseException as e: + if e.code == 403: + raise SynapseError( + 403, + "You are not allowed to view the public rooms list of %s" + % (remote_server,), + errcode=Codes.FORBIDDEN, + ) + raise else: path = _create_v1_path("/publicRooms") @@ -363,9 +381,19 @@ class TransportLayerClient(object): if since_token: args["since"] = [since_token] - response = yield self.client.get_json( - destination=remote_server, path=path, args=args, ignore_backoff=True - ) + try: + response = yield self.client.get_json( + destination=remote_server, path=path, args=args, ignore_backoff=True + ) + except HttpResponseException as e: + if e.code == 403: + raise SynapseError( + 403, + "You are not allowed to view the public rooms list of %s" + % (remote_server,), + errcode=Codes.FORBIDDEN, + ) + raise return response diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 4f0dc0a209..4acb4fa489 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise NotImplementedError() - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from the group; either a user is leaving or an admin kicked them. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_kick = False if requester_user_id != user_id: - is_admin = yield self.store.is_user_admin_in_group( + is_admin = await self.store.is_user_admin_in_group( group_id, requester_user_id ) if not is_admin: @@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler): is_kick = True - yield self.store.remove_user_from_group(group_id, user_id) + await self.store.remove_user_from_group(group_id, user_id) if is_kick: if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) if not self.hs.is_mine_id(user_id): - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # Delete group if the last user has left - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) if not users: - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) return {} - @defer.inlineCallbacks - def create_group(self, group_id, requester_user_id, content): - group = yield self.check_group_is_ours(group_id, requester_user_id) + async def create_group(self, group_id, requester_user_id, content): + group = await self.check_group_is_ours(group_id, requester_user_id) logger.info("Attempting to create group with ID: %r", group_id) @@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) if not is_admin: @@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): long_description = profile.get("long_description") user_profile = content.get("user_profile", {}) - yield self.store.create_group( + await self.store.create_group( group_id, requester_user_id, name=name, @@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if not self.hs.is_mine_id(requester_user_id): remote_attestation = content["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=requester_user_id, group_id=group_id ) @@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): local_attestation = None remote_attestation = None - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, requester_user_id, is_admin=True, @@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): ) if not self.hs.is_mine_id(requester_user_id): - yield self.store.add_remote_profile_cache( + await self.store.add_remote_profile_cache( requester_user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), @@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {"group_id": group_id} - @defer.inlineCallbacks - def delete_group(self, group_id, requester_user_id): + async def delete_group(self, group_id, requester_user_id): """Deletes a group, kicking out all current members. Only group admins or server admins can call this request @@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler): Deferred """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) # Only server admins or group admins can delete groups. - is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) + is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) if not is_admin: - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) @@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise SynapseError(403, "User is not an admin") # Before deleting the group lets kick everyone out of it - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) - @defer.inlineCallbacks - def _kick_user_from_group(user_id): + async def _kick_user_from_group(user_id): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # We kick users out in the order of: # 1. Non-admins @@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler): else: non_admins.append(u["user_id"]) - yield concurrently_execute(_kick_user_from_group, non_admins, 10) - yield concurrently_execute(_kick_user_from_group, admins, 10) - yield _kick_user_from_group(requester_user_id) + await concurrently_execute(_kick_user_from_group, non_admins, 10) + await concurrently_execute(_kick_user_from_group, admins, 10) + await _kick_user_from_group(requester_user_id) - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) def _parse_join_policy_from_contents(content): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 51413d910e..3b781d9836 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -126,30 +126,28 @@ class BaseHandler(object): retry_after_ms=int(1000 * (time_allowed - time_now)) ) - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, context=None): + async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids() - current_state = yield self.store.get_events( + current_state_ids = await context.get_current_state_ids() + current_state = await self.store.get_events( list(current_state_ids.values()) ) else: - current_state = yield self.state_handler.get_current_state( + current_state = await self.state_handler.get_current_state( event.room_id ) current_state = list(current_state.values()) logger.info("maybe_kick_guest_users %r", current_state) - yield self.kick_guest_users(current_state) + await self.kick_guest_users(current_state) - @defer.inlineCallbacks - def kick_guest_users(self, current_state): + async def kick_guest_users(self, current_state): for member_event in current_state: try: if member_event.type != EventTypes.Member: @@ -180,7 +178,7 @@ class BaseHandler(object): # homeserver. requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() - yield handler.update_membership( + await handler.update_membership( requester, target_user, member_event.room_id, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 53e5f585d9..f2f16b1e43 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler): room_alias, room_id, servers, creator=creator ) - @defer.inlineCallbacks - def create_association( + async def create_association( self, requester: Requester, room_alias: RoomAlias, @@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler): else: # Server admins are not subject to the same constraints as normal # users when creating an alias (e.g. being in the room). - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) if (self.require_membership and check_membership) and not is_admin: - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + rooms_for_user = await self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( 403, "You must be in the room to create an alias for it" @@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias(room_alias, user_id=user_id) + can_create = await self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( 400, @@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - yield self._create_association(room_alias, room_id, servers, creator=user_id) + await self._create_association(room_alias, room_id, servers, creator=user_id) - @defer.inlineCallbacks - def delete_association(self, requester: Requester, room_alias: RoomAlias): + async def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call @@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler): user_id = requester.user.to_string() try: - can_delete = yield self._user_can_delete_alias(room_alias, user_id) + can_delete = await self._user_can_delete_alias(room_alias, user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown room alias") @@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler): if not can_delete: raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) + can_delete = await self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( 400, @@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - room_id = yield self._delete_association(room_alias) + room_id = await self._delete_association(room_alias) try: - yield self._update_canonical_alias(requester, user_id, room_id, room_alias) + await self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - @defer.inlineCallbacks - def _update_canonical_alias( + async def _update_canonical_alias( self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias ): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. """ - alias_event = yield self.state.get_current_state( + alias_event = await self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" ) @@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler): del content["alt_aliases"] if send_update: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler): # either no interested services, or no service with an exclusive lock return defer.succeed(True) - @defer.inlineCallbacks - def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): """Determine whether a user can delete an alias. One of the following must be true: @@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler): for the current room. """ - creator = yield self.store.get_room_alias_creator(alias.to_string()) + creator = await self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True # Resolve the alias to the corresponding room. - room_mapping = yield self.get_association(alias) + room_mapping = await self.get_association(alias) room_id = room_mapping["room_id"] if not room_id: return False - res = yield self.auth.check_can_change_room_list( + res = await self.auth.check_can_change_room_list( room_id, UserID.from_string(user_id) ) return res - @defer.inlineCallbacks - def edit_published_room_list( + async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str ): """Edit the entry of the room in the published room list. @@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler): 403, "This user is not permitted to publish rooms to the room list" ) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Unknown room") - can_change_room_list = yield self.auth.check_can_change_room_list( + can_change_room_list = await self.auth.check_can_change_room_list( room_id, requester.user ) if not can_change_room_list: @@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler): making_public = visibility == "public" if making_public: - room_aliases = yield self.store.get_aliases_for_room(room_id) - canonical_alias = yield self.store.get_canonical_alias_for_room(room_id) + room_aliases = await self.store.get_aliases_for_room(room_id) + canonical_alias = await self.store.get_canonical_alias_for_room(room_id) if canonical_alias: room_aliases.append(canonical_alias) @@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") - yield self.store.set_room_is_public(room_id, making_public) + await self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks def edit_published_appservice_room_list( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 41b96c0a73..4e5c645525 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler): "missing": [e.event_id for e in missing_locals], } - @defer.inlineCallbacks @log_function - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): third_party_invite = {"signed": signed} @@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler): "state_key": target_user_id, } - if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): - room_version = yield self.store.get_room_version_id(room_id) + if await self.auth.check_host_in_room(room_id, self.hs.hostname): + room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler): 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - event, context = yield self.add_display_name_to_third_party_invite( + event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) @@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e - yield self._check_signature(event, context) + await self._check_signature(event, context) # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() - yield member_handler.send_membership_event(None, event, context) + await member_handler.send_membership_event(None, event, context) else: destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} - yield self.federation_client.forward_third_party_invite( + await self.federation_client.forward_third_party_invite( destinations, room_id, event_dict ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ad22415782..ca5c83811a 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - @defer.inlineCallbacks - def create_group(self, group_id, user_id, content): + async def create_group(self, group_id, user_id, content): """Create a group """ logger.info("Asking to create group with ID: %r", group_id) if self.is_mine_id(group_id): - res = yield self.groups_server_handler.create_group( + res = await self.groups_server_handler.create_group( group_id, user_id, content ) local_attestation = None @@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = self.attestations.create_attestation(group_id, user_id) content["attestation"] = local_attestation - content["user_profile"] = yield self.profile_handler.get_profile(user_id) + content["user_profile"] = await self.profile_handler.get_profile(user_id) try: - res = yield self.transport_client.create_group( + res = await self.transport_client.create_group( get_domain_from_id(group_id), group_id, user_id, content ) except HttpResponseException as e: @@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): raise SynapseError(502, "Failed to contact group server") remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): ) is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from a group """ if user_id == requester_user_id: - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" ) self.notifier.on_new_event("groups_key", token, users=[user_id]) @@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): # retry if the group server is currently down. if self.is_mine_id(group_id): - res = yield self.groups_server_handler.remove_user_from_group( + res = await self.groups_server_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id try: - res = yield self.transport_client.remove_user_from_group( + res = await self.transport_client.remove_user_from_group( get_domain_from_id(group_id), group_id, requester_user_id, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 522271eed1..a324f09340 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -626,8 +626,7 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - @defer.inlineCallbacks - def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event(self, requester, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -647,7 +646,7 @@ class EventCreationHandler(object): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = yield self.deduplicate_state_event(event, context) + prev_state = await self.deduplicate_state_event(event, context) if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", @@ -656,7 +655,7 @@ class EventCreationHandler(object): ) return prev_state - yield self.handle_new_client_event( + await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -683,8 +682,7 @@ class EventCreationHandler(object): return prev_event return - @defer.inlineCallbacks - def create_and_send_nonmember_event( + async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None ): """ @@ -698,8 +696,8 @@ class EventCreationHandler(object): # a situation where event persistence can't keep up, causing # extremities to pile up, which in turn leads to state resolution # taking longer. - with (yield self.limiter.queue(event_dict["room_id"])): - event, context = yield self.create_event( + with (await self.limiter.queue(event_dict["room_id"])): + event, context = await self.create_event( requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) @@ -709,7 +707,7 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - yield self.send_nonmember_event( + await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) return event @@ -770,8 +768,7 @@ class EventCreationHandler(object): return (event, context) @measure_func("handle_new_client_event") - @defer.inlineCallbacks - def handle_new_client_event( + async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): """Processes a new event. This includes checking auth, persisting it, @@ -794,9 +791,9 @@ class EventCreationHandler(object): ): room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: - room_version = yield self.store.get_room_version_id(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -805,7 +802,7 @@ class EventCreationHandler(object): ) try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) raise err @@ -818,7 +815,7 @@ class EventCreationHandler(object): logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event(event, context) + await self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. @@ -826,7 +823,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - yield self.send_event_to_master( + await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -838,7 +835,7 @@ class EventCreationHandler(object): success = True return - yield self.persist_and_notify_client_event( + await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) @@ -883,8 +880,7 @@ class EventCreationHandler(object): Codes.BAD_ALIAS, ) - @defer.inlineCallbacks - def persist_and_notify_client_event( + async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): """Called when we have fully built the event, have already @@ -901,7 +897,7 @@ class EventCreationHandler(object): # user is actually admin or not). is_admin_redaction = False if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -913,11 +909,11 @@ class EventCreationHandler(object): original_event and event.sender != original_event.sender ) - yield self.base_handler.ratelimit( + await self.base_handler.ratelimit( requester, is_admin_redaction=is_admin_redaction ) - yield self.base_handler.maybe_kick_guest_users(event, context) + await self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Validate a newly added alias or newly added alt_aliases. @@ -927,7 +923,7 @@ class EventCreationHandler(object): original_event_id = event.unsigned.get("replaces_state") if original_event_id: - original_event = yield self.store.get_event(original_event_id) + original_event = await self.store.get_event(original_event_id) if original_event: original_alias = original_event.content.get("alias", None) @@ -937,7 +933,7 @@ class EventCreationHandler(object): room_alias_str = event.content.get("alias", None) directory_handler = self.hs.get_handlers().directory_handler if room_alias_str and room_alias_str != original_alias: - yield self._validate_canonical_alias( + await self._validate_canonical_alias( directory_handler, room_alias_str, event.room_id ) @@ -957,7 +953,7 @@ class EventCreationHandler(object): new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) if new_alt_aliases: for alias_str in new_alt_aliases: - yield self._validate_canonical_alias( + await self._validate_canonical_alias( directory_handler, alias_str, event.room_id ) @@ -969,7 +965,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids() + current_state_ids = await context.get_current_state_ids() state_to_include_ids = [ e_id @@ -978,7 +974,7 @@ class EventCreationHandler(object): or k == (EventTypes.Member, event.sender) ] - state_to_include = yield self.store.get_events(state_to_include_ids) + state_to_include = await self.store.get_events(state_to_include_ids) event.unsigned["invite_room_state"] = [ { @@ -996,8 +992,8 @@ class EventCreationHandler(object): # way? If we have been invited by a remote server, we need # to get them to sign the event. - returned_invite = yield defer.ensureDeferred( - federation_handler.send_invite(invitee.domain, event) + returned_invite = await federation_handler.send_invite( + invitee.domain, event ) event.unsigned.pop("room_state", None) @@ -1005,7 +1001,7 @@ class EventCreationHandler(object): event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -1021,14 +1017,14 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids() - auth_events_ids = yield self.auth.compute_auth_events( + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} - room_version = yield self.store.get_room_version_id(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if event_auth.check_redaction( @@ -1047,11 +1043,11 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( + event_stream_id, max_stream_id = await self.storage.persistence.persist_event( event, context=context ) @@ -1059,7 +1055,7 @@ class EventCreationHandler(object): # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: @@ -1083,13 +1079,12 @@ class EventCreationHandler(object): except Exception: logger.exception("Error bumping presence active time") - @defer.inlineCallbacks - def _send_dummy_events_to_fill_extremities(self): + async def _send_dummy_events_to_fill_extremities(self): """Background task to send dummy events into rooms that have a large number of extremities """ self._expire_rooms_to_exclude_from_dummy_event_insertion() - room_ids = yield self.store.get_rooms_with_many_extremities( + room_ids = await self.store.get_rooms_with_many_extremities( min_count=10, limit=5, room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), @@ -1099,9 +1094,9 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - latest_event_ids = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - members = yield self.state.get_current_users_in_room( + members = await self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids ) dummy_event_sent = False @@ -1110,7 +1105,7 @@ class EventCreationHandler(object): continue requester = create_requester(user_id) try: - event, context = yield self.create_event( + event, context = await self.create_event( requester, { "type": "org.matrix.dummy_event", @@ -1123,7 +1118,7 @@ class EventCreationHandler(object): event.internal_metadata.proactively_send = False - yield self.send_nonmember_event( + await self.send_nonmember_event( requester, event, context, ratelimit=False ) dummy_event_sent = True diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6aa1c0f5e0..302efc1b9a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler): return result["displayname"] - @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname, by_admin=False): + async def set_displayname( + self, target_user, requester, new_displayname, by_admin=False + ): """Set the displayname of a user Args: @@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's displayname") if not by_admin and not self.hs.config.enable_set_displayname: - profile = yield self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( 400, @@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - yield self.store.set_profile_displayname(target_user.localpart, new_displayname) + await self.store.set_profile_displayname(target_user.localpart, new_displayname) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler): return result["avatar_url"] - @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): + async def set_avatar_url( + self, target_user, requester, new_avatar_url, by_admin=False + ): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): @@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's avatar_url") if not by_admin and not self.hs.config.enable_set_avatar_url: - profile = yield self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN @@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def on_profile_query(self, args): @@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler): return response - @defer.inlineCallbacks - def _update_join_states(self, requester, target_user): + async def _update_join_states(self, requester, target_user): if not self.hs.is_mine(target_user): return - yield self.ratelimit(requester) + await self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) + room_ids = await self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() try: # Assume the target_user isn't a guest, # because we don't let guests set profile or avatar data. - yield handler.update_membership( + await handler.update_membership( requester, target_user, room_id, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3a65b46ecd..1e6bdac0ad 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler): """Registers a new client on the server. Args: - localpart : The local part of the user ID to register. If None, + localpart: The local part of the user ID to register. If None, one will be generated. - password (unicode) : The password to assign to this user so they can + password (unicode): The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). user_type (str|None): type of user. One of the values from @@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler): fail_count += 1 if not self.hs.config.user_consent_at_registration: - yield self._auto_join_rooms(user_id) + yield defer.ensureDeferred(self._auto_join_rooms(user_id)) else: logger.info( "Skipping auto-join for %s because consent is required at registration", @@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler): return user_id - @defer.inlineCallbacks - def _auto_join_rooms(self, user_id): + async def _auto_join_rooms(self, user_id): """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler): # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_real_user = yield self.store.is_real_user(user_id) + is_real_user = await self.store.is_real_user(user_id) if self.hs.config.autocreate_auto_join_rooms and is_real_user: - count = yield self.store.count_real_users() + count = await self.store.count_real_users() should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) @@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler): # getting the RoomCreationHandler during init gives a dependency # loop - yield self.hs.get_room_creation_handler().create_room( + await self.hs.get_room_creation_handler().create_room( fake_requester, config={ "preset": "public_chat", @@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) else: - yield self._join_user_to_room(fake_requester, r) + await self._join_user_to_room(fake_requester, r) except ConsentNotGivenError as e: # Technically not necessary to pull out this error though # moving away from bare excepts is a good thing to do. @@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - @defer.inlineCallbacks - def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id): """A series of registration actions that can only be carried out once consent has been granted Args: user_id (str): The user to join """ - yield self._auto_join_rooms(user_id) + await self._auto_join_rooms(user_id) @defer.inlineCallbacks def appservice_register(self, user_localpart, as_token): @@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id += 1 return str(id) - @defer.inlineCallbacks - def _join_user_to_room(self, requester, room_identifier): + async def _join_user_to_room(self, requester, room_identifier): room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( + room_id, remote_room_hosts = await room_member_handler.lookup_room_alias( room_alias ) room_id = room_id.to_string() @@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler): 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield room_member_handler.update_membership( + await room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler): return (device_id, access_token) - @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions(self, user_id, auth_result, access_token): """A user has completed registration Args: @@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler): device, or None if `inhibit_login` enabled. """ if self.hs.config.worker_app: - yield self._post_registration_client( + await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token ) return @@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(user_id) + await self.store.upsert_monthly_active_user(user_id) - yield self._register_email_threepid(user_id, threepid, access_token) + await self._register_email_threepid(user_id, threepid, access_token) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid(user_id, threepid) + await self._register_msisdn_threepid(user_id, threepid) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented(user_id, self.hs.config.user_consent_version) + await self._on_user_consented(user_id, self.hs.config.user_consent_version) - @defer.inlineCallbacks - def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id, consent_version): """A user consented to the terms on registration Args: @@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version(user_id, consent_version) - yield self.post_consent_actions(user_id) + await self.store.user_set_consent_version(user_id, consent_version) + await self.post_consent_actions(user_id) @defer.inlineCallbacks def _register_email_threepid(self, user_id, threepid, token): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d10e4b2d9..73f9eeb399 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,8 +25,6 @@ from collections import OrderedDict from six import iteritems, string_types -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() - @defer.inlineCallbacks - def upgrade_room( + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): """Replace a room with a new room with a different version @@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler): Returns: Deferred[unicode]: the new room id """ - yield self.ratelimit(requester) + await self.ratelimit(requester) user_id = requester.user.to_string() @@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler): # If this user has sent multiple upgrade requests for the same room # and one of them is not complete yet, cache the response and # return it to all subsequent requests - ret = yield self._upgrade_response_cache.wrap( + ret = await self._upgrade_response_cache.wrap( (old_room_id, user_id), self._upgrade_room, requester, @@ -148,17 +145,16 @@ class RoomCreationHandler(BaseHandler): return ret - @defer.inlineCallbacks - def _upgrade_room( + async def _upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): user_id = requester.user.to_string() # start by allocating a new room id - r = yield self.store.get_room(old_room_id) + r = await self.store.get_room(old_room_id) if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) - new_room_id = yield self._generate_room_id( + new_room_id = await self._generate_room_id( creator_id=user_id, is_public=r["is_public"], room_version=new_version, ) @@ -169,7 +165,7 @@ class RoomCreationHandler(BaseHandler): ( tombstone_event, tombstone_context, - ) = yield self.event_creation_handler.create_event( + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Tombstone, @@ -183,12 +179,12 @@ class RoomCreationHandler(BaseHandler): }, token_id=requester.access_token_id, ) - old_room_version = yield self.store.get_room_version_id(old_room_id) - yield self.auth.check_from_context( + old_room_version = await self.store.get_room_version_id(old_room_id) + await self.auth.check_from_context( old_room_version, tombstone_event, tombstone_context ) - yield self.clone_existing_room( + await self.clone_existing_room( requester, old_room_id=old_room_id, new_room_id=new_room_id, @@ -197,32 +193,31 @@ class RoomCreationHandler(BaseHandler): ) # now send the tombstone - yield self.event_creation_handler.send_nonmember_event( + await self.event_creation_handler.send_nonmember_event( requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids() + old_room_state = await tombstone_context.get_current_state_ids() # update any aliases - yield self._move_aliases_to_new_room( + await self._move_aliases_to_new_room( requester, old_room_id, new_room_id, old_room_state ) # Copy over user push rules, tags and migrate room directory state - yield self.room_member_handler.transfer_room_state_on_room_upgrade( + await self.room_member_handler.transfer_room_state_on_room_upgrade( old_room_id, new_room_id ) # finally, shut down the PLs in the old room, and update them in the new # room. - yield self._update_upgraded_room_pls( + await self._update_upgraded_room_pls( requester, old_room_id, new_room_id, old_room_state, ) return new_room_id - @defer.inlineCallbacks - def _update_upgraded_room_pls( + async def _update_upgraded_room_pls( self, requester: Requester, old_room_id: str, @@ -249,7 +244,7 @@ class RoomCreationHandler(BaseHandler): ) return - old_room_pl_state = yield self.store.get_event(old_room_pl_event_id) + old_room_pl_state = await self.store.get_event(old_room_pl_event_id) # we try to stop regular users from speaking by setting the PL required # to send regular events and invites to 'Moderator' level. That's normally @@ -278,7 +273,7 @@ class RoomCreationHandler(BaseHandler): if updated: try: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -292,7 +287,7 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -304,8 +299,7 @@ class RoomCreationHandler(BaseHandler): ratelimit=False, ) - @defer.inlineCallbacks - def clone_existing_room( + async def clone_existing_room( self, requester: Requester, old_room_id: str, @@ -338,7 +332,7 @@ class RoomCreationHandler(BaseHandler): # Check if old room was non-federatable # Get old room's create event - old_room_create_event = yield self.store.get_create_event_for_room(old_room_id) + old_room_create_event = await self.store.get_create_event_for_room(old_room_id) # Check if the create event specified a non-federatable room if not old_room_create_event.content.get("m.federate", True): @@ -361,11 +355,11 @@ class RoomCreationHandler(BaseHandler): (EventTypes.PowerLevels, ""), ) - old_room_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent - old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) + old_room_state_events = await self.store.get_events(old_room_state_ids.values()) for k, old_event_id in iteritems(old_room_state_ids): old_event = old_room_state_events.get(old_event_id) @@ -400,7 +394,7 @@ class RoomCreationHandler(BaseHandler): if current_power_level < needed_power_level: power_levels["users"][user_id] = needed_power_level - yield self._send_events_for_new_room( + await self._send_events_for_new_room( requester, new_room_id, # we expect to override all the presets with initial_state, so this is @@ -412,12 +406,12 @@ class RoomCreationHandler(BaseHandler): ) # Transfer membership events - old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_member_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent - old_room_member_state_events = yield self.store.get_events( + old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): @@ -426,7 +420,7 @@ class RoomCreationHandler(BaseHandler): "membership" in old_event.content and old_event.content["membership"] == "ban" ): - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(old_event["state_key"]), new_room_id, @@ -438,8 +432,7 @@ class RoomCreationHandler(BaseHandler): # XXX invites/joins # XXX 3pid invites - @defer.inlineCallbacks - def _move_aliases_to_new_room( + async def _move_aliases_to_new_room( self, requester: Requester, old_room_id: str, @@ -448,13 +441,13 @@ class RoomCreationHandler(BaseHandler): ): directory_handler = self.hs.get_handlers().directory_handler - aliases = yield self.store.get_aliases_for_room(old_room_id) + aliases = await self.store.get_aliases_for_room(old_room_id) # check to see if we have a canonical alias. canonical_alias_event = None canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event_id: - canonical_alias_event = yield self.store.get_event(canonical_alias_event_id) + canonical_alias_event = await self.store.get_event(canonical_alias_event_id) # first we try to remove the aliases from the old room (we suppress sending # the room_aliases event until the end). @@ -472,7 +465,7 @@ class RoomCreationHandler(BaseHandler): for alias_str in aliases: alias = RoomAlias.from_string(alias_str) try: - yield directory_handler.delete_association(requester, alias) + await directory_handler.delete_association(requester, alias) removed_aliases.append(alias_str) except SynapseError as e: logger.warning("Unable to remove alias %s from old room: %s", alias, e) @@ -485,7 +478,7 @@ class RoomCreationHandler(BaseHandler): # we can now add any aliases we successfully removed to the new room. for alias in removed_aliases: try: - yield directory_handler.create_association( + await directory_handler.create_association( requester, RoomAlias.from_string(alias), new_room_id, @@ -502,7 +495,7 @@ class RoomCreationHandler(BaseHandler): # alias event for the new room with a copy of the information. try: if canonical_alias_event: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -518,8 +511,9 @@ class RoomCreationHandler(BaseHandler): # we returned the new room to the client at this point. logger.error("Unable to send updated alias events in new room: %s", e) - @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): + async def create_room( + self, requester, config, ratelimit=True, creator_join_profile=None + ): """ Creates a new room. Args: @@ -547,7 +541,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) if ( self._server_notices_mxid is not None @@ -556,11 +550,11 @@ class RoomCreationHandler(BaseHandler): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. - event_allowed = yield self.third_party_event_rules.on_create_room( + event_allowed = await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) if not event_allowed: @@ -574,7 +568,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: - yield self.ratelimit(requester) + await self.ratelimit(requester) room_version_id = config.get( "room_version", self.config.default_room_version.identifier @@ -597,7 +591,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(400, "Invalid characters in room alias") room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) - mapping = yield self.store.get_association_from_room_alias(room_alias) + mapping = await self.store.get_association_from_room_alias(room_alias) if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) @@ -612,7 +606,7 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy(requester) + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") if ( @@ -631,13 +625,13 @@ class RoomCreationHandler(BaseHandler): visibility = config.get("visibility", None) is_public = visibility == "public" - room_id = yield self._generate_room_id( + room_id = await self._generate_room_id( creator_id=user_id, is_public=is_public, room_version=room_version, ) directory_handler = self.hs.get_handlers().directory_handler if room_alias: - yield directory_handler.create_association( + await directory_handler.create_association( requester=requester, room_id=room_id, room_alias=room_alias, @@ -670,7 +664,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - yield self._send_events_for_new_room( + await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -684,7 +678,7 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -698,7 +692,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -716,7 +710,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -730,7 +724,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - yield self.hs.get_room_member_handler().do_3pid_invite( + await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -748,8 +742,7 @@ class RoomCreationHandler(BaseHandler): return result - @defer.inlineCallbacks - def _send_events_for_new_room( + async def _send_events_for_new_room( self, creator, # A Requester object. room_id, @@ -769,11 +762,10 @@ class RoomCreationHandler(BaseHandler): return e - @defer.inlineCallbacks - def send(etype, content, **kwargs): + async def send(etype, content, **kwargs): event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) @@ -784,10 +776,10 @@ class RoomCreationHandler(BaseHandler): event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send(etype=EventTypes.Create, content=creation_content) + await send(etype=EventTypes.Create, content=creation_content) logger.debug("Sending %s in new room", EventTypes.Member) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( creator, creator.user, room_id, @@ -800,7 +792,7 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send(etype=EventTypes.PowerLevels, content=pl_content) + await send(etype=EventTypes.PowerLevels, content=pl_content) else: power_level_content = { "users": {creator_id: 100}, @@ -833,36 +825,35 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - yield send(etype=EventTypes.PowerLevels, content=power_level_content) + await send(etype=EventTypes.PowerLevels, content=power_level_content) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - yield send( + await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - yield send( + await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - yield send( + await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - yield send( + await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send(etype=etype, state_key=state_key, content=content) + await send(etype=etype, state_key=state_key, content=content) - @defer.inlineCallbacks - def _generate_room_id( + async def _generate_room_id( self, creator_id: str, is_public: str, room_version: RoomVersion, ): # autogen room IDs and try to create it. We may clash, so just @@ -874,7 +865,7 @@ class RoomCreationHandler(BaseHandler): gen_room_id = RoomID(random_string, self.hs.hostname).to_string() if isinstance(gen_room_id, bytes): gen_room_id = gen_room_id.decode("utf-8") - yield self.store.store_room( + await self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, is_public=is_public, @@ -893,8 +884,7 @@ class RoomContextHandler(object): self.storage = hs.get_storage() self.state_store = self.storage.state - @defer.inlineCallbacks - def get_event_context(self, user, room_id, event_id, limit, event_filter): + async def get_event_context(self, user, room_id, event_id, limit, event_filter): """Retrieves events, pagination tokens and state around a given event in a room. @@ -913,7 +903,7 @@ class RoomContextHandler(object): before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit - users = yield self.store.get_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) is_peeking = user.to_string() not in users def filter_evts(events): @@ -921,17 +911,17 @@ class RoomContextHandler(object): self.storage, user.to_string(), events, is_peeking=is_peeking ) - event = yield self.store.get_event( + event = await self.store.get_event( event_id, get_prev_content=True, allow_none=True ) if not event: return None - filtered = yield (filter_evts([event])) + filtered = await filter_evts([event]) if not filtered: raise AuthError(403, "You don't have permission to access that event.") - results = yield self.store.get_events_around( + results = await self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter ) @@ -939,8 +929,8 @@ class RoomContextHandler(object): results["events_before"] = event_filter.filter(results["events_before"]) results["events_after"] = event_filter.filter(results["events_after"]) - results["events_before"] = yield filter_evts(results["events_before"]) - results["events_after"] = yield filter_evts(results["events_after"]) + results["events_before"] = await filter_evts(results["events_before"]) + results["events_after"] = await filter_evts(results["events_after"]) # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. @@ -967,7 +957,7 @@ class RoomContextHandler(object): # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = yield self.state_store.get_state_for_events( + state = await self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) @@ -975,7 +965,7 @@ class RoomContextHandler(object): if event_filter: state_events = event_filter.filter(state_events) - results["state"] = yield filter_evts(state_events) + results["state"] = await filter_evts(state_events) # We use a dummy token here as we only care about the room portion of # the token, which we replace. @@ -994,13 +984,12 @@ class RoomEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_new_events( + async def get_new_events( self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None ): # We just ignore the key for now. - to_key = yield self.get_current_key() + to_key = await self.get_current_key() from_token = RoomStreamToken.parse(from_key) if from_token.topological: @@ -1013,11 +1002,11 @@ class RoomEventSource(object): # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() else: - room_events = yield self.store.get_membership_changes_for_user( + room_events = await self.store.get_membership_changes_for_user( user.to_string(), from_key, to_key ) - room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=room_ids, from_key=from_key, to_key=to_key, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c3ee8db4f0..53b49bc15f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -142,8 +142,7 @@ class RoomMemberHandler(object): """ raise NotImplementedError() - @defer.inlineCallbacks - def _local_membership_update( + async def _local_membership_update( self, requester, target, @@ -164,7 +163,7 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" - event, context = yield self.event_creation_handler.create_event( + event, context = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -182,18 +181,18 @@ class RoomMemberHandler(object): ) # Check if this event matches the previous membership event for the user. - duplicate = yield self.event_creation_handler.deduplicate_state_event( + duplicate = await self.event_creation_handler.deduplicate_state_event( event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. return duplicate - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -203,15 +202,15 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target, room_id) + await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target, room_id) + await self._user_left_room(target, room_id) return event @@ -253,8 +252,7 @@ class RoomMemberHandler(object): for tag, tag_content in room_tags.items(): yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) - @defer.inlineCallbacks - def update_membership( + async def update_membership( self, requester, target, @@ -269,8 +267,8 @@ class RoomMemberHandler(object): ): key = (room_id,) - with (yield self.member_linearizer.queue(key)): - result = yield self._update_membership( + with (await self.member_linearizer.queue(key)): + result = await self._update_membership( requester, target, room_id, @@ -285,8 +283,7 @@ class RoomMemberHandler(object): return result - @defer.inlineCallbacks - def _update_membership( + async def _update_membership( self, requester, target, @@ -321,7 +318,7 @@ class RoomMemberHandler(object): # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: - yield self.federation_handler.exchange_third_party_invite( + await self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -332,7 +329,7 @@ class RoomMemberHandler(object): remote_room_hosts = [] if effective_membership_state not in ("leave", "ban"): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -351,7 +348,7 @@ class RoomMemberHandler(object): is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -370,9 +367,9 @@ class RoomMemberHandler(object): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - latest_event_ids = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - current_state_ids = yield self.state_handler.get_current_state_ids( + current_state_ids = await self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids ) @@ -381,7 +378,7 @@ class RoomMemberHandler(object): # transitions and generic otherwise old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) if old_state_id: - old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( @@ -413,7 +410,7 @@ class RoomMemberHandler(object): old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): - is_blocked = yield self._is_server_notice_room(room_id) + is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( http_client.FORBIDDEN, @@ -424,18 +421,18 @@ class RoomMemberHandler(object): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = yield self._is_host_in_room(current_state_ids) + is_host_in_room = await self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(current_state_ids) + guest_can_join = await self._can_guest_join(current_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) @@ -443,13 +440,13 @@ class RoomMemberHandler(object): profile = self.profile_handler if not content_specified: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + content["displayname"] = await profile.get_displayname(target) + content["avatar_url"] = await profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" - remote_join_response = yield self._remote_join( + remote_join_response = await self._remote_join( requester, remote_room_hosts, room_id, target, content ) @@ -458,7 +455,7 @@ class RoomMemberHandler(object): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") @@ -472,12 +469,12 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - res = yield self._remote_reject_invite( + res = await self._remote_reject_invite( requester, remote_room_hosts, room_id, target, content, ) return res - res = yield self._local_membership_update( + res = await self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -572,8 +569,7 @@ class RoomMemberHandler(object): ) continue - @defer.inlineCallbacks - def send_membership_event(self, requester, event, context, ratelimit=True): + async def send_membership_event(self, requester, event, context, ratelimit=True): """ Change the membership status of a user in a room. @@ -599,27 +595,27 @@ class RoomMemberHandler(object): else: requester = types.create_requester(target_user) - prev_event = yield self.event_creation_handler.deduplicate_state_event( + prev_event = await self.event_creation_handler.deduplicate_state_event( event, context ) if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(prev_state_ids) + guest_can_join = await self._can_guest_join(prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if event.membership not in (Membership.LEAVE, Membership.BAN): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) @@ -633,15 +629,15 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target_user, room_id) + await self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target_user, room_id) + await self._user_left_room(target_user, room_id) @defer.inlineCallbacks def _can_guest_join(self, current_state_ids): @@ -699,8 +695,7 @@ class RoomMemberHandler(object): if invite: return UserID.from_string(invite.sender) - @defer.inlineCallbacks - def do_3pid_invite( + async def do_3pid_invite( self, room_id, inviter, @@ -712,7 +707,7 @@ class RoomMemberHandler(object): id_access_token=None, ): if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN @@ -720,9 +715,9 @@ class RoomMemberHandler(object): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - yield self.base_handler.ratelimit(requester) + await self.base_handler.ratelimit(requester) - can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( + can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id ) if not can_invite: @@ -737,16 +732,16 @@ class RoomMemberHandler(object): 403, "Looking up third-party identifiers is denied from this server" ) - invitee = yield self.identity_handler.lookup_3pid( + invitee = await self.identity_handler.lookup_3pid( id_server, medium, address, id_access_token ) if invitee: - yield self.update_membership( + await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - yield self._make_and_store_3pid_invite( + await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -757,8 +752,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - @defer.inlineCallbacks - def _make_and_store_3pid_invite( + async def _make_and_store_3pid_invite( self, requester, id_server, @@ -769,7 +763,7 @@ class RoomMemberHandler(object): txn_id, id_access_token=None, ): - room_state = yield self.state_handler.get_current_state(room_id) + room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" @@ -807,7 +801,7 @@ class RoomMemberHandler(object): public_keys, fallback_public_key, display_name, - ) = yield self.identity_handler.ask_id_server_for_third_party_invite( + ) = await self.identity_handler.ask_id_server_for_third_party_invite( requester=requester, id_server=id_server, medium=medium, @@ -823,7 +817,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity - @defer.inlineCallbacks - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: # Fetch the room complexity - too_complex = yield self._is_remote_room_too_complex( + too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts ) if too_complex is True: @@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - yield defer.ensureDeferred( - self.federation_handler.do_invite_join( - remote_room_hosts, room_id, user.to_string(), content - ) + await self.federation_handler.do_invite_join( + remote_room_hosts, room_id, user.to_string(), content ) - yield self._user_joined_room(user, room_id) + await self._user_joined_room(user, room_id) # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. @@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return # Check again, but with the local state events - too_complex = yield self._is_local_room_too_complex(room_id) + too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. @@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # The room is too large. Leave. requester = types.create_requester(user, None, False, None) - yield self.update_membership( + await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) raise SynapseError( @@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler): def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room """ - return user_joined_room(self.distributor, target, room_id) + return defer.succeed(user_joined_room(self.distributor, target, room_id)) def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ - return user_left_room(self.distributor, target, room_id) + return defer.succeed(user_left_room(self.distributor, target, room_id)) @defer.inlineCallbacks def forget(self, user, room_id): diff --git a/synapse/notifier.py b/synapse/notifier.py index 88a5a97caf..71d9ed62b0 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -273,10 +273,9 @@ class Notifier(object): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) - @defer.inlineCallbacks - def _notify_app_services(self, room_stream_id): + async def _notify_app_services(self, room_stream_id): try: - yield self.appservice_handler.notify_interested_services(room_stream_id) + await self.appservice_handler.notify_interested_services(room_stream_id) except Exception: logger.exception("Error notifying application services of event") @@ -475,20 +474,18 @@ class Notifier(object): return result - @defer.inlineCallbacks - def _get_room_ids(self, user, explicit_room_id): - joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) + async def _get_room_ids(self, user, explicit_room_id): + joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: return [explicit_room_id], True - if (yield self._is_world_readable(explicit_room_id)): + if await self._is_world_readable(explicit_room_id): return [explicit_room_id], False raise AuthError(403, "Non-joined access not allowed") return joined_room_ids, True - @defer.inlineCallbacks - def _is_world_readable(self, room_id): - state = yield self.state_handler.get_current_state( + async def _is_world_readable(self, room_id): + state = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 1be1ccbdf3..f88c80ae84 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,7 @@ import abc import logging import re +from inspect import signature from typing import Dict, List, Tuple from six import raise_from @@ -60,6 +61,8 @@ class ReplicationEndpoint(object): must call `register` to register the path with the HTTP server. Requests can be sent by calling the client returned by `make_client`. + Requests are sent to master process by default, but can be sent to other + named processes by specifying an `instance_name` keyword argument. Attributes: NAME (str): A name for the endpoint, added to the path as well as used @@ -91,6 +94,16 @@ class ReplicationEndpoint(object): hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) + # We reserve `instance_name` as a parameter to sending requests, so we + # assert here that sub classes don't try and use the name. + assert ( + "instance_name" not in self.PATH_ARGS + ), "`instance_name` is a reserved paramater name" + assert ( + "instance_name" + not in signature(self.__class__._serialize_payload).parameters + ), "`instance_name` is a reserved paramater name" + assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod @@ -135,7 +148,11 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks - def send_request(**kwargs): + def send_request(instance_name="master", **kwargs): + # Currently we only support sending requests to master process. + if instance_name != "master": + raise Exception("Unknown instance") + data = yield cls._serialize_payload(**kwargs) url_args = [ diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index f35cebc710..0459f582bf 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) + self._instance_name = hs.get_instance_name() + # We pull the streams from the replication steamer (if we try and make # them ourselves we end up in an import loop). self.streams = hs.get_replication_streamer().get_streams() @@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): upto_token = parse_integer(request, "upto_token", required=True) updates, upto_token, limited = await stream.get_updates_since( - from_token, upto_token + self._instance_name, from_token, upto_token ) return ( diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 751c799d94..5d7c8871a4 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Optional +from typing import Optional import six @@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): self.hs = hs - def stream_positions(self) -> Dict[str, int]: - """ - Get the current positions of all the streams this store wants to subscribe to - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - pos = {} - if self._cache_id_gen: - pos["caches"] = self._cache_id_gen.get_current_token() - return pos - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index ebe94909cb..65e54b1c71 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedAccountDataStore, self).stream_positions() - position = self._account_data_id_gen.get_current_token() - result["user_account_data"] = position - result["room_account_data"] = position - result["tag_account_data"] = position - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 0c237c6e0f..c923751e50 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def stream_positions(self): - result = super(SlavedDeviceInboxStore, self).stream_positions() - result["to_device"] = self._device_inbox_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 23b1650e41..58fb0eaae3 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) - def stream_positions(self): - result = super(SlavedDeviceStore, self).stream_positions() - # The user signature stream uses the same stream ID generator as the - # device list stream, so set them both to the device list ID - # generator's current token. - current_token = self._device_list_id_gen.get_current_token() - result[DeviceListsStream.NAME] = current_token - result[UserSignatureStream.NAME] = current_token - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index e73342c657..15011259df 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,12 +93,6 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedEventStore, self).stream_positions() - result["events"] = self._stream_id_gen.get_current_token() - result["backfill"] = -self._backfill_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "events": self._stream_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 2d4fd08cf5..01bcf0e882 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedGroupServerStore, self).stream_positions() - result["groups"] = self._group_updates_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index ad8f0c15a9..fae3125072 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPresenceStore, self).stream_positions() - - if self.hs.config.use_presence: - position = self._presence_id_gen.get_current_token() - result["presence"] = position - - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index eebd5a1fb6..6138796da4 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPushRuleStore, self).stream_positions() - result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index bce8a3d115..67be337945 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def stream_positions(self): - result = super(SlavedPusherStore, self).stream_positions() - result["pushers"] = self._pushers_id_gen.get_current_token() - return result - def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index d40dc6e1f5..993432edcb 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedReceiptsStore, self).stream_positions() - result["receipts"] = self._receipts_id_gen.get_current_token() - return result - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate_many((room_id,)) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 3a20f45316..10dda8708f 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def stream_positions(self): - result = super(RoomStore, self).stream_positions() - result["public_rooms"] = self._public_room_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2d07b8b2d0..3bbf3c3569 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,7 +16,7 @@ """ import logging -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from twisted.internet.protocol import ReconnectingClientFactory @@ -86,37 +86,22 @@ class ReplicationDataHandler: def __init__(self, store: BaseSlavedStore): self.store = store - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to handle more. Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. + stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, token, rows) - def get_streams_to_replicate(self) -> Dict[str, int]: - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - args = self.store.stream_positions() - user_account_data = args.pop("user_account_data", None) - room_account_data = args.pop("room_account_data", None) - if user_account_data: - args["account_data"] = user_account_data - elif room_account_data: - args["account_data"] = room_account_data - return args - async def on_position(self, stream_name: str, token: int): self.store.process_replication_rows(stream_name, token, []) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6f7054d5af..2d1d119c7c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -278,19 +278,24 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self._replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: @@ -314,15 +319,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(cmd.stream_name, []) # Find where we previously streamed up to. - current_token = self._replication_data_handler.get_streams_to_replicate().get( - cmd.stream_name - ) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", - cmd.stream_name, - ) - return + current_token = stream.current_token() # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates @@ -333,7 +330,9 @@ class ReplicationCommandHandler: updates, current_token, missing_updates, - ) = await stream.get_updates_since(current_token, cmd.token) + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) # TODO: add some tests for this @@ -342,7 +341,10 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + cmd.stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 617e860f95..41c623d737 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -61,6 +61,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): outbound_redis_connection = None # type: txredisapi.RedisProtocol def connectionMade(self): + super().connectionMade() logger.info("Connected to redis instance") self.subscribe(self.stream_name) self.send_command(ReplicateCommand()) @@ -119,6 +120,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): logger.warning("Unhandled command: %r", cmd) def connectionLost(self, reason): + super().connectionLost(reason) logger.info("Lost connection to redis instance") self.handler.lost_connection(self) @@ -189,5 +191,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): p.handler = self.handler p.outbound_redis_connection = self.outbound_redis_connection p.stream_name = self.stream_name + p.password = self.password return p diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 33d2f589ac..b690abedad 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -80,7 +80,7 @@ class ReplicationStreamer(object): for stream in STREAMS_MAP.values(): if stream == FederationStream and hs.config.send_federation: # We only support federation stream if federation sending - # hase been disabled on the master. + # has been disabled on the master. continue self.streams.append(stream(hs)) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4af1afd119..084604e8b0 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # # The arguments are: # +# * instance_name: the writer of the stream # * from_token: the previous stream token: the starting point for fetching the # updates # * to_token: the new stream token: the point to get updates up to @@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # If there are more updates available, it should set `limited` in the result, and # it will be called again to get the next batch. # -UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): @@ -93,6 +94,7 @@ class Stream(object): def __init__( self, + local_instance_name: str, current_token_function: Callable[[], Token], update_function: UpdateFunction, ): @@ -102,15 +104,18 @@ class Stream(object): implemented by subclasses. current_token_function is called to get the current token of the underlying - stream. + stream. It is only meaningful on the process that is the source of the + replication stream (ie, usually the master). update_function is called to get updates for this stream between a pair of stream tokens. See the UpdateFunction type definition for more info. Args: + local_instance_name: The instance name of the current process current_token_function: callback to get the current token, as above update_function: callback go get stream updates, as above """ + self.local_instance_name = local_instance_name self.current_token = current_token_function self.update_function = update_function @@ -135,14 +140,14 @@ class Stream(object): """ current_token = self.current_token() updates, current_token, limited = await self.get_updates_since( - self.last_token, current_token + self.local_instance_name, self.last_token, current_token ) self.last_token = current_token return updates, current_token, limited async def get_updates_since( - self, from_token: Token, upto_token: Token + self, instance_name: str, from_token: Token, upto_token: Token ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -160,19 +165,19 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ - async def update_function(from_token, upto_token, limit): + async def update_function(instance_name, from_token, upto_token, limit): rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] limited = False @@ -193,10 +198,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: client = ReplicationGetStreamUpdates.make_client(hs) async def update_function( - from_token: int, upto_token: int, limit: int + instance_name: str, from_token: int, upto_token: int, limit: int ) -> StreamUpdateResult: result = await client( - stream_name=stream_name, from_token=from_token, upto_token=upto_token, + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, ) return result["updates"], result["upto_token"], result["limited"] @@ -226,6 +234,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_backfill_token, db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -261,7 +270,9 @@ class PresenceStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(store.get_current_presence_token, update_function) + super().__init__( + hs.get_instance_name(), store.get_current_presence_token, update_function + ) class TypingStream(Stream): @@ -284,7 +295,9 @@ class TypingStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(typing_handler.get_current_token, update_function) + super().__init__( + hs.get_instance_name(), typing_handler.get_current_token, update_function + ) class ReceiptsStream(Stream): @@ -305,6 +318,7 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_receipt_stream_id, db_query_to_update_function(store.get_all_updated_receipts), ) @@ -322,14 +336,16 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super(PushRulesStream, self).__init__( - self._current_token, self._update_function + hs.get_instance_name(), self._current_token, self._update_function ) def _current_token(self) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function(self, from_token: Token, to_token: Token, limit: int): + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) limited = False @@ -356,6 +372,7 @@ class PushersStream(Stream): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_pushers_stream_token, db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -387,6 +404,7 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_cache_stream_token, db_query_to_update_function(store.get_all_updated_caches), ) @@ -412,6 +430,7 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_public_room_stream_id, db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -432,6 +451,7 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -449,6 +469,7 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_to_device_stream_token, db_query_to_update_function(store.get_all_new_device_messages), ) @@ -468,6 +489,7 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_account_data_stream_id, db_query_to_update_function(store.get_all_updated_tags), ) @@ -487,6 +509,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super().__init__( + hs.get_instance_name(), self.store.get_max_account_data_stream_id, db_query_to_update_function(self._update_function), ) @@ -517,6 +540,7 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_group_stream_token, db_query_to_update_function(store.get_all_groups_changes), ) @@ -534,6 +558,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function( store.get_all_user_signature_changes_for_remotes diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 52df81b1bd..890e75d827 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -118,11 +118,17 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() super().__init__( - self._store.get_current_events_token, self._update_function, + hs.get_instance_name(), + self._store.get_current_events_token, + self._update_function, ) async def _update_function( - self, from_token: Token, current_token: Token, target_row_count: int + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, ) -> StreamUpdateResult: # the events stream merges together three separate sources: diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 75133d7e40..b0505b8a2c 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,7 @@ # limitations under the License. from collections import namedtuple -from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function +from synapse.replication.tcp.streams._base import Stream, make_http_update_function class FederationStream(Stream): @@ -35,21 +35,33 @@ class FederationStream(Stream): ROW_TYPE = FederationStreamRow def __init__(self, hs): - # Not all synapse instances will have a federation sender instance, - # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, - # so we stub the stream out when that is the case. - if hs.config.worker_app is None or hs.should_send_federation(): + if hs.config.worker_app is None: + # master process: get updates from the FederationRemoteSendQueue. + # (if the master is configured to send federation itself, federation_sender + # will be a real FederationSender, which has stubs for current_token and + # get_replication_rows.) federation_sender = hs.get_federation_sender() current_token = federation_sender.get_current_token - update_function = db_query_to_update_function( - federation_sender.get_replication_rows - ) + update_function = federation_sender.get_replication_rows + + elif hs.should_send_federation(): + # federation sender: Query master process + update_function = make_http_update_function(hs, self.NAME) + current_token = self._stub_current_token + else: - current_token = lambda: 0 + # other worker: stub out the update function (we're not interested in + # any updates so when we get a POSITION we do nothing) update_function = self._stub_update_function + current_token = self._stub_current_token - super().__init__(current_token, update_function) + super().__init__(hs.get_instance_name(), current_token, update_function) + + @staticmethod + def _stub_current_token(): + # dummy current-token method for use on workers + return 0 @staticmethod - async def _stub_update_function(from_token, upto_token, limit): + async def _stub_update_function(instance_name, from_token, upto_token, limit): return [], upto_token, False diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 5736c56032..3bf330da49 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -16,8 +16,6 @@ import logging from six import iteritems, string_types -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError @@ -59,8 +57,7 @@ class ConsentServerNotices(object): self._consent_uri_builder = ConsentURIBuilder(hs.config) - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, and does so if so Args: @@ -78,7 +75,7 @@ class ConsentServerNotices(object): return self._users_in_progress.add(user_id) try: - u = yield self._store.get_user_by_id(user_id) + u = await self._store.get_user_by_id(user_id) if u["is_guest"] and not self._send_to_guests: # don't send to guests @@ -100,8 +97,8 @@ class ConsentServerNotices(object): content = copy_with_str_subst( self._server_notice_content, {"consent_uri": consent_uri} ) - yield self._server_notices_manager.send_notice(user_id, content) - yield self._store.user_set_consent_server_notice_sent( + await self._server_notices_manager.send_notice(user_id, content) + await self._store.user_set_consent_server_notice_sent( user_id, self._current_consent_version ) except SynapseError as e: diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index ce4a828894..d97166351e 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -16,8 +16,6 @@ import logging from six import iteritems -from twisted.internet import defer - from synapse.api.constants import ( EventTypes, LimitBlockingTypes, @@ -50,8 +48,7 @@ class ResourceLimitsServerNotices(object): self._notifier = hs.get_notifier() - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, this will be true in two cases. 1. The server has reached its limit does not reflect this @@ -74,13 +71,13 @@ class ResourceLimitsServerNotices(object): # Don't try and send server notices unless they've been enabled return - timestamp = yield self._store.user_last_seen_monthly_active(user_id) + timestamp = await self._store.user_last_seen_monthly_active(user_id) if timestamp is None: # This user will be blocked from receiving the notice anyway. # In practice, not sure we can ever get here return - room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user( + room_id = await self._server_notices_manager.get_or_create_notice_room_for_user( user_id ) @@ -88,10 +85,10 @@ class ResourceLimitsServerNotices(object): logger.warning("Failed to get server notices room") return - yield self._check_and_set_tags(user_id, room_id) + await self._check_and_set_tags(user_id, room_id) # Determine current state of room - currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + currently_blocked, ref_events = await self._is_room_currently_blocked(room_id) limit_msg = None limit_type = None @@ -99,7 +96,7 @@ class ResourceLimitsServerNotices(object): # Normally should always pass in user_id to check_auth_blocking # if you have it, but in this case are checking what would happen # to other users if they were to arrive. - yield self._auth.check_auth_blocking() + await self._auth.check_auth_blocking() except ResourceLimitError as e: limit_msg = e.msg limit_type = e.limit_type @@ -112,22 +109,21 @@ class ResourceLimitsServerNotices(object): # We have hit the MAU limit, but MAU alerting is disabled: # reset room if necessary and return if currently_blocked: - self._remove_limit_block_notification(user_id, ref_events) + await self._remove_limit_block_notification(user_id, ref_events) return if currently_blocked and not limit_msg: # Room is notifying of a block, when it ought not to be. - yield self._remove_limit_block_notification(user_id, ref_events) + await self._remove_limit_block_notification(user_id, ref_events) elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. - yield self._apply_limit_block_notification( + await self._apply_limit_block_notification( user_id, limit_msg, limit_type ) except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) - @defer.inlineCallbacks - def _remove_limit_block_notification(self, user_id, ref_events): + async def _remove_limit_block_notification(self, user_id, ref_events): """Utility method to remove limit block notifications from the server notices room. @@ -137,12 +133,13 @@ class ResourceLimitsServerNotices(object): limit blocking and need to be preserved. """ content = {"pinned": ref_events} - yield self._server_notices_manager.send_notice( + await self._server_notices_manager.send_notice( user_id, content, EventTypes.Pinned, "" ) - @defer.inlineCallbacks - def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): + async def _apply_limit_block_notification( + self, user_id, event_body, event_limit_type + ): """Utility method to apply limit block notifications in the server notices room. @@ -159,17 +156,16 @@ class ResourceLimitsServerNotices(object): "admin_contact": self._config.admin_contact, "limit_type": event_limit_type, } - event = yield self._server_notices_manager.send_notice( + event = await self._server_notices_manager.send_notice( user_id, content, EventTypes.Message ) content = {"pinned": [event.event_id]} - yield self._server_notices_manager.send_notice( + await self._server_notices_manager.send_notice( user_id, content, EventTypes.Pinned, "" ) - @defer.inlineCallbacks - def _check_and_set_tags(self, user_id, room_id): + async def _check_and_set_tags(self, user_id, room_id): """ Since server notices rooms were originally not with tags, important to check that tags have been set correctly @@ -177,20 +173,19 @@ class ResourceLimitsServerNotices(object): user_id(str): the user in question room_id(str): the server notices room for that user """ - tags = yield self._store.get_tags_for_room(user_id, room_id) + tags = await self._store.get_tags_for_room(user_id, room_id) need_to_set_tag = True if tags: if SERVER_NOTICE_ROOM_TAG in tags: # tag already present, nothing to do here need_to_set_tag = False if need_to_set_tag: - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) - @defer.inlineCallbacks - def _is_room_currently_blocked(self, room_id): + async def _is_room_currently_blocked(self, room_id): """ Determines if the room is currently blocked @@ -198,7 +193,7 @@ class ResourceLimitsServerNotices(object): room_id(str): The room id of the server notices room Returns: - + Deferred[Tuple[bool, List]]: bool: Is the room currently blocked list: The list of pinned events that are unrelated to limit blocking This list can be used as a convenience in the case where the block @@ -208,7 +203,7 @@ class ResourceLimitsServerNotices(object): currently_blocked = False pinned_state_event = None try: - pinned_state_event = yield self._state.get_current_state( + pinned_state_event = await self._state.get_current_state( room_id, event_type=EventTypes.Pinned ) except AuthError: @@ -219,7 +214,7 @@ class ResourceLimitsServerNotices(object): if pinned_state_event is not None: referenced_events = list(pinned_state_event.content.get("pinned", [])) - events = yield self._store.get_events(referenced_events) + events = await self._store.get_events(referenced_events) for event_id, event in iteritems(events): if event.type != EventTypes.Message: continue diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index bf0943f265..999c621b92 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -14,11 +14,9 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.types import UserID, create_requester -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -51,8 +49,7 @@ class ServerNoticesManager(object): """ return self._config.server_notices_mxid is not None - @defer.inlineCallbacks - def send_notice( + async def send_notice( self, user_id, event_content, type=EventTypes.Message, state_key=None ): """Send a notice to the given user @@ -68,8 +65,8 @@ class ServerNoticesManager(object): Returns: Deferred[FrozenEvent] """ - room_id = yield self.get_or_create_notice_room_for_user(user_id) - yield self.maybe_invite_user_to_room(user_id, room_id) + room_id = await self.get_or_create_notice_room_for_user(user_id) + await self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid requester = create_requester(system_mxid) @@ -86,13 +83,13 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = yield self._event_creation_handler.create_and_send_nonmember_event( + res = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) return res - @cachedInlineCallbacks() - def get_or_create_notice_room_for_user(self, user_id): + @cached() + async def get_or_create_notice_room_for_user(self, user_id): """Get the room for notices for a given user If we have not yet created a notice room for this user, create it, but don't @@ -109,7 +106,7 @@ class ServerNoticesManager(object): assert self._is_mine_id(user_id), "Cannot send server notices to remote users" - rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) for room in rooms: @@ -118,7 +115,7 @@ class ServerNoticesManager(object): # be joined. This is kinda deliberate, in that if somebody somehow # manages to invite the system user to a room, that doesn't make it # the server notices room. - user_ids = yield self._store.get_users_in_room(room.room_id) + user_ids = await self._store.get_users_in_room(room.room_id) if self.server_notices_mxid in user_ids: # we found a room which our user shares with the system notice # user @@ -146,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, @@ -158,7 +155,7 @@ class ServerNoticesManager(object): ) room_id = info["room_id"] - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) @@ -166,8 +163,7 @@ class ServerNoticesManager(object): logger.info("Created server notices room %s for %s", room_id, user_id) return room_id - @defer.inlineCallbacks - def maybe_invite_user_to_room(self, user_id: str, room_id: str): + async def maybe_invite_user_to_room(self, user_id: str, room_id: str): """Invite the given user to the given server room, unless the user has already joined or been invited to it. @@ -179,14 +175,14 @@ class ServerNoticesManager(object): # Check whether the user has already joined or been invited to this room. If # that's the case, there is no need to re-invite them. - joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) for room in joined_rooms: if room.room_id == room_id: return - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( requester=requester, target=UserID.from_string(user_id), room_id=room_id, diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 652bab58e3..be74e86641 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, @@ -36,18 +34,16 @@ class ServerNoticesSender(object): ResourceLimitsServerNotices(hs), ) - @defer.inlineCallbacks - def on_user_syncing(self, user_id): + async def on_user_syncing(self, user_id): """Called when the user performs a sync operation. Args: user_id (str): mxid of user who synced """ for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) - @defer.inlineCallbacks - def on_user_ip(self, user_id): + async def on_user_ip(self, user_id): """Called on the master when a worker process saw a client request. Args: @@ -57,4 +53,4 @@ class ServerNoticesSender(object): # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 3e53c8568a..efcdd2100b 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="delete_account_validity_for_user", ) - @defer.inlineCallbacks - def is_server_admin(self, user): + async def is_server_admin(self, user): """Determines if a user is an admin of this homeserver. Args: @@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self.db.simple_select_one_onecol( + res = await self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql index 163529c071..bbdde121e8 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql @@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN ( 'populate_stats_cleanup' ); +-- this relies on current_state_events.membership having been populated, so add +-- a dependency on current_state_events_membership. INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_rooms', '{}', ''); + ('populate_stats_process_rooms', '{}', 'current_state_events_membership'); +-- this also relies on current_state_events.membership having been populated, but +-- we get that as a side-effect of depending on populate_stats_process_rooms. INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9d851beaa5..86d04ea9ac 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -16,6 +16,11 @@ import contextlib import threading from collections import deque +from typing import Dict, Set, Tuple + +from typing_extensions import Deque + +from synapse.storage.database import Database, LoggingTransaction class IdGenerator(object): @@ -87,7 +92,7 @@ class StreamIdGenerator(object): self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[int] def get_next(self): """ @@ -163,7 +168,7 @@ class ChainedIdGenerator(object): self.chained_generator = chained_generator self._lock = threading.Lock() self._current_max = _load_current_id(db_conn, table, column) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] def get_next(self): """ @@ -198,3 +203,163 @@ class ChainedIdGenerator(object): return stream_id - 1, chained_id return self._current_max, self.chained_generator.get_current_token() + + +class MultiWriterIdGenerator: + """An ID generator that tracks a stream that can have multiple writers. + + Uses a Postgres sequence to coordinate ID assignment, but positions of other + writers will only get updated when `advance` is called (by replication). + + Note: Only works with Postgres. + + Args: + db_conn + db + instance_name: The name of this instance. + table: Database table associated with stream. + instance_column: Column that stores the row's writer's instance name + id_column: Column that stores the stream ID. + sequence_name: The name of the postgres sequence used to generate new + IDs. + """ + + def __init__( + self, + db_conn, + db: Database, + instance_name: str, + table: str, + instance_column: str, + id_column: str, + sequence_name: str, + ): + self._db = db + self._instance_name = instance_name + self._sequence_name = sequence_name + + # We lock as some functions may be called from DB threads. + self._lock = threading.Lock() + + self._current_positions = self._load_current_ids( + db_conn, table, instance_column, id_column + ) + + # Set of local IDs that we're still processing. The current position + # should be less than the minimum of this set (if not empty). + self._unfinished_ids = set() # type: Set[int] + + def _load_current_ids( + self, db_conn, table: str, instance_column: str, id_column: str + ) -> Dict[str, int]: + sql = """ + SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + GROUP BY %(instance)s + """ % { + "instance": instance_column, + "id": id_column, + "table": table, + } + + cur = db_conn.cursor() + cur.execute(sql) + + # `cur` is an iterable over returned rows, which are 2-tuples. + current_positions = dict(cur) + + cur.close() + + return current_positions + + def _load_next_id_txn(self, txn): + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + (next_id,) = txn.fetchone() + return next_id + + async def get_next(self): + """ + Usage: + with await stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) + + # Assert the fetched ID is actually greater than what we currently + # believe the ID to be. If not, then the sequence and table have got + # out of sync somehow. + assert self.get_current_token() < next_id + + with self._lock: + self._unfinished_ids.add(next_id) + + @contextlib.contextmanager + def manager(): + try: + yield next_id + finally: + self._mark_id_as_finished(next_id) + + return manager() + + def get_next_txn(self, txn: LoggingTransaction): + """ + Usage: + + stream_id = stream_id_gen.get_next(txn) + # ... persist event ... + """ + + next_id = self._load_next_id_txn(txn) + + with self._lock: + self._unfinished_ids.add(next_id) + + txn.call_after(self._mark_id_as_finished, next_id) + txn.call_on_exception(self._mark_id_as_finished, next_id) + + return next_id + + def _mark_id_as_finished(self, next_id: int): + """The ID has finished being processed so we should advance the + current poistion if possible. + """ + + with self._lock: + self._unfinished_ids.discard(next_id) + + # Figure out if its safe to advance the position by checking there + # aren't any lower allocated IDs that are yet to finish. + if all(c > next_id for c in self._unfinished_ids): + curr = self._current_positions.get(self._instance_name, 0) + self._current_positions[self._instance_name] = max(curr, next_id) + + def get_current_token(self, instance_name: str = None) -> int: + """Gets the current position of a named writer (defaults to current + instance). + + Returns 0 if we don't have a position for the named writer (likely due + to it being a new writer). + """ + + if instance_name is None: + instance_name = self._instance_name + + with self._lock: + return self._current_positions.get(instance_name, 0) + + def get_positions(self) -> Dict[str, int]: + """Get a copy of the current positon map. + """ + + with self._lock: + return dict(self._current_positions) + + def advance(self, instance_name: str, new_id: int): + """Advance the postion of the named writer to the given ID, if greater + than existing entry. + """ + + with self._lock: + self._current_positions[instance_name] = max( + new_id, self._current_positions.get(instance_name, 0) + ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index be665262c6..8aa56f1496 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), "Frank Jr.", ) # Set displayname again - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank" + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) ) self.assertEquals( @@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase): ) # Setting displayname a second time is forbidden - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) yield self.assertFailure(d, SynapseError) @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + ) ) yield self.assertFailure(d, AuthError) @@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) self.assertEquals( @@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase): ) # Set avatar again - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/me.png", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) ) self.assertEquals( @@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase): ) # Set avatar a second time is forbidden - d = self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + d = defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index f1dc51d6c9..1b7935cef2 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=False) + self.store.is_real_user = Mock(return_value=defer.succeed(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=1) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(1)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=2) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(2)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) - @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, password_hash=None): + async def get_or_create_user( + self, requester, localpart, displayname, password_hash=None + ): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self.hs.get_auth().check_auth_blocking() + await self.hs.get_auth().check_auth_blocking() need_register = True try: - yield self.handler.check_username(localpart) + await self.handler.check_username(localpart) except SynapseError as e: if e.errcode == Codes.USER_IN_USE: need_register = False @@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): token = self.macaroon_generator.generate_access_token(user_id) if need_register: - yield self.handler.register_with_store( + await self.handler.register_with_store( user_id=user_id, password_hash=password_hash, create_profile_with_displayname=user.localpart, ) else: - yield defer.ensureDeferred( - self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - ) + await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None ) if displayname is not None: # logger.info("setting user display name: %s -> %s", user_id, displayname) - yield self.hs.get_profile_handler().set_displayname( + await self.hs.get_profile_handler().set_displayname( user, requester, displayname, by_admin=True ) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 83e16cfe3d..9d4f0bbe44 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple import attr @@ -22,13 +22,16 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel -from synapse.app.generic_worker import GenericWorkerServer +from synapse.app.generic_worker import ( + GenericWorkerReplicationHandler, + GenericWorkerServer, +) from synapse.http.site import SynapseRequest -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.http import streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.server import HomeServer from synapse.util import Clock from tests import unittest @@ -40,6 +43,10 @@ logger = logging.getLogger(__name__) class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + servlets = [ + streams.register_servlets, + ] + def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) @@ -47,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.server = server_factory.buildProtocol(None) # Make a new HomeServer object for the worker - config = self.default_config() - config["worker_app"] = "synapse.app.generic_worker" - config["worker_replication_host"] = "testserv" - config["worker_replication_http_port"] = "8765" - self.reactor.lookups["testserv"] = "1.2.3.4" - self.worker_hs = self.setup_test_homeserver( http_client=None, homeserverToUse=GenericWorkerServer, - config=config, + config=self._get_worker_hs_config(), reactor=self.reactor, ) @@ -76,8 +77,15 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.generic_worker" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + def _build_replication_data_handler(self): - return TestReplicationDataHandler(self.worker_hs.get_datastore()) + return TestReplicationDataHandler(self.worker_hs) def reconnect(self): if self._client_transport: @@ -172,32 +180,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") -class TestReplicationDataHandler(ReplicationDataHandler): +class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, store: BaseSlavedStore): - super().__init__(store) - - # streams to subscribe to: map from stream id to position - self.stream_positions = {} # type: Dict[str, int] + def __init__(self, hs: HomeServer): + super().__init__(hs) # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - def get_streams_to_replicate(self): - return self.stream_positions - - async def on_rdata(self, stream_name, token, rows): - await super().on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) - if ( - stream_name in self.stream_positions - and token > self.stream_positions[stream_name] - ): - self.stream_positions[stream_name] = token - @attr.s() class OneShotRequestFactory: diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 1fa28084f9..8bd67bb9f1 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase): self.user_tok = self.login("u1", "pass") self.reconnect() - self.test_handler.stream_positions["events"] = 0 self.room_id = self.helper.create_room_as(tok=self.user_tok) self.test_handler.received_rdata_rows.clear() @@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + for event in events: stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) @@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # now we should have received all the expected rows in the right order. + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) # # we expect: # @@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # of the states that got reverted. # - two rows for state2 - received_rows = self.test_handler.received_rdata_rows + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] # first check the first two rows, which should be state1 @@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] self.assertGreaterEqual(len(received_rows), len(events)) for i in range(NUM_USERS): # for each user, we expect the PL event row, followed by state rows for diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py new file mode 100644 index 0000000000..eea4565da3 --- /dev/null +++ b/tests/replication/tcp/streams/test_federation.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.federation.send_queue import EduRow +from synapse.replication.tcp.streams.federation import FederationStream + +from tests.replication.tcp.streams._base import BaseStreamTestCase + + +class FederationStreamTestCase(BaseStreamTestCase): + def _get_worker_hs_config(self) -> dict: + # enable federation sending on the worker + config = super()._get_worker_hs_config() + # TODO: make it so we don't need both of these + config["send_federation"] = True + config["worker_app"] = "synapse.app.federation_sender" + return config + + def test_catchup(self): + """Basic test of catchup on reconnect + + Makes sure that updates sent while we are offline are received later. + """ + fed_sender = self.hs.get_federation_sender() + received_rows = self.test_handler.received_rdata_rows + + fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"}) + + self.reconnect() + self.reactor.advance(0) + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual(received_rows, []) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "federation") + + # we should have received an update row + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test_edu") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"a": "b"}) + + self.assertEqual(received_rows, []) + + # additional updates should be transferred without an HTTP hit + fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"}) + self.reactor.advance(0) + # there should be no http hit + self.assertEqual(len(self.reactor.tcpClients), 0) + # ... but we should have a row + self.assertEqual(len(received_rows), 1) + + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test1") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"c": "d"}) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index c122b8589c..5853314fd4 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.stream_positions.update({"receipts": 0}) - # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( @@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # there should be one RDATA command self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow @@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # We should now have caught up and get the missing data self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 4d354a9db8..125c63dab5 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -15,7 +15,6 @@ from mock import Mock from synapse.handlers.typing import RoomMember -from synapse.replication.http import streams from synapse.replication.tcp.streams import TypingStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -24,10 +23,6 @@ USER_ID = "@feeling:blue" class TypingStreamTestCase(BaseStreamTestCase): - servlets = [ - streams.register_servlets, - ] - def _build_replication_data_handler(self): return Mock(wraps=super()._build_replication_data_handler()) @@ -38,9 +33,6 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the typing stream - self.test_handler.stream_positions.update({"typing": 0}) - typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) self.reactor.advance(0) @@ -50,7 +42,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow @@ -77,7 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 93eb053b8c..406f29a7c0 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -55,26 +55,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) ) - self._send_notice = self._rlsn._server_notices_manager.send_notice - self._rlsn._server_notices_manager.send_notice = Mock() - self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_events = Mock(return_value=defer.succeed({})) - + self._rlsn._server_notices_manager.send_notice = Mock( + return_value=defer.succeed(Mock()) + ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.hs.config.limit_usage_by_mau = True self.user_id = "@user_id:test" - # self.server_notices_mxid = "@server:test" - # self.server_notices_mxid_display_name = None - # self.server_notices_mxid_avatar_url = None - # self.server_notices_room_name = "Server Notices" - self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - returnValue="" + return_value=defer.succeed("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock() - self._rlsn._store.get_tags_for_room = Mock(return_value={}) + self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) + self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) self.hs.config.admin_contact = "mailto:user@test.com" def test_maybe_send_server_notice_to_user_flag_off(self): @@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() @@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) mock_event = Mock( @@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) @@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): an alert message is not sent into the room """ self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER - ) + ), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) - self.assertTrue(self._send_notice.call_count == 0) + self.assertEqual(self._send_notice.call_count, 0) def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED - ) + ), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self.hs.config.mau_limit_alerting = False self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER - ) + ), ) + self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( return_value=defer.succeed((True, [])) ) @@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock(return_value=1000) + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(1000) + ) # Call the function multiple times to ensure we only send the notice once self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py new file mode 100644 index 0000000000..55e9ecf264 --- /dev/null +++ b/tests/storage/test_id_generators.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.storage.database import Database +from synapse.storage.util.id_generators import MultiWriterIdGenerator + +from tests.unittest import HomeserverTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class MultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db = self.store.db # type: Database + + self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + ) + + return self.get_success(self.db.runWithConnection(_create)) + + def _insert_rows(self, instance_name: str, number: int): + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", + (instance_name,), + ) + + self.get_success(self.db.runInteraction("test_single_instance", _insert)) + + def test_empty(self): + """Test an ID generator against an empty database gives sensible + current positions. + """ + + id_gen = self._create_id_generator() + + # The table is empty so we expect an empty map for positions + self.assertEqual(id_gen.get_positions(), {}) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(_get_next_async()) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) + + def test_multi_instance(self): + """Test that reads and writes from multiple processes are handled + correctly. + """ + self._insert_rows("first", 3) + self._insert_rows("second", 4) + + first_id_gen = self._create_id_generator("first") + second_id_gen = self._create_id_generator("second") + + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token("first"), 3) + self.assertEqual(first_id_gen.get_current_token("second"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await first_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) + + # However the ID gen on the second instance won't have seen the update + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + + # ... but calling `get_next` on the second instance should give a unique + # stream ID + + async def _get_next_async(): + with await second_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 9) + + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) + + # If the second ID gen gets told about the first, it correctly updates + second_id_gen.advance("first", 8) + self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) + + def test_get_next_txn(self): + """Test that the `get_next_txn` function works correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + def _get_next_txn(txn): + stream_id = id_gen.get_next_txn(txn) + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(self.db.runInteraction("test", _get_next_txn)) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) diff --git a/tests/test_federation.py b/tests/test_federation.py index 9b5cf562f3..f297de95f1 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = room_creator.create_room( - our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + room = ensureDeferred( + room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) ) self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] diff --git a/tox.ini b/tox.ini index eccc44e436..ad4ed8299e 100644 --- a/tox.ini +++ b/tox.ini @@ -181,11 +181,7 @@ commands = mypy \ synapse/appservice \ synapse/config \ synapse/events/spamcheck.py \ - synapse/federation/federation_base.py \ - synapse/federation/federation_client.py \ - synapse/federation/federation_server.py \ - synapse/federation/sender \ - synapse/federation/transport \ + synapse/federation \ synapse/handlers/auth.py \ synapse/handlers/cas_handler.py \ synapse/handlers/directory.py \ @@ -203,6 +199,7 @@ commands = mypy \ synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ + synapse/storage/util \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ tests/replication/tcp/streams \ |