From e26dbd82ef5f1d755be9a62165556ebce041af10 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Jul 2018 16:32:05 +0100 Subject: Add replication APIs for persisting federation events --- synapse/handlers/federation.py | 44 +++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 9 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 533b82c783..0524dec942 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -44,6 +44,8 @@ from synapse.crypto.event_signing import ( compute_event_signature, ) from synapse.events.validator import EventValidator +from synapse.replication.http.federation import send_federation_events_to_master +from synapse.replication.http.membership import notify_user_membership_change from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id from synapse.util import logcontext, unwrapFirstError @@ -86,6 +88,8 @@ class FederationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self._server_notices_mxid = hs.config.server_notices_mxid + self.config = hs.config + self.http_client = hs.get_simple_http_client() # When joining a room we need to queue any events for that room up self.room_queues = {} @@ -2288,7 +2292,7 @@ class FederationHandler(BaseHandler): for revocation. """ try: - response = yield self.hs.get_simple_http_client().get_json( + response = yield self.http_client.get_json( url, {"public_key": public_key} ) @@ -2313,14 +2317,25 @@ class FederationHandler(BaseHandler): Returns: Deferred """ - max_stream_id = yield self.store.persist_events( - event_and_contexts, - backfilled=backfilled, - ) + if self.config.worker_app: + yield send_federation_events_to_master( + clock=self.hs.get_clock(), + store=self.store, + client=self.http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + event_and_contexts=event_and_contexts, + backfilled=backfilled + ) + else: + max_stream_id = yield self.store.persist_events( + event_and_contexts, + backfilled=backfilled, + ) - if not backfilled: # Never notify for backfilled events - for event, _ in event_and_contexts: - self._notify_persisted_event(event, max_stream_id) + if not backfilled: # Never notify for backfilled events + for event, _ in event_and_contexts: + self._notify_persisted_event(event, max_stream_id) def _notify_persisted_event(self, event, max_stream_id): """Checks to see if notifier/pushers should be notified about the @@ -2359,9 +2374,20 @@ class FederationHandler(BaseHandler): ) def _clean_room_for_join(self, room_id): + # TODO move this out to master return self.store.clean_room_for_join(room_id) def user_joined_room(self, user, room_id): """Called when a new user has joined the room """ - return user_joined_room(self.distributor, user, room_id) + if self.config.worker_app: + return notify_user_membership_change( + client=self.http_client, + host=self.config.worker_replication_host, + port=self.config.worker_replication_http_port, + room_id=room_id, + user_id=user.to_string(), + change="joined", + ) + else: + return user_joined_room(self.distributor, user, room_id) -- cgit 1.5.1 From a3f5bf79a0fc0ea6d59069945f53717a3e9c6581 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jul 2018 11:44:22 +0100 Subject: Add EDU/query handling over replication --- synapse/federation/federation_server.py | 43 +++++++++++++++++++++++++++++++++ synapse/handlers/federation.py | 24 +++++++++--------- synapse/replication/http/federation.py | 2 +- synapse/server.py | 6 ++++- 4 files changed, 62 insertions(+), 13 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index bf89d568af..941e30a596 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -33,6 +33,10 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name +from synapse.replication.http.federation import ( + ReplicationFederationSendEduRestServlet, + ReplicationGetQueryRestServlet, +) from synapse.types import get_domain_from_id from synapse.util import async from synapse.util.caches.response_cache import ResponseCache @@ -745,6 +749,8 @@ class FederationHandlerRegistry(object): if edu_type in self.edu_handlers: raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + logger.info("Registering federation EDU handler for %r", edu_type) + self.edu_handlers[edu_type] = handler def register_query_handler(self, query_type, handler): @@ -763,6 +769,8 @@ class FederationHandlerRegistry(object): "Already have a Query handler for %s" % (query_type,) ) + logger.info("Registering federation query handler for %r", query_type) + self.query_handlers[query_type] = handler @defer.inlineCallbacks @@ -785,3 +793,38 @@ class FederationHandlerRegistry(object): raise NotFoundError("No handler for Query type '%s'" % (query_type,)) return handler(args) + + +class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): + def __init__(self, hs): + self.config = hs.config + self.http_client = hs.get_simple_http_client() + self.clock = hs.get_clock() + + self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) + self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) + + super(ReplicationFederationHandlerRegistry, self).__init__() + + def on_edu(self, edu_type, origin, content): + handler = self.edu_handlers.get(edu_type) + if handler: + return super(ReplicationFederationHandlerRegistry, self).on_edu( + edu_type, origin, content, + ) + + return self._send_edu( + edu_type=edu_type, + origin=origin, + content=content, + ) + + def on_query(self, query_type, args): + handler = self.query_handlers.get(query_type) + if handler: + return handler(args) + + return self._get_query_client( + query_type=query_type, + args=args, + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0524dec942..d2cbb12df3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -44,8 +44,10 @@ from synapse.crypto.event_signing import ( compute_event_signature, ) from synapse.events.validator import EventValidator -from synapse.replication.http.federation import send_federation_events_to_master -from synapse.replication.http.membership import notify_user_membership_change +from synapse.replication.http.federation import ( + ReplicationFederationSendEventsRestServlet, +) +from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id from synapse.util import logcontext, unwrapFirstError @@ -91,6 +93,13 @@ class FederationHandler(BaseHandler): self.config = hs.config self.http_client = hs.get_simple_http_client() + self._send_events_to_master = ( + ReplicationFederationSendEventsRestServlet.make_client(hs) + ) + self._notify_user_membership_change = ( + ReplicationUserJoinedLeftRoomRestServlet.make_client(hs) + ) + # When joining a room we need to queue any events for that room up self.room_queues = {} self._room_pdu_linearizer = Linearizer("fed_room_pdu") @@ -2318,12 +2327,8 @@ class FederationHandler(BaseHandler): Deferred """ if self.config.worker_app: - yield send_federation_events_to_master( - clock=self.hs.get_clock(), + yield self._send_events_to_master( store=self.store, - client=self.http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, event_and_contexts=event_and_contexts, backfilled=backfilled ) @@ -2381,10 +2386,7 @@ class FederationHandler(BaseHandler): """Called when a new user has joined the room """ if self.config.worker_app: - return notify_user_membership_change( - client=self.http_client, - host=self.config.worker_replication_host, - port=self.config.worker_replication_http_port, + return self._notify_user_membership_change( room_id=room_id, user_id=user.to_string(), change="joined", diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index f39aaa89be..3fa7bd64c7 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -59,8 +59,8 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.notifier = hs.get_notifier() self.pusher_pool = hs.get_pusherpool() - @defer.inlineCallbacks @staticmethod + @defer.inlineCallbacks def _serialize_payload(store, event_and_contexts, backfilled): """ Args: diff --git a/synapse/server.py b/synapse/server.py index 140be9ebe8..26228d8c72 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -36,6 +36,7 @@ from synapse.federation.federation_client import FederationClient from synapse.federation.federation_server import ( FederationHandlerRegistry, FederationServer, + ReplicationFederationHandlerRegistry, ) from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.transaction_queue import TransactionQueue @@ -423,7 +424,10 @@ class HomeServer(object): return RoomMemberMasterHandler(self) def build_federation_registry(self): - return FederationHandlerRegistry() + if self.config.worker_app: + return ReplicationFederationHandlerRegistry(self) + else: + return FederationHandlerRegistry() def build_server_notices_manager(self): if self.config.worker_app: -- cgit 1.5.1 From 360ba89c50ea5cbf824e54f04d536b89b57f3304 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 Aug 2018 11:54:55 +0100 Subject: Don't fail requests to unbind 3pids for non supporting ID servers Older identity servers may not support the unbind 3pid request, so we shouldn't fail the requests if we received one of 400/404/501. The request still fails if we receive e.g. 500 responses, allowing clients to retry requests on transient identity server errors that otherwise do support the API. Fixes #3661 --- changelog.d/3661.bugfix | 1 + synapse/handlers/auth.py | 20 +++++++++++++++++--- synapse/handlers/deactivate_account.py | 13 +++++++++++-- synapse/handlers/identity.py | 30 +++++++++++++++++++++--------- synapse/rest/client/v1/admin.py | 11 +++++++++-- synapse/rest/client/v2_alpha/account.py | 22 ++++++++++++++++++---- 6 files changed, 77 insertions(+), 20 deletions(-) create mode 100644 changelog.d/3661.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/3661.bugfix b/changelog.d/3661.bugfix new file mode 100644 index 0000000000..f2b4703d80 --- /dev/null +++ b/changelog.d/3661.bugfix @@ -0,0 +1 @@ +Fix bug on deleting 3pid when using identity servers that don't support unbind API diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 184eef09d0..da17e73fdd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -828,12 +828,26 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def delete_threepid(self, user_id, medium, address): + """Attempts to unbind the 3pid on the identity servers and deletes it + from the local database. + + Args: + user_id (str) + medium (str) + address (str) + + Returns: + Deferred[bool]: Returns True if successfully unbound the 3pid on + the identity server, False if identity server doesn't support the + unbind API. + """ + # 'Canonicalise' email addresses as per above if medium == 'email': address = address.lower() identity_handler = self.hs.get_handlers().identity_handler - yield identity_handler.unbind_threepid( + result = yield identity_handler.try_unbind_threepid( user_id, { 'medium': medium, @@ -841,10 +855,10 @@ class AuthHandler(BaseHandler): }, ) - ret = yield self.store.user_delete_threepid( + yield self.store.user_delete_threepid( user_id, medium, address, ) - defer.returnValue(ret) + defer.returnValue(result) def _save_session(self, session): # TODO: Persistent storage diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index b3c5a9ee64..b078df4a76 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -51,7 +51,8 @@ class DeactivateAccountHandler(BaseHandler): erase_data (bool): whether to GDPR-erase the user's data Returns: - Deferred + Deferred[bool]: True if identity server supports removing + threepids, otherwise False. """ # FIXME: Theoretically there is a race here wherein user resets # password using threepid. @@ -60,16 +61,22 @@ class DeactivateAccountHandler(BaseHandler): # leave the user still active so they can try again. # Ideally we would prevent password resets and then do this in the # background thread. + + # This will be set to false if the identity server doesn't support + # unbinding + identity_server_supports_unbinding = True + threepids = yield self.store.user_get_threepids(user_id) for threepid in threepids: try: - yield self._identity_handler.unbind_threepid( + result = yield self._identity_handler.try_unbind_threepid( user_id, { 'medium': threepid['medium'], 'address': threepid['address'], }, ) + identity_server_supports_unbinding &= result except Exception: # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") @@ -103,6 +110,8 @@ class DeactivateAccountHandler(BaseHandler): # parts users from rooms (if it isn't already running) self._start_user_parting() + defer.returnValue(identity_server_supports_unbinding) + def _start_user_parting(self): """ Start the process that goes through the table of users diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 1d36d967c3..a0f5fecc96 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -137,15 +137,19 @@ class IdentityHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def unbind_threepid(self, mxid, threepid): - """ - Removes a binding from an identity server + def try_unbind_threepid(self, mxid, threepid): + """Removes a binding from an identity server + Args: mxid (str): Matrix user ID of binding to be removed threepid (dict): Dict with medium & address of binding to be removed + Raises: + SynapseError: If we failed to contact the identity server + Returns: - Deferred[bool]: True on success, otherwise False + Deferred[bool]: True on success, otherwise False if the identity + server doesn't support unbinding """ logger.debug("unbinding threepid %r from %s", threepid, mxid) if not self.trusted_id_servers: @@ -175,11 +179,19 @@ class IdentityHandler(BaseHandler): content=content, destination_is=id_server, ) - yield self.http_client.post_json_get_json( - url, - content, - headers, - ) + try: + yield self.http_client.post_json_get_json( + url, + content, + headers, + ) + except HttpResponseException as e: + if e.code in (400, 404, 501,): + # The remote server probably doesn't support unbinding (yet) + defer.returnValue(False) + else: + raise SynapseError(502, "Failed to contact identity server") + defer.returnValue(True) @defer.inlineCallbacks diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 80d625eecc..ad536ab570 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -391,10 +391,17 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): if not is_admin: raise AuthError(403, "You are not a server admin") - yield self._deactivate_account_handler.deactivate_account( + result = yield self._deactivate_account_handler.deactivate_account( target_user_id, erase, ) - defer.returnValue((200, {})) + if result: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + defer.returnValue((200, { + "id_server_unbind_result": id_server_unbind_result, + })) class ShutdownRoomRestServlet(ClientV1RestServlet): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index eeae466d82..372648cafd 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -209,10 +209,17 @@ class DeactivateAccountRestServlet(RestServlet): yield self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request), ) - yield self._deactivate_account_handler.deactivate_account( + result = yield self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, ) - defer.returnValue((200, {})) + if result: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + defer.returnValue((200, { + "id_server_unbind_result": id_server_unbind_result, + })) class EmailThreepidRequestTokenRestServlet(RestServlet): @@ -364,7 +371,7 @@ class ThreepidDeleteRestServlet(RestServlet): user_id = requester.user.to_string() try: - yield self.auth_handler.delete_threepid( + ret = yield self.auth_handler.delete_threepid( user_id, body['medium'], body['address'] ) except Exception: @@ -374,7 +381,14 @@ class ThreepidDeleteRestServlet(RestServlet): logger.exception("Failed to remove threepid") raise SynapseError(500, "Failed to remove threepid") - defer.returnValue((200, {})) + if ret: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + defer.returnValue((200, { + "id_server_unbind_result": id_server_unbind_result, + })) class WhoamiRestServlet(RestServlet): -- cgit 1.5.1 From e92fb00f32c63de6ea50ba1cbbadf74060ea143d Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 8 Aug 2018 17:54:49 +0100 Subject: sync auth blocking --- synapse/handlers/sync.py | 16 +++++++++++----- tests/handlers/test_sync.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests/handlers/test_sync.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dff1f67dcb..f748d9afb0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -191,6 +191,7 @@ class SyncHandler(object): self.clock = hs.get_clock() self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() + self.auth = hs.get_auth() # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -198,18 +199,23 @@ class SyncHandler(object): max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) + @defer.inlineCallbacks def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. Returns: - A Deferred SyncResult. + Deferred[SyncResult] """ - return self.response_cache.wrap( - sync_config.request_key, - self._wait_for_sync_for_user, - sync_config, since_token, timeout, full_state, + yield self.auth.check_auth_blocking() + + defer.returnValue( + self.response_cache.wrap( + sync_config.request_key, + self._wait_for_sync_for_user, + sync_config, since_token, timeout, full_state, + ) ) @defer.inlineCallbacks diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py new file mode 100644 index 0000000000..3b1b4d4923 --- /dev/null +++ b/tests/handlers/test_sync.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +import tests.unittest +import tests.utils +from tests.utils import setup_test_homeserver +from synapse.handlers.sync import SyncHandler, SyncConfig +from synapse.types import UserID + + +class SyncTestCase(tests.unittest.TestCase): + """ Tests Sync Handler. """ + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.sync_handler = SyncHandler(self.hs) + + @defer.inlineCallbacks + def test_wait_for_sync_for_user_auth_blocking(self): + sync_config = SyncConfig( + user=UserID("@user","server"), + filter_collection=None, + is_guest=False, + request_key="request_key", + device_id="device_id", + ) + res = yield self.sync_handler.wait_for_sync_for_user(sync_config) + print res -- cgit 1.5.1 From bf7598f582cdfbd7db0fed55afce28bcc51a4801 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 9 Aug 2018 10:09:56 +0100 Subject: Log when we 3pid/unbind request fails --- synapse/handlers/identity.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index a0f5fecc96..5feb3f22a6 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -188,8 +188,10 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: if e.code in (400, 404, 501,): # The remote server probably doesn't support unbinding (yet) + logger.warn("Received %d response while unbinding threepid", e.code) defer.returnValue(False) else: + logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(502, "Failed to contact identity server") defer.returnValue(True) -- cgit 1.5.1 From b179537f2a51f4de52e2625939cc32eeba75cd6b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 9 Aug 2018 10:29:48 +0100 Subject: Move clean_room_for_join to master --- synapse/handlers/federation.py | 16 ++++++++++++++-- synapse/replication/http/federation.py | 35 ++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 37d2307d0a..acabca1d25 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -50,6 +50,7 @@ from synapse.crypto.event_signing import ( ) from synapse.events.validator import EventValidator from synapse.replication.http.federation import ( + ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet @@ -104,6 +105,9 @@ class FederationHandler(BaseHandler): self._notify_user_membership_change = ( ReplicationUserJoinedLeftRoomRestServlet.make_client(hs) ) + self._clean_room_for_join_client = ( + ReplicationCleanRoomRestServlet.make_client(hs) + ) # When joining a room we need to queue any events for that room up self.room_queues = {} @@ -2388,8 +2392,16 @@ class FederationHandler(BaseHandler): ) def _clean_room_for_join(self, room_id): - # TODO move this out to master - return self.store.clean_room_for_join(room_id) + """Called to clean up any data in DB for a given room, ready for the + server to join the room. + + Args: + room_id (str) + """ + if self.config.worker_app: + return self._clean_room_for_join_client(room_id) + else: + return self.store.clean_room_for_join(room_id) def user_joined_room(self, user, room_id): """Called when a new user has joined the room diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 3e6cbbf5a1..7b0b1cd32e 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -256,7 +256,42 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): defer.returnValue((200, result)) +class ReplicationCleanRoomRestServlet(ReplicationEndpoint): + """Called to clean up any data in DB for a given room, ready for the + server to join the room. + + Request format: + + POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id + + {} + """ + + NAME = "fed_cleanup_room" + PATH_ARGS = ("room_id",) + + def __init__(self, hs): + super(ReplicationCleanRoomRestServlet, self).__init__(hs) + + self.store = hs.get_datastore() + + @staticmethod + def _serialize_payload(room_id, args): + """ + Args: + room_id (str) + """ + return {} + + @defer.inlineCallbacks + def _handle_request(self, request, room_id): + yield self.store.clean_room_for_join(room_id) + + defer.returnValue((200, {})) + + def register_servlets(hs, http_server): ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) + ReplicationCleanRoomRestServlet(hs).register(http_server) -- cgit 1.5.1 From 69ce057ea613f425d5ef6ace03d0019a8e4fdf49 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 9 Aug 2018 12:26:27 +0100 Subject: block sync if auth checks fail --- synapse/handlers/sync.py | 12 +++++------- tests/handlers/test_sync.py | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 13 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f748d9afb0..776ddca638 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -209,14 +209,12 @@ class SyncHandler(object): Deferred[SyncResult] """ yield self.auth.check_auth_blocking() - - defer.returnValue( - self.response_cache.wrap( - sync_config.request_key, - self._wait_for_sync_for_user, - sync_config, since_token, timeout, full_state, - ) + res = yield self.response_cache.wrap( + sync_config.request_key, + self._wait_for_sync_for_user, + sync_config, since_token, timeout, full_state, ) + defer.returnValue(res) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 3b1b4d4923..497e4bd933 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -14,11 +14,14 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.errors import AuthError +from synapse.api.filtering import DEFAULT_FILTER_COLLECTION +from synapse.handlers.sync import SyncConfig, SyncHandler +from synapse.types import UserID + import tests.unittest import tests.utils from tests.utils import setup_test_homeserver -from synapse.handlers.sync import SyncHandler, SyncConfig -from synapse.types import UserID class SyncTestCase(tests.unittest.TestCase): @@ -32,11 +35,15 @@ class SyncTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_wait_for_sync_for_user_auth_blocking(self): sync_config = SyncConfig( - user=UserID("@user","server"), - filter_collection=None, + user=UserID("@user", "server"), + filter_collection=DEFAULT_FILTER_COLLECTION, is_guest=False, request_key="request_key", device_id="device_id", ) - res = yield self.sync_handler.wait_for_sync_for_user(sync_config) - print res + # Ensure that an exception is not thrown + yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.hs.config.hs_disabled = True + + with self.assertRaises(AuthError): + yield self.sync_handler.wait_for_sync_for_user(sync_config) -- cgit 1.5.1 From 09cf13089858902f3cdcb49b9f9bc3d214ba6337 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 9 Aug 2018 17:39:12 +0100 Subject: only block on sync where user is not part of the mau cohort --- synapse/api/auth.py | 13 +++++++++++-- synapse/handlers/sync.py | 7 ++++++- tests/handlers/test_sync.py | 40 +++++++++++++++++++++++++++++++--------- 3 files changed, 48 insertions(+), 12 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9c62ec4374..170039fc82 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -775,17 +775,26 @@ class Auth(object): ) @defer.inlineCallbacks - def check_auth_blocking(self): + def check_auth_blocking(self, user_id=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag + + Args: + user_id(str): If present, checks for presence against existing MAU cohort """ if self.hs.config.hs_disabled: raise AuthError( 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED ) if self.hs.config.limit_usage_by_mau is True: + # If the user is already part of the MAU cohort + if user_id: + timestamp = yield self.store._user_last_seen_monthly_active(user_id) + if timestamp: + return + # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 776ddca638..d3b26a4106 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -208,7 +208,12 @@ class SyncHandler(object): Returns: Deferred[SyncResult] """ - yield self.auth.check_auth_blocking() + # If the user is not part of the mau group, then check that limits have + # not been exceeded (if not part of the group by this point, almost certain + # auth_blocking will occur) + user_id = sync_config.user.to_string() + yield self.auth.check_auth_blocking(user_id) + res = yield self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 497e4bd933..b95a8743a7 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from twisted.internet import defer +from synapse.api.errors import AuthError, Codes -from synapse.api.errors import AuthError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.types import UserID @@ -31,19 +31,41 @@ class SyncTestCase(tests.unittest.TestCase): def setUp(self): self.hs = yield setup_test_homeserver() self.sync_handler = SyncHandler(self.hs) + self.store = self.hs.get_datastore() @defer.inlineCallbacks def test_wait_for_sync_for_user_auth_blocking(self): - sync_config = SyncConfig( - user=UserID("@user", "server"), + + user_id1 = "@user1:server" + user_id2 = "@user2:server" + sync_config = self._generate_sync_config(user_id1) + + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 1 + + # Check that the happy case does not throw errors + yield self.store.upsert_monthly_active_user(user_id1) + yield self.sync_handler.wait_for_sync_for_user(sync_config) + + # Test that global lock works + self.hs.config.hs_disabled = True + with self.assertRaises(AuthError) as e: + yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) + + self.hs.config.hs_disabled = False + + sync_config = self._generate_sync_config(user_id2) + + with self.assertRaises(AuthError) as e: + yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) + + def _generate_sync_config(self, user_id): + return SyncConfig( + user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), filter_collection=DEFAULT_FILTER_COLLECTION, is_guest=False, request_key="request_key", device_id="device_id", ) - # Ensure that an exception is not thrown - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.hs.config.hs_disabled = True - - with self.assertRaises(AuthError): - yield self.sync_handler.wait_for_sync_for_user(sync_config) -- cgit 1.5.1 From b37c4724191995eec4f5bc9d64f176b2a315f777 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 10 Aug 2018 23:50:21 +1000 Subject: Rename async to async_helpers because `async` is a keyword on Python 3.7 (#3678) --- changelog.d/3678.misc | 1 + synapse/app/federation_sender.py | 2 +- synapse/federation/federation_server.py | 8 +- synapse/handlers/device.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/http/client.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/notifier.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/mailer.py | 2 +- synapse/rest/client/transactions.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/state.py | 2 +- synapse/storage/events.py | 2 +- synapse/storage/roommember.py | 2 +- synapse/util/async.py | 415 -------------------------- synapse/util/async_helpers.py | 415 ++++++++++++++++++++++++++ synapse/util/caches/descriptors.py | 2 +- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/snapshot_cache.py | 2 +- synapse/util/logcontext.py | 2 +- tests/storage/test__base.py | 2 +- tests/util/test_linearizer.py | 2 +- tests/util/test_rwlock.py | 2 +- 34 files changed, 450 insertions(+), 449 deletions(-) create mode 100644 changelog.d/3678.misc delete mode 100644 synapse/util/async.py create mode 100644 synapse/util/async_helpers.py (limited to 'synapse/handlers') diff --git a/changelog.d/3678.misc b/changelog.d/3678.misc new file mode 100644 index 0000000000..0d7c8da64a --- /dev/null +++ b/changelog.d/3678.misc @@ -0,0 +1 @@ +Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index e082d45c94..7bbf0ad082 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,7 +40,7 @@ from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b62f687b6..a23136784a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -40,7 +40,7 @@ from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name from synapse.types import get_domain_from_id -from synapse.util import async +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.util.logutils import log_function @@ -67,8 +67,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler - self._server_linearizer = async.Linearizer("fed_server") - self._transaction_linearizer = async.Linearizer("fed_txn_handler") + self._server_linearizer = Linearizer("fed_server") + self._transaction_linearizer = Linearizer("fed_txn_handler") self.transaction_actions = TransactionActions(self.store) @@ -200,7 +200,7 @@ class FederationServer(FederationBase): event_id, f.getTraceback().rstrip(), ) - yield async.concurrently_execute( + yield concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2d44f15da3..9e017116a9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import FederationDeniedError from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import stringutils -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0dffd44e22..2380d17f4e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -52,7 +52,7 @@ from synapse.events.validator import EventValidator from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id from synapse.util import logcontext, unwrapFirstError -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.frozenutils import unfreeze from synapse.util.logutils import log_function diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 40e7580a61..1fb17fd9a5 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bcb093ba3e..01a362360e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -32,7 +32,7 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.types import RoomAlias, UserID -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index b2849783ed..a97d43550f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -22,7 +22,7 @@ from synapse.api.constants import Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event from synapse.types import RoomStreamToken -from synapse.util.async import ReadWriteLock +from synapse.util.async_helpers import ReadWriteLock from synapse.util.logcontext import run_in_background from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3732830194..20fc3b0323 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.logcontext import run_in_background from synapse.util.logutils import log_function diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 995460f82a..32108568c6 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from ._base import BaseHandler diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0e16bbe0ee..3526b20d5a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -28,7 +28,7 @@ from synapse.api.errors import ( ) from synapse.http.client import CaptchaServerHttpClient from synapse.types import RoomAlias, RoomID, UserID, create_requester -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.threepids import check_3pid_allowed from ._base import BaseHandler diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 828229f5c3..37e41afd61 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.types import ThirdPartyInstanceID -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0d4a3f4677..fb94b5d7d4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -30,7 +30,7 @@ import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.types import RoomID, UserID -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room logger = logging.getLogger(__name__) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dff1f67dcb..6393a9674b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user from synapse.types import RoomStreamToken -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.response_cache import ResponseCache diff --git a/synapse/http/client.py b/synapse/http/client.py index 3771e0b3f6..ab4fbf59b2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint -from synapse.util.async import add_timeout_to_deferred +from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 762273f59b..44b61e70a4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -43,7 +43,7 @@ from synapse.api.errors import ( from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint from synapse.util import logcontext -from synapse.util.async import add_timeout_to_deferred +from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) diff --git a/synapse/notifier.py b/synapse/notifier.py index e650c3e494..82f391481c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -25,7 +25,7 @@ from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state from synapse.metrics import LaterGauge from synapse.types import StreamToken -from synapse.util.async import ( +from synapse.util.async_helpers import ( DeferredTimeoutError, ObservableDeferred, add_timeout_to_deferred, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 1d14d3639c..8f9a76147f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache from synapse.util.caches.descriptors import cached diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 9d601208fd..bfa6df7b68 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -35,7 +35,7 @@ from synapse.push.presentable_names import ( name_from_member_event, ) from synapse.types import UserID -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 00b1b3066e..511e96ab00 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,7 +17,7 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 8fb413d825..4c589e05e0 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -36,7 +36,7 @@ from synapse.api.errors import ( ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import is_ascii, random_string diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 27aa0def2f..778ef97337 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -42,7 +42,7 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import is_ascii, random_string diff --git a/synapse/state.py b/synapse/state.py index e1092b97a9..8b92d4057a 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -28,7 +28,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ce32e8fefd..d4aa192a0a 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 10dce21cea..9b4e6d6aa8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii diff --git a/synapse/util/async.py b/synapse/util/async.py deleted file mode 100644 index a7094e2fb4..0000000000 --- a/synapse/util/async.py +++ /dev/null @@ -1,415 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import collections -import logging -from contextlib import contextmanager - -from six.moves import range - -from twisted.internet import defer -from twisted.internet.defer import CancelledError -from twisted.python import failure - -from synapse.util import Clock, logcontext, unwrapFirstError - -from .logcontext import ( - PreserveLoggingContext, - make_deferred_yieldable, - run_in_background, -) - -logger = logging.getLogger(__name__) - - -class ObservableDeferred(object): - """Wraps a deferred object so that we can add observer deferreds. These - observer deferreds do not affect the callback chain of the original - deferred. - - If consumeErrors is true errors will be captured from the origin deferred. - - Cancelling or otherwise resolving an observer will not affect the original - ObservableDeferred. - - NB that it does not attempt to do anything with logcontexts; in general - you should probably make_deferred_yieldable the deferreds - returned by `observe`, and ensure that the original deferred runs its - callbacks in the sentinel logcontext. - """ - - __slots__ = ["_deferred", "_observers", "_result"] - - def __init__(self, deferred, consumeErrors=False): - object.__setattr__(self, "_deferred", deferred) - object.__setattr__(self, "_result", None) - object.__setattr__(self, "_observers", set()) - - def callback(r): - object.__setattr__(self, "_result", (True, r)) - while self._observers: - try: - # TODO: Handle errors here. - self._observers.pop().callback(r) - except Exception: - pass - return r - - def errback(f): - object.__setattr__(self, "_result", (False, f)) - while self._observers: - try: - # TODO: Handle errors here. - self._observers.pop().errback(f) - except Exception: - pass - - if consumeErrors: - return None - else: - return f - - deferred.addCallbacks(callback, errback) - - def observe(self): - """Observe the underlying deferred. - - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. - """ - if not self._result: - d = defer.Deferred() - - def remove(r): - self._observers.discard(d) - return r - d.addBoth(remove) - - self._observers.add(d) - return d - else: - success, res = self._result - return res if success else defer.fail(res) - - def observers(self): - return self._observers - - def has_called(self): - return self._result is not None - - def has_succeeded(self): - return self._result is not None and self._result[0] is True - - def get_result(self): - return self._result[1] - - def __getattr__(self, name): - return getattr(self._deferred, name) - - def __setattr__(self, name, value): - setattr(self._deferred, name, value) - - def __repr__(self): - return "" % ( - id(self), self._result, self._deferred, - ) - - -def concurrently_execute(func, args, limit): - """Executes the function with each argument conncurrently while limiting - the number of concurrent executions. - - Args: - func (func): Function to execute, should return a deferred. - args (list): List of arguments to pass to func, each invocation of func - gets a signle argument. - limit (int): Maximum number of conccurent executions. - - Returns: - deferred: Resolved when all function invocations have finished. - """ - it = iter(args) - - @defer.inlineCallbacks - def _concurrently_execute_inner(): - try: - while True: - yield func(next(it)) - except StopIteration: - pass - - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(_concurrently_execute_inner) - for _ in range(limit) - ], consumeErrors=True)).addErrback(unwrapFirstError) - - -class Linearizer(object): - """Limits concurrent access to resources based on a key. Useful to ensure - only a few things happen at a time on a given resource. - - Example: - - with (yield limiter.queue("test_key")): - # do some work. - - """ - def __init__(self, name=None, max_count=1, clock=None): - """ - Args: - max_count(int): The maximum number of concurrent accesses - """ - if name is None: - self.name = id(self) - else: - self.name = name - - if not clock: - from twisted.internet import reactor - clock = Clock(reactor) - self._clock = clock - self.max_count = max_count - - # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing, and - # the second element is an OrderedDict, where the keys are deferreds for the - # things blocked from executing. - self.key_to_defer = {} - - @defer.inlineCallbacks - def queue(self, key): - entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) - - # If the number of things executing is greater than the maximum - # then add a deferred to the list of blocked items - # When on of the things currently executing finishes it will callback - # this item so that it can continue executing. - if entry[0] >= self.max_count: - new_defer = defer.Deferred() - entry[1][new_defer] = 1 - - logger.info( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) - try: - yield make_deferred_yieldable(new_defer) - except Exception as e: - if isinstance(e, CancelledError): - logger.info( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, - ) - else: - logger.warn( - "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, - ) - - # we just have to take ourselves back out of the queue. - del entry[1][new_defer] - raise - - logger.info("Acquired linearizer lock %r for key %r", self.name, key) - entry[0] += 1 - - # if the code holding the lock completes synchronously, then it - # will recursively run the next claimant on the list. That can - # relatively rapidly lead to stack exhaustion. This is essentially - # the same problem as http://twistedmatrix.com/trac/ticket/9304. - # - # In order to break the cycle, we add a cheeky sleep(0) here to - # ensure that we fall back to the reactor between each iteration. - # - # (This needs to happen while we hold the lock, and the context manager's exit - # code must be synchronous, so this is the only sensible place.) - yield self._clock.sleep(0) - - else: - logger.info( - "Acquired uncontended linearizer lock %r for key %r", self.name, key, - ) - entry[0] += 1 - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - logger.info("Releasing linearizer lock %r for key %r", self.name, key) - - # We've finished executing so check if there are any things - # blocked waiting to execute and start one of them - entry[0] -= 1 - - if entry[1]: - (next_def, _) = entry[1].popitem(last=False) - - # we need to run the next thing in the sentinel context. - with PreserveLoggingContext(): - next_def.callback(None) - elif entry[0] == 0: - # We were the last thing for this key: remove it from the - # map. - del self.key_to_defer[key] - - defer.returnValue(_ctx_manager()) - - -class ReadWriteLock(object): - """A deferred style read write lock. - - Example: - - with (yield read_write_lock.read("test_key")): - # do some work - """ - - # IMPLEMENTATION NOTES - # - # We track the most recent queued reader and writer deferreds (which get - # resolved when they release the lock). - # - # Read: We know its safe to acquire a read lock when the latest writer has - # been resolved. The new reader is appeneded to the list of latest readers. - # - # Write: We know its safe to acquire the write lock when both the latest - # writers and readers have been resolved. The new writer replaces the latest - # writer. - - def __init__(self): - # Latest readers queued - self.key_to_current_readers = {} - - # Latest writer queued - self.key_to_current_writer = {} - - @defer.inlineCallbacks - def read(self, key): - new_defer = defer.Deferred() - - curr_readers = self.key_to_current_readers.setdefault(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) - - curr_readers.add(new_defer) - - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - yield make_deferred_yieldable(curr_writer) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - new_defer.callback(None) - self.key_to_current_readers.get(key, set()).discard(new_defer) - - defer.returnValue(_ctx_manager()) - - @defer.inlineCallbacks - def write(self, key): - new_defer = defer.Deferred() - - curr_readers = self.key_to_current_readers.get(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) - - # We wait on all latest readers and writer. - to_wait_on = list(curr_readers) - if curr_writer: - to_wait_on.append(curr_writer) - - # We can clear the list of current readers since the new writer waits - # for them to finish. - curr_readers.clear() - self.key_to_current_writer[key] = new_defer - - yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - new_defer.callback(None) - if self.key_to_current_writer[key] == new_defer: - self.key_to_current_writer.pop(key) - - defer.returnValue(_ctx_manager()) - - -class DeferredTimeoutError(Exception): - """ - This error is raised by default when a L{Deferred} times out. - """ - - -def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): - """ - Add a timeout to a deferred by scheduling it to be cancelled after - timeout seconds. - - This is essentially a backport of deferred.addTimeout, which was introduced - in twisted 16.5. - - If the deferred gets timed out, it errbacks with a DeferredTimeoutError, - unless a cancelable function was passed to its initialization or unless - a different on_timeout_cancel callable is provided. - - Args: - deferred (defer.Deferred): deferred to be timed out - timeout (Number): seconds to time out after - reactor (twisted.internet.reactor): the Twisted reactor to use - - on_timeout_cancel (callable): A callable which is called immediately - after the deferred times out, and not if this deferred is - otherwise cancelled before the timeout. - - It takes an arbitrary value, which is the value of the deferred at - that exact point in time (probably a CancelledError Failure), and - the timeout. - - The default callable (if none is provided) will translate a - CancelledError Failure into a DeferredTimeoutError. - """ - timed_out = [False] - - def time_it_out(): - timed_out[0] = True - deferred.cancel() - - delayed_call = reactor.callLater(timeout, time_it_out) - - def convert_cancelled(value): - if timed_out[0]: - to_call = on_timeout_cancel or _cancelled_to_timed_out_error - return to_call(value, timeout) - return value - - deferred.addBoth(convert_cancelled) - - def cancel_timeout(result): - # stop the pending call to cancel the deferred if it's been fired - if delayed_call.active(): - delayed_call.cancel() - return result - - deferred.addBoth(cancel_timeout) - - -def _cancelled_to_timed_out_error(value, timeout): - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise DeferredTimeoutError(timeout, "Deferred") - return value diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py new file mode 100644 index 0000000000..a7094e2fb4 --- /dev/null +++ b/synapse/util/async_helpers.py @@ -0,0 +1,415 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import logging +from contextlib import contextmanager + +from six.moves import range + +from twisted.internet import defer +from twisted.internet.defer import CancelledError +from twisted.python import failure + +from synapse.util import Clock, logcontext, unwrapFirstError + +from .logcontext import ( + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, +) + +logger = logging.getLogger(__name__) + + +class ObservableDeferred(object): + """Wraps a deferred object so that we can add observer deferreds. These + observer deferreds do not affect the callback chain of the original + deferred. + + If consumeErrors is true errors will be captured from the origin deferred. + + Cancelling or otherwise resolving an observer will not affect the original + ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. + """ + + __slots__ = ["_deferred", "_observers", "_result"] + + def __init__(self, deferred, consumeErrors=False): + object.__setattr__(self, "_deferred", deferred) + object.__setattr__(self, "_result", None) + object.__setattr__(self, "_observers", set()) + + def callback(r): + object.__setattr__(self, "_result", (True, r)) + while self._observers: + try: + # TODO: Handle errors here. + self._observers.pop().callback(r) + except Exception: + pass + return r + + def errback(f): + object.__setattr__(self, "_result", (False, f)) + while self._observers: + try: + # TODO: Handle errors here. + self._observers.pop().errback(f) + except Exception: + pass + + if consumeErrors: + return None + else: + return f + + deferred.addCallbacks(callback, errback) + + def observe(self): + """Observe the underlying deferred. + + Can return either a deferred if the underlying deferred is still pending + (or has failed), or the actual value. Callers may need to use maybeDeferred. + """ + if not self._result: + d = defer.Deferred() + + def remove(r): + self._observers.discard(d) + return r + d.addBoth(remove) + + self._observers.add(d) + return d + else: + success, res = self._result + return res if success else defer.fail(res) + + def observers(self): + return self._observers + + def has_called(self): + return self._result is not None + + def has_succeeded(self): + return self._result is not None and self._result[0] is True + + def get_result(self): + return self._result[1] + + def __getattr__(self, name): + return getattr(self._deferred, name) + + def __setattr__(self, name, value): + setattr(self._deferred, name, value) + + def __repr__(self): + return "" % ( + id(self), self._result, self._deferred, + ) + + +def concurrently_execute(func, args, limit): + """Executes the function with each argument conncurrently while limiting + the number of concurrent executions. + + Args: + func (func): Function to execute, should return a deferred. + args (list): List of arguments to pass to func, each invocation of func + gets a signle argument. + limit (int): Maximum number of conccurent executions. + + Returns: + deferred: Resolved when all function invocations have finished. + """ + it = iter(args) + + @defer.inlineCallbacks + def _concurrently_execute_inner(): + try: + while True: + yield func(next(it)) + except StopIteration: + pass + + return logcontext.make_deferred_yieldable(defer.gatherResults([ + run_in_background(_concurrently_execute_inner) + for _ in range(limit) + ], consumeErrors=True)).addErrback(unwrapFirstError) + + +class Linearizer(object): + """Limits concurrent access to resources based on a key. Useful to ensure + only a few things happen at a time on a given resource. + + Example: + + with (yield limiter.queue("test_key")): + # do some work. + + """ + def __init__(self, name=None, max_count=1, clock=None): + """ + Args: + max_count(int): The maximum number of concurrent accesses + """ + if name is None: + self.name = id(self) + else: + self.name = name + + if not clock: + from twisted.internet import reactor + clock = Clock(reactor) + self._clock = clock + self.max_count = max_count + + # key_to_defer is a map from the key to a 2 element list where + # the first element is the number of things executing, and + # the second element is an OrderedDict, where the keys are deferreds for the + # things blocked from executing. + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) + + # If the number of things executing is greater than the maximum + # then add a deferred to the list of blocked items + # When on of the things currently executing finishes it will callback + # this item so that it can continue executing. + if entry[0] >= self.max_count: + new_defer = defer.Deferred() + entry[1][new_defer] = 1 + + logger.info( + "Waiting to acquire linearizer lock %r for key %r", self.name, key, + ) + try: + yield make_deferred_yieldable(new_defer) + except Exception as e: + if isinstance(e, CancelledError): + logger.info( + "Cancelling wait for linearizer lock %r for key %r", + self.name, key, + ) + else: + logger.warn( + "Unexpected exception waiting for linearizer lock %r for key %r", + self.name, key, + ) + + # we just have to take ourselves back out of the queue. + del entry[1][new_defer] + raise + + logger.info("Acquired linearizer lock %r for key %r", self.name, key) + entry[0] += 1 + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # (This needs to happen while we hold the lock, and the context manager's exit + # code must be synchronous, so this is the only sensible place.) + yield self._clock.sleep(0) + + else: + logger.info( + "Acquired uncontended linearizer lock %r for key %r", self.name, key, + ) + entry[0] += 1 + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + logger.info("Releasing linearizer lock %r for key %r", self.name, key) + + # We've finished executing so check if there are any things + # blocked waiting to execute and start one of them + entry[0] -= 1 + + if entry[1]: + (next_def, _) = entry[1].popitem(last=False) + + # we need to run the next thing in the sentinel context. + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] + + defer.returnValue(_ctx_manager()) + + +class ReadWriteLock(object): + """A deferred style read write lock. + + Example: + + with (yield read_write_lock.read("test_key")): + # do some work + """ + + # IMPLEMENTATION NOTES + # + # We track the most recent queued reader and writer deferreds (which get + # resolved when they release the lock). + # + # Read: We know its safe to acquire a read lock when the latest writer has + # been resolved. The new reader is appeneded to the list of latest readers. + # + # Write: We know its safe to acquire the write lock when both the latest + # writers and readers have been resolved. The new writer replaces the latest + # writer. + + def __init__(self): + # Latest readers queued + self.key_to_current_readers = {} + + # Latest writer queued + self.key_to_current_writer = {} + + @defer.inlineCallbacks + def read(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + curr_readers.add(new_defer) + + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + yield make_deferred_yieldable(curr_writer) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + self.key_to_current_readers.get(key, set()).discard(new_defer) + + defer.returnValue(_ctx_manager()) + + @defer.inlineCallbacks + def write(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) + + # We can clear the list of current readers since the new writer waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer + + yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + if self.key_to_current_writer[key] == new_defer: + self.key_to_current_writer.pop(key) + + defer.returnValue(_ctx_manager()) + + +class DeferredTimeoutError(Exception): + """ + This error is raised by default when a L{Deferred} times out. + """ + + +def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): + """ + Add a timeout to a deferred by scheduling it to be cancelled after + timeout seconds. + + This is essentially a backport of deferred.addTimeout, which was introduced + in twisted 16.5. + + If the deferred gets timed out, it errbacks with a DeferredTimeoutError, + unless a cancelable function was passed to its initialization or unless + a different on_timeout_cancel callable is provided. + + Args: + deferred (defer.Deferred): deferred to be timed out + timeout (Number): seconds to time out after + reactor (twisted.internet.reactor): the Twisted reactor to use + + on_timeout_cancel (callable): A callable which is called immediately + after the deferred times out, and not if this deferred is + otherwise cancelled before the timeout. + + It takes an arbitrary value, which is the value of the deferred at + that exact point in time (probably a CancelledError Failure), and + the timeout. + + The default callable (if none is provided) will translate a + CancelledError Failure into a DeferredTimeoutError. + """ + timed_out = [False] + + def time_it_out(): + timed_out[0] = True + deferred.cancel() + + delayed_call = reactor.callLater(timeout, time_it_out) + + def convert_cancelled(value): + if timed_out[0]: + to_call = on_timeout_cancel or _cancelled_to_timed_out_error + return to_call(value, timeout) + return value + + deferred.addBoth(convert_cancelled) + + def cancel_timeout(result): + # stop the pending call to cancel the deferred if it's been fired + if delayed_call.active(): + delayed_call.cancel() + return result + + deferred.addBoth(cancel_timeout) + + +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise DeferredTimeoutError(timeout, "Deferred") + return value diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 861c24809c..187510576a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,7 +25,7 @@ from six import itervalues, string_types from twisted.internet import defer from synapse.util import logcontext, unwrapFirstError -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index a8491b42d5..afb03b2e1b 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -16,7 +16,7 @@ import logging from twisted.internet import defer -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache from synapse.util.logcontext import make_deferred_yieldable, run_in_background diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py index d03678b8c8..8318db8d2c 100644 --- a/synapse/util/caches/snapshot_cache.py +++ b/synapse/util/caches/snapshot_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred class SnapshotCache(object): diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 8dcae50b39..07e83fadda 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -526,7 +526,7 @@ _to_ignore = [ "synapse.util.logcontext", "synapse.http.server", "synapse.storage._base", - "synapse.util.async", + "synapse.util.async_helpers", ] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 6d6f00c5c5..52eff2a104 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -18,7 +18,7 @@ from mock import Mock from twisted.internet import defer -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached from tests import unittest diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 4729bd5a0a..7f48a72de8 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -20,7 +20,7 @@ from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError from synapse.util import Clock, logcontext -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from tests import unittest diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 24194e3b25..7cd470be67 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -14,7 +14,7 @@ # limitations under the License. -from synapse.util.async import ReadWriteLock +from synapse.util.async_helpers import ReadWriteLock from tests import unittest -- cgit 1.5.1 From 99dd975dae7baaaef2a3b0a92fa51965b121ae34 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 13 Aug 2018 16:47:46 +1000 Subject: Run tests under PostgreSQL (#3423) --- .travis.yml | 10 +++ changelog.d/3423.misc | 1 + synapse/handlers/presence.py | 5 ++ synapse/storage/client_ips.py | 5 ++ tests/__init__.py | 3 + tests/api/test_auth.py | 2 +- tests/api/test_filtering.py | 5 +- tests/crypto/test_keyring.py | 6 +- tests/handlers/test_auth.py | 2 +- tests/handlers/test_device.py | 2 +- tests/handlers/test_directory.py | 1 + tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_profile.py | 1 + tests/handlers/test_register.py | 1 + tests/handlers/test_typing.py | 1 + tests/replication/slave/storage/_base.py | 1 + tests/rest/client/v1/test_admin.py | 2 +- tests/rest/client/v1/test_events.py | 1 + tests/rest/client/v1/test_profile.py | 1 + tests/rest/client/v1/test_register.py | 2 +- tests/rest/client/v1/test_rooms.py | 1 + tests/rest/client/v1/test_typing.py | 1 + tests/rest/client/v2_alpha/test_filter.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 2 +- tests/rest/client/v2_alpha/test_sync.py | 2 +- tests/server.py | 10 ++- tests/storage/test_appservice.py | 13 ++- tests/storage/test_background_update.py | 4 +- tests/storage/test_client_ips.py | 2 +- tests/storage/test_devices.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_end_to_end_keys.py | 3 +- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_keys.py | 2 +- tests/storage/test_monthly_active_users.py | 2 +- tests/storage/test_presence.py | 2 +- tests/storage/test_profile.py | 2 +- tests/storage/test_redaction.py | 2 +- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_state.py | 2 +- tests/storage/test_user_directory.py | 2 +- tests/test_federation.py | 5 +- tests/test_server.py | 2 +- tests/test_visibility.py | 2 +- tests/utils.py | 133 ++++++++++++++++++++++++---- tox.ini | 20 ++++- 49 files changed, 227 insertions(+), 59 deletions(-) create mode 100644 changelog.d/3423.misc (limited to 'synapse/handlers') diff --git a/.travis.yml b/.travis.yml index b34b17af75..318701c9f8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,9 @@ before_script: - git remote set-branches --add origin develop - git fetch origin develop +services: + - postgresql + matrix: fast_finish: true include: @@ -20,6 +23,9 @@ matrix: - python: 2.7 env: TOX_ENV=py27 + - python: 2.7 + env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" + - python: 3.6 env: TOX_ENV=py36 @@ -29,6 +35,10 @@ matrix: - python: 3.6 env: TOX_ENV=check-newsfragment + allow_failures: + - python: 2.7 + env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" + install: - pip install tox diff --git a/changelog.d/3423.misc b/changelog.d/3423.misc new file mode 100644 index 0000000000..51768c6d14 --- /dev/null +++ b/changelog.d/3423.misc @@ -0,0 +1 @@ +The test suite now can run under PostgreSQL. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 20fc3b0323..3671d24f60 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -95,6 +95,7 @@ class PresenceHandler(object): Args: hs (synapse.server.HomeServer): """ + self.hs = hs self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.clock = hs.get_clock() @@ -230,6 +231,10 @@ class PresenceHandler(object): earlier than they should when synapse is restarted. This affect of this is some spurious presence changes that will self-correct. """ + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 2489527f2c..8fc678fa67 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -96,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) def _update_client_ips_batch(self): + + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + def update(): to_update = self._batch_row_update self._batch_row_update = {} diff --git a/tests/__init__.py b/tests/__init__.py index 24006c949e..9d9ca22829 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -15,4 +15,7 @@ from twisted.trial import util +from tests import utils + util.DEFAULT_TIMEOUT_DURATION = 10 +utils.setupdb() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index f8e28876bb..a65689ba89 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -39,7 +39,7 @@ class AuthTestCase(unittest.TestCase): self.state_handler = Mock() self.store = Mock() - self.hs = yield setup_test_homeserver(handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) self.hs.get_datastore = Mock(return_value=self.store) self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 1c2d71052c..48b2d3d663 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -46,7 +46,10 @@ class FilteringTestCase(unittest.TestCase): self.mock_http_client.put_json = DeferredMockCallable() hs = yield setup_test_homeserver( - handlers=None, http_client=self.mock_http_client, keyring=Mock() + self.addCleanup, + handlers=None, + http_client=self.mock_http_client, + keyring=Mock(), ) self.filtering = hs.get_filtering() diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 0c6f510d11..8299dc72c8 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -58,12 +58,10 @@ class KeyringTestCase(unittest.TestCase): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() self.hs = yield utils.setup_test_homeserver( - handlers=None, http_client=self.http_client + self.addCleanup, handlers=None, http_client=self.http_client ) keys = self.mock_perspective_server.get_verify_keys() - self.hs.config.perspectives = { - self.mock_perspective_server.server_name: keys - } + self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys} def check_context(self, _, expected): self.assertEquals( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index ede01f8099..56c0f87fb7 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -35,7 +35,7 @@ class AuthHandlers(object): class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index d70d645504..56e7acd37c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -34,7 +34,7 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield utils.setup_test_homeserver() + hs = yield utils.setup_test_homeserver(self.addCleanup) self.handler = hs.get_device_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 06de9f5eca..ec7355688b 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,6 +46,7 @@ class DirectoryTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 57ab228455..8dccc6826e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - handlers=None, federation_client=mock.Mock() + self.addCleanup, handlers=None, federation_client=mock.Mock() ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 9268a6fe2b..62dc69003c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,6 +48,7 @@ class ProfileTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, handlers=None, resource_for_federation=Mock(), diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index dbec81076f..d48d40c8dd 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -40,6 +40,7 @@ class RegistrationTestCase(unittest.TestCase): self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() self.hs = yield setup_test_homeserver( + self.addCleanup, handlers=None, http_client=None, expire_access_token=True, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index becfa77bfa..ad58073a14 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -67,6 +67,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.state_handler = Mock() hs = yield setup_test_homeserver( + self.addCleanup, "test", auth=self.auth, clock=self.clock, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index c23b6e2cfd..65df116efc 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -54,6 +54,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver( + self.addCleanup, "blue", http_client=None, federation_client=Mock(), diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index 67d9ab94e2..1a553fa3f9 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -51,7 +51,7 @@ class UserRegisterTestCase(unittest.TestCase): self.secrets = Mock() self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.config.registration_shared_secret = u"shared" diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 0316b74fa1..956f7fc4c4 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -41,6 +41,7 @@ class EventStreamPermissionsTestCase(RestTestCase): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 9ba0ffc19f..1eab9c3bdb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase): ) hs = yield setup_test_homeserver( + self.addCleanup, "test", http_client=None, resource_for_client=self.mock_resource, diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 6f15d69ecd..4be88b8a39 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -49,7 +49,7 @@ class CreateUserServletTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 00fc796787..9fe0760496 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -50,6 +50,7 @@ class RoomBase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( + self.addCleanup, "red", http_client=None, clock=self.hs_clock, diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 7f1a435e7b..677265edf6 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -44,6 +44,7 @@ class RoomTypingTestCase(RestTestCase): self.auth_user_id = self.user_id hs = yield setup_test_homeserver( + self.addCleanup, "red", clock=self.clock, http_client=None, diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index de33b10a5f..8260c130f8 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -43,7 +43,7 @@ class FilterTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9487babac3..b72bd0fb7f 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -47,7 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase): login_handler=self.login_handler, ) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index bafc0d1df0..2e1d06c509 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -40,7 +40,7 @@ class FilterTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() diff --git a/tests/server.py b/tests/server.py index 05708be8b9..beb24cf032 100644 --- a/tests/server.py +++ b/tests/server.py @@ -147,12 +147,15 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return d -def setup_test_homeserver(*args, **kwargs): +def setup_test_homeserver(cleanup_func, *args, **kwargs): """ Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(*args, **kwargs).result + d = _sth(cleanup_func, *args, **kwargs).result + + if isinstance(d, Failure): + d.raiseException() # Make the thread pool synchronous. clock = d.get_clock() @@ -189,6 +192,9 @@ def setup_test_homeserver(*args, **kwargs): def start(self): pass + def stop(self): + pass + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): def _(res): if isinstance(res, Failure): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index fbb25a8844..c893990454 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, federation_sender=Mock(), federation_client=Mock() + self.addCleanup, + config=config, + federation_sender=Mock(), + federation_client=Mock(), ) self.as_token = "token1" @@ -108,7 +111,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, federation_sender=Mock(), federation_client=Mock() + self.addCleanup, + config=config, + federation_sender=Mock(), + federation_client=Mock(), ) self.db_pool = hs.get_db_pool() @@ -392,6 +398,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), @@ -409,6 +416,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), @@ -432,6 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index b4f6baf441..81403727c5 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,7 +9,9 @@ from tests.utils import setup_test_homeserver class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() # type: synapse.server.HomeServer + hs = yield setup_test_homeserver( + self.addCleanup + ) # type: synapse.server.HomeServer self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index ea00bbe84c..fa60d949ba 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -28,7 +28,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield tests.utils.setup_test_homeserver() + self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 63bc42d9e0..aef4dfaf57 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 9a8ba2fcfe..b4510c1c8d 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.store = DirectoryStore(None, hs) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index d45c775c2d..8f0aaece40 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -26,8 +26,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 66eb119581..2fdf34fdf6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -22,7 +22,7 @@ import tests.utils class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 5e87b4530d..b114c6fb1d 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -33,7 +33,7 @@ HIGHLIGHT = [ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index ad0a55b324..47f4a8ceac 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -28,7 +28,7 @@ class KeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 22b1072d9f..0a2c859f26 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -28,7 +28,7 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.store = self.hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index 12c540dfab..b5b58ff660 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -26,7 +26,7 @@ from tests.utils import MockClock, setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver(clock=MockClock()) + hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) self.store = PresenceStore(None, hs) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 5acbc8be0c..a1f6618bf9 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.store = ProfileStore(None, hs) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 85ce61e841..c4e9fb72bf 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -29,7 +29,7 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + self.addCleanup, resource_for_federation=Mock(), http_client=None ) self.store = hs.get_datastore() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index bd96896bb3..4eda122edc 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -23,7 +23,7 @@ from tests.utils import setup_test_homeserver class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 84d49b55c1..a1ea23b068 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class RoomStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table @@ -57,7 +57,7 @@ class RoomStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = setup_test_homeserver() + hs = setup_test_homeserver(self.addCleanup) # Room events need the full datastore, for persist_event() and # get_room_state() diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 0d9908926a..c83ef60062 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -29,7 +29,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + self.addCleanup, resource_for_federation=Mock(), http_client=None ) # We can't test the RoomMemberStore on its own without the other event # storage logic diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index ed5b41644a..6168c46248 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -33,7 +33,7 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7a273eab48..b46e0ea7e2 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -29,7 +29,7 @@ BOBBY = "@bobby:a" class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.store = UserDirectoryStore(None, self.hs) # alice and bob are both in !room_id. bobby is not but shares diff --git a/tests/test_federation.py b/tests/test_federation.py index f40ff29b52..2540604fcc 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -18,7 +18,10 @@ class MessageAcceptTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, + http_client=self.http_client, + clock=self.hs_clock, + reactor=self.reactor, ) user_id = UserID("us", "test") diff --git a/tests/test_server.py b/tests/test_server.py index fc396226ea..895e490406 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -16,7 +16,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor = MemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor ) def test_handler_for_request(self): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 8643d63125..45a78338d6 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM" class FilterEventsForServerTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index 8668b5478f..90378326f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit import hashlib +import os +import uuid from inspect import getcallargs from mock import Mock, patch @@ -27,23 +30,80 @@ from synapse.http.server import HttpServer from synapse.server import HomeServer from synapse.storage import PostgresEngine from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database +from synapse.storage.prepare_database import ( + _get_or_create_schema_state, + _setup_new_database, + prepare_database, +) from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. -# It requires you to have a local postgres database called synapse_test, within -# which ALL TABLES WILL BE DROPPED -USE_POSTGRES_FOR_TESTS = False +USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) +POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") +POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) + + +def setupdb(): + + # If we're using PostgreSQL, set up the db once + if USE_POSTGRES_FOR_TESTS: + pgconfig = { + "name": "psycopg2", + "args": { + "database": POSTGRES_BASE_DB, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + config = Mock() + config.password_providers = [] + config.database_config = pgconfig + db_engine = create_engine(pgconfig) + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + # Set up in the db + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + cur = db_conn.cursor() + _get_or_create_schema_state(cur, db_engine) + _setup_new_database(cur, db_engine) + db_conn.commit() + cur.close() + db_conn.close() + + def _cleanup(): + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + atexit.register(_cleanup) @defer.inlineCallbacks def setup_test_homeserver( - name="test", datastore=None, config=None, reactor=None, **kargs + cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs ): - """Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. If no datastore is supplied a - datastore backed by an in-memory sqlite db will be given to the HS. + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. """ if reactor is None: from twisted.internet import reactor @@ -95,9 +155,11 @@ def setup_test_homeserver( kargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + config.database_config = { "name": "psycopg2", - "args": {"database": "synapse_test", "cp_min": 1, "cp_max": 5}, + "args": {"database": test_db, "cp_min": 1, "cp_max": 5}, } else: config.database_config = { @@ -107,6 +169,21 @@ def setup_test_homeserver( db_engine = create_engine(config.database_config) + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if datastore is None and isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + # we need to configure the connection pool to run the on_new_connection # function, so that we can test code that uses custom sqlite functions # (like rank). @@ -125,15 +202,35 @@ def setup_test_homeserver( reactor=reactor, **kargs ) - db_conn = hs.get_db_conn() - # make sure that the database is empty - if isinstance(db_engine, PostgresEngine): - cur = db_conn.cursor() - cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") - rows = cur.fetchall() - for r in rows: - cur.execute("DROP TABLE %s CASCADE" % r[0]) - yield prepare_database(db_conn, db_engine, config) + + # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to + # date db + if not isinstance(db_engine, PostgresEngine): + db_conn = hs.get_db_conn() + yield prepare_database(db_conn, db_engine, config) + db_conn.commit() + db_conn.close() + + else: + # We need to do cleanup on PostgreSQL + def cleanup(): + # Close all the db pools + hs.get_db_pool().close() + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + cur.close() + db_conn.close() + + # Register the cleanup hook + cleanup_func(cleanup) + hs.setup() else: hs = HomeServer( diff --git a/tox.ini b/tox.ini index ed26644bd9..085f438989 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = packaging, py27, py36, pep8, check_isort -[testenv] +[base] deps = coverage Twisted>=15.1 @@ -15,6 +15,15 @@ deps = setenv = PYTHONDONTWRITEBYTECODE = no_byte_code +[testenv] +deps = + {[base]deps} + +setenv = + {[base]setenv} + +passenv = * + commands = /usr/bin/find "{toxinidir}" -name '*.pyc' -delete coverage run {env:COVERAGE_OPTS:} --source="{toxinidir}/synapse" \ @@ -46,6 +55,15 @@ commands = # ) usedevelop=true +[testenv:py27-postgres] +usedevelop=true +deps = + {[base]deps} + psycopg2 +setenv = + {[base]setenv} + SYNAPSE_POSTGRES = 1 + [testenv:py36] usedevelop=true commands = -- cgit 1.5.1 From 0d43f991a19840a224d3dac78d79f13d78212ee6 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 13 Aug 2018 18:00:23 +0100 Subject: support admin_email config and pass through into blocking errors, return AuthError in all cases --- synapse/api/auth.py | 8 ++++++-- synapse/api/errors.py | 13 +++++++++++-- synapse/config/server.py | 4 ++++ synapse/handlers/register.py | 27 ++++++++++++++------------- tests/api/test_auth.py | 6 +++++- tests/handlers/test_register.py | 8 ++++---- tests/utils.py | 1 + 7 files changed, 45 insertions(+), 22 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9c62ec4374..4f028078fa 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -781,11 +781,15 @@ class Auth(object): """ if self.hs.config.hs_disabled: raise AuthError( - 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED + 403, self.hs.config.hs_disabled_message, + errcode=Codes.HS_DISABLED, + admin_email=self.hs.config.admin_email, ) if self.hs.config.limit_usage_by_mau is True: current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + 403, "MAU Limit Exceeded", + admin_email=self.hs.config.admin_email, + errcode=Codes.MAU_LIMIT_EXCEEDED ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index dc3bed5fcb..d74848159e 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -225,11 +225,20 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" - def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(AuthError, self).__init__(*args, **kwargs) + self.admin_email = kwargs.get('admin_email') + self.msg = kwargs.get('msg') + self.errcode = kwargs.get('errcode') + super(AuthError, self).__init__(*args, errcode=kwargs["errcode"]) + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + admin_email=self.admin_email, + ) class EventSizeError(SynapseError): diff --git a/synapse/config/server.py b/synapse/config/server.py index 3b078d72ca..64a5121a45 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -82,6 +82,10 @@ class ServerConfig(Config): self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") + # Admin email to direct users at should their instance become blocked + # due to resource constraints + self.admin_email = config.get("admin_email", None) + # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3526b20d5a..ef7222d7b8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield self._check_mau_limits() + + yield self.auth.check_auth_blocking() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() need_register = True try: @@ -534,13 +535,13 @@ class RegistrationHandler(BaseHandler): action="join", ) - @defer.inlineCallbacks - def _check_mau_limits(self): - """ - Do not accept registrations if monthly active user limits exceeded - and limiting is enabled - """ - try: - yield self.auth.check_auth_blocking() - except AuthError as e: - raise RegistrationError(e.code, str(e), e.errcode) + # @defer.inlineCallbacks + # def _s(self): + # """ + # Do not accept registrations if monthly active user limits exceeded + # and limiting is enabled + # """ + # try: + # yield self.auth.check_auth_blocking() + # except AuthError as e: + # raise RegistrationError(e.code, str(e), e.errcode) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index a65689ba89..e8a1894e65 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -455,8 +455,11 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) + self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) + self.assertEquals(e.exception.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = Mock( @@ -470,5 +473,6 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.code, 403) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index d48d40c8dd..35d1bcab3e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import RegistrationError +from synapse.api.errors import AuthError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/utils.py b/tests/utils.py index 90378326f8..4af81624eb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -139,6 +139,7 @@ def setup_test_homeserver( config.hs_disabled_message = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] + config.admin_email = None # we need a sane default_room_version, otherwise attempts to create rooms will # fail. -- cgit 1.5.1 From ce7de9ae6b74e8e5e89ff442bc29f8cd73328042 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 13 Aug 2018 18:06:18 +0100 Subject: Revert "support admin_email config and pass through into blocking errors, return AuthError in all cases" This reverts commit 0d43f991a19840a224d3dac78d79f13d78212ee6. --- synapse/api/auth.py | 8 ++------ synapse/api/errors.py | 13 ++----------- synapse/config/server.py | 4 ---- synapse/handlers/register.py | 27 +++++++++++++-------------- tests/api/test_auth.py | 6 +----- tests/handlers/test_register.py | 8 ++++---- tests/utils.py | 1 - 7 files changed, 22 insertions(+), 45 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 4f028078fa..9c62ec4374 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -781,15 +781,11 @@ class Auth(object): """ if self.hs.config.hs_disabled: raise AuthError( - 403, self.hs.config.hs_disabled_message, - errcode=Codes.HS_DISABLED, - admin_email=self.hs.config.admin_email, + 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED ) if self.hs.config.limit_usage_by_mau is True: current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "MAU Limit Exceeded", - admin_email=self.hs.config.admin_email, - errcode=Codes.MAU_LIMIT_EXCEEDED + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d74848159e..dc3bed5fcb 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -225,20 +225,11 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" + def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - self.admin_email = kwargs.get('admin_email') - self.msg = kwargs.get('msg') - self.errcode = kwargs.get('errcode') - super(AuthError, self).__init__(*args, errcode=kwargs["errcode"]) - - def error_dict(self): - return cs_error( - self.msg, - self.errcode, - admin_email=self.admin_email, - ) + super(AuthError, self).__init__(*args, **kwargs) class EventSizeError(SynapseError): diff --git a/synapse/config/server.py b/synapse/config/server.py index 64a5121a45..3b078d72ca 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -82,10 +82,6 @@ class ServerConfig(Config): self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") - # Admin email to direct users at should their instance become blocked - # due to resource constraints - self.admin_email = config.get("admin_email", None) - # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ef7222d7b8..3526b20d5a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -144,8 +144,7 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - - yield self.auth.check_auth_blocking() + yield self._check_mau_limits() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -290,7 +289,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) - yield self.auth.check_auth_blocking() + yield self._check_mau_limits() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -440,7 +439,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self.auth.check_auth_blocking() + yield self._check_mau_limits() need_register = True try: @@ -535,13 +534,13 @@ class RegistrationHandler(BaseHandler): action="join", ) - # @defer.inlineCallbacks - # def _s(self): - # """ - # Do not accept registrations if monthly active user limits exceeded - # and limiting is enabled - # """ - # try: - # yield self.auth.check_auth_blocking() - # except AuthError as e: - # raise RegistrationError(e.code, str(e), e.errcode) + @defer.inlineCallbacks + def _check_mau_limits(self): + """ + Do not accept registrations if monthly active user limits exceeded + and limiting is enabled + """ + try: + yield self.auth.check_auth_blocking() + except AuthError as e: + raise RegistrationError(e.code, str(e), e.errcode) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e8a1894e65..a65689ba89 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -455,11 +455,8 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(lots_of_users) ) - with self.assertRaises(AuthError) as e: + with self.assertRaises(AuthError): yield self.auth.check_auth_blocking() - self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) - self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) - self.assertEquals(e.exception.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = Mock( @@ -473,6 +470,5 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() - self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.code, 403) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 35d1bcab3e..d48d40c8dd 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import AuthError +from synapse.api.errors import RegistrationError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/utils.py b/tests/utils.py index 4af81624eb..90378326f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -139,7 +139,6 @@ def setup_test_homeserver( config.hs_disabled_message = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] - config.admin_email = None # we need a sane default_room_version, otherwise attempts to create rooms will # fail. -- cgit 1.5.1 From f4b49152e27593dd6c863e71479a2ab712c4ada2 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 13 Aug 2018 18:00:23 +0100 Subject: support admin_email config and pass through into blocking errors, return AuthError in all cases --- synapse/api/auth.py | 8 ++++++-- synapse/api/errors.py | 13 +++++++++++-- synapse/config/server.py | 4 ++++ synapse/handlers/register.py | 27 ++++++++++++++------------- tests/api/test_auth.py | 6 +++++- tests/handlers/test_register.py | 8 ++++---- tests/utils.py | 1 + 7 files changed, 45 insertions(+), 22 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9c62ec4374..4f028078fa 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -781,11 +781,15 @@ class Auth(object): """ if self.hs.config.hs_disabled: raise AuthError( - 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED + 403, self.hs.config.hs_disabled_message, + errcode=Codes.HS_DISABLED, + admin_email=self.hs.config.admin_email, ) if self.hs.config.limit_usage_by_mau is True: current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + 403, "MAU Limit Exceeded", + admin_email=self.hs.config.admin_email, + errcode=Codes.MAU_LIMIT_EXCEEDED ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index dc3bed5fcb..d74848159e 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -225,11 +225,20 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" - def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(AuthError, self).__init__(*args, **kwargs) + self.admin_email = kwargs.get('admin_email') + self.msg = kwargs.get('msg') + self.errcode = kwargs.get('errcode') + super(AuthError, self).__init__(*args, errcode=kwargs["errcode"]) + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + admin_email=self.admin_email, + ) class EventSizeError(SynapseError): diff --git a/synapse/config/server.py b/synapse/config/server.py index 3b078d72ca..64a5121a45 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -82,6 +82,10 @@ class ServerConfig(Config): self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") + # Admin email to direct users at should their instance become blocked + # due to resource constraints + self.admin_email = config.get("admin_email", None) + # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3526b20d5a..ef7222d7b8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield self._check_mau_limits() + + yield self.auth.check_auth_blocking() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() need_register = True try: @@ -534,13 +535,13 @@ class RegistrationHandler(BaseHandler): action="join", ) - @defer.inlineCallbacks - def _check_mau_limits(self): - """ - Do not accept registrations if monthly active user limits exceeded - and limiting is enabled - """ - try: - yield self.auth.check_auth_blocking() - except AuthError as e: - raise RegistrationError(e.code, str(e), e.errcode) + # @defer.inlineCallbacks + # def _s(self): + # """ + # Do not accept registrations if monthly active user limits exceeded + # and limiting is enabled + # """ + # try: + # yield self.auth.check_auth_blocking() + # except AuthError as e: + # raise RegistrationError(e.code, str(e), e.errcode) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index a65689ba89..e8a1894e65 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -455,8 +455,11 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) + self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) + self.assertEquals(e.exception.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = Mock( @@ -470,5 +473,6 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.code, 403) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index d48d40c8dd..35d1bcab3e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import RegistrationError +from synapse.api.errors import AuthError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/utils.py b/tests/utils.py index 90378326f8..4af81624eb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -139,6 +139,7 @@ def setup_test_homeserver( config.hs_disabled_message = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] + config.admin_email = None # we need a sane default_room_version, otherwise attempts to create rooms will # fail. -- cgit 1.5.1 From 99ebaed8e6c5a65872a91f7a50e914ad9e574f7f Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 14 Aug 2018 14:55:55 +0100 Subject: Update register.py remove comments --- synapse/handlers/register.py | 10 ---------- 1 file changed, 10 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ef7222d7b8..54e3434928 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -535,13 +535,3 @@ class RegistrationHandler(BaseHandler): action="join", ) - # @defer.inlineCallbacks - # def _s(self): - # """ - # Do not accept registrations if monthly active user limits exceeded - # and limiting is enabled - # """ - # try: - # yield self.auth.check_auth_blocking() - # except AuthError as e: - # raise RegistrationError(e.code, str(e), e.errcode) -- cgit 1.5.1 From ed4bc3d2fc7242e27b3cdd36bc6c27c98fac09c8 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 14 Aug 2018 15:04:48 +0100 Subject: fix off by 1s on mau --- synapse/handlers/auth.py | 4 ++-- tests/handlers/test_auth.py | 39 ++++++++++++++++++++++++++++++++++++++- tests/handlers/test_register.py | 14 ++++++++++---- 3 files changed, 50 insertions(+), 7 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7ea8ce9f94..7baaa39447 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -520,7 +520,7 @@ class AuthHandler(BaseHandler): """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) - yield self.auth.check_auth_blocking() + yield self.auth.check_auth_blocking(user_id) # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -734,7 +734,6 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): - yield self.auth.check_auth_blocking() auth_api = self.hs.get_auth() user_id = None try: @@ -743,6 +742,7 @@ class AuthHandler(BaseHandler): auth_api.validate_macaroon(macaroon, "login", True, user_id) except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) + yield self.auth.check_auth_blocking(user_id) defer.returnValue(user_id) @defer.inlineCallbacks diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 56c0f87fb7..9ca7b2ee4e 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -124,7 +124,7 @@ class AuthTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_mau_limits_exceeded(self): + def test_mau_limits_exceeded_large(self): self.hs.config.limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) @@ -141,6 +141,43 @@ class AuthTestCase(unittest.TestCase): self._get_macaroon().serialize() ) + @defer.inlineCallbacks + def test_mau_limits_parity(self): + self.hs.config.limit_usage_by_mau = True + + # If not in monthly active cohort + self.hs.get_datastore().get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + with self.assertRaises(AuthError): + yield self.auth_handler.get_access_token_for_user_id('user_a') + + self.hs.get_datastore().get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + with self.assertRaises(AuthError): + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + # If in monthly active cohort + self.hs.get_datastore().user_last_seen_monthly_active = Mock( + return_value=defer.succeed(self.hs.get_clock().time_msec()) + ) + self.hs.get_datastore().get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + yield self.auth_handler.get_access_token_for_user_id('user_a') + self.hs.get_datastore().user_last_seen_monthly_active = Mock( + return_value=defer.succeed(self.hs.get_clock().time_msec()) + ) + self.hs.get_datastore().get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + + @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): self.hs.config.limit_usage_by_mau = True diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 35d1bcab3e..a821da0750 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import AuthError +from synapse.api.errors import RegistrationError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,13 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): + yield self.handler.register(localpart="local_part") + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + with self.assertRaises(RegistrationError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -127,5 +133,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.register_saml2(localpart="local_part") -- cgit 1.5.1 From c74c71128d164ce7aa6b50538ab960ec85f7a8da Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 14 Aug 2018 15:06:24 +0100 Subject: remove blank line --- synapse/handlers/register.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 54e3434928..f03ee1476b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -534,4 +534,3 @@ class RegistrationHandler(BaseHandler): remote_room_hosts=remote_room_hosts, action="join", ) - -- cgit 1.5.1 From 8f9a7eb58de214cd489ba233c381521e9bf79dec Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 13 Aug 2018 18:00:23 +0100 Subject: support admin_email config and pass through into blocking errors, return AuthError in all cases --- synapse/handlers/register.py | 1 - tests/handlers/test_register.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 54e3434928..f03ee1476b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -534,4 +534,3 @@ class RegistrationHandler(BaseHandler): remote_room_hosts=remote_room_hosts, action="join", ) - diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a821da0750..6699d25121 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import RegistrationError +from synapse.api.errors import AuthError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register(localpart="local_part") self.store.get_monthly_active_count = Mock( @@ -133,5 +133,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register_saml2(localpart="local_part") -- cgit 1.5.1 From 488ffe6fdb36c7479052096489c683777597c2aa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Aug 2018 14:17:18 +0100 Subject: Use federation handler function rather than duplicate This involves renaming _persist_events to be a public function. --- synapse/handlers/federation.py | 14 +++++------ synapse/replication/http/federation.py | 44 +++------------------------------- 2 files changed, 10 insertions(+), 48 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index acabca1d25..9a37d627ca 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1175,7 +1175,7 @@ class FederationHandler(BaseHandler): ) context = yield self.state_handler.compute_event_context(event) - yield self._persist_events([(event, context)]) + yield self.persist_events_and_notify([(event, context)]) defer.returnValue(event) @@ -1206,7 +1206,7 @@ class FederationHandler(BaseHandler): ) context = yield self.state_handler.compute_event_context(event) - yield self._persist_events([(event, context)]) + yield self.persist_events_and_notify([(event, context)]) defer.returnValue(event) @@ -1449,7 +1449,7 @@ class FederationHandler(BaseHandler): event, context ) - yield self._persist_events( + yield self.persist_events_and_notify( [(event, context)], backfilled=backfilled, ) @@ -1487,7 +1487,7 @@ class FederationHandler(BaseHandler): ], consumeErrors=True, )) - yield self._persist_events( + yield self.persist_events_and_notify( [ (ev_info["event"], context) for ev_info, context in zip(event_infos, contexts) @@ -1575,7 +1575,7 @@ class FederationHandler(BaseHandler): raise events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR - yield self._persist_events( + yield self.persist_events_and_notify( [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) @@ -1586,7 +1586,7 @@ class FederationHandler(BaseHandler): event, old_state=state ) - yield self._persist_events( + yield self.persist_events_and_notify( [(event, new_event_context)], ) @@ -2327,7 +2327,7 @@ class FederationHandler(BaseHandler): raise AuthError(403, "Third party certificate was invalid") @defer.inlineCallbacks - def _persist_events(self, event_and_contexts, backfilled=False): + def persist_events_and_notify(self, event_and_contexts, backfilled=False): """Persists events and tells the notifier/pushers about them, if necessary. diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7b0b1cd32e..2ddd18f73b 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -17,13 +17,10 @@ import logging from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint -from synapse.types import UserID -from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -55,9 +52,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() self.clock = hs.get_clock() - self.is_mine_id = hs.is_mine_id - self.notifier = hs.get_notifier() - self.pusher_pool = hs.get_pusherpool() + self.federation_handler = hs.get_handlers().federation_handler @staticmethod @defer.inlineCallbacks @@ -114,45 +109,12 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): len(event_and_contexts), ) - max_stream_id = yield self.store.persist_events( - event_and_contexts, - backfilled=backfilled + yield self.federation_handler.persist_events_and_notify( + event_and_contexts, backfilled, ) - if not backfilled: - for event, _ in event_and_contexts: - self._notify_persisted_event(event, max_stream_id) - defer.returnValue((200, {})) - def _notify_persisted_event(self, event, max_stream_id): - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - - # We notify for memberships if its an invite for one of our - # users - if event.internal_metadata.is_outlier(): - if event.membership != Membership.INVITE: - if not self.is_mine_id(target_user_id): - return - - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - elif event.internal_metadata.is_outlier(): - return - - event_stream_id = event.internal_metadata.stream_ordering - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) - - run_in_background( - self.pusher_pool.on_new_notifications, - event_stream_id, max_stream_id, - ) - class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): """Handles EDUs newly received from federation, including persisting and -- cgit 1.5.1 From 2f78f432c421702e3756d029b3c669db837d8bcd Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 15 Aug 2018 17:35:22 +0200 Subject: speed up /members and add at= and membership params (#3568) --- changelog.d/3331.feature | 1 + changelog.d/3567.feature | 1 + changelog.d/3568.feature | 1 + synapse/handlers/message.py | 88 ++++++++++++++++++++++++++++++++++++------ synapse/rest/client/v1/room.py | 32 +++++++++++++-- synapse/storage/events.py | 2 +- synapse/storage/state.py | 66 ++++++++++++++++++++++++++++++- tests/storage/test_state.py | 2 +- 8 files changed, 174 insertions(+), 19 deletions(-) create mode 100644 changelog.d/3331.feature create mode 100644 changelog.d/3567.feature create mode 100644 changelog.d/3568.feature (limited to 'synapse/handlers') diff --git a/changelog.d/3331.feature b/changelog.d/3331.feature new file mode 100644 index 0000000000..e574b9bcc3 --- /dev/null +++ b/changelog.d/3331.feature @@ -0,0 +1 @@ +add support for the include_redundant_members filter param as per MSC1227 diff --git a/changelog.d/3567.feature b/changelog.d/3567.feature new file mode 100644 index 0000000000..c74c1f57a9 --- /dev/null +++ b/changelog.d/3567.feature @@ -0,0 +1 @@ +make the /context API filter & lazy-load aware as per MSC1227 diff --git a/changelog.d/3568.feature b/changelog.d/3568.feature new file mode 100644 index 0000000000..247f02ba4e --- /dev/null +++ b/changelog.d/3568.feature @@ -0,0 +1 @@ +speed up /members API and add `at` and `membership` params as per MSC1227 diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 01a362360e..893c9bcdc4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -25,7 +25,13 @@ from twisted.internet import defer from twisted.internet.defer import succeed from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + ConsentNotGivenError, + NotFoundError, + SynapseError, +) from synapse.api.urls import ConsentURIBuilder from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event @@ -36,6 +42,7 @@ from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func +from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -82,28 +89,85 @@ class MessageHandler(object): defer.returnValue(data) @defer.inlineCallbacks - def get_state_events(self, user_id, room_id, is_guest=False): + def get_state_events( + self, user_id, room_id, types=None, filtered_types=None, + at_token=None, is_guest=False, + ): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has - left the room return the state events from when they left. + left the room return the state events from when they left. If an explicit + 'at' parameter is passed, return the state events as of that event, if + visible. Args: user_id(str): The user requesting state events. room_id(str): The room ID to get all state events from. + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + at_token(StreamToken|None): the stream token of the at which we are requesting + the stats. If the user is not allowed to view the state as of that + stream token, we raise a 403 SynapseError. If None, returns the current + state based on the current_state_events table. + is_guest(bool): whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] + Raises: + NotFoundError (404) if the at token does not yield an event + + AuthError (403) if the user doesn't have permission to view + members of this room. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + if at_token: + # FIXME this claims to get the state at a stream position, but + # get_recent_events_for_room operates by topo ordering. This therefore + # does not reliably give you the state at the given stream position. + # (https://github.com/matrix-org/synapse/issues/3305) + last_events, _ = yield self.store.get_recent_events_for_room( + room_id, end_token=at_token.room_key, limit=1, + ) - if membership == Membership.JOIN: - room_state = yield self.state.get_current_state(room_id) - elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( - [membership_event_id], None + if not last_events: + raise NotFoundError("Can't find event for token %s" % (at_token, )) + + visible_events = yield filter_events_for_client( + self.store, user_id, last_events, + ) + + event = last_events[0] + if visible_events: + room_state = yield self.store.get_state_for_events( + [event.event_id], types, filtered_types=filtered_types, + ) + room_state = room_state[event.event_id] + else: + raise AuthError( + 403, + "User %s not allowed to view events in room %s at token %s" % ( + user_id, room_id, at_token, + ) + ) + else: + membership, membership_event_id = ( + yield self.auth.check_in_room_or_world_readable( + room_id, user_id, + ) ) - room_state = room_state[membership_event_id] + + if membership == Membership.JOIN: + state_ids = yield self.store.get_filtered_current_state_ids( + room_id, types, filtered_types=filtered_types, + ) + room_state = yield self.store.get_events(state_ids.values()) + elif membership == Membership.LEAVE: + room_state = yield self.store.get_state_for_events( + [membership_event_id], types, filtered_types=filtered_types, + ) + room_state = room_state[membership_event_id] now = self.clock.time_msec() defer.returnValue( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fa5989e74e..fcc1091760 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -34,7 +34,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.streams.config import PaginationConfig -from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID +from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from .base import ClientV1RestServlet, client_path_patterns @@ -384,15 +384,39 @@ class RoomMemberListRestServlet(ClientV1RestServlet): def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) requester = yield self.auth.get_user_by_req(request) - events = yield self.message_handler.get_state_events( + handler = self.message_handler + + # request the state as of a given event, as identified by a stream token, + # for consistency with /messages etc. + # useful for getting the membership in retrospect as of a given /sync + # response. + at_token_string = parse_string(request, "at") + if at_token_string is None: + at_token = None + else: + at_token = StreamToken.from_string(at_token_string) + + # let you filter down on particular memberships. + # XXX: this may not be the best shape for this API - we could pass in a filter + # instead, except filters aren't currently aware of memberships. + # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details. + membership = parse_string(request, "membership") + not_membership = parse_string(request, "not_membership") + + events = yield handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), + at_token=at_token, + types=[(EventTypes.Member, None)], ) chunk = [] for event in events: - if event["type"] != EventTypes.Member: + if ( + (membership and event['content'].get("membership") != membership) or + (not_membership and event['content'].get("membership") == not_membership) + ): continue chunk.append(event) @@ -401,6 +425,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet): })) +# deprecated in favour of /members?membership=join? +# except it does custom AS logic and has a simpler return format class JoinedRoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/joined_members$") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 135af54fa9..025a7fb6d9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1911,7 +1911,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore max_depth = max(row[0] for row in rows) if max_depth <= token.topological: - # We need to ensure we don't delete all the events from the datanase + # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) raise SynapseError( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 17b14d464b..754dfa6973 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -116,6 +116,69 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): _get_current_state_ids_txn, ) + # FIXME: how should this be cached? + def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + """Get the current state event of a given type for a room based on the + current_state_events table. This may not be as up-to-date as the result + of doing a fresh state resolution as per state_handler.get_current_state + Args: + room_id (str) + types (list[(Str, (Str|None))]): List of (type, state_key) tuples + which are used to filter the state fetched. `state_key` may be + None, which matches any `state_key` + filtered_types (list[Str]|None): List of types to apply the above filter to. + Returns: + deferred: dict of (type, state_key) -> event + """ + + include_other_types = False if filtered_types is None else True + + def _get_filtered_current_state_ids_txn(txn): + results = {} + sql = """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? %s""" + # Turns out that postgres doesn't like doing a list of OR's and + # is about 1000x slower, so we just issue a query for each specific + # type seperately. + if types: + clause_to_args = [ + ( + "AND type = ? AND state_key = ?", + (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) + ) + for etype, state_key in types + ] + + if include_other_types: + unique_types = set(filtered_types) + clause_to_args.append( + ( + "AND type <> ? " * len(unique_types), + list(unique_types) + ) + ) + else: + # If types is None we fetch all the state, and so just use an + # empty where clause with no extra args. + clause_to_args = [("", [])] + for where_clause, where_args in clause_to_args: + args = [room_id] + args.extend(where_args) + txn.execute(sql % (where_clause,), args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results + + return self.runInteraction( + "get_filtered_current_state_ids", + _get_filtered_current_state_ids_txn, + ) + @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between @@ -389,8 +452,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - deferred: A list of dicts corresponding to the event_ids given. - The dicts are mappings from (type, state_key) -> state_events + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] """ event_to_groups = yield self._get_state_group_for_events( event_ids, diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 6168c46248..ebfd969b36 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -148,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase): {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state ) - # check we can use filter_types to grab a specific room member + # check we can use filtered_types to grab a specific room member # without filtering out the other event types state = yield self.store.get_state_for_event( e5.event_id, -- cgit 1.5.1 From 3f543dc021dd9456c8ed2da7a1e4769a68c07729 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 16 Aug 2018 10:46:50 +0200 Subject: initial cut at a room summary API (#3574) --- changelog.d/3574.feature | 1 + synapse/handlers/sync.py | 159 ++++++++++++++++++++++++++++++++--- synapse/rest/client/v2_alpha/sync.py | 1 + synapse/storage/_base.py | 7 +- synapse/storage/state.py | 5 +- synapse/storage/stream.py | 4 +- 6 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 changelog.d/3574.feature (limited to 'synapse/handlers') diff --git a/changelog.d/3574.feature b/changelog.d/3574.feature new file mode 100644 index 0000000000..87ac32df72 --- /dev/null +++ b/changelog.d/3574.feature @@ -0,0 +1 @@ +implement `summary` block in /sync response as per MSC688 diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3b21a04a5d..ac3edf0cc9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -75,6 +75,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ "ephemeral", "account_data", "unread_notifications", + "summary", ])): __slots__ = [] @@ -503,10 +504,142 @@ class SyncHandler(object): state = {} defer.returnValue(state) + @defer.inlineCallbacks + def compute_summary(self, room_id, sync_config, batch, state, now_token): + """ Works out a room summary block for this room, summarising the number + of joined members in the room, and providing the 'hero' members if the + room has no name so clients can consistently name rooms. Also adds + state events to 'state' if needed to describe the heroes. + + Args: + room_id(str): + sync_config(synapse.handlers.sync.SyncConfig): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + state(dict): dict of (type, state_key) -> Event as returned by + compute_state_delta + now_token(str): Token of the end of the current batch. + + Returns: + A deferred dict describing the room summary + """ + + # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 + last_events, _ = yield self.store.get_recent_event_ids_for_room( + room_id, end_token=now_token.room_key, limit=1, + ) + + if not last_events: + defer.returnValue(None) + return + + last_event = last_events[-1] + state_ids = yield self.store.get_state_ids_for_event( + last_event.event_id, [ + (EventTypes.Member, None), + (EventTypes.Name, ''), + (EventTypes.CanonicalAlias, ''), + ] + ) + + member_ids = { + state_key: event_id + for (t, state_key), event_id in state_ids.iteritems() + if t == EventTypes.Member + } + name_id = state_ids.get((EventTypes.Name, '')) + canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) + + summary = {} + + # FIXME: it feels very heavy to load up every single membership event + # just to calculate the counts. + member_events = yield self.store.get_events(member_ids.values()) + + joined_user_ids = [] + invited_user_ids = [] + + for ev in member_events.values(): + if ev.content.get("membership") == Membership.JOIN: + joined_user_ids.append(ev.state_key) + elif ev.content.get("membership") == Membership.INVITE: + invited_user_ids.append(ev.state_key) + + # TODO: only send these when they change. + summary["m.joined_member_count"] = len(joined_user_ids) + summary["m.invited_member_count"] = len(invited_user_ids) + + if name_id or canonical_alias_id: + defer.returnValue(summary) + + # FIXME: order by stream ordering, not alphabetic + + me = sync_config.user.to_string() + if (joined_user_ids or invited_user_ids): + summary['m.heroes'] = sorted( + [ + user_id + for user_id in (joined_user_ids + invited_user_ids) + if user_id != me + ] + )[0:5] + else: + summary['m.heroes'] = sorted( + [user_id for user_id in member_ids.keys() if user_id != me] + )[0:5] + + if not sync_config.filter_collection.lazy_load_members(): + defer.returnValue(summary) + + # ensure we send membership events for heroes if needed + cache_key = (sync_config.user.to_string(), sync_config.device_id) + cache = self.get_lazy_loaded_members_cache(cache_key) + + # track which members the client should already know about via LL: + # Ones which are already in state... + existing_members = set( + user_id for (typ, user_id) in state.keys() + if typ == EventTypes.Member + ) + + # ...or ones which are in the timeline... + for ev in batch.events: + if ev.type == EventTypes.Member: + existing_members.add(ev.state_key) + + # ...and then ensure any missing ones get included in state. + missing_hero_event_ids = [ + member_ids[hero_id] + for hero_id in summary['m.heroes'] + if ( + cache.get(hero_id) != member_ids[hero_id] and + hero_id not in existing_members + ) + ] + + missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = missing_hero_state.values() + + for s in missing_hero_state: + cache.set(s.state_key, s.event_id) + state[(EventTypes.Member, s.state_key)] = s + + defer.returnValue(summary) + + def get_lazy_loaded_members_cache(self, cache_key): + cache = self.lazy_loaded_members_cache.get(cache_key) + if cache is None: + logger.debug("creating LruCache for %r", cache_key) + cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) + self.lazy_loaded_members_cache[cache_key] = cache + else: + logger.debug("found LruCache for %r", cache_key) + return cache + @defer.inlineCallbacks def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, full_state): - """ Works out the differnce in state between the start of the timeline + """ Works out the difference in state between the start of the timeline and the previous sync. Args: @@ -520,7 +653,7 @@ class SyncHandler(object): full_state(bool): Whether to force returning the full state. Returns: - A deferred new event dictionary + A deferred dict of (type, state_key) -> Event """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -618,13 +751,7 @@ class SyncHandler(object): if lazy_load_members and not include_redundant_members: cache_key = (sync_config.user.to_string(), sync_config.device_id) - cache = self.lazy_loaded_members_cache.get(cache_key) - if cache is None: - logger.debug("creating LruCache for %r", cache_key) - cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) - self.lazy_loaded_members_cache[cache_key] = cache - else: - logger.debug("found LruCache for %r", cache_key) + cache = self.get_lazy_loaded_members_cache(cache_key) # if it's a new sync sequence, then assume the client has had # amnesia and doesn't want any recent lazy-loaded members @@ -1425,7 +1552,6 @@ class SyncHandler(object): if events == [] and tags is None: return - since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config @@ -1468,6 +1594,18 @@ class SyncHandler(object): full_state=full_state ) + summary = {} + if ( + sync_config.filter_collection.lazy_load_members() and + ( + any(ev.type == EventTypes.Member for ev in batch.events) or + since_token is None + ) + ): + summary = yield self.compute_summary( + room_id, sync_config, batch, state, now_token + ) + if room_builder.rtype == "joined": unread_notifications = {} room_sync = JoinedSyncResult( @@ -1477,6 +1615,7 @@ class SyncHandler(object): ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + summary=summary, ) if room_sync or always_include: diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 8aa06faf23..1275baa1ba 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -370,6 +370,7 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + result["summary"] = room.summary return result diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 44f37b4c1e..08dffd774f 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1150,17 +1150,16 @@ class SQLBaseStore(object): defer.returnValue(retval) def get_user_count_txn(self, txn): - """Get a total number of registerd users in the users list. + """Get a total number of registered users in the users list. Args: txn : Transaction object Returns: - defer.Deferred: resolves to int + int : number of users """ sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" txn.execute(sql_count) - count = txn.fetchone()[0] - defer.returnValue(count) + return txn.fetchone()[0] def _simple_search_list(self, table, term, col, retcols, desc="_simple_search_list"): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 754dfa6973..dd03c4168b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -480,7 +480,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): """ - Get the state dicts corresponding to a list of events + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned @@ -493,7 +494,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - A deferred dict from event_id -> (type, state_key) -> state_event + A deferred dict from event_id -> (type, state_key) -> event_id """ event_to_groups = yield self._get_state_group_for_events( event_ids, diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b9f2b74ac6..4c296d72c0 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -348,7 +348,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token (str): The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of + Deferred[tuple[list[FrozenEvent], str]]: Returns a list of events and a token pointing to the start of the returned events. The events returned are in ascending order. @@ -379,7 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token (str): The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of + Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. -- cgit 1.5.1 From 762a758feae40ccb5baf0ba4808308e99e73e13f Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 16 Aug 2018 15:22:47 +0200 Subject: lazyload aware /messages (#3589) --- changelog.d/3589.feature | 1 + synapse/handlers/pagination.py | 35 ++++++++++++++++++++++++++++++++++- synapse/rest/client/versions.py | 13 ++++++++++++- 3 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 changelog.d/3589.feature (limited to 'synapse/handlers') diff --git a/changelog.d/3589.feature b/changelog.d/3589.feature new file mode 100644 index 0000000000..a8d7124719 --- /dev/null +++ b/changelog.d/3589.feature @@ -0,0 +1 @@ +Add lazy-loading support to /messages as per MSC1227 diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a97d43550f..5170d093e3 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -18,7 +18,7 @@ import logging from twisted.internet import defer from twisted.python.failure import Failure -from synapse.api.constants import Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event from synapse.types import RoomStreamToken @@ -251,6 +251,33 @@ class PaginationHandler(object): is_peeking=(member_event_id is None), ) + state = None + if event_filter and event_filter.lazy_load_members(): + # TODO: remove redundant members + + types = [ + (EventTypes.Member, state_key) + for state_key in set( + event.sender # FIXME: we also care about invite targets etc. + for event in events + ) + ] + + state_ids = yield self.store.get_state_ids_for_event( + events[0].event_id, types=types, + ) + + if state_ids: + state = yield self.store.get_events(list(state_ids.values())) + + if state: + state = yield filter_events_for_client( + self.store, + user_id, + state.values(), + is_peeking=(member_event_id is None), + ) + time_now = self.clock.time_msec() chunk = { @@ -262,4 +289,10 @@ class PaginationHandler(object): "end": next_token.to_string(), } + if state: + chunk["state"] = [ + serialize_event(e, time_now, as_client_event) + for e in state + ] + defer.returnValue(chunk) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 6ac2987b98..29e62bfcdd 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -27,11 +27,22 @@ class VersionsRestServlet(RestServlet): def on_GET(self, request): return (200, { "versions": [ + # XXX: at some point we need to decide whether we need to include + # the previous version numbers, given we've defined r0.3.0 to be + # backwards compatible with r0.2.0. But need to check how + # conscientious we've been in compatibility, and decide whether the + # middle number is the major revision when at 0.X.Y (as opposed to + # X.Y.Z). And we need to decide whether it's fair to make clients + # parse the version string to figure out what's going on. "r0.0.1", "r0.1.0", "r0.2.0", "r0.3.0", - ] + ], + # as per MSC1497: + "unstable_features": { + "m.lazy_load_members": True, + } }) -- cgit 1.5.1 From 372bf073c1a5ac5e66cca717ce9aa72c43ba9404 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 16 Aug 2018 21:25:16 +0100 Subject: block event creation and room creation on hitting resource limits --- synapse/handlers/message.py | 6 +++++- synapse/handlers/room.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 893c9bcdc4..4d006df63c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -276,10 +276,14 @@ class EventCreationHandler(object): where *hashes* is a map from algorithm to hash. If None, they will be requested from the database. - + Raises: + ResourceLimitError if server is blocked to some resource being + exceeded Returns: Tuple of created event (FrozenEvent), Context """ + yield self.auth.check_auth_blocking(requester.user.to_string()) + builder = self.event_builder_factory.new(event_dict) self.validator.validate_new(builder) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6a17c42238..c3f820b975 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -98,9 +98,13 @@ class RoomCreationHandler(BaseHandler): Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. + ResourceLimitError if server is blocked to some resource being + exceeded """ user_id = requester.user.to_string() + self.auth.check_auth_blocking(user_id) + if not self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") -- cgit 1.5.1 From 66f7dc8c87a6c585b4d456c5b2c4999452fc82c2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 17 Aug 2018 00:32:39 +0100 Subject: Fix logcontexts for running pushers First of all, avoid resetting the logcontext before running the pushers, to fix the "Starting db txn 'get_all_updated_receipts' from sentinel context" warning. Instead, give them their own "background process" logcontexts. --- synapse/app/pusher.py | 4 ++-- synapse/handlers/federation.py | 3 +-- synapse/handlers/message.py | 7 ++----- synapse/handlers/receipts.py | 18 ++++++++---------- synapse/push/pusherpool.py | 17 +++++++++++++++-- 5 files changed, 28 insertions(+), 21 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 9295a51d5b..0089a7c64f 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -162,11 +162,11 @@ class PusherReplicationHandler(ReplicationClientHandler): else: yield self.start_pusher(row.user_id, row.app_id, row.pushkey) elif stream_name == "events": - yield self.pusher_pool.on_new_notifications( + self.pusher_pool.on_new_notifications( token, token, ) elif stream_name == "receipts": - yield self.pusher_pool.on_new_receipts( + self.pusher_pool.on_new_receipts( token, token, set(row.room_id for row in rows) ) except Exception: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f38b393e4a..3dd107a285 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2386,8 +2386,7 @@ class FederationHandler(BaseHandler): extra_users=extra_users ) - logcontext.run_in_background( - self.pusher_pool.on_new_notifications, + self.pusher_pool.on_new_notifications( event_stream_id, max_stream_id, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 893c9bcdc4..f21d740968 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -774,11 +774,8 @@ class EventCreationHandler(object): event, context=context ) - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - run_in_background( - self.pusher_pool.on_new_notifications, - event_stream_id, max_stream_id + self.pusher_pool.on_new_notifications( + event_stream_id, max_stream_id, ) def _notify(): diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index cb905a3903..a6f3181f09 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,7 +18,6 @@ from twisted.internet import defer from synapse.types import get_domain_from_id from synapse.util import logcontext -from synapse.util.logcontext import PreserveLoggingContext from ._base import BaseHandler @@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler): affected_room_ids = list(set([r["room_id"] for r in receipts])) - with PreserveLoggingContext(): - self.notifier.on_new_event( - "receipt_key", max_batch_id, rooms=affected_room_ids - ) - # Note that the min here shouldn't be relied upon to be accurate. - self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids - ) + self.notifier.on_new_event( + "receipt_key", max_batch_id, rooms=affected_room_ids + ) + # Note that the min here shouldn't be relied upon to be accurate. + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids, + ) - defer.returnValue(True) + defer.returnValue(True) @logcontext.preserve_fn # caller should not yield on this @defer.inlineCallbacks diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 36bb5bbc65..9f7d5ef217 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push.pusher import PusherFactory from synapse.util.logcontext import make_deferred_yieldable, run_in_background @@ -122,8 +123,14 @@ class PusherPool: p['app_id'], p['pushkey'], p['user_name'], ) - @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): + run_as_background_process( + "on_new_notifications", + self._on_new_notifications, min_stream_id, max_stream_id, + ) + + @defer.inlineCallbacks + def _on_new_notifications(self, min_stream_id, max_stream_id): try: users_affected = yield self.store.get_push_action_users_in_range( min_stream_id, max_stream_id @@ -147,8 +154,14 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + run_as_background_process( + "on_new_receipts", + self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids, + ) + + @defer.inlineCallbacks + def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive -- cgit 1.5.1 From c334ca67bb89039b3a00b7c9a1ce610e99859653 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 18 Aug 2018 01:08:45 +1000 Subject: Integrate presence from hotfixes (#3694) --- changelog.d/3694.feature | 1 + docs/workers.rst | 8 +++ synapse/app/_base.py | 6 ++- synapse/app/frontend_proxy.py | 39 ++++++++++++++- synapse/app/synchrotron.py | 16 ++++-- synapse/config/server.py | 6 +++ synapse/federation/transaction_queue.py | 4 ++ synapse/handlers/initial_sync.py | 4 ++ synapse/handlers/presence.py | 26 +++++++--- synapse/handlers/sync.py | 3 +- synapse/rest/client/v1/presence.py | 3 +- tests/app/__init__.py | 0 tests/app/test_frontend_proxy.py | 88 +++++++++++++++++++++++++++++++++ tests/rest/client/v1/test_presence.py | 72 +++++++++++++++++++++++++++ tests/rest/client/v2_alpha/test_sync.py | 79 +++++++++++++---------------- tests/unittest.py | 8 ++- tests/utils.py | 7 +-- 17 files changed, 303 insertions(+), 67 deletions(-) create mode 100644 changelog.d/3694.feature create mode 100644 tests/app/__init__.py create mode 100644 tests/app/test_frontend_proxy.py create mode 100644 tests/rest/client/v1/test_presence.py (limited to 'synapse/handlers') diff --git a/changelog.d/3694.feature b/changelog.d/3694.feature new file mode 100644 index 0000000000..916a342ff4 --- /dev/null +++ b/changelog.d/3694.feature @@ -0,0 +1 @@ +Synapse's presence functionality can now be disabled with the "use_presence" configuration option. diff --git a/docs/workers.rst b/docs/workers.rst index ac9efb621f..aec319dd84 100644 --- a/docs/workers.rst +++ b/docs/workers.rst @@ -241,6 +241,14 @@ regular expressions:: ^/_matrix/client/(api/v1|r0|unstable)/keys/upload +If ``use_presence`` is False in the homeserver config, it can also handle REST +endpoints matching the following regular expressions:: + + ^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status + +This "stub" presence handler will pass through ``GET`` request but make the +``PUT`` effectively a no-op. + It will proxy any requests it cannot handle to the main synapse instance. It must therefore be configured with the location of the main instance, via the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 391bd14c5c..7c866e246a 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port): logger.info("Metrics now reporting on %s:%d", host, port) -def listen_tcp(bind_addresses, port, factory, backlog=50): +def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): """ Create a TCP socket for a port and several addresses """ @@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50): check_bind_error(e, address, bind_addresses) -def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): +def listen_ssl( + bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50 +): """ Create an SSL socket for a port and several addresses """ diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 671fbbcb2a..8d484c1cd4 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -38,6 +38,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns from synapse.rest.client.v2_alpha._base import client_v2_patterns from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -49,6 +50,35 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.frontend_proxy") +class PresenceStatusStubServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/presence/(?P[^/]*)/status") + + def __init__(self, hs): + super(PresenceStatusStubServlet, self).__init__(hs) + self.http_client = hs.get_simple_http_client() + self.auth = hs.get_auth() + self.main_uri = hs.config.worker_main_http_uri + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + headers = { + "Authorization": auth_headers, + } + result = yield self.http_client.get_json( + self.main_uri + request.uri, + headers=headers, + ) + defer.returnValue((200, result)) + + @defer.inlineCallbacks + def on_PUT(self, request, user_id): + yield self.auth.get_user_by_req(request) + defer.returnValue((200, {})) + + class KeyUploadServlet(RestServlet): PATTERNS = client_v2_patterns("/keys/upload(/(?P[^/]+))?$") @@ -135,6 +165,12 @@ class FrontendProxyServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) KeyUploadServlet(self).register(resource) + + # If presence is disabled, use the stub servlet that does + # not allow sending presence + if not self.config.use_presence: + PresenceStatusStubServlet(self).register(resource) + resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, @@ -153,7 +189,8 @@ class FrontendProxyServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), + reactor=self.get_reactor() ) logger.info("Synapse client reader now listening on port %d", port) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index e201f18efd..cade09d60e 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -114,7 +114,10 @@ class SynchrotronPresence(object): logger.info("Presence process_id is %r", self.process_id) def send_user_sync(self, user_id, is_syncing, last_sync_ms): - self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_user_sync( + user_id, is_syncing, last_sync_ms + ) def mark_as_coming_online(self, user_id): """A user has started syncing. Send a UserSync to the master, unless they @@ -211,10 +214,13 @@ class SynchrotronPresence(object): yield self.notify_from_replication(states, stream_id) def get_currently_syncing_users(self): - return [ - user_id for user_id, count in iteritems(self.user_to_num_current_syncs) - if count > 0 - ] + if self.hs.config.use_presence: + return [ + user_id for user_id, count in iteritems(self.user_to_num_current_syncs) + if count > 0 + ] + else: + return set() class SynchrotronTyping(object): diff --git a/synapse/config/server.py b/synapse/config/server.py index a41c48e69c..68a612e594 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -49,6 +49,9 @@ class ServerConfig(Config): # "disable" federation self.send_federation = config.get("send_federation", True) + # Whether to enable user presence. + self.use_presence = config.get("use_presence", True) + # Whether to update the user directory or not. This should be set to # false only if we are updating the user directory in a worker self.update_user_directory = config.get("update_user_directory", True) @@ -250,6 +253,9 @@ class ServerConfig(Config): # hard limit. soft_file_limit: 0 + # Set to false to disable presence tracking on this homeserver. + use_presence: true + # The GC threshold parameters to pass to `gc.set_threshold`, if defined # gc_thresholds: [700, 10, 10] diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index f603c8a368..94d7423d01 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -58,6 +58,7 @@ class TransactionQueue(object): """ def __init__(self, hs): + self.hs = hs self.server_name = hs.hostname self.store = hs.get_datastore() @@ -308,6 +309,9 @@ class TransactionQueue(object): Args: states (list(UserPresenceState)) """ + if not self.hs.config.use_presence: + # No-op if presence is disabled. + return # First we queue up the new presence by user ID, so multiple presence # updates in quick successtion are correctly handled diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 1fb17fd9a5..e009395207 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -372,6 +372,10 @@ class InitialSyncHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): + # If presence is disabled, return an empty list + if not self.hs.config.use_presence: + defer.returnValue([]) + states = yield presence_handler.get_states( [m.user_id for m in room_members], as_event=True, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3671d24f60..ba3856674d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -395,6 +395,10 @@ class PresenceHandler(object): """We've seen the user do something that indicates they're interacting with the app. """ + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + user_id = user.to_string() bump_active_time_counter.inc() @@ -424,6 +428,11 @@ class PresenceHandler(object): Useful for streams that are not associated with an actual client that is being used by a user. """ + # Override if it should affect the user's presence, if presence is + # disabled. + if not self.hs.config.use_presence: + affect_presence = False + if affect_presence: curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 @@ -469,13 +478,16 @@ class PresenceHandler(object): Returns: set(str): A set of user_id strings. """ - syncing_user_ids = { - user_id for user_id, count in self.user_to_num_current_syncs.items() - if count - } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) - return syncing_user_ids + if self.hs.config.use_presence: + syncing_user_ids = { + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + return syncing_user_ids + else: + return set() @defer.inlineCallbacks def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ac3edf0cc9..648debc8aa 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -185,6 +185,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ class SyncHandler(object): def __init__(self, hs): + self.hs_config = hs.config self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() @@ -860,7 +861,7 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) - if not block_all_presence_data: + if self.hs_config.use_presence and not block_all_presence_data: yield self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_users ) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a14f0c807e..b5a6d6aebf 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -84,7 +84,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except Exception: raise SynapseError(400, "Unable to parse state") - yield self.presence_handler.set_state(user, state) + if self.hs.config.use_presence: + yield self.presence_handler.set_state(user, state) defer.returnValue((200, {})) diff --git a/tests/app/__init__.py b/tests/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py new file mode 100644 index 0000000000..76b5090fff --- /dev/null +++ b/tests/app/test_frontend_proxy.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.app.frontend_proxy import FrontendProxyServer + +from tests.unittest import HomeserverTestCase + + +class FrontendProxyTests(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver( + http_client=None, homeserverToUse=FrontendProxyServer + ) + + return hs + + def test_listen_http_with_presence_enabled(self): + """ + When presence is on, the stub servlet will not register. + """ + # Presence is on + self.hs.config.use_presence = True + + config = { + "port": 8080, + "bind_addresses": ["0.0.0.0"], + "resources": [{"names": ["client"]}], + } + + # Listen with the config + self.hs._listen_http(config) + + # Grab the resource from the site that was told to listen + self.assertEqual(len(self.reactor.tcpServers), 1) + site = self.reactor.tcpServers[0][1] + self.resource = ( + site.resource.children["_matrix"].children["client"].children["r0"] + ) + + request, channel = self.make_request("PUT", "presence/a/status") + self.render(request) + + # 400 + unrecognised, because nothing is registered + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + def test_listen_http_with_presence_disabled(self): + """ + When presence is on, the stub servlet will register. + """ + # Presence is off + self.hs.config.use_presence = False + + config = { + "port": 8080, + "bind_addresses": ["0.0.0.0"], + "resources": [{"names": ["client"]}], + } + + # Listen with the config + self.hs._listen_http(config) + + # Grab the resource from the site that was told to listen + self.assertEqual(len(self.reactor.tcpServers), 1) + site = self.reactor.tcpServers[0][1] + self.resource = ( + site.resource.children["_matrix"].children["client"].children["r0"] + ) + + request, channel = self.make_request("PUT", "presence/a/status") + self.render(request) + + # 401, because the stub servlet still checks authentication + self.assertEqual(channel.code, 401) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py new file mode 100644 index 0000000000..66c2b68707 --- /dev/null +++ b/tests/rest/client/v1/test_presence.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from synapse.rest.client.v1 import presence +from synapse.types import UserID + +from tests import unittest + + +class PresenceTestCase(unittest.HomeserverTestCase): + """ Tests presence REST API. """ + + user_id = "@sid:red" + + user = UserID.from_string(user_id) + servlets = [presence.register_servlets] + + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver( + "red", http_client=None, federation_client=Mock() + ) + + hs.presence_handler = Mock() + + return hs + + def test_put_presence(self): + """ + PUT to the status endpoint with use_presence enabled will call + set_state on the presence handler. + """ + self.hs.config.use_presence = True + + body = {"presence": "here", "status_msg": "beep boop"} + request, channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + self.render(request) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.presence_handler.set_state.call_count, 1) + + def test_put_presence_disabled(self): + """ + PUT to the status endpoint with use_presence disbled will NOT call + set_state on the presence handler. + """ + self.hs.config.use_presence = False + + body = {"presence": "here", "status_msg": "beep boop"} + request, channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + self.render(request) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.presence_handler.set_state.call_count, 0) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 9f3d8bd1db..560b1fba96 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -13,71 +13,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -import synapse.types -from synapse.http.server import JsonResource +from mock import Mock + from synapse.rest.client.v2_alpha import sync -from synapse.types import UserID -from synapse.util import Clock from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock as MemoryReactorClock, - make_request, - render, - setup_test_homeserver, -) - -PATH_PREFIX = "/_matrix/client/v2_alpha" -class FilterTestCase(unittest.TestCase): +class FilterTestCase(unittest.HomeserverTestCase): - USER_ID = "@apple:test" - TO_REGISTER = [sync] + user_id = "@apple:test" + servlets = [sync.register_servlets] - def setUp(self): - self.clock = MemoryReactorClock() - self.hs_clock = Clock(self.clock) + def make_homeserver(self, reactor, clock): - self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock + hs = self.setup_test_homeserver( + "red", http_client=None, federation_client=Mock() ) + return hs - self.auth = self.hs.get_auth() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.USER_ID), - "token_id": 1, - "is_guest": False, - } - - def get_user_by_req(request, allow_guest=False, rights="access"): - return synapse.types.create_requester( - UserID.from_string(self.USER_ID), 1, False, None - ) - - self.auth.get_user_by_access_token = get_user_by_access_token - self.auth.get_user_by_req = get_user_by_req + def test_sync_argless(self): + request, channel = self.make_request("GET", "/sync") + self.render(request) - self.store = self.hs.get_datastore() - self.filtering = self.hs.get_filtering() - self.resource = JsonResource(self.hs) + self.assertEqual(channel.code, 200) + self.assertTrue( + set( + [ + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + ] + ).issubset(set(channel.json_body.keys())) + ) - for r in self.TO_REGISTER: - r.register_servlets(self.hs, self.resource) + def test_sync_presence_disabled(self): + """ + When presence is disabled, the key does not appear in /sync. + """ + self.hs.config.use_presence = False - def test_sync_argless(self): - request, channel = make_request("GET", "/_matrix/client/r0/sync") - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", "/sync") + self.render(request) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertTrue( set( [ "next_batch", "rooms", - "presence", "account_data", "to_device", "device_lists", diff --git a/tests/unittest.py b/tests/unittest.py index e6afe3b96d..d852e2465a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -18,6 +18,8 @@ import logging from mock import Mock +from canonicaljson import json + import twisted import twisted.logger from twisted.trial import unittest @@ -241,11 +243,15 @@ class HomeserverTestCase(TestCase): method (bytes/unicode): The HTTP request method ("verb"). path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). - content (bytes): The body of the request. + content (bytes or dict): The body of the request. JSON-encoded, if + a dict. Returns: A synapse.http.site.SynapseRequest. """ + if isinstance(content, dict): + content = json.dumps(content).encode('utf8') + return make_request(method, path, content) def render(self, request): diff --git a/tests/utils.py b/tests/utils.py index 6f8b1de3e7..bb0fc74054 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -93,7 +93,8 @@ def setupdb(): @defer.inlineCallbacks def setup_test_homeserver( - cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs + cleanup_func, name="test", datastore=None, config=None, reactor=None, + homeserverToUse=HomeServer, **kargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -192,7 +193,7 @@ def setup_test_homeserver( config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection if datastore is None: - hs = HomeServer( + hs = homeserverToUse( name, config=config, db_config=config.database_config, @@ -235,7 +236,7 @@ def setup_test_homeserver( hs.setup() else: - hs = HomeServer( + hs = homeserverToUse( name, db_pool=None, datastore=datastore, -- cgit 1.5.1