diff options
28 files changed, 599 insertions, 331 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 23be6c8efa..12abd6cfc4 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,86 @@ +Changes in synapse v0.18.1 (2016-10-0) +====================================== + +No changes since v0.18.1-rc1 + + +Changes in synapse v0.18.1-rc1 (2016-09-30) +=========================================== + +Features: + +* Add total_room_count_estimate to ``/publicRooms`` (PR #1133) + + +Changes: + +* Time out typing over federation (PR #1140) +* Restructure LDAP authentication (PR #1153) + + +Bug fixes: + +* Fix 3pid invites when server is already in the room (PR #1136) +* Fix upgrading with SQLite taking lots of CPU for a few days + after upgrade (PR #1144) +* Fix upgrading from very old database versions (PR #1145) +* Fix port script to work with recently added tables (PR #1146) + + +Changes in synapse v0.18.0 (2016-09-19) +======================================= + +The release includes major changes to the state storage database schemas, which +significantly reduce database size. Synapse will attempt to upgrade the current +data in the background. Servers with large SQLite database may experience +degradation of performance while this upgrade is in progress, therefore you may +want to consider migrating to using Postgres before upgrading very large SQLite +databases + + +Changes: + +* Make public room search case insensitive (PR #1127) + + +Bug fixes: + +* Fix and clean up publicRooms pagination (PR #1129) + + +Changes in synapse v0.18.0-rc1 (2016-09-16) +=========================================== + +Features: + +* Add ``only=highlight`` on ``/notifications`` (PR #1081) +* Add server param to /publicRooms (PR #1082) +* Allow clients to ask for the whole of a single state event (PR #1094) +* Add is_direct param to /createRoom (PR #1108) +* Add pagination support to publicRooms (PR #1121) +* Add very basic filter API to /publicRooms (PR #1126) +* Add basic direct to device messaging support for E2E (PR #1074, #1084, #1104, + #1111) + + +Changes: + +* Move to storing state_groups_state as deltas, greatly reducing DB size (PR + #1065) +* Reduce amount of state pulled out of the DB during common requests (PR #1069) +* Allow PDF to be rendered from media repo (PR #1071) +* Reindex state_groups_state after pruning (PR #1085) +* Clobber EDUs in send queue (PR #1095) +* Conform better to the CAS protocol specification (PR #1100) +* Limit how often we ask for keys from dead servers (PR #1114) + + +Bug fixes: + +* Fix /notifications API when used with ``from`` param (PR #1080) +* Fix backfill when cannot find an event. (PR #1107) + + Changes in synapse v0.17.3 (2016-09-09) ======================================= diff --git a/res/templates/notif_mail.html b/res/templates/notif_mail.html index 535bea764d..fcdb3109fe 100644 --- a/res/templates/notif_mail.html +++ b/res/templates/notif_mail.html @@ -18,7 +18,9 @@ <div class="summarytext">{{ summary_text }}</div> </td> <td class="logo"> - {% if app_name == "Vector" %} + {% if app_name == "Riot" %} + <img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> + {% elif app_name == "Vector" %} <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> {% else %} <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 66c61b0198..2cb2eab68b 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -39,6 +39,7 @@ BOOLEAN_COLUMNS = { "event_edges": ["is_state"], "presence_list": ["accepted"], "presence_stream": ["currently_active"], + "public_room_list_stream": ["visibility"], } @@ -71,6 +72,14 @@ APPEND_ONLY_TABLES = [ "event_to_state_groups", "rejections", "event_search", + "presence_stream", + "push_rules_stream", + "current_state_resets", + "ex_outlier_stream", + "cache_invalidation_stream", + "public_room_list_stream", + "state_group_edges", + "stream_ordering_to_exterm", ] diff --git a/synapse/__init__.py b/synapse/__init__.py index b778cd65c9..6dbe8fc7e7 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.17.3" +__version__ = "0.18.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e75fd518be..1b3b55d517 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -653,7 +653,7 @@ class Auth(object): @defer.inlineCallbacks def _get_appservice_user_id(self, request): - app_service = yield self.store.get_app_service_by_token( + app_service = self.store.get_app_service_by_token( get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) @@ -855,13 +855,12 @@ class Auth(object): } defer.returnValue(user_info) - @defer.inlineCallbacks def get_appservice_by_req(self, request): try: token = get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) - service = yield self.store.get_app_service_by_token(token) + service = self.store.get_app_service_by_token(token) if not service: logger.warn("Unrecognised appservice access token: %s" % (token,)) raise AuthError( @@ -870,7 +869,7 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) request.authenticated_entity = service.sender - defer.returnValue(service) + return defer.succeed(service) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." @@ -1002,16 +1001,6 @@ class Auth(object): 403, "You are not allowed to set others state" ) - else: - sender_domain = UserID.from_string( - event.user_id - ).domain - - if sender_domain != event.state_key: - raise AuthError( - 403, - "You are not allowed to set others state" - ) return True diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 06d0320b1a..94e76b1978 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -136,9 +136,7 @@ class FederationClient(FederationBase): sent_edus_counter.inc() - # TODO, add errback, etc. self._transaction_queue.enqueue_edu(edu, key=key) - return defer.succeed(None) @log_function def send_device_messages(self, destination): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index e58735294e..4981643166 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -55,8 +55,14 @@ class BaseHandler(object): def ratelimit(self, requester): time_now = self.clock.time() + user_id = requester.user.to_string() + + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service is not None: + return # do not ratelimit app service senders + allowed, time_allowed = self.ratelimiter.send_message( - requester.user.to_string(), time_now, + user_id, time_now, msg_rate_hz=self.hs.config.rc_messages_per_second, burst_count=self.hs.config.rc_message_burst_count, ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 88fa0bb2e4..05af54d31b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -59,7 +59,7 @@ class ApplicationServicesHandler(object): Args: current_id(int): The current maximum ID. """ - services = yield self.store.get_app_services() + services = self.store.get_app_services() if not services or not self.notify_appservices: return @@ -142,7 +142,7 @@ class ApplicationServicesHandler(object): association can be found. """ room_alias_str = room_alias.to_string() - services = yield self.store.get_app_services() + services = self.store.get_app_services() alias_query_services = [ s for s in services if ( s.is_interested_in_alias(room_alias_str) @@ -177,7 +177,7 @@ class ApplicationServicesHandler(object): @defer.inlineCallbacks def get_3pe_protocols(self, only_protocol=None): - services = yield self.store.get_app_services() + services = self.store.get_app_services() protocols = {} # Collect up all the individual protocol responses out of the ASes @@ -224,7 +224,7 @@ class ApplicationServicesHandler(object): list<ApplicationService>: A list of services interested in this event based on the service regex. """ - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( yield s.is_interested(event, self.store) @@ -232,23 +232,21 @@ class ApplicationServicesHandler(object): ] defer.returnValue(interested_list) - @defer.inlineCallbacks def _get_services_for_user(self, user_id): - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( s.is_interested_in_user(user_id) ) ] - defer.returnValue(interested_list) + return defer.succeed(interested_list) - @defer.inlineCallbacks def _get_services_for_3pn(self, protocol): - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if s.is_interested_in_protocol(protocol) ] - defer.returnValue(interested_list) + return defer.succeed(interested_list) @defer.inlineCallbacks def _is_unknown_user(self, user_id): @@ -264,7 +262,7 @@ class ApplicationServicesHandler(object): return # user not found; could be the AS though, so check. - services = yield self.store.get_app_services() + services = self.store.get_app_services() service_list = [s for s in services if s.sender == user_id] defer.returnValue(len(service_list) == 0) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6986930c0d..6b8de1e7cf 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -31,6 +31,7 @@ import simplejson try: import ldap3 + import ldap3.core.exceptions except ImportError: ldap3 = None pass @@ -58,7 +59,6 @@ class AuthHandler(BaseHandler): } self.bcrypt_rounds = hs.config.bcrypt_rounds self.sessions = {} - self.INVALID_TOKEN_HTTP_STATUS = 401 self.ldap_enabled = hs.config.ldap_enabled if self.ldap_enabled: @@ -148,13 +148,30 @@ class AuthHandler(BaseHandler): creds = session['creds'] # check auth type currently being presented + errordict = {} if 'type' in authdict: - if authdict['type'] not in self.checkers: + login_type = authdict['type'] + if login_type not in self.checkers: raise LoginError(400, "", Codes.UNRECOGNIZED) - result = yield self.checkers[authdict['type']](authdict, clientip) - if result: - creds[authdict['type']] = result - self._save_session(session) + try: + result = yield self.checkers[login_type](authdict, clientip) + if result: + creds[login_type] = result + self._save_session(session) + except LoginError, e: + if login_type == LoginType.EMAIL_IDENTITY: + # riot used to have a bug where it would request a new + # validation token (thus sending a new email) each time it + # got a 401 with a 'flows' field. + # (https://github.com/vector-im/vector-web/issues/2447). + # + # Grandfather in the old behaviour for now to avoid + # breaking old riot deployments. + raise e + + # this step failed. Merge the error dict into the response + # so that the client can have another go. + errordict = e.error_dict() for f in flows: if len(set(f) - set(creds.keys())) == 0: @@ -163,6 +180,7 @@ class AuthHandler(BaseHandler): ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() + ret.update(errordict) defer.returnValue((False, ret, clientdict, session['id'])) @defer.inlineCallbacks @@ -430,37 +448,40 @@ class AuthHandler(BaseHandler): defer.Deferred: (str) canonical_user_id, or None if zero or multiple matches """ - try: - res = yield self._find_user_id_and_pwd_hash(user_id) + res = yield self._find_user_id_and_pwd_hash(user_id) + if res is not None: defer.returnValue(res[0]) - except LoginError: - defer.returnValue(None) + defer.returnValue(None) @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): """Checks to see if a user with the given id exists. Will check case - insensitively, but will throw if there are multiple inexact matches. + insensitively, but will return None if there are multiple inexact + matches. Returns: tuple: A 2-tuple of `(canonical_user_id, password_hash)` + None: if there is not exactly one match """ user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + + result = None if not user_infos: logger.warn("Attempted to login as %s but they do not exist", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - if len(user_infos) > 1: - if user_id not in user_infos: - logger.warn( - "Attempted to login as %s but it matches more than one user " - "inexactly: %r", - user_id, user_infos.keys() - ) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - defer.returnValue((user_id, user_infos[user_id])) + elif len(user_infos) == 1: + # a single match (possibly not exact) + result = user_infos.popitem() + elif user_id in user_infos: + # multiple matches, but one is exact + result = (user_id, user_infos[user_id]) else: - defer.returnValue(user_infos.popitem()) + # multiple matches, none of them exact + logger.warn( + "Attempted to login as %s but it matches more than one user " + "inexactly: %r", + user_id, user_infos.keys() + ) + defer.returnValue(result) @defer.inlineCallbacks def _check_password(self, user_id, password): @@ -474,36 +495,185 @@ class AuthHandler(BaseHandler): Returns: (str) the canonical_user_id Raises: - LoginError if the password was incorrect + LoginError if login fails """ valid_ldap = yield self._check_ldap_password(user_id, password) if valid_ldap: defer.returnValue(user_id) - result = yield self._check_local_password(user_id, password) - defer.returnValue(result) + canonical_user_id = yield self._check_local_password(user_id, password) + + if canonical_user_id: + defer.returnValue(canonical_user_id) + + # unknown username or invalid password. We raise a 403 here, but note + # that if we're doing user-interactive login, it turns all LoginErrors + # into a 401 anyway. + raise LoginError( + 403, "Invalid password", + errcode=Codes.FORBIDDEN + ) @defer.inlineCallbacks def _check_local_password(self, user_id, password): """Authenticate a user against the local password database. - user_id is checked case insensitively, but will throw if there are + user_id is checked case insensitively, but will return None if there are multiple inexact matches. Args: user_id (str): complete @user:id Returns: - (str) the canonical_user_id - Raises: - LoginError if the password was incorrect + (str) the canonical_user_id, or None if unknown user / bad password """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + lookupres = yield self._find_user_id_and_pwd_hash(user_id) + if not lookupres: + defer.returnValue(None) + (user_id, password_hash) = lookupres result = self.validate_hash(password, password_hash) if not result: logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(None) defer.returnValue(user_id) + def _ldap_simple_bind(self, server, localpart, password): + """ Attempt a simple bind with the credentials + given by the user against the LDAP server. + + Returns True, LDAP3Connection + if the bind was successful + Returns False, None + if an error occured + """ + + try: + # bind with the the local users ldap credentials + bind_dn = "{prop}={value},{base}".format( + prop=self.ldap_attributes['uid'], + value=localpart, + base=self.ldap_base + ) + conn = ldap3.Connection(server, bind_dn, password) + logger.debug( + "Established LDAP connection in simple bind mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded LDAP connection in simple bind mode through StartTLS: %s", + conn + ) + + if conn.bind(): + # GOOD: bind okay + logger.debug("LDAP Bind successful in simple bind mode.") + return True, conn + + # BAD: bind failed + logger.info( + "Binding against LDAP failed for '%s' failed: %s", + localpart, conn.result['description'] + ) + conn.unbind() + return False, None + + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during LDAP authentication: %s", e) + return False, None + + def _ldap_authenticated_search(self, server, localpart, password): + """ Attempt to login with the preconfigured bind_dn + and then continue searching and filtering within + the base_dn + + Returns (True, LDAP3Connection) + if a single matching DN within the base was found + that matched the filter expression, and with which + a successful bind was achieved + + The LDAP3Connection returned is the instance that was used to + verify the password not the one using the configured bind_dn. + Returns (False, None) + if an error occured + """ + + try: + conn = ldap3.Connection( + server, + self.ldap_bind_dn, + self.ldap_bind_password + ) + logger.debug( + "Established LDAP connection in search mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded LDAP connection in search mode through StartTLS: %s", + conn + ) + + if not conn.bind(): + logger.warn( + "Binding against LDAP with `bind_dn` failed: %s", + conn.result['description'] + ) + conn.unbind() + return False, None + + # construct search_filter like (uid=localpart) + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + if self.ldap_filter: + # combine with the AND expression + query = "(&{query}{filter})".format( + query=query, + filter=self.ldap_filter + ) + logger.debug( + "LDAP search filter: %s", + query + ) + conn.search( + search_base=self.ldap_base, + search_filter=query + ) + + if len(conn.response) == 1: + # GOOD: found exactly one result + user_dn = conn.response[0]['dn'] + logger.debug('LDAP search found dn: %s', user_dn) + + # unbind and simple bind with user_dn to verify the password + # Note: do not use rebind(), for some reason it did not verify + # the password for me! + conn.unbind() + return self._ldap_simple_bind(server, localpart, password) + else: + # BAD: found 0 or > 1 results, abort! + if len(conn.response) == 0: + logger.info( + "LDAP search returned no results for '%s'", + localpart + ) + else: + logger.info( + "LDAP search returned too many (%s) results for '%s'", + len(conn.response), localpart + ) + conn.unbind() + return False, None + + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during LDAP authentication: %s", e) + return False, None + @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): """ Attempt to authenticate a user against an LDAP Server @@ -516,106 +686,62 @@ class AuthHandler(BaseHandler): if not ldap3 or not self.ldap_enabled: defer.returnValue(False) - if self.ldap_mode not in LDAPMode.LIST: - raise RuntimeError( - 'Invalid ldap mode specified: {mode}'.format( - mode=self.ldap_mode - ) - ) + localpart = UserID.from_string(user_id).localpart try: server = ldap3.Server(self.ldap_uri) logger.debug( - "Attempting ldap connection with %s", + "Attempting LDAP connection with %s", self.ldap_uri ) - localpart = UserID.from_string(user_id).localpart if self.ldap_mode == LDAPMode.SIMPLE: - # bind with the the local users ldap credentials - bind_dn = "{prop}={value},{base}".format( - prop=self.ldap_attributes['uid'], - value=localpart, - base=self.ldap_base + result, conn = self._ldap_simple_bind( + server=server, localpart=localpart, password=password ) - conn = ldap3.Connection(server, bind_dn, password) logger.debug( - "Established ldap connection in simple mode: %s", + 'LDAP authentication method simple bind returned: %s (conn: %s)', + result, conn ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded ldap connection in simple mode through StartTLS: %s", - conn - ) - - conn.bind() - + if not result: + defer.returnValue(False) elif self.ldap_mode == LDAPMode.SEARCH: - # connect with preconfigured credentials and search for local user - conn = ldap3.Connection( - server, - self.ldap_bind_dn, - self.ldap_bind_password + result, conn = self._ldap_authenticated_search( + server=server, localpart=localpart, password=password ) logger.debug( - "Established ldap connection in search mode: %s", + 'LDAP auth method authenticated search returned: %s (conn: %s)', + result, conn ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded ldap connection in search mode through StartTLS: %s", - conn + if not result: + defer.returnValue(False) + else: + raise RuntimeError( + 'Invalid LDAP mode specified: {mode}'.format( + mode=self.ldap_mode ) - - conn.bind() - - # find matching dn - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart ) - if self.ldap_filter: - query = "(&{query}{filter})".format( - query=query, - filter=self.ldap_filter - ) - logger.debug("ldap search filter: %s", query) - result = conn.search(self.ldap_base, query) - - if result and len(conn.response) == 1: - # found exactly one result - user_dn = conn.response[0]['dn'] - logger.debug('ldap search found dn: %s', user_dn) - - # unbind and reconnect, rebind with found dn - conn.unbind() - conn = ldap3.Connection( - server, - user_dn, - password, - auto_bind=True - ) - else: - # found 0 or > 1 results, abort! - logger.warn( - "ldap search returned unexpected (%d!=1) amount of results", - len(conn.response) - ) - defer.returnValue(False) - logger.info( - "User authenticated against ldap server: %s", - conn - ) + try: + logger.info( + "User authenticated against LDAP server: %s", + conn + ) + except NameError: + logger.warn("Authentication method yielded no LDAP connection, aborting!") + defer.returnValue(False) + + # check if user with user_id exists + if (yield self.check_user_exists(user_id)): + # exists, authentication complete + conn.unbind() + defer.returnValue(True) - # check for existing account, if none exists, create one - if not (yield self.check_user_exists(user_id)): - # query user metadata for account creation + else: + # does not exist, fetch metadata for account creation from + # existing ldap connection query = "({prop}={value})".format( prop=self.ldap_attributes['uid'], value=localpart @@ -626,9 +752,12 @@ class AuthHandler(BaseHandler): filter=query, user_filter=self.ldap_filter ) - logger.debug("ldap registration filter: %s", query) + logger.debug( + "ldap registration filter: %s", + query + ) - result = conn.search( + conn.search( search_base=self.ldap_base, search_filter=query, attributes=[ @@ -651,20 +780,27 @@ class AuthHandler(BaseHandler): # TODO: bind email, set displayname with data from ldap directory logger.info( - "ldap registration successful: %d: %s (%s, %)", + "Registration based on LDAP data was successful: %d: %s (%s, %)", user_id, localpart, name, mail ) + + defer.returnValue(True) else: - logger.warn( - "ldap registration failed: unexpected (%d!=1) amount of results", - len(conn.response) - ) + if len(conn.response) == 0: + logger.warn("LDAP registration failed, no result.") + else: + logger.warn( + "LDAP registration failed, too many results (%s)", + len(conn.response) + ) + defer.returnValue(False) - defer.returnValue(True) + defer.returnValue(False) + except ldap3.core.exceptions.LDAPException as e: logger.warn("Error during ldap authentication: %s", e) defer.returnValue(False) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 14352985e2..c00274afc3 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -288,13 +288,12 @@ class DirectoryHandler(BaseHandler): result = yield as_handler.query_room_alias_exists(room_alias) defer.returnValue(result) - @defer.inlineCallbacks def can_modify_alias(self, alias, user_id=None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have # non-exclusive locks on the alias (or there are no interested services) - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_alias(alias.to_string()) ] @@ -302,14 +301,12 @@ class DirectoryHandler(BaseHandler): for service in interested_services: if user_id == service.sender: # this user IS the app service so they can do whatever they like - defer.returnValue(True) - return + return defer.succeed(True) elif service.is_exclusive_alias(alias.to_string()): # another service has an exclusive lock on this alias. - defer.returnValue(False) - return + return defer.succeed(False) # either no interested services, or no service with an exclusive lock - defer.returnValue(True) + return defer.succeed(True) @defer.inlineCallbacks def _user_can_delete_alias(self, alias, user_id): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index d9ac09078d..87f74dfb8e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -65,13 +65,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["displayname"]) @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname): + def set_displayname(self, target_user, requester, new_displayname, by_admin=False): """target_user is the user whose displayname is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != requester.user: + if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") if new_displayname == '': @@ -111,13 +111,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["avatar_url"]) @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url): + def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != requester.user: + if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") yield self.store.set_profile_avatar_url( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index dd75c4fecf..7e119f13b1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -19,7 +19,6 @@ import urllib from twisted.internet import defer -import synapse.types from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) @@ -194,7 +193,7 @@ class RegistrationHandler(BaseHandler): def appservice_register(self, user_localpart, as_token): user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() - service = yield self.store.get_app_service_by_token(as_token) + service = self.store.get_app_service_by_token(as_token) if not service: raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): @@ -305,11 +304,10 @@ class RegistrationHandler(BaseHandler): # XXX: This should be a deferred list, shouldn't it? yield identity_handler.bind_threepid(c, user_id) - @defer.inlineCallbacks def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): # valid user IDs must not clash with any user ID namespaces claimed by # application services. - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_user(user_id) @@ -371,7 +369,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_in_ms, + def get_or_create_user(self, requester, localpart, displayname, duration_in_ms, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -418,9 +416,8 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) profile_handler = self.hs.get_handlers().profile_handler - requester = synapse.types.create_requester(user) yield profile_handler.set_displayname( - user, requester, displayname + user, requester, displayname, by_admin=True, ) defer.returnValue((user_id, token)) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index cbd26f8f95..a7f533f7be 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -437,7 +437,7 @@ class RoomEventSource(object): logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = yield self.store.get_app_service_by_user_id( + app_service = self.store.get_app_service_by_user_id( user.to_string() ) if app_service: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b5962f4f5a..1f910ff814 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -788,7 +788,7 @@ class SyncHandler(object): assert since_token - app_service = yield self.store.get_app_service_by_user_id(user_id) + app_service = self.store.get_app_service_by_user_id(user_id) if app_service: rooms = yield self.store.get_app_service_rooms(app_service) joined_room_ids = set(r.room_id for r in rooms) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0548b81c34..08313417b2 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -16,10 +16,9 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, -) +from synapse.util.logcontext import preserve_fn from synapse.util.metrics import Measure +from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID, get_domain_from_id import logging @@ -35,6 +34,13 @@ logger = logging.getLogger(__name__) RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 60 * 1000 + +# How often to resend typing across federation. +FEDERATION_PING_INTERVAL = 40 * 1000 + + class TypingHandler(object): def __init__(self, hs): self.store = hs.get_datastore() @@ -44,7 +50,10 @@ class TypingHandler(object): self.notifier = hs.get_notifier() self.state = hs.get_state_handler() + self.hs = hs + self.clock = hs.get_clock() + self.wheel_timer = WheelTimer(bucket_size=5000) self.federation = hs.get_replication_layer() @@ -53,7 +62,7 @@ class TypingHandler(object): hs.get_distributor().observe("user_left_room", self.user_left_room) self._member_typing_until = {} # clock time we expect to stop - self._member_typing_timer = {} # deferreds to manage theabove + self._member_last_federation_poke = {} # map room IDs to serial numbers self._room_serials = {} @@ -61,12 +70,41 @@ class TypingHandler(object): # map room IDs to sets of users currently typing self._room_typing = {} - def tearDown(self): - """Cancels all the pending timers. - Normally this shouldn't be needed, but it's required from unit tests - to avoid a "Reactor was unclean" warning.""" - for t in self._member_typing_timer.values(): - self.clock.cancel_call_later(t) + self.clock.looping_call( + self._handle_timeouts, + 5000, + ) + + def _handle_timeouts(self): + logger.info("Checking for typing timeouts") + + now = self.clock.time_msec() + + members = set(self.wheel_timer.fetch(now)) + + for member in members: + if not self.is_typing(member): + # Nothing to do if they're no longer typing + continue + + until = self._member_typing_until.get(member, None) + if not until or until < now: + logger.info("Timing out typing for: %s", member.user_id) + preserve_fn(self._stopped_typing)(member) + continue + + # Check if we need to resend a keep alive over federation for this + # user. + if self.hs.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now: + preserve_fn(self._push_remote)( + member=member, + typing=True + ) + + def is_typing(self, member): + return member.user_id in self._room_typing.get(member.room_id, []) @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): @@ -85,23 +123,17 @@ class TypingHandler(object): "%s has started typing in %s", target_user_id, room_id ) - until = self.clock.time_msec() + timeout member = RoomMember(room_id=room_id, user_id=target_user_id) - was_present = member in self._member_typing_until - - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) + was_present = member.user_id in self._room_typing.get(room_id, set()) - def _cb(): - logger.debug( - "%s has timed out in %s", target_user.to_string(), room_id - ) - self._stopped_typing(member) + now = self.clock.time_msec() + self._member_typing_until[member] = now + timeout - self._member_typing_until[member] = until - self._member_typing_timer[member] = self.clock.call_later( - timeout / 1000.0, _cb + self.wheel_timer.insert( + now=now, + obj=member, + then=now + timeout, ) if was_present: @@ -109,8 +141,7 @@ class TypingHandler(object): defer.returnValue(None) yield self._push_update( - room_id=room_id, - user_id=target_user_id, + member=member, typing=True, ) @@ -133,10 +164,6 @@ class TypingHandler(object): member = RoomMember(room_id=room_id, user_id=target_user_id) - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) - del self._member_typing_timer[member] - yield self._stopped_typing(member) @defer.inlineCallbacks @@ -148,57 +175,61 @@ class TypingHandler(object): @defer.inlineCallbacks def _stopped_typing(self, member): - if member not in self._member_typing_until: + if member.user_id not in self._room_typing.get(member.room_id, set()): # No point defer.returnValue(None) + self._member_typing_until.pop(member, None) + self._member_last_federation_poke.pop(member, None) + yield self._push_update( - room_id=member.room_id, - user_id=member.user_id, + member=member, typing=False, ) - del self._member_typing_until[member] - - if member in self._member_typing_timer: - # Don't cancel it - either it already expired, or the real - # stopped_typing() will cancel it - del self._member_typing_timer[member] + @defer.inlineCallbacks + def _push_update(self, member, typing): + if self.hs.is_mine_id(member.user_id): + # Only send updates for changes to our own users. + yield self._push_remote(member, typing) + + self._push_update_local( + member=member, + typing=typing + ) @defer.inlineCallbacks - def _push_update(self, room_id, user_id, typing): - users = yield self.state.get_current_user_in_room(room_id) - domains = set(get_domain_from_id(u) for u in users) + def _push_remote(self, member, typing): + users = yield self.state.get_current_user_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, + ) - deferreds = [] - for domain in domains: - if domain == self.server_name: - preserve_fn(self._push_update_local)( - room_id=room_id, - user_id=user_id, - typing=typing - ) - else: - deferreds.append(preserve_fn(self.federation.send_edu)( + for domain in set(get_domain_from_id(u) for u in users): + if domain != self.server_name: + self.federation.send_edu( destination=domain, edu_type="m.typing", content={ - "room_id": room_id, - "user_id": user_id, + "room_id": member.room_id, + "user_id": member.user_id, "typing": typing, }, - key=(room_id, user_id), - )) - - yield preserve_context_over_deferred( - defer.DeferredList(deferreds, consumeErrors=True) - ) + key=member, + ) @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] user_id = content["user_id"] + member = RoomMember(user_id=user_id, room_id=room_id) + # Check that the string is a valid user id user = UserID.from_string(user_id) @@ -213,26 +244,32 @@ class TypingHandler(object): domains = set(get_domain_from_id(u) for u in users) if self.server_name in domains: + logger.info("Got typing update from %s: %r", user_id, content) + now = self.clock.time_msec() + self._member_typing_until[member] = now + FEDERATION_TIMEOUT + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_TIMEOUT, + ) self._push_update_local( - room_id=room_id, - user_id=user_id, + member=member, typing=content["typing"] ) - def _push_update_local(self, room_id, user_id, typing): - room_set = self._room_typing.setdefault(room_id, set()) + def _push_update_local(self, member, typing): + room_set = self._room_typing.setdefault(member.room_id, set()) if typing: - room_set.add(user_id) + room_set.add(member.user_id) else: - room_set.discard(user_id) + room_set.discard(member.user_id) self._latest_room_serial += 1 - self._room_serials[room_id] = self._latest_room_serial + self._room_serials[member.room_id] = self._latest_room_serial - with PreserveLoggingContext(): - self.notifier.on_new_event( - "typing_key", self._latest_room_serial, rooms=[room_id] - ) + self.notifier.on_new_event( + "typing_key", self._latest_room_serial, rooms=[member.room_id] + ) def get_all_typing_updates(self, last_id, current_id): # TODO: Work out a way to do this without scanning the entire state. diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 3046da7aec..b5a76fefac 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.http.servlet import parse_json_object_from_request +from synapse.types import create_requester from synapse.util.async import run_on_reactor @@ -391,15 +392,16 @@ class CreateUserRestServlet(ClientV1RestServlet): user_json = parse_json_object_from_request(request) access_token = get_access_token_from_request(request) - app_service = yield self.store.get_app_service_by_token( + app_service = self.store.get_app_service_by_token( access_token ) if not app_service: raise SynapseError(403, "Invalid application service token.") - logger.debug("creating user: %s", user_json) + requester = create_requester(app_service.sender) - response = yield self._do_create(user_json) + logger.debug("creating user: %s", user_json) + response = yield self._do_create(requester, user_json) defer.returnValue((200, response)) @@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet): return 403, {} @defer.inlineCallbacks - def _do_create(self, user_json): + def _do_create(self, requester, user_json): yield run_on_reactor() if "localpart" not in user_json: @@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet): handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( + requester=requester, localpart=localpart, displayname=displayname, duration_in_ms=(duration_seconds * 1000), diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 20889e4af0..010fbc7c32 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -705,12 +705,15 @@ class RoomTypingRestServlet(ClientV1RestServlet): yield self.presence_handler.bump_presence_active_time(requester.user) + # Limit timeout to stop people from setting silly typing timeouts. + timeout = min(content.get("timeout", 30000), 120000) + if content["typing"]: yield self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, - timeout=content.get("timeout", 30000), + timeout=timeout, ) else: yield self.typing_handler.stopped_typing( diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 58d3cad6a1..8e5577148f 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -77,8 +77,10 @@ SUCCESS_TEMPLATE = """ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> <link rel="stylesheet" href="/_matrix/static/client/register/style.css"> <script> -if (window.onAuthDone != undefined) { +if (window.onAuthDone) { window.onAuthDone(); +} else if (window.opener && window.opener.postMessage) { + window.opener.postMessage("authDone", "*"); } </script> </head> diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index a854a87eab..3d5994a580 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -37,7 +37,7 @@ class ApplicationServiceStore(SQLBaseStore): ) def get_app_services(self): - return defer.succeed(self.services_cache) + return self.services_cache def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. @@ -54,8 +54,8 @@ class ApplicationServiceStore(SQLBaseStore): """ for service in self.services_cache: if service.sender == user_id: - return defer.succeed(service) - return defer.succeed(None) + return service + return None def get_app_service_by_token(self, token): """Get the application service with the given appservice token. @@ -67,8 +67,8 @@ class ApplicationServiceStore(SQLBaseStore): """ for service in self.services_cache: if service.token == token: - return defer.succeed(service) - return defer.succeed(None) + return service + return None def get_app_service_rooms(self, service): """Get a list of RoomsForUser for this application service. @@ -163,7 +163,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore): ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore - as_list = yield self.get_app_services() + as_list = self.get_app_services() services = [] for res in results: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6dc46fa50f..6cf9d1176d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1355,39 +1355,53 @@ class EventsStore(SQLBaseStore): min_stream_id = rows[-1][0] event_ids = [row[1] for row in rows] - events = self._get_events_txn(txn, event_ids) + rows_to_update = [] - rows = [] - for event in events: - try: - event_id = event.event_id - origin_server_ts = event.origin_server_ts - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue + chunks = [ + event_ids[i:i + 100] + for i in xrange(0, len(event_ids), 100) + ] + for chunk in chunks: + ev_rows = self._simple_select_many_txn( + txn, + table="event_json", + column="event_id", + iterable=chunk, + retcols=["event_id", "json"], + keyvalues={}, + ) - rows.append((origin_server_ts, event_id)) + for row in ev_rows: + event_id = row["event_id"] + event_json = json.loads(row["json"]) + try: + origin_server_ts = event_json["origin_server_ts"] + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows_to_update.append((origin_server_ts, event_id)) sql = ( "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" ) - for index in range(0, len(rows), INSERT_CLUMP_SIZE): - clump = rows[index:index + INSERT_CLUMP_SIZE] + for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): + clump = rows_to_update[index:index + INSERT_CLUMP_SIZE] txn.executemany(sql, clump) progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows) + "rows_inserted": rows_inserted + len(rows_to_update) } self._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) - return len(rows) + return len(rows_to_update) result = yield self.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7eb342674c..49abf0ac74 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -307,6 +307,9 @@ class StateStore(SQLBaseStore): def _get_state_groups_from_groups_txn(self, txn, groups, types=None): results = {group: {} for group in groups} + if types is not None: + types = list(set(types)) # deduplicate types list + if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is # a temporary hack until we can add the right indices in @@ -375,10 +378,35 @@ class StateStore(SQLBaseStore): # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: - group_tree = [group] next_group = group while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indicies (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + if types: + args.extend(i for typ in types for i in typ) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? %s" % (where_clause,), + args + ) + rows = txn.fetchall() + results[group].update({ + (typ, state_key): event_id + for typ, state_key, event_id in rows + if (typ, state_key) not in results[group] + }) + + # If the lengths match then we must have all the types, + # so no need to go walk further down the tree. + if types is not None and len(results[group]) == len(types): + break + next_group = self._simple_select_one_onecol_txn( txn, table="state_group_edges", @@ -386,28 +414,6 @@ class StateStore(SQLBaseStore): retcol="prev_state_group", allow_none=True, ) - if next_group: - group_tree.append(next_group) - - sql = (""" - SELECT type, state_key, event_id FROM state_groups_state - INNER JOIN ( - SELECT type, state_key, max(state_group) as state_group - FROM state_groups_state - WHERE state_group IN (%s) %s - GROUP BY type, state_key - ) USING (type, state_key, state_group); - """) % (",".join("?" for _ in group_tree), where_clause,) - - args = list(group_tree) - if types is not None: - args.extend([i for typ in types for i in typ]) - - txn.execute(sql, args) - rows = self.cursor_to_dict(txn) - for row in rows: - key = (row["type"], row["state_key"]) - results[group][key] = row["event_id"] return results diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a7de3c7c17..9c9d144690 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from twisted.internet import defer from .. import unittest from synapse.handlers.register import RegistrationHandler -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests.utils import setup_test_homeserver @@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase): local_part = "someone" display_name = "someone" user_id = "@someone:test" + requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - local_part, display_name, duration_ms) + requester, local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase): local_part = "frank" display_name = "Frank" user_id = "@frank:test" + requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - local_part, display_name, duration_ms) + requester, local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ea1f0f7c33..c3108f5181 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -267,10 +267,7 @@ class TypingNotificationsTestCase(unittest.TestCase): from synapse.handlers.typing import RoomMember member = RoomMember(self.room_id, self.u_apple.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._member_typing_timer[member] = ( - self.clock.call_later(1002, lambda: 0) - ) - self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),)) + self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()]) self.assertEquals(self.event_source.get_current_key(), 0) @@ -330,7 +327,7 @@ class TypingNotificationsTestCase(unittest.TestCase): }, }]) - self.clock.advance_time(11) + self.clock.advance_time(16) self.on_new_event.assert_has_calls([ call('typing_key', 2, rooms=[self.room_id]), diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 4a898a034f..44ba9ff58f 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase): ) self.request.args = {} - self.appservice = None - self.auth = Mock(get_appservice_by_req=Mock( - side_effect=lambda x: defer.succeed(self.appservice)) - ) + self.registration_handler = Mock() - self.auth_result = (False, None, None, None) - self.auth_handler = Mock( - check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), - get_session_data=Mock(return_value=None) + self.appservice = Mock(sender="@as:test") + self.datastore = Mock( + get_app_service_by_token=Mock(return_value=self.appservice) ) - self.registration_handler = Mock() - self.identity_handler = Mock() - self.login_handler = Mock() - # do the dance to hook it up to the hs global - self.handlers = Mock( - auth_handler=self.auth_handler, + # do the dance to hook things up to the hs global + handlers = Mock( registration_handler=self.registration_handler, - identity_handler=self.identity_handler, - login_handler=self.login_handler ) self.hs = Mock() - self.hs.hostname = "supergbig~testing~thing.com" - self.hs.get_auth = Mock(return_value=self.auth) - self.hs.get_handlers = Mock(return_value=self.handlers) - self.hs.config.enable_registration = True - # init the thing we're testing + self.hs.hostname = "superbig~testing~thing.com" + self.hs.get_datastore = Mock(return_value=self.datastore) + self.hs.get_handlers = Mock(return_value=handlers) self.servlet = CreateUserRestServlet(self.hs) @defer.inlineCallbacks diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 467f253ef6..a269e6f56e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase): # Need another user to make notifications actually work yield self.join(self.room_id, user="@jim:red") - def tearDown(self): - self.hs.get_typing_handler().tearDown() - @defer.inlineCallbacks def test_set_typing(self): (code, _) = yield self.mock_resource.trigger( @@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase): self.assertEquals(self.event_source.get_current_key(), 1) - self.clock.advance_time(31) + self.clock.advance_time(36) self.assertEquals(self.event_source.get_current_key(), 2) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 8ac56a1fb2..e9cb416e4b 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -19,7 +19,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.appservice = None self.auth = Mock(get_appservice_by_req=Mock( - side_effect=lambda x: defer.succeed(self.appservice)) + side_effect=lambda x: self.appservice) ) self.auth_result = (False, None, None, None) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 3e2862daae..f3df8302da 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -71,14 +71,12 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - @defer.inlineCallbacks def test_retrieve_unknown_service_token(self): - service = yield self.store.get_app_service_by_token("invalid_token") + service = self.store.get_app_service_by_token("invalid_token") self.assertEquals(service, None) - @defer.inlineCallbacks def test_retrieval_of_service(self): - stored_service = yield self.store.get_app_service_by_token( + stored_service = self.store.get_app_service_by_token( self.as_token ) self.assertEquals(stored_service.token, self.as_token) @@ -97,9 +95,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): [] ) - @defer.inlineCallbacks def test_retrieval_of_all_services(self): - services = yield self.store.get_app_services() + services = self.store.get_app_services() self.assertEquals(len(services), 3) diff --git a/tests/utils.py b/tests/utils.py index 915b934e94..92d470cb48 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -220,6 +220,7 @@ class MockClock(object): # list of lists of [absolute_time, callback, expired] in no particular # order self.timers = [] + self.loopers = [] def time(self): return self.now @@ -240,7 +241,7 @@ class MockClock(object): return t def looping_call(self, function, interval): - pass + self.loopers.append([function, interval / 1000., self.now]) def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: @@ -269,6 +270,12 @@ class MockClock(object): else: self.timers.append(t) + for looped in self.loopers: + func, interval, last = looped + if last + interval < self.now: + func() + looped[2] = self.now + def advance_time_msec(self, ms): self.advance_time(ms / 1000.) |