diff options
56 files changed, 734 insertions, 344 deletions
diff --git a/changelog.d/3694.feature b/changelog.d/3694.feature new file mode 100644 index 0000000000..916a342ff4 --- /dev/null +++ b/changelog.d/3694.feature @@ -0,0 +1 @@ +Synapse's presence functionality can now be disabled with the "use_presence" configuration option. diff --git a/changelog.d/3700.bugfix b/changelog.d/3700.bugfix new file mode 100644 index 0000000000..492cce1dfc --- /dev/null +++ b/changelog.d/3700.bugfix @@ -0,0 +1 @@ +Improve HTTP request logging to include all requests \ No newline at end of file diff --git a/changelog.d/3701.bugfix b/changelog.d/3701.bugfix new file mode 100644 index 0000000000..c22de34537 --- /dev/null +++ b/changelog.d/3701.bugfix @@ -0,0 +1 @@ +Avoid timing out requests while we are streaming back the response diff --git a/changelog.d/3703.removal b/changelog.d/3703.removal new file mode 100644 index 0000000000..d13ce5adb4 --- /dev/null +++ b/changelog.d/3703.removal @@ -0,0 +1 @@ +The Shared-Secret registration method of the legacy v1/register REST endpoint has been removed. For a replacement, please see [the admin/register API documentation](https://github.com/matrix-org/synapse/blob/master/docs/admin_api/register_api.rst). diff --git a/changelog.d/3707.misc b/changelog.d/3707.misc new file mode 100644 index 0000000000..8123ca6543 --- /dev/null +++ b/changelog.d/3707.misc @@ -0,0 +1 @@ +add new error type ResourceLimit diff --git a/changelog.d/3708.feature b/changelog.d/3708.feature new file mode 100644 index 0000000000..2f146ba62b --- /dev/null +++ b/changelog.d/3708.feature @@ -0,0 +1 @@ +For resource limit blocked users, prevent writing into rooms diff --git a/changelog.d/3709.misc b/changelog.d/3709.misc new file mode 100644 index 0000000000..bbda357d44 --- /dev/null +++ b/changelog.d/3709.misc @@ -0,0 +1 @@ +Logcontexts for replication command handlers diff --git a/changelog.d/3710.bugfix b/changelog.d/3710.bugfix new file mode 100644 index 0000000000..c28e177667 --- /dev/null +++ b/changelog.d/3710.bugfix @@ -0,0 +1 @@ +Fix "Starting db txn 'get_all_updated_receipts' from sentinel context" warning \ No newline at end of file diff --git a/changelog.d/3712.misc b/changelog.d/3712.misc new file mode 100644 index 0000000000..30f8c2af21 --- /dev/null +++ b/changelog.d/3712.misc @@ -0,0 +1 @@ +Update admin register API documentation to reference a real user ID. diff --git a/docs/admin_api/register_api.rst b/docs/admin_api/register_api.rst index 209cd140fd..16d65c86b3 100644 --- a/docs/admin_api/register_api.rst +++ b/docs/admin_api/register_api.rst @@ -33,7 +33,7 @@ As an example:: < { "access_token": "token_here", - "user_id": "@pepper_roni@test", + "user_id": "@pepper_roni:localhost", "home_server": "test", "device_id": "device_id_here" } diff --git a/docs/workers.rst b/docs/workers.rst index ac9efb621f..aec319dd84 100644 --- a/docs/workers.rst +++ b/docs/workers.rst @@ -241,6 +241,14 @@ regular expressions:: ^/_matrix/client/(api/v1|r0|unstable)/keys/upload +If ``use_presence`` is False in the homeserver config, it can also handle REST +endpoints matching the following regular expressions:: + + ^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status + +This "stub" presence handler will pass through ``GET`` request but make the +``PUT`` effectively a no-op. + It will proxy any requests it cannot handle to the main synapse instance. It must therefore be configured with the location of the main instance, via the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3b2a2ab77a..022211e34e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,7 @@ from twisted.internet import defer import synapse.types from synapse import event_auth from synapse.api.constants import EventTypes, JoinRules, Membership -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, ResourceLimitError from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -784,10 +784,11 @@ class Auth(object): MAU cohort """ if self.hs.config.hs_disabled: - raise AuthError( + raise ResourceLimitError( 403, self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEED, admin_uri=self.hs.config.admin_uri, + limit_type=self.hs.config.hs_disabled_limit_type ) if self.hs.config.limit_usage_by_mau is True: # If the user is already part of the MAU cohort @@ -798,8 +799,10 @@ class Auth(object): # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: - raise AuthError( - 403, "Monthly Active User Limits AU Limit Exceeded", + raise ResourceLimitError( + 403, "Monthly Active User Limit Exceeded", + admin_uri=self.hs.config.admin_uri, - errcode=Codes.RESOURCE_LIMIT_EXCEED + errcode=Codes.RESOURCE_LIMIT_EXCEED, + limit_type="monthly_active_user" ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 08f0cb5554..e26001ab12 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -224,15 +224,34 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" - def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None): + + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.FORBIDDEN + super(AuthError, self).__init__(*args, **kwargs) + + +class ResourceLimitError(SynapseError): + """ + Any error raised when there is a problem with resource usage. + For instance, the monthly active user limit for the server has been exceeded + """ + def __init__( + self, code, msg, + errcode=Codes.RESOURCE_LIMIT_EXCEED, + admin_uri=None, + limit_type=None, + ): self.admin_uri = admin_uri - super(AuthError, self).__init__(code, msg, errcode=errcode) + self.limit_type = limit_type + super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) def error_dict(self): return cs_error( self.msg, self.errcode, admin_uri=self.admin_uri, + limit_type=self.limit_type ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 391bd14c5c..7c866e246a 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port): logger.info("Metrics now reporting on %s:%d", host, port) -def listen_tcp(bind_addresses, port, factory, backlog=50): +def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): """ Create a TCP socket for a port and several addresses """ @@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50): check_bind_error(e, address, bind_addresses) -def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50): +def listen_ssl( + bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50 +): """ Create an SSL socket for a port and several addresses """ diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 9a37384fb7..3348a8ec6d 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -117,8 +117,9 @@ class ASReplicationHandler(ReplicationClientHandler): super(ASReplicationHandler, self).__init__(hs.get_datastore()) self.appservice_handler = hs.get_application_service_handler() + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) if stream_name == "events": max_stream_id = self.store.get_room_max_stream_ordering() diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 7a4310ca18..d59007099b 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -144,8 +144,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) self.send_handler = FederationSenderHandler(hs, self) + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(FederationSenderReplicationHandler, self).on_rdata( + yield super(FederationSenderReplicationHandler, self).on_rdata( stream_name, token, rows ) self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 671fbbcb2a..8d484c1cd4 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -38,6 +38,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns from synapse.rest.client.v2_alpha._base import client_v2_patterns from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -49,6 +50,35 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.frontend_proxy") +class PresenceStatusStubServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") + + def __init__(self, hs): + super(PresenceStatusStubServlet, self).__init__(hs) + self.http_client = hs.get_simple_http_client() + self.auth = hs.get_auth() + self.main_uri = hs.config.worker_main_http_uri + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + headers = { + "Authorization": auth_headers, + } + result = yield self.http_client.get_json( + self.main_uri + request.uri, + headers=headers, + ) + defer.returnValue((200, result)) + + @defer.inlineCallbacks + def on_PUT(self, request, user_id): + yield self.auth.get_user_by_req(request) + defer.returnValue((200, {})) + + class KeyUploadServlet(RestServlet): PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") @@ -135,6 +165,12 @@ class FrontendProxyServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) KeyUploadServlet(self).register(resource) + + # If presence is disabled, use the stub servlet that does + # not allow sending presence + if not self.config.use_presence: + PresenceStatusStubServlet(self).register(resource) + resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, @@ -153,7 +189,8 @@ class FrontendProxyServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), + reactor=self.get_reactor() ) logger.info("Synapse client reader now listening on port %d", port) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index a98bb506e5..005921dcf7 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -525,6 +525,7 @@ def run(hs): clock.looping_call( hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 ) + hs.get_datastore().reap_monthly_active_users() @defer.inlineCallbacks def generate_monthly_active_users(): diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 9295a51d5b..a4fc7e91fa 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -148,8 +148,9 @@ class PusherReplicationHandler(ReplicationClientHandler): self.pusher_pool = hs.get_pusherpool() + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows) @defer.inlineCallbacks @@ -162,11 +163,11 @@ class PusherReplicationHandler(ReplicationClientHandler): else: yield self.start_pusher(row.user_id, row.app_id, row.pushkey) elif stream_name == "events": - yield self.pusher_pool.on_new_notifications( + self.pusher_pool.on_new_notifications( token, token, ) elif stream_name == "receipts": - yield self.pusher_pool.on_new_receipts( + self.pusher_pool.on_new_receipts( token, token, set(row.room_id for row in rows) ) except Exception: diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index e201f18efd..27e1998660 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -114,7 +114,10 @@ class SynchrotronPresence(object): logger.info("Presence process_id is %r", self.process_id) def send_user_sync(self, user_id, is_syncing, last_sync_ms): - self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_user_sync( + user_id, is_syncing, last_sync_ms + ) def mark_as_coming_online(self, user_id): """A user has started syncing. Send a UserSync to the master, unless they @@ -211,10 +214,13 @@ class SynchrotronPresence(object): yield self.notify_from_replication(states, stream_id) def get_currently_syncing_users(self): - return [ - user_id for user_id, count in iteritems(self.user_to_num_current_syncs) - if count > 0 - ] + if self.hs.config.use_presence: + return [ + user_id for user_id, count in iteritems(self.user_to_num_current_syncs) + if count > 0 + ] + else: + return set() class SynchrotronTyping(object): @@ -332,8 +338,9 @@ class SyncReplicationHandler(ReplicationClientHandler): self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows) def get_streams_to_replicate(self): diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index cb78de8834..1388a42b59 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -169,8 +169,9 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) self.user_directory = hs.get_user_directory_handler() + @defer.inlineCallbacks def on_rdata(self, stream_name, token, rows): - super(UserDirectoryReplicationHandler, self).on_rdata( + yield super(UserDirectoryReplicationHandler, self).on_rdata( stream_name, token, rows ) if stream_name == "current_state_deltas": diff --git a/synapse/config/server.py b/synapse/config/server.py index 2190f3210a..68a612e594 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -49,6 +49,9 @@ class ServerConfig(Config): # "disable" federation self.send_federation = config.get("send_federation", True) + # Whether to enable user presence. + self.use_presence = config.get("use_presence", True) + # Whether to update the user directory or not. This should be set to # false only if we are updating the user directory in a worker self.update_user_directory = config.get("update_user_directory", True) @@ -81,6 +84,7 @@ class ServerConfig(Config): # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") + self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "") # Admin uri to direct users at should their instance become blocked # due to resource constraints @@ -249,6 +253,9 @@ class ServerConfig(Config): # hard limit. soft_file_limit: 0 + # Set to false to disable presence tracking on this homeserver. + use_presence: true + # The GC threshold parameters to pass to `gc.set_threshold`, if defined # gc_thresholds: [700, 10, 10] @@ -340,6 +347,32 @@ class ServerConfig(Config): # - port: 9000 # bind_addresses: ['::1', '127.0.0.1'] # type: manhole + + + # Homeserver blocking + # + # How to reach the server admin, used in ResourceLimitError + # admin_uri: 'mailto:admin@server.com' + # + # Global block config + # + # hs_disabled: False + # hs_disabled_message: 'Human readable reason for why the HS is blocked' + # hs_disabled_limit_type: 'error code(str), to help clients decode reason' + # + # Monthly Active User Blocking + # + # Enables monthly active user checking + # limit_usage_by_mau: False + # max_mau_value: 50 + # + # Sometimes the server admin will want to ensure certain accounts are + # never blocked by mau checking. These accounts are specified here. + # + # mau_limit_reserved_threepids: + # - medium: 'email' + # address: 'reserved_user@example.com' + """ % locals() def read_arguments(self, args): diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index f603c8a368..94d7423d01 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -58,6 +58,7 @@ class TransactionQueue(object): """ def __init__(self, hs): + self.hs = hs self.server_name = hs.hostname self.store = hs.get_datastore() @@ -308,6 +309,9 @@ class TransactionQueue(object): Args: states (list(UserPresenceState)) """ + if not self.hs.config.use_presence: + # No-op if presence is disabled. + return # First we queue up the new presence by user ID, so multiple presence # updates in quick successtion are correctly handled diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f38b393e4a..3dd107a285 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2386,8 +2386,7 @@ class FederationHandler(BaseHandler): extra_users=extra_users ) - logcontext.run_in_background( - self.pusher_pool.on_new_notifications, + self.pusher_pool.on_new_notifications( event_stream_id, max_stream_id, ) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 1fb17fd9a5..e009395207 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -372,6 +372,10 @@ class InitialSyncHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): + # If presence is disabled, return an empty list + if not self.hs.config.use_presence: + defer.returnValue([]) + states = yield presence_handler.get_states( [m.user_id for m in room_members], as_event=True, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 893c9bcdc4..e484061cc0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -276,10 +276,14 @@ class EventCreationHandler(object): where *hashes* is a map from algorithm to hash. If None, they will be requested from the database. - + Raises: + ResourceLimitError if server is blocked to some resource being + exceeded Returns: Tuple of created event (FrozenEvent), Context """ + yield self.auth.check_auth_blocking(requester.user.to_string()) + builder = self.event_builder_factory.new(event_dict) self.validator.validate_new(builder) @@ -774,11 +778,8 @@ class EventCreationHandler(object): event, context=context ) - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - run_in_background( - self.pusher_pool.on_new_notifications, - event_stream_id, max_stream_id + self.pusher_pool.on_new_notifications( + event_stream_id, max_stream_id, ) def _notify(): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3671d24f60..ba3856674d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -395,6 +395,10 @@ class PresenceHandler(object): """We've seen the user do something that indicates they're interacting with the app. """ + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + user_id = user.to_string() bump_active_time_counter.inc() @@ -424,6 +428,11 @@ class PresenceHandler(object): Useful for streams that are not associated with an actual client that is being used by a user. """ + # Override if it should affect the user's presence, if presence is + # disabled. + if not self.hs.config.use_presence: + affect_presence = False + if affect_presence: curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 @@ -469,13 +478,16 @@ class PresenceHandler(object): Returns: set(str): A set of user_id strings. """ - syncing_user_ids = { - user_id for user_id, count in self.user_to_num_current_syncs.items() - if count - } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) - return syncing_user_ids + if self.hs.config.use_presence: + syncing_user_ids = { + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + return syncing_user_ids + else: + return set() @defer.inlineCallbacks def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index cb905a3903..a6f3181f09 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,7 +18,6 @@ from twisted.internet import defer from synapse.types import get_domain_from_id from synapse.util import logcontext -from synapse.util.logcontext import PreserveLoggingContext from ._base import BaseHandler @@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler): affected_room_ids = list(set([r["room_id"] for r in receipts])) - with PreserveLoggingContext(): - self.notifier.on_new_event( - "receipt_key", max_batch_id, rooms=affected_room_ids - ) - # Note that the min here shouldn't be relied upon to be accurate. - self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids - ) + self.notifier.on_new_event( + "receipt_key", max_batch_id, rooms=affected_room_ids + ) + # Note that the min here shouldn't be relied upon to be accurate. + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids, + ) - defer.returnValue(True) + defer.returnValue(True) @logcontext.preserve_fn # caller should not yield on this @defer.inlineCallbacks diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6a17c42238..c3f820b975 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -98,9 +98,13 @@ class RoomCreationHandler(BaseHandler): Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. + ResourceLimitError if server is blocked to some resource being + exceeded """ user_id = requester.user.to_string() + self.auth.check_auth_blocking(user_id) + if not self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ac3edf0cc9..648debc8aa 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -185,6 +185,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ class SyncHandler(object): def __init__(self, hs): + self.hs_config = hs.config self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() @@ -860,7 +861,7 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) - if not block_all_presence_data: + if self.hs_config.use_presence and not block_all_presence_data: yield self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_users ) diff --git a/synapse/http/server.py b/synapse/http/server.py index 6dacb31037..2d5c23e673 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -25,8 +25,9 @@ from canonicaljson import encode_canonical_json, encode_pretty_printed_json, jso from twisted.internet import defer from twisted.python import failure -from twisted.web import resource, server +from twisted.web import resource from twisted.web.server import NOT_DONE_YET +from twisted.web.static import NoRangeStaticProducer from twisted.web.util import redirectTo import synapse.events @@ -37,10 +38,13 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) -from synapse.http.request_metrics import requests_counter from synapse.util.caches import intern_dict -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.metrics import Measure +from synapse.util.logcontext import preserve_fn + +if PY3: + from io import BytesIO +else: + from cStringIO import StringIO as BytesIO logger = logging.getLogger(__name__) @@ -60,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> def wrap_json_request_handler(h): """Wraps a request handler method with exception handling. - Also adds logging as per wrap_request_handler_with_logging. + Also does the wrapping with request.processing as per wrap_async_request_handler. The handler method must have a signature of "handle_foo(self, request)", - where "self" must have a "clock" attribute (and "request" must be a - SynapseRequest). + where "request" must be a SynapseRequest. The handler must return a deferred. If the deferred succeeds we assume that a response has been sent. If the deferred fails with a SynapseError we use @@ -108,24 +111,23 @@ def wrap_json_request_handler(h): pretty_print=_request_user_agent_is_curl(request), ) - return wrap_request_handler_with_logging(wrapped_request_handler) + return wrap_async_request_handler(wrapped_request_handler) def wrap_html_request_handler(h): """Wraps a request handler method with exception handling. - Also adds logging as per wrap_request_handler_with_logging. + Also does the wrapping with request.processing as per wrap_async_request_handler. The handler method must have a signature of "handle_foo(self, request)", - where "self" must have a "clock" attribute (and "request" must be a - SynapseRequest). + where "request" must be a SynapseRequest. """ def wrapped_request_handler(self, request): d = defer.maybeDeferred(h, self, request) d.addErrback(_return_html_error, request) return d - return wrap_request_handler_with_logging(wrapped_request_handler) + return wrap_async_request_handler(wrapped_request_handler) def _return_html_error(f, request): @@ -170,46 +172,26 @@ def _return_html_error(f, request): finish_request(request) -def wrap_request_handler_with_logging(h): - """Wraps a request handler to provide logging and metrics +def wrap_async_request_handler(h): + """Wraps an async request handler so that it calls request.processing. + + This helps ensure that work done by the request handler after the request is completed + is correctly recorded against the request metrics/logs. The handler method must have a signature of "handle_foo(self, request)", - where "self" must have a "clock" attribute (and "request" must be a - SynapseRequest). + where "request" must be a SynapseRequest. - As well as calling `request.processing` (which will log the response and - duration for this request), the wrapped request handler will insert the - request id into the logging context. + The handler may return a deferred, in which case the completion of the request isn't + logged until the deferred completes. """ @defer.inlineCallbacks - def wrapped_request_handler(self, request): - """ - Args: - self: - request (synapse.http.site.SynapseRequest): - """ + def wrapped_async_request_handler(self, request): + with request.processing(): + yield h(self, request) - request_id = request.get_request_id() - with LoggingContext(request_id) as request_context: - request_context.request = request_id - with Measure(self.clock, "wrapped_request_handler"): - # we start the request metrics timer here with an initial stab - # at the servlet name. For most requests that name will be - # JsonResource (or a subclass), and JsonResource._async_render - # will update it once it picks a servlet. - servlet_name = self.__class__.__name__ - with request.processing(servlet_name): - with PreserveLoggingContext(request_context): - d = defer.maybeDeferred(h, self, request) - - # record the arrival of the request *after* - # dispatching to the handler, so that the handler - # can update the servlet name in the request - # metrics - requests_counter.labels(request.method, - request.request_metrics.name).inc() - yield d - return wrapped_request_handler + # we need to preserve_fn here, because the synchronous render method won't yield for + # us (obviously) + return preserve_fn(wrapped_async_request_handler) class HttpServer(object): @@ -272,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource): """ This gets called by twisted every time someone sends us a request. """ self._async_render(request) - return server.NOT_DONE_YET + return NOT_DONE_YET @wrap_json_request_handler @defer.inlineCallbacks @@ -413,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False, return if pretty_print: - json_bytes = (encode_pretty_printed_json(json_object) + "\n" - ).encode("utf-8") + json_bytes = encode_pretty_printed_json(json_object) + b"\n" else: if canonical_json or synapse.events.USE_FROZEN_DICTS: # canonicaljson already encodes to bytes @@ -450,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, if send_cors: set_cors_headers(request) - request.write(json_bytes) - finish_request(request) + # todo: we can almost certainly avoid this copy and encode the json straight into + # the bytesIO, but it would involve faffing around with string->bytes wrappers. + bytes_io = BytesIO(json_bytes) + + producer = NoRangeStaticProducer(request, bytes_io) + producer.start() return NOT_DONE_YET diff --git a/synapse/http/site.py b/synapse/http/site.py index 5fd30a4c2c..f5a8f78406 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib import logging import time @@ -19,8 +18,8 @@ import time from twisted.web.server import Request, Site from synapse.http import redact_uri -from synapse.http.request_metrics import RequestMetrics -from synapse.util.logcontext import ContextResourceUsage, LoggingContext +from synapse.http.request_metrics import RequestMetrics, requests_counter +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext logger = logging.getLogger(__name__) @@ -34,25 +33,43 @@ class SynapseRequest(Request): It extends twisted's twisted.web.server.Request, and adds: * Unique request ID + * A log context associated with the request * Redaction of access_token query-params in __repr__ * Logging at start and end * Metrics to record CPU, wallclock and DB time by endpoint. - It provides a method `processing` which should be called by the Resource - which is handling the request, and returns a context manager. + It also provides a method `processing`, which returns a context manager. If this + method is called, the request won't be logged until the context manager is closed; + this is useful for asynchronous request handlers which may go on processing the + request even after the client has disconnected. + Attributes: + logcontext(LoggingContext) : the log context for this request """ def __init__(self, site, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) self.site = site - self._channel = channel + self._channel = channel # this is used by the tests self.authenticated_entity = None self.start_time = 0 + # we can't yet create the logcontext, as we don't know the method. + self.logcontext = None + global _next_request_seq self.request_seq = _next_request_seq _next_request_seq += 1 + # whether an asynchronous request handler has called processing() + self._is_processing = False + + # the time when the asynchronous request handler completed its processing + self._processing_finished_time = None + + # what time we finished sending the response to the client (or the connection + # dropped) + self.finish_time = None + def __repr__(self): # We overwrite this so that we don't log ``access_token`` return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( @@ -74,11 +91,116 @@ class SynapseRequest(Request): return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] def render(self, resrc): + # this is called once a Resource has been found to serve the request; in our + # case the Resource in question will normally be a JsonResource. + + # create a LogContext for this request + request_id = self.get_request_id() + logcontext = self.logcontext = LoggingContext(request_id) + logcontext.request = request_id + # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) - return Request.render(self, resrc) + + with PreserveLoggingContext(self.logcontext): + # we start the request metrics timer here with an initial stab + # at the servlet name. For most requests that name will be + # JsonResource (or a subclass), and JsonResource._async_render + # will update it once it picks a servlet. + servlet_name = resrc.__class__.__name__ + self._started_processing(servlet_name) + + Request.render(self, resrc) + + # record the arrival of the request *after* + # dispatching to the handler, so that the handler + # can update the servlet name in the request + # metrics + requests_counter.labels(self.method, + self.request_metrics.name).inc() + + @contextlib.contextmanager + def processing(self): + """Record the fact that we are processing this request. + + Returns a context manager; the correct way to use this is: + + @defer.inlineCallbacks + def handle_request(request): + with request.processing("FooServlet"): + yield really_handle_the_request() + + Once the context manager is closed, the completion of the request will be logged, + and the various metrics will be updated. + """ + if self._is_processing: + raise RuntimeError("Request is already processing") + self._is_processing = True + + try: + yield + except Exception: + # this should already have been caught, and sent back to the client as a 500. + logger.exception("Asynchronous messge handler raised an uncaught exception") + finally: + # the request handler has finished its work and either sent the whole response + # back, or handed over responsibility to a Producer. + + self._processing_finished_time = time.time() + self._is_processing = False + + # if we've already sent the response, log it now; otherwise, we wait for the + # response to be sent. + if self.finish_time is not None: + self._finished_processing() + + def finish(self): + """Called when all response data has been written to this Request. + + Overrides twisted.web.server.Request.finish to record the finish time and do + logging. + """ + self.finish_time = time.time() + Request.finish(self) + if not self._is_processing: + with PreserveLoggingContext(self.logcontext): + self._finished_processing() + + def connectionLost(self, reason): + """Called when the client connection is closed before the response is written. + + Overrides twisted.web.server.Request.connectionLost to record the finish time and + do logging. + """ + self.finish_time = time.time() + Request.connectionLost(self, reason) + + # we only get here if the connection to the client drops before we send + # the response. + # + # It's useful to log it here so that we can get an idea of when + # the client disconnects. + with PreserveLoggingContext(self.logcontext): + logger.warn( + "Error processing request: %s %s", reason.type, reason.value, + ) + + if not self._is_processing: + self._finished_processing() def _started_processing(self, servlet_name): + """Record the fact that we are processing this request. + + This will log the request's arrival. Once the request completes, + be sure to call finished_processing. + + Args: + servlet_name (str): the name of the servlet which will be + processing this request. This is used in the metrics. + + It is possible to update this afterwards by updating + self.request_metrics.name. + """ self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( @@ -94,13 +216,21 @@ class SynapseRequest(Request): ) def _finished_processing(self): - try: - context = LoggingContext.current_context() - usage = context.get_resource_usage() - except Exception: - usage = ContextResourceUsage() + """Log the completion of this request and update the metrics + """ + + usage = self.logcontext.get_resource_usage() + + if self._processing_finished_time is None: + # we completed the request without anything calling processing() + self._processing_finished_time = time.time() - end_time = time.time() + # the time between receiving the request and the request handler finishing + processing_time = self._processing_finished_time - self.start_time + + # the time between the request handler finishing and the response being sent + # to the client (nb may be negative) + response_send_time = self.finish_time - self._processing_finished_time # need to decode as it could be raw utf-8 bytes # from a IDN servname in an auth header @@ -116,22 +246,31 @@ class SynapseRequest(Request): user_agent = self.get_user_agent() if user_agent is not None: user_agent = user_agent.decode("utf-8", "replace") + else: + user_agent = "-" + + code = str(self.code) + if not self.finished: + # we didn't send the full response before we gave up (presumably because + # the connection dropped) + code += "!" self.site.access_logger.info( "%s - %s - {%s}" - " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" + " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", self.getClientIP(), self.site.site_tag, authenticated_entity, - end_time - self.start_time, + processing_time, + response_send_time, usage.ru_utime, usage.ru_stime, usage.db_sched_duration_sec, usage.db_txn_duration_sec, int(usage.db_txn_count), self.sentLength, - self.code, + code, self.method, self.get_redacted_uri(), self.clientproto, @@ -140,38 +279,10 @@ class SynapseRequest(Request): ) try: - self.request_metrics.stop(end_time, self) + self.request_metrics.stop(self.finish_time, self) except Exception as e: logger.warn("Failed to stop metrics: %r", e) - @contextlib.contextmanager - def processing(self, servlet_name): - """Record the fact that we are processing this request. - - Returns a context manager; the correct way to use this is: - - @defer.inlineCallbacks - def handle_request(request): - with request.processing("FooServlet"): - yield really_handle_the_request() - - This will log the request's arrival. Once the context manager is - closed, the completion of the request will be logged, and the various - metrics will be updated. - - Args: - servlet_name (str): the name of the servlet which will be - processing this request. This is used in the metrics. - - It is possible to update this afterwards by updating - self.request_metrics.servlet_name. - """ - # TODO: we should probably just move this into render() and finish(), - # to save having to call a separate method. - self._started_processing(servlet_name) - yield - self._finished_processing() - class XForwardedForRequest(SynapseRequest): def __init__(self, *args, **kw): diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 36bb5bbc65..9f7d5ef217 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push.pusher import PusherFactory from synapse.util.logcontext import make_deferred_yieldable, run_in_background @@ -122,8 +123,14 @@ class PusherPool: p['app_id'], p['pushkey'], p['user_name'], ) - @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): + run_as_background_process( + "on_new_notifications", + self._on_new_notifications, min_stream_id, max_stream_id, + ) + + @defer.inlineCallbacks + def _on_new_notifications(self, min_stream_id, max_stream_id): try: users_affected = yield self.store.get_push_action_users_in_range( min_stream_id, max_stream_id @@ -147,8 +154,14 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + run_as_background_process( + "on_new_receipts", + self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids, + ) + + @defer.inlineCallbacks + def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 970e94313e..cbe9645817 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -107,7 +107,7 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ logger.info("Received rdata %s -> %s", stream_name, token) - self.store.process_replication_rows(stream_name, token, rows) + return self.store.process_replication_rows(stream_name, token, rows) def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes @@ -115,7 +115,7 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ - self.store.process_replication_rows(stream_name, token, []) + return self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f3908df642..327556f6a1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -59,6 +59,12 @@ class Command(object): """ return self.data + def get_logcontext_id(self): + """Get a suitable string for the logcontext when processing this command""" + + # by default, we just use the command name. + return self.NAME + class ServerCommand(Command): """Sent by the server on new connection and includes the server_name. @@ -116,6 +122,9 @@ class RdataCommand(Command): _json_encoder.encode(self.row), )) + def get_logcontext_id(self): + return "RDATA-" + self.stream_name + class PositionCommand(Command): """Sent by the client to tell the client the stream postition without @@ -190,6 +199,9 @@ class ReplicateCommand(Command): def to_line(self): return " ".join((self.stream_name, str(self.token),)) + def get_logcontext_id(self): + return "REPLICATE-" + self.stream_name + class UserSyncCommand(Command): """Sent by the client to inform the server that a user has started or diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index dec5ac0913..74e892c104 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string from .commands import ( @@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Now lets try and call on_<CMD_NAME> function try: - getattr(self, "on_%s" % (cmd_name,))(cmd) + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), + getattr(self, "on_%s" % (cmd_name,)), + cmd, + ) except Exception: logger.exception("[%s] Failed to handle line: %r", self.id(), line) @@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.name = cmd.data def on_USER_SYNC(self, cmd): - self.streamer.on_user_sync( + return self.streamer.on_user_sync( self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, ) @@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. - for stream in iterkeys(self.streamer.streams_by_name): - self.subscribe_to_stream(stream, token) + deferreds = [ + run_in_background( + self.subscribe_to_stream, + stream, token, + ) + for stream in iterkeys(self.streamer.streams_by_name) + ] + + return make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) else: - self.subscribe_to_stream(stream_name, token) + return self.subscribe_to_stream(stream_name, token) def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) + return self.streamer.federation_ack(cmd.token) def on_REMOVE_PUSHER(self, cmd): - self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + return self.streamer.on_remove_pusher( + cmd.app_id, cmd.push_key, cmd.user_id, + ) def on_INVALIDATE_CACHE(self, cmd): - self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) def on_USER_IP(self, cmd): - self.streamer.on_user_ip( + return self.streamer.on_user_ip( cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, cmd.last_seen, ) @@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - - self.handler.on_rdata(stream_name, cmd.token, rows) + return self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): - self.handler.on_position(cmd.stream_name, cmd.token) + return self.handler.on_position(cmd.stream_name, cmd.token) def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) + return self.handler.on_sync(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a14f0c807e..b5a6d6aebf 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -84,7 +84,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except Exception: raise SynapseError(400, "Unable to parse state") - yield self.presence_handler.set_state(user, state) + if self.hs.config.use_presence: + yield self.presence_handler.set_state(user, state) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py index 3439c3c6d4..5e99cffbcb 100644 --- a/synapse/rest/client/v1_only/register.py +++ b/synapse/rest/client/v1_only/register.py @@ -129,12 +129,9 @@ class RegisterRestServlet(ClientV1RestServlet): login_type = register_json["type"] is_application_server = login_type == LoginType.APPLICATION_SERVICE - is_using_shared_secret = login_type == LoginType.SHARED_SECRET - can_register = ( self.enable_registration or is_application_server - or is_using_shared_secret ) if not can_register: raise SynapseError(403, "Registration has been disabled") @@ -144,7 +141,6 @@ class RegisterRestServlet(ClientV1RestServlet): LoginType.PASSWORD: self._do_password, LoginType.EMAIL_IDENTITY: self._do_email_identity, LoginType.APPLICATION_SERVICE: self._do_app_service, - LoginType.SHARED_SECRET: self._do_shared_secret, } session_info = self._get_session_info(request, session) @@ -325,56 +321,6 @@ class RegisterRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, }) - @defer.inlineCallbacks - def _do_shared_secret(self, request, register_json, session): - assert_params_in_dict(register_json, ["mac", "user", "password"]) - - if not self.hs.config.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") - - user = register_json["user"].encode("utf-8") - password = register_json["password"].encode("utf-8") - admin = register_json.get("admin", None) - - # Its important to check as we use null bytes as HMAC field separators - if b"\x00" in user: - raise SynapseError(400, "Invalid user") - if b"\x00" in password: - raise SynapseError(400, "Invalid password") - - # str() because otherwise hmac complains that 'unicode' does not - # have the buffer interface - got_mac = str(register_json["mac"]) - - want_mac = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), - digestmod=sha1, - ) - want_mac.update(user) - want_mac.update(b"\x00") - want_mac.update(password) - want_mac.update(b"\x00") - want_mac.update(b"admin" if admin else b"notadmin") - want_mac = want_mac.hexdigest() - - if compare_digest(want_mac, got_mac): - handler = self.handlers.registration_handler - user_id, token = yield handler.register( - localpart=user.lower(), - password=password, - admin=bool(admin), - ) - self._remove_session(session) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - }) - else: - raise SynapseError( - 403, "HMAC incorrect", - ) - class CreateUserRestServlet(ClientV1RestServlet): """Handles user creation via a server-to-server interface diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 7e417f811e..06f9a75a97 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -96,7 +96,10 @@ class MonthlyActiveUsersStore(SQLBaseStore): # While Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support - query_args = [self.hs.config.max_mau_value] + safe_guard = self.hs.config.max_mau_value - len(self.reserved_users) + # Must be greater than zero for postgres + safe_guard = safe_guard if safe_guard > 0 else 0 + query_args = [safe_guard] base_sql = """ DELETE FROM monthly_active_users diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 32a2b5fc3d..022d81ce3e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -21,7 +21,7 @@ from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, ResourceLimitError from synapse.types import UserID from tests import unittest @@ -455,7 +455,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(lots_of_users) ) - with self.assertRaises(AuthError) as e: + with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) @@ -471,7 +471,7 @@ class AuthTestCase(unittest.TestCase): def test_hs_disabled(self): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" - with self.assertRaises(AuthError) as e: + with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) diff --git a/tests/app/__init__.py b/tests/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/app/__init__.py diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py new file mode 100644 index 0000000000..76b5090fff --- /dev/null +++ b/tests/app/test_frontend_proxy.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.app.frontend_proxy import FrontendProxyServer + +from tests.unittest import HomeserverTestCase + + +class FrontendProxyTests(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver( + http_client=None, homeserverToUse=FrontendProxyServer + ) + + return hs + + def test_listen_http_with_presence_enabled(self): + """ + When presence is on, the stub servlet will not register. + """ + # Presence is on + self.hs.config.use_presence = True + + config = { + "port": 8080, + "bind_addresses": ["0.0.0.0"], + "resources": [{"names": ["client"]}], + } + + # Listen with the config + self.hs._listen_http(config) + + # Grab the resource from the site that was told to listen + self.assertEqual(len(self.reactor.tcpServers), 1) + site = self.reactor.tcpServers[0][1] + self.resource = ( + site.resource.children["_matrix"].children["client"].children["r0"] + ) + + request, channel = self.make_request("PUT", "presence/a/status") + self.render(request) + + # 400 + unrecognised, because nothing is registered + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + def test_listen_http_with_presence_disabled(self): + """ + When presence is on, the stub servlet will register. + """ + # Presence is off + self.hs.config.use_presence = False + + config = { + "port": 8080, + "bind_addresses": ["0.0.0.0"], + "resources": [{"names": ["client"]}], + } + + # Listen with the config + self.hs._listen_http(config) + + # Grab the resource from the site that was told to listen + self.assertEqual(len(self.reactor.tcpServers), 1) + site = self.reactor.tcpServers[0][1] + self.resource = ( + site.resource.children["_matrix"].children["client"].children["r0"] + ) + + request, channel = self.make_request("PUT", "presence/a/status") + self.render(request) + + # 401, because the stub servlet still checks authentication + self.assertEqual(channel.code, 401) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 3046bd6093..1e39fe0ec2 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer import synapse import synapse.api.errors -from synapse.api.errors import AuthError +from synapse.api.errors import ResourceLimitError from synapse.handlers.auth import AuthHandler from tests import unittest @@ -130,13 +130,13 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.large_number_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.auth_handler.get_access_token_for_user_id('user_a') self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) @@ -149,13 +149,13 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.auth_handler.get_access_token_for_user_id('user_a') self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 7154816a34..7b4ade3dfb 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import AuthError +from synapse.api.errors import ResourceLimitError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,13 +109,13 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.get_or_create_user("requester", 'b', "display_name") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -124,13 +124,13 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.register(localpart="local_part") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -139,11 +139,11 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.register_saml2(localpart="local_part") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 33d861bd64..a01ab471f5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -14,7 +14,7 @@ # limitations under the License. from twisted.internet import defer -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.types import UserID @@ -49,7 +49,7 @@ class SyncTestCase(tests.unittest.TestCase): # Test that global lock works self.hs.config.hs_disabled = True - with self.assertRaises(AuthError) as e: + with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) @@ -57,7 +57,7 @@ class SyncTestCase(tests.unittest.TestCase): sync_config = self._generate_sync_config(user_id2) - with self.assertRaises(AuthError) as e: + with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py new file mode 100644 index 0000000000..66c2b68707 --- /dev/null +++ b/tests/rest/client/v1/test_presence.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from synapse.rest.client.v1 import presence +from synapse.types import UserID + +from tests import unittest + + +class PresenceTestCase(unittest.HomeserverTestCase): + """ Tests presence REST API. """ + + user_id = "@sid:red" + + user = UserID.from_string(user_id) + servlets = [presence.register_servlets] + + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver( + "red", http_client=None, federation_client=Mock() + ) + + hs.presence_handler = Mock() + + return hs + + def test_put_presence(self): + """ + PUT to the status endpoint with use_presence enabled will call + set_state on the presence handler. + """ + self.hs.config.use_presence = True + + body = {"presence": "here", "status_msg": "beep boop"} + request, channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + self.render(request) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.presence_handler.set_state.call_count, 1) + + def test_put_presence_disabled(self): + """ + PUT to the status endpoint with use_presence disbled will NOT call + set_state on the presence handler. + """ + self.hs.config.use_presence = False + + body = {"presence": "here", "status_msg": "beep boop"} + request, channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + self.render(request) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.presence_handler.set_state.call_count, 0) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 4be88b8a39..6b7ff813d5 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -25,7 +25,7 @@ from synapse.rest.client.v1_only.register import register_servlets from synapse.util import Clock from tests import unittest -from tests.server import make_request, setup_test_homeserver +from tests.server import make_request, render, setup_test_homeserver class CreateUserServletTestCase(unittest.TestCase): @@ -77,10 +77,7 @@ class CreateUserServletTestCase(unittest.TestCase): ) request, channel = make_request(b"POST", url, request_data) - request.render(res) - - # Advance the clock because it waits - self.clock.advance(1) + render(request, res, self.clock) self.assertEquals(channel.result["code"], b"200") diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 9f862f9dfa..40dc4ea256 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -23,7 +23,7 @@ from twisted.internet import defer from synapse.api.constants import Membership from tests import unittest -from tests.server import make_request, wait_until_result +from tests.server import make_request, render class RestTestCase(unittest.TestCase): @@ -171,8 +171,7 @@ class RestHelper(object): request, channel = make_request( "POST", path, json.dumps(content).encode('utf8') ) - request.render(self.resource) - wait_until_result(self.hs.get_reactor(), channel) + render(request, self.resource, self.hs.get_reactor()) assert channel.result["code"] == b"200", channel.result self.auth_user_id = temp_id @@ -220,8 +219,7 @@ class RestHelper(object): request, channel = make_request("PUT", path, json.dumps(data).encode('utf8')) - request.render(self.resource) - wait_until_result(self.hs.get_reactor(), channel) + render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 8260c130f8..6a886ee3b8 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -24,8 +24,8 @@ from tests import unittest from tests.server import ( ThreadedMemoryReactorClock as MemoryReactorClock, make_request, + render, setup_test_homeserver, - wait_until_result, ) PATH_PREFIX = "/_matrix/client/v2_alpha" @@ -76,8 +76,7 @@ class FilterTestCase(unittest.TestCase): "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON, ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, {"filter_id": "0"}) @@ -91,8 +90,7 @@ class FilterTestCase(unittest.TestCase): "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"403") self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) @@ -105,8 +103,7 @@ class FilterTestCase(unittest.TestCase): "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON, ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") @@ -121,8 +118,7 @@ class FilterTestCase(unittest.TestCase): request, channel = make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"200") self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) @@ -131,8 +127,7 @@ class FilterTestCase(unittest.TestCase): request, channel = make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"400") self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -143,8 +138,7 @@ class FilterTestCase(unittest.TestCase): request, channel = make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"400") @@ -153,7 +147,6 @@ class FilterTestCase(unittest.TestCase): request, channel = make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b72bd0fb7f..1c128e81f5 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -11,7 +11,7 @@ from synapse.rest.client.v2_alpha.register import register_servlets from synapse.util import Clock from tests import unittest -from tests.server import make_request, setup_test_homeserver, wait_until_result +from tests.server import make_request, render, setup_test_homeserver class RegisterRestServletTestCase(unittest.TestCase): @@ -72,8 +72,7 @@ class RegisterRestServletTestCase(unittest.TestCase): request, channel = make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEquals(channel.result["code"], b"200", channel.result) det_data = { @@ -89,16 +88,14 @@ class RegisterRestServletTestCase(unittest.TestCase): request, channel = make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEquals(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): request_data = json.dumps({"username": "kermit", "password": 666}) request, channel = make_request(b"POST", self.url, request_data) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid password") @@ -106,8 +103,7 @@ class RegisterRestServletTestCase(unittest.TestCase): def test_POST_bad_username(self): request_data = json.dumps({"username": 777, "password": "monkey"}) request, channel = make_request(b"POST", self.url, request_data) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid username") @@ -126,8 +122,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.device_handler.check_device_registered = Mock(return_value=device_id) request, channel = make_request(b"POST", self.url, request_data) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) det_data = { "user_id": user_id, @@ -149,8 +144,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.register = Mock(return_value=("@user:id", "t")) request, channel = make_request(b"POST", self.url, request_data) - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Registration has been disabled") @@ -162,8 +156,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.register = Mock(return_value=(user_id, None)) request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) det_data = { "user_id": user_id, @@ -177,8 +170,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.config.allow_guest_access = False request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") - request.render(self.resource) - wait_until_result(self.clock, channel) + render(request, self.resource, self.clock) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 2e1d06c509..560b1fba96 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -13,72 +13,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -import synapse.types -from synapse.http.server import JsonResource +from mock import Mock + from synapse.rest.client.v2_alpha import sync -from synapse.types import UserID -from synapse.util import Clock from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock as MemoryReactorClock, - make_request, - setup_test_homeserver, - wait_until_result, -) - -PATH_PREFIX = "/_matrix/client/v2_alpha" -class FilterTestCase(unittest.TestCase): +class FilterTestCase(unittest.HomeserverTestCase): - USER_ID = "@apple:test" - TO_REGISTER = [sync] + user_id = "@apple:test" + servlets = [sync.register_servlets] - def setUp(self): - self.clock = MemoryReactorClock() - self.hs_clock = Clock(self.clock) + def make_homeserver(self, reactor, clock): - self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock + hs = self.setup_test_homeserver( + "red", http_client=None, federation_client=Mock() ) + return hs - self.auth = self.hs.get_auth() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.USER_ID), - "token_id": 1, - "is_guest": False, - } - - def get_user_by_req(request, allow_guest=False, rights="access"): - return synapse.types.create_requester( - UserID.from_string(self.USER_ID), 1, False, None - ) - - self.auth.get_user_by_access_token = get_user_by_access_token - self.auth.get_user_by_req = get_user_by_req + def test_sync_argless(self): + request, channel = self.make_request("GET", "/sync") + self.render(request) - self.store = self.hs.get_datastore() - self.filtering = self.hs.get_filtering() - self.resource = JsonResource(self.hs) + self.assertEqual(channel.code, 200) + self.assertTrue( + set( + [ + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + ] + ).issubset(set(channel.json_body.keys())) + ) - for r in self.TO_REGISTER: - r.register_servlets(self.hs, self.resource) + def test_sync_presence_disabled(self): + """ + When presence is disabled, the key does not appear in /sync. + """ + self.hs.config.use_presence = False - def test_sync_argless(self): - request, channel = make_request("GET", "/_matrix/client/r0/sync") - request.render(self.resource) - wait_until_result(self.clock, channel) + request, channel = self.make_request("GET", "/sync") + self.render(request) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertTrue( set( [ "next_batch", "rooms", - "presence", "account_data", "to_device", "device_lists", diff --git a/tests/server.py b/tests/server.py index beb24cf032..c63b2c3100 100644 --- a/tests/server.py +++ b/tests/server.py @@ -24,6 +24,7 @@ class FakeChannel(object): """ result = attr.ib(default=attr.Factory(dict)) + _producer = None @property def json_body(self): @@ -49,6 +50,15 @@ class FakeChannel(object): self.result["body"] += content + def registerProducer(self, producer, streaming): + self._producer = producer + + def unregisterProducer(self): + if self._producer is None: + return + + self._producer = None + def requestDone(self, _self): self.result["done"] = True @@ -111,14 +121,19 @@ def make_request(method, path, content=b""): return req, channel -def wait_until_result(clock, channel, timeout=100): +def wait_until_result(clock, request, timeout=100): """ - Wait until the channel has a result. + Wait until the request is finished. """ clock.run() x = 0 - while not channel.result: + while not request.finished: + + # If there's a producer, tell it to resume producing so we get content + if request._channel._producer: + request._channel._producer.resumeProducing() + x += 1 if x > timeout: @@ -129,7 +144,7 @@ def wait_until_result(clock, channel, timeout=100): def render(request, resource, clock): request.render(resource) - wait_until_result(clock, request._channel) + wait_until_result(clock, request) class ThreadedMemoryReactorClock(MemoryReactorClock): diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 511acbde9b..f2ed866ae7 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -75,6 +75,19 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): active_count = yield self.store.get_monthly_active_count() self.assertEquals(active_count, user_num) + # Test that regalar users are removed from the db + ru_count = 2 + yield self.store.upsert_monthly_active_user("@ru1:server") + yield self.store.upsert_monthly_active_user("@ru2:server") + active_count = yield self.store.get_monthly_active_count() + + self.assertEqual(active_count, user_num + ru_count) + self.hs.config.max_mau_value = user_num + yield self.store.reap_monthly_active_users() + + active_count = yield self.store.get_monthly_active_count() + self.assertEquals(active_count, user_num) + @defer.inlineCallbacks def test_can_insert_and_count_mau(self): count = yield self.store.get_monthly_active_count() diff --git a/tests/test_server.py b/tests/test_server.py index 895e490406..ef74544e93 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,7 +8,7 @@ from synapse.http.server import JsonResource from synapse.util import Clock from tests import unittest -from tests.server import make_request, setup_test_homeserver +from tests.server import make_request, render, setup_test_homeserver class JsonResourceTests(unittest.TestCase): @@ -37,7 +37,7 @@ class JsonResourceTests(unittest.TestCase): ) request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") - request.render(res) + render(request, res, self.reactor) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) @@ -55,7 +55,7 @@ class JsonResourceTests(unittest.TestCase): res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(b"GET", b"/_matrix/foo") - request.render(res) + render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'500') @@ -78,13 +78,8 @@ class JsonResourceTests(unittest.TestCase): res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(b"GET", b"/_matrix/foo") - request.render(res) + render(request, res, self.reactor) - # No error has been raised yet - self.assertTrue("code" not in channel.result) - - # Advance time, now there's an error - self.reactor.advance(1) self.assertEqual(channel.result["code"], b'500') def test_callback_synapseerror(self): @@ -100,7 +95,7 @@ class JsonResourceTests(unittest.TestCase): res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(b"GET", b"/_matrix/foo") - request.render(res) + render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'403') self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -121,7 +116,7 @@ class JsonResourceTests(unittest.TestCase): res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(b"GET", b"/_matrix/foobar") - request.render(res) + render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.json_body["error"], "Unrecognized request") diff --git a/tests/unittest.py b/tests/unittest.py index e6afe3b96d..d852e2465a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -18,6 +18,8 @@ import logging from mock import Mock +from canonicaljson import json + import twisted import twisted.logger from twisted.trial import unittest @@ -241,11 +243,15 @@ class HomeserverTestCase(TestCase): method (bytes/unicode): The HTTP request method ("verb"). path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). - content (bytes): The body of the request. + content (bytes or dict): The body of the request. JSON-encoded, if + a dict. Returns: A synapse.http.site.SynapseRequest. """ + if isinstance(content, dict): + content = json.dumps(content).encode('utf8') + return make_request(method, path, content) def render(self, request): diff --git a/tests/utils.py b/tests/utils.py index 52326d4f67..bb0fc74054 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -93,7 +93,8 @@ def setupdb(): @defer.inlineCallbacks def setup_test_homeserver( - cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs + cleanup_func, name="test", datastore=None, config=None, reactor=None, + homeserverToUse=HomeServer, **kargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -137,6 +138,7 @@ def setup_test_homeserver( config.limit_usage_by_mau = False config.hs_disabled = False config.hs_disabled_message = "" + config.hs_disabled_limit_type = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] config.admin_uri = None @@ -191,7 +193,7 @@ def setup_test_homeserver( config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection if datastore is None: - hs = HomeServer( + hs = homeserverToUse( name, config=config, db_config=config.database_config, @@ -234,7 +236,7 @@ def setup_test_homeserver( hs.setup() else: - hs = HomeServer( + hs = homeserverToUse( name, db_pool=None, datastore=datastore, |