diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/_base.py | 8 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 20 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 328 | ||||
-rw-r--r-- | synapse/handlers/device.py | 2 | ||||
-rw-r--r-- | synapse/handlers/directory.py | 11 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 132 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 19 | ||||
-rw-r--r-- | synapse/handlers/initial_sync.py | 443 | ||||
-rw-r--r-- | synapse/handlers/message.py | 381 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 4 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 8 | ||||
-rw-r--r-- | synapse/handlers/register.py | 11 | ||||
-rw-r--r-- | synapse/handlers/room.py | 161 | ||||
-rw-r--r-- | synapse/handlers/room_list.py | 403 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 2 | ||||
-rw-r--r-- | synapse/handlers/typing.py | 177 |
16 files changed, 1232 insertions, 878 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index e58735294e..4981643166 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -55,8 +55,14 @@ class BaseHandler(object): def ratelimit(self, requester): time_now = self.clock.time() + user_id = requester.user.to_string() + + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service is not None: + return # do not ratelimit app service senders + allowed, time_allowed = self.ratelimiter.send_message( - requester.user.to_string(), time_now, + user_id, time_now, msg_rate_hz=self.hs.config.rc_messages_per_second, burst_count=self.hs.config.rc_message_burst_count, ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 88fa0bb2e4..05af54d31b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -59,7 +59,7 @@ class ApplicationServicesHandler(object): Args: current_id(int): The current maximum ID. """ - services = yield self.store.get_app_services() + services = self.store.get_app_services() if not services or not self.notify_appservices: return @@ -142,7 +142,7 @@ class ApplicationServicesHandler(object): association can be found. """ room_alias_str = room_alias.to_string() - services = yield self.store.get_app_services() + services = self.store.get_app_services() alias_query_services = [ s for s in services if ( s.is_interested_in_alias(room_alias_str) @@ -177,7 +177,7 @@ class ApplicationServicesHandler(object): @defer.inlineCallbacks def get_3pe_protocols(self, only_protocol=None): - services = yield self.store.get_app_services() + services = self.store.get_app_services() protocols = {} # Collect up all the individual protocol responses out of the ASes @@ -224,7 +224,7 @@ class ApplicationServicesHandler(object): list<ApplicationService>: A list of services interested in this event based on the service regex. """ - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( yield s.is_interested(event, self.store) @@ -232,23 +232,21 @@ class ApplicationServicesHandler(object): ] defer.returnValue(interested_list) - @defer.inlineCallbacks def _get_services_for_user(self, user_id): - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( s.is_interested_in_user(user_id) ) ] - defer.returnValue(interested_list) + return defer.succeed(interested_list) - @defer.inlineCallbacks def _get_services_for_3pn(self, protocol): - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if s.is_interested_in_protocol(protocol) ] - defer.returnValue(interested_list) + return defer.succeed(interested_list) @defer.inlineCallbacks def _is_unknown_user(self, user_id): @@ -264,7 +262,7 @@ class ApplicationServicesHandler(object): return # user not found; could be the AS though, so check. - services = yield self.store.get_app_services() + services = self.store.get_app_services() service_list = [s for s in services if s.sender == user_id] defer.returnValue(len(service_list) == 0) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6986930c0d..dc0fe60e1b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -20,7 +20,6 @@ from synapse.api.constants import LoginType from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor -from synapse.config.ldap import LDAPMode from twisted.web.client import PartialDownloadError @@ -29,12 +28,6 @@ import bcrypt import pymacaroons import simplejson -try: - import ldap3 -except ImportError: - ldap3 = None - pass - import synapse.util.stringutils as stringutils @@ -58,23 +51,15 @@ class AuthHandler(BaseHandler): } self.bcrypt_rounds = hs.config.bcrypt_rounds self.sessions = {} - self.INVALID_TOKEN_HTTP_STATUS = 401 - self.ldap_enabled = hs.config.ldap_enabled - if self.ldap_enabled: - if not ldap3: - raise RuntimeError( - 'Missing ldap3 library. This is required for LDAP Authentication.' - ) - self.ldap_mode = hs.config.ldap_mode - self.ldap_uri = hs.config.ldap_uri - self.ldap_start_tls = hs.config.ldap_start_tls - self.ldap_base = hs.config.ldap_base - self.ldap_attributes = hs.config.ldap_attributes - if self.ldap_mode == LDAPMode.SEARCH: - self.ldap_bind_dn = hs.config.ldap_bind_dn - self.ldap_bind_password = hs.config.ldap_bind_password - self.ldap_filter = hs.config.ldap_filter + account_handler = _AccountHandler( + hs, check_user_exists=self.check_user_exists + ) + + self.password_providers = [ + module(config=config, account_handler=account_handler) + for module, config in hs.config.password_providers + ] self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() @@ -148,13 +133,30 @@ class AuthHandler(BaseHandler): creds = session['creds'] # check auth type currently being presented + errordict = {} if 'type' in authdict: - if authdict['type'] not in self.checkers: + login_type = authdict['type'] + if login_type not in self.checkers: raise LoginError(400, "", Codes.UNRECOGNIZED) - result = yield self.checkers[authdict['type']](authdict, clientip) - if result: - creds[authdict['type']] = result - self._save_session(session) + try: + result = yield self.checkers[login_type](authdict, clientip) + if result: + creds[login_type] = result + self._save_session(session) + except LoginError, e: + if login_type == LoginType.EMAIL_IDENTITY: + # riot used to have a bug where it would request a new + # validation token (thus sending a new email) each time it + # got a 401 with a 'flows' field. + # (https://github.com/vector-im/vector-web/issues/2447). + # + # Grandfather in the old behaviour for now to avoid + # breaking old riot deployments. + raise e + + # this step failed. Merge the error dict into the response + # so that the client can have another go. + errordict = e.error_dict() for f in flows: if len(set(f) - set(creds.keys())) == 0: @@ -163,6 +165,7 @@ class AuthHandler(BaseHandler): ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() + ret.update(errordict) defer.returnValue((False, ret, clientdict, session['id'])) @defer.inlineCallbacks @@ -430,37 +433,40 @@ class AuthHandler(BaseHandler): defer.Deferred: (str) canonical_user_id, or None if zero or multiple matches """ - try: - res = yield self._find_user_id_and_pwd_hash(user_id) + res = yield self._find_user_id_and_pwd_hash(user_id) + if res is not None: defer.returnValue(res[0]) - except LoginError: - defer.returnValue(None) + defer.returnValue(None) @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): """Checks to see if a user with the given id exists. Will check case - insensitively, but will throw if there are multiple inexact matches. + insensitively, but will return None if there are multiple inexact + matches. Returns: tuple: A 2-tuple of `(canonical_user_id, password_hash)` + None: if there is not exactly one match """ user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + + result = None if not user_infos: logger.warn("Attempted to login as %s but they do not exist", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - if len(user_infos) > 1: - if user_id not in user_infos: - logger.warn( - "Attempted to login as %s but it matches more than one user " - "inexactly: %r", - user_id, user_infos.keys() - ) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - defer.returnValue((user_id, user_infos[user_id])) + elif len(user_infos) == 1: + # a single match (possibly not exact) + result = user_infos.popitem() + elif user_id in user_infos: + # multiple matches, but one is exact + result = (user_id, user_infos[user_id]) else: - defer.returnValue(user_infos.popitem()) + # multiple matches, none of them exact + logger.warn( + "Attempted to login as %s but it matches more than one user " + "inexactly: %r", + user_id, user_infos.keys() + ) + defer.returnValue(result) @defer.inlineCallbacks def _check_password(self, user_id, password): @@ -474,202 +480,49 @@ class AuthHandler(BaseHandler): Returns: (str) the canonical_user_id Raises: - LoginError if the password was incorrect + LoginError if login fails """ - valid_ldap = yield self._check_ldap_password(user_id, password) - if valid_ldap: - defer.returnValue(user_id) - - result = yield self._check_local_password(user_id, password) - defer.returnValue(result) + for provider in self.password_providers: + is_valid = yield provider.check_password(user_id, password) + if is_valid: + defer.returnValue(user_id) + + canonical_user_id = yield self._check_local_password(user_id, password) + + if canonical_user_id: + defer.returnValue(canonical_user_id) + + # unknown username or invalid password. We raise a 403 here, but note + # that if we're doing user-interactive login, it turns all LoginErrors + # into a 401 anyway. + raise LoginError( + 403, "Invalid password", + errcode=Codes.FORBIDDEN + ) @defer.inlineCallbacks def _check_local_password(self, user_id, password): """Authenticate a user against the local password database. - user_id is checked case insensitively, but will throw if there are + user_id is checked case insensitively, but will return None if there are multiple inexact matches. Args: user_id (str): complete @user:id Returns: - (str) the canonical_user_id - Raises: - LoginError if the password was incorrect + (str) the canonical_user_id, or None if unknown user / bad password """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + lookupres = yield self._find_user_id_and_pwd_hash(user_id) + if not lookupres: + defer.returnValue(None) + (user_id, password_hash) = lookupres result = self.validate_hash(password, password_hash) if not result: logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(None) defer.returnValue(user_id) @defer.inlineCallbacks - def _check_ldap_password(self, user_id, password): - """ Attempt to authenticate a user against an LDAP Server - and register an account if none exists. - - Returns: - True if authentication against LDAP was successful - """ - - if not ldap3 or not self.ldap_enabled: - defer.returnValue(False) - - if self.ldap_mode not in LDAPMode.LIST: - raise RuntimeError( - 'Invalid ldap mode specified: {mode}'.format( - mode=self.ldap_mode - ) - ) - - try: - server = ldap3.Server(self.ldap_uri) - logger.debug( - "Attempting ldap connection with %s", - self.ldap_uri - ) - - localpart = UserID.from_string(user_id).localpart - if self.ldap_mode == LDAPMode.SIMPLE: - # bind with the the local users ldap credentials - bind_dn = "{prop}={value},{base}".format( - prop=self.ldap_attributes['uid'], - value=localpart, - base=self.ldap_base - ) - conn = ldap3.Connection(server, bind_dn, password) - logger.debug( - "Established ldap connection in simple mode: %s", - conn - ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded ldap connection in simple mode through StartTLS: %s", - conn - ) - - conn.bind() - - elif self.ldap_mode == LDAPMode.SEARCH: - # connect with preconfigured credentials and search for local user - conn = ldap3.Connection( - server, - self.ldap_bind_dn, - self.ldap_bind_password - ) - logger.debug( - "Established ldap connection in search mode: %s", - conn - ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded ldap connection in search mode through StartTLS: %s", - conn - ) - - conn.bind() - - # find matching dn - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart - ) - if self.ldap_filter: - query = "(&{query}{filter})".format( - query=query, - filter=self.ldap_filter - ) - logger.debug("ldap search filter: %s", query) - result = conn.search(self.ldap_base, query) - - if result and len(conn.response) == 1: - # found exactly one result - user_dn = conn.response[0]['dn'] - logger.debug('ldap search found dn: %s', user_dn) - - # unbind and reconnect, rebind with found dn - conn.unbind() - conn = ldap3.Connection( - server, - user_dn, - password, - auto_bind=True - ) - else: - # found 0 or > 1 results, abort! - logger.warn( - "ldap search returned unexpected (%d!=1) amount of results", - len(conn.response) - ) - defer.returnValue(False) - - logger.info( - "User authenticated against ldap server: %s", - conn - ) - - # check for existing account, if none exists, create one - if not (yield self.check_user_exists(user_id)): - # query user metadata for account creation - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart - ) - - if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: - query = "(&{filter}{user_filter})".format( - filter=query, - user_filter=self.ldap_filter - ) - logger.debug("ldap registration filter: %s", query) - - result = conn.search( - search_base=self.ldap_base, - search_filter=query, - attributes=[ - self.ldap_attributes['name'], - self.ldap_attributes['mail'] - ] - ) - - if len(conn.response) == 1: - attrs = conn.response[0]['attributes'] - mail = attrs[self.ldap_attributes['mail']][0] - name = attrs[self.ldap_attributes['name']][0] - - # create account - registration_handler = self.hs.get_handlers().registration_handler - user_id, access_token = ( - yield registration_handler.register(localpart=localpart) - ) - - # TODO: bind email, set displayname with data from ldap directory - - logger.info( - "ldap registration successful: %d: %s (%s, %)", - user_id, - localpart, - name, - mail - ) - else: - logger.warn( - "ldap registration failed: unexpected (%d!=1) amount of results", - len(conn.response) - ) - defer.returnValue(False) - - defer.returnValue(True) - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during ldap authentication: %s", e) - defer.returnValue(False) - - @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token, @@ -806,3 +659,30 @@ class AuthHandler(BaseHandler): stored_hash.encode('utf-8')) == stored_hash else: return False + + +class _AccountHandler(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, check_user_exists): + self.hs = hs + + self._check_user_exists = check_user_exists + + def check_user_exists(self, user_id): + """Check if user exissts. + + Returns: + Deferred(bool) + """ + return self._check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 8d630c6b1a..aa68755936 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -58,7 +58,7 @@ class DeviceHandler(BaseHandler): attempts = 0 while attempts < 5: try: - device_id = stringutils.random_string_with_symbols(16) + device_id = stringutils.random_string(10).upper() yield self.store.store_device( user_id=user_id, device_id=device_id, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 14352985e2..c00274afc3 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -288,13 +288,12 @@ class DirectoryHandler(BaseHandler): result = yield as_handler.query_room_alias_exists(room_alias) defer.returnValue(result) - @defer.inlineCallbacks def can_modify_alias(self, alias, user_id=None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have # non-exclusive locks on the alias (or there are no interested services) - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_alias(alias.to_string()) ] @@ -302,14 +301,12 @@ class DirectoryHandler(BaseHandler): for service in interested_services: if user_id == service.sender: # this user IS the app service so they can do whatever they like - defer.returnValue(True) - return + return defer.succeed(True) elif service.is_exclusive_alias(alias.to_string()): # another service has an exclusive lock on this alias. - defer.returnValue(False) - return + return defer.succeed(False) # either no interested services, or no service with an exclusive lock - defer.returnValue(True) + return defer.succeed(True) @defer.inlineCallbacks def _user_can_delete_alias(self, alias, user_id): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5bfd700931..fd11935b40 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,14 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +import ujson as json import logging +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import SynapseError, CodeMessageException from synapse.types import get_domain_from_id from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination logger = logging.getLogger(__name__) @@ -29,7 +31,9 @@ class E2eKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() self.federation = hs.get_replication_layer() + self.device_handler = hs.get_device_handler() self.is_mine_id = hs.is_mine_id + self.clock = hs.get_clock() # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the @@ -85,27 +89,37 @@ class E2eKeysHandler(object): def do_remote_query(destination): destination_query = remote_queries[destination] try: - remote_result = yield self.federation.query_client_keys( - destination, - {"device_keys": destination_query}, - timeout=timeout + limiter = yield get_retry_limiter( + destination, self.clock, self.store ) + with limiter: + remote_result = yield self.federation.query_client_keys( + destination, + {"device_keys": destination_query}, + timeout=timeout + ) + for user_id, keys in remote_result["device_keys"].items(): if user_id in destination_query: results[user_id] = keys + except CodeMessageException as e: failures[destination] = { "status": e.code, "message": e.message } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(do_remote_query)(destination) for destination in remote_queries ])) - defer.returnValue((200, { + defer.returnValue({ "device_keys": results, "failures": failures, - })) + }) @defer.inlineCallbacks def query_local_devices(self, query): @@ -159,3 +173,107 @@ class E2eKeysHandler(object): device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) defer.returnValue({"device_keys": res}) + + @defer.inlineCallbacks + def claim_one_time_keys(self, query, timeout): + local_query = [] + remote_queries = {} + + for user_id, device_keys in query.get("one_time_keys", {}).items(): + if self.is_mine_id(user_id): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_keys + + results = yield self.store.claim_e2e_one_time_keys(local_query) + + json_result = {} + failures = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + + @defer.inlineCallbacks + def claim_client_keys(destination): + device_keys = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.claim_client_keys( + destination, + {"one_time_keys": device_keys}, + timeout=timeout + ) + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(claim_client_keys)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "one_time_keys": json_result, + "failures": failures + }) + + @defer.inlineCallbacks + def upload_keys_for_user(self, user_id, device_id, keys): + time_now = self.clock.time_msec() + + # TODO: Validate the JSON to make sure it has the right keys. + device_keys = keys.get("device_keys", None) + if device_keys: + logger.info( + "Updating device_keys for device %r for user %s at %d", + device_id, user_id, time_now + ) + # TODO: Sign the JSON with the server key + yield self.store.set_e2e_device_keys( + user_id, device_id, time_now, + encode_canonical_json(device_keys) + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + logger.info( + "Adding %d one_time_keys for device %r for user %r at %d", + len(one_time_keys), device_id, user_id, time_now + ) + key_list = [] + for key_id, key_json in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, encode_canonical_json(key_json) + )) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, key_list + ) + + # the device should have been registered already, but it may have been + # deleted due to a race with a DELETE request. Or we may be using an + # old access_token without an associated device_id. Either way, we + # need to double-check the device is registered to avoid ending up with + # keys without a corresponding device. + self.device_handler.check_device_registered(user_id, device_id) + + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + + defer.returnValue({"one_time_key_counts": result}) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8a1038c44a..2d801bad47 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1585,10 +1585,12 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state + context.current_state_ids = dict(context.current_state_ids) context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() if k != event_key }) + context.prev_state_ids = dict(context.prev_state_ids) context.prev_state_ids.update({ k: a.event_id for k, a in auth_events.items() }) @@ -1670,10 +1672,12 @@ class FederationHandler(BaseHandler): # 4. Look at rejects and their proofs. # TODO. + context.current_state_ids = dict(context.current_state_ids) context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() if k != event_key }) + context.prev_state_ids = dict(context.prev_state_ids) context.prev_state_ids.update({ k: a.event_id for k, a in auth_events.items() }) @@ -1918,15 +1922,18 @@ class FederationHandler(BaseHandler): original_invite = yield self.store.get_event( original_invite_id, allow_none=True ) - if not original_invite: + if original_invite: + display_name = original_invite.content["display_name"] + event_dict["content"]["third_party_invite"]["display_name"] = display_name + else: logger.info( - "Could not find invite event for third_party_invite - " - "discarding: %s" % (event_dict,) + "Could not find invite event for third_party_invite: %r", + event_dict ) - return + # We don't discard here as this is not the appropriate place to do + # auth checks. If we need the invite and don't have it then the + # auth check code will explode appropriately. - display_name = original_invite.content["display_name"] - event_dict["content"]["third_party_invite"]["display_name"] = display_name builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) message_handler = self.hs.get_handlers().message_handler diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py new file mode 100644 index 0000000000..fbfa5a0281 --- /dev/null +++ b/synapse/handlers/initial_sync.py @@ -0,0 +1,443 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import AuthError, Codes +from synapse.events.utils import serialize_event +from synapse.events.validator import EventValidator +from synapse.streams.config import PaginationConfig +from synapse.types import ( + UserID, StreamToken, +) +from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute +from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.visibility import filter_events_for_client + +from ._base import BaseHandler + +import logging + + +logger = logging.getLogger(__name__) + + +class InitialSyncHandler(BaseHandler): + def __init__(self, hs): + super(InitialSyncHandler, self).__init__(hs) + self.hs = hs + self.state = hs.get_state_handler() + self.clock = hs.get_clock() + self.validator = EventValidator() + self.snapshot_cache = SnapshotCache() + + def snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): + """Retrieve a snapshot of all rooms the user is invited or has joined. + + This snapshot may include messages for all rooms where the user is + joined, depending on the pagination config. + + Args: + user_id (str): The ID of the user making the request. + pagin_config (synapse.api.streams.PaginationConfig): The pagination + config used to determine how many messages *PER ROOM* to return. + as_client_event (bool): True to get events in client-server format. + include_archived (bool): True to get rooms that the user has left + Returns: + A list of dicts with "room_id" and "membership" keys for all rooms + the user is currently invited or joined in on. Rooms where the user + is joined on, may return a "messages" key with messages, depending + on the specified PaginationConfig. + """ + key = ( + user_id, + pagin_config.from_token, + pagin_config.to_token, + pagin_config.direction, + pagin_config.limit, + as_client_event, + include_archived, + ) + now_ms = self.clock.time_msec() + result = self.snapshot_cache.get(now_ms, key) + if result is not None: + return result + + return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( + user_id, pagin_config, as_client_event, include_archived + )) + + @defer.inlineCallbacks + def _snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): + + memberships = [Membership.INVITE, Membership.JOIN] + if include_archived: + memberships.append(Membership.LEAVE) + + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, membership_list=memberships + ) + + user = UserID.from_string(user_id) + + rooms_ret = [] + + now_token = yield self.hs.get_event_sources().get_current_token() + + presence_stream = self.hs.get_event_sources().sources["presence"] + pagination_config = PaginationConfig(from_token=now_token) + presence, _ = yield presence_stream.get_pagination_rows( + user, pagination_config.get_source_config("presence"), None + ) + + receipt_stream = self.hs.get_event_sources().sources["receipt"] + receipt, _ = yield receipt_stream.get_pagination_rows( + user, pagination_config.get_source_config("receipt"), None + ) + + tags_by_room = yield self.store.get_tags_for_user(user_id) + + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user(user_id) + ) + + public_room_ids = yield self.store.get_public_room_ids() + + limit = pagin_config.limit + if limit is None: + limit = 10 + + @defer.inlineCallbacks + def handle_room(event): + d = { + "room_id": event.room_id, + "membership": event.membership, + "visibility": ( + "public" if event.room_id in public_room_ids + else "private" + ), + } + + if event.membership == Membership.INVITE: + time_now = self.clock.time_msec() + d["inviter"] = event.sender + + invite_event = yield self.store.get_event(event.event_id) + d["invite"] = serialize_event(invite_event, time_now, as_client_event) + + rooms_ret.append(d) + + if event.membership not in (Membership.JOIN, Membership.LEAVE): + return + + try: + if event.membership == Membership.JOIN: + room_end_token = now_token.room_key + deferred_room_state = self.state_handler.get_current_state( + event.room_id + ) + elif event.membership == Membership.LEAVE: + room_end_token = "s%d" % (event.stream_ordering,) + deferred_room_state = self.store.get_state_for_events( + [event.event_id], None + ) + deferred_room_state.addCallback( + lambda states: states[event.event_id] + ) + + (messages, token), current_state = yield preserve_context_over_deferred( + defer.gatherResults( + [ + preserve_fn(self.store.get_recent_events_for_room)( + event.room_id, + limit=limit, + end_token=room_end_token, + ), + deferred_room_state, + ] + ) + ).addErrback(unwrapFirstError) + + messages = yield filter_events_for_client( + self.store, user_id, messages + ) + + start_token = now_token.copy_and_replace("room_key", token[0]) + end_token = now_token.copy_and_replace("room_key", token[1]) + time_now = self.clock.time_msec() + + d["messages"] = { + "chunk": [ + serialize_event(m, time_now, as_client_event) + for m in messages + ], + "start": start_token.to_string(), + "end": end_token.to_string(), + } + + d["state"] = [ + serialize_event(c, time_now, as_client_event) + for c in current_state.values() + ] + + account_data_events = [] + tags = tags_by_room.get(event.room_id) + if tags: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + account_data = account_data_by_room.get(event.room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + d["account_data"] = account_data_events + except: + logger.exception("Failed to get snapshot") + + yield concurrently_execute(handle_room, room_list, 10) + + account_data_events = [] + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + ret = { + "rooms": rooms_ret, + "presence": presence, + "account_data": account_data_events, + "receipts": receipt, + "end": now_token.to_string(), + } + + defer.returnValue(ret) + + @defer.inlineCallbacks + def room_initial_sync(self, requester, room_id, pagin_config=None): + """Capture the a snapshot of a room. If user is currently a member of + the room this will be what is currently in the room. If the user left + the room this will be what was in the room when they left. + + Args: + requester(Requester): The user to get a snapshot for. + room_id(str): The room to get a snapshot of. + pagin_config(synapse.streams.config.PaginationConfig): + The pagination config used to determine how many messages to + return. + Raises: + AuthError if the user wasn't in the room. + Returns: + A JSON serialisable dict with the snapshot of the room. + """ + + user_id = requester.user.to_string() + + membership, member_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, + ) + is_peeking = member_event_id is None + + if membership == Membership.JOIN: + result = yield self._room_initial_sync_joined( + user_id, room_id, pagin_config, membership, is_peeking + ) + elif membership == Membership.LEAVE: + result = yield self._room_initial_sync_parted( + user_id, room_id, pagin_config, membership, member_event_id, is_peeking + ) + + account_data_events = [] + tags = yield self.store.get_tags_for_room(user_id, room_id) + if tags: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + account_data = yield self.store.get_account_data_for_room(user_id, room_id) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + result["account_data"] = account_data_events + + defer.returnValue(result) + + @defer.inlineCallbacks + def _room_initial_sync_parted(self, user_id, room_id, pagin_config, + membership, member_event_id, is_peeking): + room_state = yield self.store.get_state_for_events( + [member_event_id], None + ) + + room_state = room_state[member_event_id] + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + stream_token = yield self.store.get_stream_token_for_event( + member_event_id + ) + + messages, token = yield self.store.get_recent_events_for_room( + room_id, + limit=limit, + end_token=stream_token + ) + + messages = yield filter_events_for_client( + self.store, user_id, messages, is_peeking=is_peeking + ) + + start_token = StreamToken.START.copy_and_replace("room_key", token[0]) + end_token = StreamToken.START.copy_and_replace("room_key", token[1]) + + time_now = self.clock.time_msec() + + defer.returnValue({ + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": [serialize_event(s, time_now) for s in room_state.values()], + "presence": [], + "receipts": [], + }) + + @defer.inlineCallbacks + def _room_initial_sync_joined(self, user_id, room_id, pagin_config, + membership, is_peeking): + current_state = yield self.state.get_current_state( + room_id=room_id, + ) + + # TODO: These concurrently + time_now = self.clock.time_msec() + state = [ + serialize_event(x, time_now) + for x in current_state.values() + ] + + now_token = yield self.hs.get_event_sources().get_current_token() + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + room_members = [ + m for m in current_state.values() + if m.type == EventTypes.Member + and m.content["membership"] == Membership.JOIN + ] + + presence_handler = self.hs.get_presence_handler() + + @defer.inlineCallbacks + def get_presence(): + states = yield presence_handler.get_states( + [m.user_id for m in room_members], + as_event=True, + ) + + defer.returnValue(states) + + @defer.inlineCallbacks + def get_receipts(): + receipts_handler = self.hs.get_handlers().receipts_handler + receipts = yield receipts_handler.get_receipts_for_room( + room_id, + now_token.receipt_key + ) + defer.returnValue(receipts) + + presence, receipts, (messages, token) = yield defer.gatherResults( + [ + preserve_fn(get_presence)(), + preserve_fn(get_receipts)(), + preserve_fn(self.store.get_recent_events_for_room)( + room_id, + limit=limit, + end_token=now_token.room_key, + ) + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + messages = yield filter_events_for_client( + self.store, user_id, messages, is_peeking=is_peeking, + ) + + start_token = now_token.copy_and_replace("room_key", token[0]) + end_token = now_token.copy_and_replace("room_key", token[1]) + + time_now = self.clock.time_msec() + + ret = { + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": state, + "presence": presence, + "receipts": receipts, + } + if not is_peeking: + ret["membership"] = membership + + defer.returnValue(ret) + + @defer.inlineCallbacks + def _check_in_room_or_world_readable(self, room_id, user_id): + try: + # check_user_was_in_room will return the most recent membership + # event for the user if: + # * The user is a non-guest user, and was ever in the room + # * The user is a guest user, and has joined the room + # else it will throw. + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + return + except AuthError: + visibility = yield self.state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if ( + visibility and + visibility.content["history_visibility"] == "world_readable" + ): + defer.returnValue((Membership.JOIN, None)) + return + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 178209a209..30ea9630f7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,14 +21,11 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.push.action_generator import ActionGenerator -from synapse.streams.config import PaginationConfig from synapse.types import ( - UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id + UserID, RoomAlias, RoomStreamToken, get_domain_from_id ) -from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock -from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.async import run_on_reactor, ReadWriteLock +from synapse.util.logcontext import preserve_fn from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -49,7 +46,6 @@ class MessageHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() self.pagination_lock = ReadWriteLock() @@ -392,377 +388,6 @@ class MessageHandler(BaseHandler): [serialize_event(c, now) for c in room_state.values()] ) - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): - """Retrieve a snapshot of all rooms the user is invited or has joined. - - This snapshot may include messages for all rooms where the user is - joined, depending on the pagination config. - - Args: - user_id (str): The ID of the user making the request. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config used to determine how many messages *PER ROOM* to return. - as_client_event (bool): True to get events in client-server format. - include_archived (bool): True to get rooms that the user has left - Returns: - A list of dicts with "room_id" and "membership" keys for all rooms - the user is currently invited or joined in on. Rooms where the user - is joined on, may return a "messages" key with messages, depending - on the specified PaginationConfig. - """ - key = ( - user_id, - pagin_config.from_token, - pagin_config.to_token, - pagin_config.direction, - pagin_config.limit, - as_client_event, - include_archived, - ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - - return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - )) - - @defer.inlineCallbacks - def _snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): - - memberships = [Membership.INVITE, Membership.JOIN] - if include_archived: - memberships.append(Membership.LEAVE) - - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, membership_list=memberships - ) - - user = UserID.from_string(user_id) - - rooms_ret = [] - - now_token = yield self.hs.get_event_sources().get_current_token() - - presence_stream = self.hs.get_event_sources().sources["presence"] - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( - user, pagination_config.get_source_config("presence"), None - ) - - receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( - user, pagination_config.get_source_config("receipt"), None - ) - - tags_by_room = yield self.store.get_tags_for_user(user_id) - - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) - ) - - public_room_ids = yield self.store.get_public_room_ids() - - limit = pagin_config.limit - if limit is None: - limit = 10 - - @defer.inlineCallbacks - def handle_room(event): - d = { - "room_id": event.room_id, - "membership": event.membership, - "visibility": ( - "public" if event.room_id in public_room_ids - else "private" - ), - } - - if event.membership == Membership.INVITE: - time_now = self.clock.time_msec() - d["inviter"] = event.sender - - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = serialize_event(invite_event, time_now, as_client_event) - - rooms_ret.append(d) - - if event.membership not in (Membership.JOIN, Membership.LEAVE): - return - - try: - if event.membership == Membership.JOIN: - room_end_token = now_token.room_key - deferred_room_state = self.state_handler.get_current_state( - event.room_id - ) - elif event.membership == Membership.LEAVE: - room_end_token = "s%d" % (event.stream_ordering,) - deferred_room_state = self.store.get_state_for_events( - [event.event_id], None - ) - deferred_room_state.addCallback( - lambda states: states[event.event_id] - ) - - (messages, token), current_state = yield preserve_context_over_deferred( - defer.gatherResults( - [ - preserve_fn(self.store.get_recent_events_for_room)( - event.room_id, - limit=limit, - end_token=room_end_token, - ), - deferred_room_state, - ] - ) - ).addErrback(unwrapFirstError) - - messages = yield filter_events_for_client( - self.store, user_id, messages - ) - - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) - time_now = self.clock.time_msec() - - d["messages"] = { - "chunk": [ - serialize_event(m, time_now, as_client_event) - for m in messages - ], - "start": start_token.to_string(), - "end": end_token.to_string(), - } - - d["state"] = [ - serialize_event(c, time_now, as_client_event) - for c in current_state.values() - ] - - account_data_events = [] - tags = tags_by_room.get(event.room_id) - if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = account_data_by_room.get(event.room_id, {}) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - d["account_data"] = account_data_events - except: - logger.exception("Failed to get snapshot") - - yield concurrently_execute(handle_room, room_list, 10) - - account_data_events = [] - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - ret = { - "rooms": rooms_ret, - "presence": presence, - "account_data": account_data_events, - "receipts": receipt, - "end": now_token.to_string(), - } - - defer.returnValue(ret) - - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): - """Capture the a snapshot of a room. If user is currently a member of - the room this will be what is currently in the room. If the user left - the room this will be what was in the room when they left. - - Args: - requester(Requester): The user to get a snapshot for. - room_id(str): The room to get a snapshot of. - pagin_config(synapse.streams.config.PaginationConfig): - The pagination config used to determine how many messages to - return. - Raises: - AuthError if the user wasn't in the room. - Returns: - A JSON serialisable dict with the snapshot of the room. - """ - - user_id = requester.user.to_string() - - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id, - ) - is_peeking = member_event_id is None - - if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( - user_id, room_id, pagin_config, membership, is_peeking - ) - elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( - user_id, room_id, pagin_config, membership, member_event_id, is_peeking - ) - - account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) - if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = yield self.store.get_account_data_for_room(user_id, room_id) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - result["account_data"] = account_data_events - - defer.returnValue(result) - - @defer.inlineCallbacks - def _room_initial_sync_parted(self, user_id, room_id, pagin_config, - membership, member_event_id, is_peeking): - room_state = yield self.store.get_state_for_events( - [member_event_id], None - ) - - room_state = room_state[member_event_id] - - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - - stream_token = yield self.store.get_stream_token_for_event( - member_event_id - ) - - messages, token = yield self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=stream_token - ) - - messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking - ) - - start_token = StreamToken.START.copy_and_replace("room_key", token[0]) - end_token = StreamToken.START.copy_and_replace("room_key", token[1]) - - time_now = self.clock.time_msec() - - defer.returnValue({ - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": [serialize_event(s, time_now) for s in room_state.values()], - "presence": [], - "receipts": [], - }) - - @defer.inlineCallbacks - def _room_initial_sync_joined(self, user_id, room_id, pagin_config, - membership, is_peeking): - current_state = yield self.state.get_current_state( - room_id=room_id, - ) - - # TODO: These concurrently - time_now = self.clock.time_msec() - state = [ - serialize_event(x, time_now) - for x in current_state.values() - ] - - now_token = yield self.hs.get_event_sources().get_current_token() - - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - - room_members = [ - m for m in current_state.values() - if m.type == EventTypes.Member - and m.content["membership"] == Membership.JOIN - ] - - presence_handler = self.hs.get_presence_handler() - - @defer.inlineCallbacks - def get_presence(): - states = yield presence_handler.get_states( - [m.user_id for m in room_members], - as_event=True, - ) - - defer.returnValue(states) - - @defer.inlineCallbacks - def get_receipts(): - receipts_handler = self.hs.get_handlers().receipts_handler - receipts = yield receipts_handler.get_receipts_for_room( - room_id, - now_token.receipt_key - ) - defer.returnValue(receipts) - - presence, receipts, (messages, token) = yield defer.gatherResults( - [ - preserve_fn(get_presence)(), - preserve_fn(get_receipts)(), - preserve_fn(self.store.get_recent_events_for_room)( - room_id, - limit=limit, - end_token=now_token.room_key, - ) - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking, - ) - - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) - - time_now = self.clock.time_msec() - - ret = { - "room_id": room_id, - "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": state, - "presence": presence, - "receipts": receipts, - } - if not is_peeking: - ret["membership"] = membership - - defer.returnValue(ret) - @measure_func("_create_new_client_event") @defer.inlineCallbacks def _create_new_client_event(self, builder, prev_event_ids=None): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index a949e39bda..b047ae2250 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -217,7 +217,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ logger.info( - "Performing _on_shutdown. Persiting %d unpersisted changes", + "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) ) @@ -234,7 +234,7 @@ class PresenceHandler(object): may stack up and slow down shutdown times. """ logger.info( - "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", + "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes", len(self.unpersisted_users_changes) ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index d9ac09078d..87f74dfb8e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -65,13 +65,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["displayname"]) @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname): + def set_displayname(self, target_user, requester, new_displayname, by_admin=False): """target_user is the user whose displayname is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != requester.user: + if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") if new_displayname == '': @@ -111,13 +111,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["avatar_url"]) @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url): + def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != requester.user: + if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") yield self.store.set_profile_avatar_url( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index dd75c4fecf..7e119f13b1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -19,7 +19,6 @@ import urllib from twisted.internet import defer -import synapse.types from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) @@ -194,7 +193,7 @@ class RegistrationHandler(BaseHandler): def appservice_register(self, user_localpart, as_token): user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() - service = yield self.store.get_app_service_by_token(as_token) + service = self.store.get_app_service_by_token(as_token) if not service: raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): @@ -305,11 +304,10 @@ class RegistrationHandler(BaseHandler): # XXX: This should be a deferred list, shouldn't it? yield identity_handler.bind_threepid(c, user_id) - @defer.inlineCallbacks def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): # valid user IDs must not clash with any user ID namespaces claimed by # application services. - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_user(user_id) @@ -371,7 +369,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_in_ms, + def get_or_create_user(self, requester, localpart, displayname, duration_in_ms, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -418,9 +416,8 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) profile_handler = self.hs.get_handlers().profile_handler - requester = synapse.types.create_requester(user) yield profile_handler.set_displayname( - user, requester, displayname + user, requester, displayname, by_admin=True, ) defer.returnValue((user_id, token)) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d40ada60c1..a7f533f7be 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,12 +20,10 @@ from ._base import BaseHandler from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.api.constants import ( - EventTypes, JoinRules, RoomCreationPreset, Membership, + EventTypes, JoinRules, RoomCreationPreset ) from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils -from synapse.util.async import concurrently_execute -from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from collections import OrderedDict @@ -36,8 +34,6 @@ import string logger = logging.getLogger(__name__) -REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 - id_server_scheme = "https://" @@ -348,159 +344,6 @@ class RoomCreationHandler(BaseHandler): ) -class RoomListHandler(BaseHandler): - def __init__(self, hs): - super(RoomListHandler, self).__init__(hs) - self.response_cache = ResponseCache(hs) - self.remote_list_request_cache = ResponseCache(hs) - self.remote_list_cache = {} - self.fetch_looping_call = hs.get_clock().looping_call( - self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL - ) - self.fetch_all_remote_lists() - - def get_local_public_room_list(self): - result = self.response_cache.get(()) - if not result: - result = self.response_cache.set((), self._get_public_room_list()) - return result - - @defer.inlineCallbacks - def _get_public_room_list(self): - room_ids = yield self.store.get_public_room_ids() - - results = [] - - @defer.inlineCallbacks - def handle_room(room_id): - current_state = yield self.state_handler.get_current_state(room_id) - - # Double check that this is actually a public room. - join_rules_event = current_state.get((EventTypes.JoinRules, "")) - if join_rules_event: - join_rule = join_rules_event.content.get("join_rule", None) - if join_rule and join_rule != JoinRules.PUBLIC: - defer.returnValue(None) - - result = {"room_id": room_id} - - num_joined_users = len([ - 1 for _, event in current_state.items() - if event.type == EventTypes.Member - and event.membership == Membership.JOIN - ]) - if num_joined_users == 0: - return - - result["num_joined_members"] = num_joined_users - - aliases = yield self.store.get_aliases_for_room(room_id) - if aliases: - result["aliases"] = aliases - - name_event = yield current_state.get((EventTypes.Name, "")) - if name_event: - name = name_event.content.get("name", None) - if name: - result["name"] = name - - topic_event = current_state.get((EventTypes.Topic, "")) - if topic_event: - topic = topic_event.content.get("topic", None) - if topic: - result["topic"] = topic - - canonical_event = current_state.get((EventTypes.CanonicalAlias, "")) - if canonical_event: - canonical_alias = canonical_event.content.get("alias", None) - if canonical_alias: - result["canonical_alias"] = canonical_alias - - visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, "")) - visibility = None - if visibility_event: - visibility = visibility_event.content.get("history_visibility", None) - result["world_readable"] = visibility == "world_readable" - - guest_event = current_state.get((EventTypes.GuestAccess, "")) - guest = None - if guest_event: - guest = guest_event.content.get("guest_access", None) - result["guest_can_join"] = guest == "can_join" - - avatar_event = current_state.get(("m.room.avatar", "")) - if avatar_event: - avatar_url = avatar_event.content.get("url", None) - if avatar_url: - result["avatar_url"] = avatar_url - - results.append(result) - - yield concurrently_execute(handle_room, room_ids, 10) - - # FIXME (erikj): START is no longer a valid value - defer.returnValue({"start": "START", "end": "END", "chunk": results}) - - @defer.inlineCallbacks - def fetch_all_remote_lists(self): - deferred = self.hs.get_replication_layer().get_public_rooms( - self.hs.config.secondary_directory_servers - ) - self.remote_list_request_cache.set((), deferred) - self.remote_list_cache = yield deferred - - @defer.inlineCallbacks - def get_remote_public_room_list(self, server_name): - res = yield self.hs.get_replication_layer().get_public_rooms( - [server_name] - ) - - if server_name not in res: - raise SynapseError(404, "Server not found") - defer.returnValue(res[server_name]) - - @defer.inlineCallbacks - def get_aggregated_public_room_list(self): - """ - Get the public room list from this server and the servers - specified in the secondary_directory_servers config option. - XXX: Pagination... - """ - # We return the results from out cache which is updated by a looping call, - # unless we're missing a cache entry, in which case wait for the result - # of the fetch if there's one in progress. If not, omit that server. - wait = False - for s in self.hs.config.secondary_directory_servers: - if s not in self.remote_list_cache: - logger.warn("No cached room list from %s: waiting for fetch", s) - wait = True - break - - if wait and self.remote_list_request_cache.get(()): - yield self.remote_list_request_cache.get(()) - - public_rooms = yield self.get_local_public_room_list() - - # keep track of which room IDs we've seen so we can de-dup - room_ids = set() - - # tag all the ones in our list with our server name. - # Also add the them to the de-deping set - for room in public_rooms['chunk']: - room["server_name"] = self.hs.hostname - room_ids.add(room["room_id"]) - - # Now add the results from federation - for server_name, server_result in self.remote_list_cache.items(): - for room in server_result["chunk"]: - if room["room_id"] not in room_ids: - room["server_name"] = server_name - public_rooms["chunk"].append(room) - room_ids.add(room["room_id"]) - - defer.returnValue(public_rooms) - - class RoomContextHandler(BaseHandler): @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, is_guest): @@ -594,7 +437,7 @@ class RoomEventSource(object): logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = yield self.store.get_app_service_by_user_id( + app_service = self.store.get_app_service_by_user_id( user.to_string() ) if app_service: diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py new file mode 100644 index 0000000000..b04aea0110 --- /dev/null +++ b/synapse/handlers/room_list.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.api.constants import ( + EventTypes, JoinRules, +) +from synapse.util.async import concurrently_execute +from synapse.util.caches.response_cache import ResponseCache + +from collections import namedtuple +from unpaddedbase64 import encode_base64, decode_base64 + +import logging +import msgpack + +logger = logging.getLogger(__name__) + +REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 + + +class RoomListHandler(BaseHandler): + def __init__(self, hs): + super(RoomListHandler, self).__init__(hs) + self.response_cache = ResponseCache(hs) + self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) + + def get_local_public_room_list(self, limit=None, since_token=None, + search_filter=None): + if search_filter: + # We explicitly don't bother caching searches. + return self._get_public_room_list(limit, since_token, search_filter) + + result = self.response_cache.get((limit, since_token)) + if not result: + result = self.response_cache.set( + (limit, since_token), + self._get_public_room_list(limit, since_token) + ) + return result + + @defer.inlineCallbacks + def _get_public_room_list(self, limit=None, since_token=None, + search_filter=None): + if since_token and since_token != "END": + since_token = RoomListNextBatch.from_token(since_token) + else: + since_token = None + + rooms_to_order_value = {} + rooms_to_num_joined = {} + rooms_to_latest_event_ids = {} + + newly_visible = [] + newly_unpublished = [] + if since_token: + stream_token = since_token.stream_ordering + current_public_id = yield self.store.get_current_public_room_stream_id() + public_room_stream_id = since_token.public_room_stream_id + newly_visible, newly_unpublished = yield self.store.get_public_room_changes( + public_room_stream_id, current_public_id + ) + else: + stream_token = yield self.store.get_room_max_stream_ordering() + public_room_stream_id = yield self.store.get_current_public_room_stream_id() + + room_ids = yield self.store.get_public_room_ids_at_stream_id( + public_room_stream_id + ) + + # We want to return rooms in a particular order: the number of joined + # users. We then arbitrarily use the room_id as a tie breaker. + + @defer.inlineCallbacks + def get_order_for_room(room_id): + latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) + if not latest_event_ids: + latest_event_ids = yield self.store.get_forward_extremeties_for_room( + room_id, stream_token + ) + rooms_to_latest_event_ids[room_id] = latest_event_ids + + if not latest_event_ids: + return + + joined_users = yield self.state_handler.get_current_user_in_room( + room_id, latest_event_ids, + ) + num_joined_users = len(joined_users) + rooms_to_num_joined[room_id] = num_joined_users + + if num_joined_users == 0: + return + + # We want larger rooms to be first, hence negating num_joined_users + rooms_to_order_value[room_id] = (-num_joined_users, room_id) + + yield concurrently_execute(get_order_for_room, room_ids, 10) + + sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) + sorted_rooms = [room_id for room_id, _ in sorted_entries] + + # `sorted_rooms` should now be a list of all public room ids that is + # stable across pagination. Therefore, we can use indices into this + # list as our pagination tokens. + + # Filter out rooms that we don't want to return + rooms_to_scan = [ + r for r in sorted_rooms + if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 + ] + + total_room_count = len(rooms_to_scan) + + if since_token: + # Filter out rooms we've already returned previously + # `since_token.current_limit` is the index of the last room we + # sent down, so we exclude it and everything before/after it. + if since_token.direction_is_forward: + rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:] + else: + rooms_to_scan = rooms_to_scan[:since_token.current_limit] + rooms_to_scan.reverse() + + # Actually generate the entries. _generate_room_entry will append to + # chunk but will stop if len(chunk) > limit + chunk = [] + if limit and not search_filter: + step = limit + 1 + for i in xrange(0, len(rooms_to_scan), step): + # We iterate here because the vast majority of cases we'll stop + # at first iteration, but occaisonally _generate_room_entry + # won't append to the chunk and so we need to loop again. + # We don't want to scan over the entire range either as that + # would potentially waste a lot of work. + yield concurrently_execute( + lambda r: self._generate_room_entry( + r, rooms_to_num_joined[r], + chunk, limit, search_filter + ), + rooms_to_scan[i:i + step], 10 + ) + if len(chunk) >= limit + 1: + break + else: + yield concurrently_execute( + lambda r: self._generate_room_entry( + r, rooms_to_num_joined[r], + chunk, limit, search_filter + ), + rooms_to_scan, 5 + ) + + chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) + + # Work out the new limit of the batch for pagination, or None if we + # know there are no more results that would be returned. + # i.e., [since_token.current_limit..new_limit] is the batch of rooms + # we've returned (or the reverse if we paginated backwards) + # We tried to pull out limit + 1 rooms above, so if we have <= limit + # then we know there are no more results to return + new_limit = None + if chunk and (not limit or len(chunk) > limit): + + if not since_token or since_token.direction_is_forward: + if limit: + chunk = chunk[:limit] + last_room_id = chunk[-1]["room_id"] + else: + if limit: + chunk = chunk[-limit:] + last_room_id = chunk[0]["room_id"] + + new_limit = sorted_rooms.index(last_room_id) + + results = { + "chunk": chunk, + "total_room_count_estimate": total_room_count, + } + + if since_token: + results["new_rooms"] = bool(newly_visible) + + if not since_token or since_token.direction_is_forward: + if new_limit is not None: + results["next_batch"] = RoomListNextBatch( + stream_ordering=stream_token, + public_room_stream_id=public_room_stream_id, + current_limit=new_limit, + direction_is_forward=True, + ).to_token() + + if since_token: + results["prev_batch"] = since_token.copy_and_replace( + direction_is_forward=False, + current_limit=since_token.current_limit + 1, + ).to_token() + else: + if new_limit is not None: + results["prev_batch"] = RoomListNextBatch( + stream_ordering=stream_token, + public_room_stream_id=public_room_stream_id, + current_limit=new_limit, + direction_is_forward=False, + ).to_token() + + if since_token: + results["next_batch"] = since_token.copy_and_replace( + direction_is_forward=True, + current_limit=since_token.current_limit - 1, + ).to_token() + + defer.returnValue(results) + + @defer.inlineCallbacks + def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, + search_filter): + if limit and len(chunk) > limit + 1: + # We've already got enough, so lets just drop it. + return + + result = { + "room_id": room_id, + "num_joined_members": num_joined_users, + } + + current_state_ids = yield self.state_handler.get_current_state_ids(room_id) + + event_map = yield self.store.get_events([ + event_id for key, event_id in current_state_ids.items() + if key[0] in ( + EventTypes.JoinRules, + EventTypes.Name, + EventTypes.Topic, + EventTypes.CanonicalAlias, + EventTypes.RoomHistoryVisibility, + EventTypes.GuestAccess, + "m.room.avatar", + ) + ]) + + current_state = { + (ev.type, ev.state_key): ev + for ev in event_map.values() + } + + # Double check that this is actually a public room. + join_rules_event = current_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + if join_rule and join_rule != JoinRules.PUBLIC: + defer.returnValue(None) + + aliases = yield self.store.get_aliases_for_room(room_id) + if aliases: + result["aliases"] = aliases + + name_event = yield current_state.get((EventTypes.Name, "")) + if name_event: + name = name_event.content.get("name", None) + if name: + result["name"] = name + + topic_event = current_state.get((EventTypes.Topic, "")) + if topic_event: + topic = topic_event.content.get("topic", None) + if topic: + result["topic"] = topic + + canonical_event = current_state.get((EventTypes.CanonicalAlias, "")) + if canonical_event: + canonical_alias = canonical_event.content.get("alias", None) + if canonical_alias: + result["canonical_alias"] = canonical_alias + + visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, "")) + visibility = None + if visibility_event: + visibility = visibility_event.content.get("history_visibility", None) + result["world_readable"] = visibility == "world_readable" + + guest_event = current_state.get((EventTypes.GuestAccess, "")) + guest = None + if guest_event: + guest = guest_event.content.get("guest_access", None) + result["guest_can_join"] = guest == "can_join" + + avatar_event = current_state.get(("m.room.avatar", "")) + if avatar_event: + avatar_url = avatar_event.content.get("url", None) + if avatar_url: + result["avatar_url"] = avatar_url + + if _matches_room_entry(result, search_filter): + chunk.append(result) + + @defer.inlineCallbacks + def get_remote_public_room_list(self, server_name, limit=None, since_token=None, + search_filter=None): + if search_filter: + # We currently don't support searching across federation, so we have + # to do it manually without pagination + limit = None + since_token = None + + res = yield self._get_remote_list_cached( + server_name, limit=limit, since_token=since_token, + ) + + if search_filter: + res = {"chunk": [ + entry + for entry in list(res.get("chunk", [])) + if _matches_room_entry(entry, search_filter) + ]} + + defer.returnValue(res) + + def _get_remote_list_cached(self, server_name, limit=None, since_token=None, + search_filter=None): + repl_layer = self.hs.get_replication_layer() + if search_filter: + # We can't cache when asking for search + return repl_layer.get_public_rooms( + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + ) + + result = self.remote_response_cache.get((server_name, limit, since_token)) + if not result: + result = self.remote_response_cache.set( + (server_name, limit, since_token), + repl_layer.get_public_rooms( + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + ) + ) + return result + + +class RoomListNextBatch(namedtuple("RoomListNextBatch", ( + "stream_ordering", # stream_ordering of the first public room list + "public_room_stream_id", # public room stream id for first public room list + "current_limit", # The number of previous rooms returned + "direction_is_forward", # Bool if this is a next_batch, false if prev_batch +))): + + KEY_DICT = { + "stream_ordering": "s", + "public_room_stream_id": "p", + "current_limit": "n", + "direction_is_forward": "d", + } + + REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} + + @classmethod + def from_token(cls, token): + return RoomListNextBatch(**{ + cls.REVERSE_KEY_DICT[key]: val + for key, val in msgpack.loads(decode_base64(token)).items() + }) + + def to_token(self): + return encode_base64(msgpack.dumps({ + self.KEY_DICT[key]: val + for key, val in self._asdict().items() + })) + + def copy_and_replace(self, **kwds): + return self._replace( + **kwds + ) + + +def _matches_room_entry(room_entry, search_filter): + if search_filter and search_filter.get("generic_search_term", None): + generic_search_term = search_filter["generic_search_term"].upper() + if generic_search_term in room_entry.get("name", "").upper(): + return True + elif generic_search_term in room_entry.get("topic", "").upper(): + return True + elif generic_search_term in room_entry.get("canonical_alias", "").upper(): + return True + else: + return True + + return False diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b5962f4f5a..1f910ff814 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -788,7 +788,7 @@ class SyncHandler(object): assert since_token - app_service = yield self.store.get_app_service_by_user_id(user_id) + app_service = self.store.get_app_service_by_user_id(user_id) if app_service: rooms = yield self.store.get_app_service_rooms(app_service) joined_room_ids = set(r.room_id for r in rooms) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0548b81c34..08313417b2 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -16,10 +16,9 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, -) +from synapse.util.logcontext import preserve_fn from synapse.util.metrics import Measure +from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID, get_domain_from_id import logging @@ -35,6 +34,13 @@ logger = logging.getLogger(__name__) RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 60 * 1000 + +# How often to resend typing across federation. +FEDERATION_PING_INTERVAL = 40 * 1000 + + class TypingHandler(object): def __init__(self, hs): self.store = hs.get_datastore() @@ -44,7 +50,10 @@ class TypingHandler(object): self.notifier = hs.get_notifier() self.state = hs.get_state_handler() + self.hs = hs + self.clock = hs.get_clock() + self.wheel_timer = WheelTimer(bucket_size=5000) self.federation = hs.get_replication_layer() @@ -53,7 +62,7 @@ class TypingHandler(object): hs.get_distributor().observe("user_left_room", self.user_left_room) self._member_typing_until = {} # clock time we expect to stop - self._member_typing_timer = {} # deferreds to manage theabove + self._member_last_federation_poke = {} # map room IDs to serial numbers self._room_serials = {} @@ -61,12 +70,41 @@ class TypingHandler(object): # map room IDs to sets of users currently typing self._room_typing = {} - def tearDown(self): - """Cancels all the pending timers. - Normally this shouldn't be needed, but it's required from unit tests - to avoid a "Reactor was unclean" warning.""" - for t in self._member_typing_timer.values(): - self.clock.cancel_call_later(t) + self.clock.looping_call( + self._handle_timeouts, + 5000, + ) + + def _handle_timeouts(self): + logger.info("Checking for typing timeouts") + + now = self.clock.time_msec() + + members = set(self.wheel_timer.fetch(now)) + + for member in members: + if not self.is_typing(member): + # Nothing to do if they're no longer typing + continue + + until = self._member_typing_until.get(member, None) + if not until or until < now: + logger.info("Timing out typing for: %s", member.user_id) + preserve_fn(self._stopped_typing)(member) + continue + + # Check if we need to resend a keep alive over federation for this + # user. + if self.hs.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now: + preserve_fn(self._push_remote)( + member=member, + typing=True + ) + + def is_typing(self, member): + return member.user_id in self._room_typing.get(member.room_id, []) @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): @@ -85,23 +123,17 @@ class TypingHandler(object): "%s has started typing in %s", target_user_id, room_id ) - until = self.clock.time_msec() + timeout member = RoomMember(room_id=room_id, user_id=target_user_id) - was_present = member in self._member_typing_until - - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) + was_present = member.user_id in self._room_typing.get(room_id, set()) - def _cb(): - logger.debug( - "%s has timed out in %s", target_user.to_string(), room_id - ) - self._stopped_typing(member) + now = self.clock.time_msec() + self._member_typing_until[member] = now + timeout - self._member_typing_until[member] = until - self._member_typing_timer[member] = self.clock.call_later( - timeout / 1000.0, _cb + self.wheel_timer.insert( + now=now, + obj=member, + then=now + timeout, ) if was_present: @@ -109,8 +141,7 @@ class TypingHandler(object): defer.returnValue(None) yield self._push_update( - room_id=room_id, - user_id=target_user_id, + member=member, typing=True, ) @@ -133,10 +164,6 @@ class TypingHandler(object): member = RoomMember(room_id=room_id, user_id=target_user_id) - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) - del self._member_typing_timer[member] - yield self._stopped_typing(member) @defer.inlineCallbacks @@ -148,57 +175,61 @@ class TypingHandler(object): @defer.inlineCallbacks def _stopped_typing(self, member): - if member not in self._member_typing_until: + if member.user_id not in self._room_typing.get(member.room_id, set()): # No point defer.returnValue(None) + self._member_typing_until.pop(member, None) + self._member_last_federation_poke.pop(member, None) + yield self._push_update( - room_id=member.room_id, - user_id=member.user_id, + member=member, typing=False, ) - del self._member_typing_until[member] - - if member in self._member_typing_timer: - # Don't cancel it - either it already expired, or the real - # stopped_typing() will cancel it - del self._member_typing_timer[member] + @defer.inlineCallbacks + def _push_update(self, member, typing): + if self.hs.is_mine_id(member.user_id): + # Only send updates for changes to our own users. + yield self._push_remote(member, typing) + + self._push_update_local( + member=member, + typing=typing + ) @defer.inlineCallbacks - def _push_update(self, room_id, user_id, typing): - users = yield self.state.get_current_user_in_room(room_id) - domains = set(get_domain_from_id(u) for u in users) + def _push_remote(self, member, typing): + users = yield self.state.get_current_user_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, + ) - deferreds = [] - for domain in domains: - if domain == self.server_name: - preserve_fn(self._push_update_local)( - room_id=room_id, - user_id=user_id, - typing=typing - ) - else: - deferreds.append(preserve_fn(self.federation.send_edu)( + for domain in set(get_domain_from_id(u) for u in users): + if domain != self.server_name: + self.federation.send_edu( destination=domain, edu_type="m.typing", content={ - "room_id": room_id, - "user_id": user_id, + "room_id": member.room_id, + "user_id": member.user_id, "typing": typing, }, - key=(room_id, user_id), - )) - - yield preserve_context_over_deferred( - defer.DeferredList(deferreds, consumeErrors=True) - ) + key=member, + ) @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] user_id = content["user_id"] + member = RoomMember(user_id=user_id, room_id=room_id) + # Check that the string is a valid user id user = UserID.from_string(user_id) @@ -213,26 +244,32 @@ class TypingHandler(object): domains = set(get_domain_from_id(u) for u in users) if self.server_name in domains: + logger.info("Got typing update from %s: %r", user_id, content) + now = self.clock.time_msec() + self._member_typing_until[member] = now + FEDERATION_TIMEOUT + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_TIMEOUT, + ) self._push_update_local( - room_id=room_id, - user_id=user_id, + member=member, typing=content["typing"] ) - def _push_update_local(self, room_id, user_id, typing): - room_set = self._room_typing.setdefault(room_id, set()) + def _push_update_local(self, member, typing): + room_set = self._room_typing.setdefault(member.room_id, set()) if typing: - room_set.add(user_id) + room_set.add(member.user_id) else: - room_set.discard(user_id) + room_set.discard(member.user_id) self._latest_room_serial += 1 - self._room_serials[room_id] = self._latest_room_serial + self._room_serials[member.room_id] = self._latest_room_serial - with PreserveLoggingContext(): - self.notifier.on_new_event( - "typing_key", self._latest_room_serial, rooms=[room_id] - ) + self.notifier.on_new_event( + "typing_key", self._latest_room_serial, rooms=[member.room_id] + ) def get_all_typing_updates(self, last_id, current_id): # TODO: Work out a way to do this without scanning the entire state. |