diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/__init__.py | 33 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 19 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 15 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 415 | ||||
-rw-r--r-- | synapse/handlers/device.py | 181 | ||||
-rw-r--r-- | synapse/handlers/directory.py | 3 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 139 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 189 | ||||
-rw-r--r-- | synapse/handlers/identity.py | 41 | ||||
-rw-r--r-- | synapse/handlers/message.py | 98 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 172 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 19 | ||||
-rw-r--r-- | synapse/handlers/register.py | 78 | ||||
-rw-r--r-- | synapse/handlers/room.py | 93 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 20 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 1154 | ||||
-rw-r--r-- | synapse/handlers/typing.py | 67 |
17 files changed, 1792 insertions, 944 deletions
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 9442ae6f1d..1a50a2ec98 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice.scheduler import AppServiceScheduler -from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomListHandler, RoomContextHandler, + RoomCreationHandler, RoomContextHandler, ) from .room_member import RoomMemberHandler from .message import MessageHandler @@ -26,8 +24,6 @@ from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler -from .appservice import ApplicationServicesHandler -from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler from .search import SearchHandler @@ -35,10 +31,21 @@ from .search import SearchHandler class Handlers(object): - """ A collection of all the event handlers. + """ Deprecated. A collection of handlers. - There's no need to lazily create these; we'll just make them all eagerly - at construction time. + At some point most of the classes whose name ended "Handler" were + accessed through this class. + + However this makes it painful to unit test the handlers and to run cut + down versions of synapse that only use specific handlers because using a + single handler required creating all of the handlers. So some of the + handlers have been lifted out of the Handlers object and are now accessed + directly through the homeserver object itself. + + Any new handlers should follow the new pattern of being accessed through + the homeserver object and should not be added to the Handlers object. + + The remaining handlers should be moved out of the handlers object. """ def __init__(self, hs): @@ -50,19 +57,9 @@ class Handlers(object): self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) - self.room_list_handler = RoomListHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.receipts_handler = ReceiptsHandler(hs) - asapi = ApplicationServiceApi(hs) - self.appservice_handler = ApplicationServicesHandler( - hs, asapi, AppServiceScheduler( - clock=hs.get_clock(), - store=hs.get_datastore(), - as_api=asapi - ) - ) - self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c904c6c500..11081a0cd5 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import LimitExceededError +import synapse.types from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, Requester - - -import logging +from synapse.api.errors import LimitExceededError +from synapse.types import UserID logger = logging.getLogger(__name__) @@ -31,11 +31,15 @@ class BaseHandler(object): Common base class for the event handlers. Attributes: - store (synapse.storage.events.StateStore): + store (synapse.storage.DataStore): state_handler (synapse.state.StateHandler): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -120,7 +124,8 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = Requester(target_user, "", True) + requester = synapse.types.create_requester( + target_user, is_guest=True) handler = self.hs.get_handlers().room_member_handler yield handler.update_membership( requester, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 75fc74c797..051ccdb380 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.appservice import ApplicationService -from synapse.types import UserID import logging @@ -35,16 +34,13 @@ def log_failure(failure): ) -# NB: Purposefully not inheriting BaseHandler since that contains way too much -# setup code which this handler does not need or use. This makes testing a lot -# easier. class ApplicationServicesHandler(object): - def __init__(self, hs, appservice_api, appservice_scheduler): + def __init__(self, hs): self.store = hs.get_datastore() - self.hs = hs - self.appservice_api = appservice_api - self.scheduler = appservice_scheduler + self.is_mine_id = hs.is_mine_id + self.appservice_api = hs.get_application_service_api() + self.scheduler = hs.get_application_service_scheduler() self.started_scheduler = False @defer.inlineCallbacks @@ -169,8 +165,7 @@ class ApplicationServicesHandler(object): @defer.inlineCallbacks def _is_unknown_user(self, user_id): - user = UserID.from_string(user_id) - if not self.hs.is_mine(user): + if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. defer.returnValue(False) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 68d0d78fc6..82998a81ce 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,8 +18,9 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType from synapse.types import UserID -from synapse.api.errors import AuthError, LoginError, Codes +from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor +from synapse.config.ldap import LDAPMode from twisted.web.client import PartialDownloadError @@ -28,6 +29,12 @@ import bcrypt import pymacaroons import simplejson +try: + import ldap3 +except ImportError: + ldap3 = None + pass + import synapse.util.stringutils as stringutils @@ -38,6 +45,10 @@ class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(AuthHandler, self).__init__(hs) self.checkers = { LoginType.PASSWORD: self._check_password_auth, @@ -50,19 +61,23 @@ class AuthHandler(BaseHandler): self.INVALID_TOKEN_HTTP_STATUS = 401 self.ldap_enabled = hs.config.ldap_enabled - self.ldap_server = hs.config.ldap_server - self.ldap_port = hs.config.ldap_port - self.ldap_tls = hs.config.ldap_tls - self.ldap_search_base = hs.config.ldap_search_base - self.ldap_search_property = hs.config.ldap_search_property - self.ldap_email_property = hs.config.ldap_email_property - self.ldap_full_name_property = hs.config.ldap_full_name_property - - if self.ldap_enabled is True: - import ldap - logger.info("Import ldap version: %s", ldap.__version__) + if self.ldap_enabled: + if not ldap3: + raise RuntimeError( + 'Missing ldap3 library. This is required for LDAP Authentication.' + ) + self.ldap_mode = hs.config.ldap_mode + self.ldap_uri = hs.config.ldap_uri + self.ldap_start_tls = hs.config.ldap_start_tls + self.ldap_base = hs.config.ldap_base + self.ldap_filter = hs.config.ldap_filter + self.ldap_attributes = hs.config.ldap_attributes + if self.ldap_mode == LDAPMode.SEARCH: + self.ldap_bind_dn = hs.config.ldap_bind_dn + self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -220,7 +235,6 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) @@ -230,11 +244,7 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - if not (yield self._check_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - defer.returnValue(user_id) + return self._check_password(user_id, password) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -270,8 +280,17 @@ class AuthHandler(BaseHandler): data = pde.response resp_body = simplejson.loads(data) - if 'success' in resp_body and resp_body['success']: - defer.returnValue(True) + if 'success' in resp_body: + # Note that we do NOT check the hostname here: we explicitly + # intend the CAPTCHA to be presented by whatever client the + # user is using, we just care that they have completed a CAPTCHA. + logger.info( + "%s reCAPTCHA from hostname %s", + "Successful" if resp_body['success'] else "Failed", + resp_body.get('hostname') + ) + if resp_body['success']: + defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) @defer.inlineCallbacks @@ -338,67 +357,84 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def login_with_password(self, user_id, password): + def validate_password_login(self, user_id, password): """ Authenticates the user with their username and password. Used only by the v1 login API. Args: - user_id (str): User ID + user_id (str): complete @user:id password (str): Password Returns: - A tuple of: - The user's ID. - The access token for the user's session. - The refresh token for the user's session. + defer.Deferred: (str) canonical user id Raises: - StoreError if there was a problem storing the token. + StoreError if there was a problem accessing the database LoginError if there was an authentication problem. """ - - if not (yield self._check_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ Gets login tuple for the user with the given user ID. + + Creates a new access/refresh token for the user. + The user is assumed to have been authenticated by some other - machanism (e.g. CAS) + machanism (e.g. CAS), and the user_id converted to the canonical case. + + The device will be recorded in the table if it is not there already. Args: - user_id (str): User ID + user_id (str): canonical User ID + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: A tuple of: - The user's ID. The access token for the user's session. The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + refresh_token = yield self.issue_refresh_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks - def does_user_exist(self, user_id): + def check_user_exists(self, user_id): + """ + Checks to see if a user with the given id exists. Will check case + insensitively, but return None if there are multiple inexact matches. + + Args: + (str) user_id: complete @user:id + + Returns: + defer.Deferred: (str) canonical_user_id, or None if zero or + multiple matches + """ try: - yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(True) + res = yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(res[0]) except LoginError: - defer.returnValue(False) + defer.returnValue(None) @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): @@ -428,84 +464,232 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_password(self, user_id, password): - """ + """Authenticate a user against the LDAP and local databases. + + user_id is checked case insensitively against the local database, but + will throw if there are multiple inexact matches. + + Args: + user_id (str): complete @user:id Returns: - True if the user_id successfully authenticated + (str) the canonical_user_id + Raises: + LoginError if the password was incorrect """ valid_ldap = yield self._check_ldap_password(user_id, password) if valid_ldap: - defer.returnValue(True) - - valid_local_password = yield self._check_local_password(user_id, password) - if valid_local_password: - defer.returnValue(True) + defer.returnValue(user_id) - defer.returnValue(False) + result = yield self._check_local_password(user_id, password) + defer.returnValue(result) @defer.inlineCallbacks def _check_local_password(self, user_id, password): - try: - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(self.validate_hash(password, password_hash)) - except LoginError: - defer.returnValue(False) + """Authenticate a user against the local password database. + + user_id is checked case insensitively, but will throw if there are + multiple inexact matches. + + Args: + user_id (str): complete @user:id + Returns: + (str) the canonical_user_id + Raises: + LoginError if the password was incorrect + """ + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + result = self.validate_hash(password, password_hash) + if not result: + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): - if not self.ldap_enabled: - logger.debug("LDAP not configured") + """ Attempt to authenticate a user against an LDAP Server + and register an account if none exists. + + Returns: + True if authentication against LDAP was successful + """ + + if not ldap3 or not self.ldap_enabled: defer.returnValue(False) - import ldap + if self.ldap_mode not in LDAPMode.LIST: + raise RuntimeError( + 'Invalid ldap mode specified: {mode}'.format( + mode=self.ldap_mode + ) + ) - logger.info("Authenticating %s with LDAP" % user_id) try: - ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) - logger.debug("Connecting LDAP server at %s" % ldap_url) - l = ldap.initialize(ldap_url) - if self.ldap_tls: - logger.debug("Initiating TLS") - self._connection.start_tls_s() - - local_name = UserID.from_string(user_id).localpart - - dn = "%s=%s, %s" % ( - self.ldap_search_property, - local_name, - self.ldap_search_base) - logger.debug("DN for LDAP authentication: %s" % dn) - - l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) - - if not (yield self.does_user_exist(user_id)): - handler = self.hs.get_handlers().registration_handler - user_id, access_token = ( - yield handler.register(localpart=local_name) + server = ldap3.Server(self.ldap_uri) + logger.debug( + "Attempting ldap connection with %s", + self.ldap_uri + ) + + localpart = UserID.from_string(user_id).localpart + if self.ldap_mode == LDAPMode.SIMPLE: + # bind with the the local users ldap credentials + bind_dn = "{prop}={value},{base}".format( + prop=self.ldap_attributes['uid'], + value=localpart, + base=self.ldap_base + ) + conn = ldap3.Connection(server, bind_dn, password) + logger.debug( + "Established ldap connection in simple mode: %s", + conn ) + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded ldap connection in simple mode through StartTLS: %s", + conn + ) + + conn.bind() + + elif self.ldap_mode == LDAPMode.SEARCH: + # connect with preconfigured credentials and search for local user + conn = ldap3.Connection( + server, + self.ldap_bind_dn, + self.ldap_bind_password + ) + logger.debug( + "Established ldap connection in search mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded ldap connection in search mode through StartTLS: %s", + conn + ) + + conn.bind() + + # find matching dn + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + if self.ldap_filter: + query = "(&{query}{filter})".format( + query=query, + filter=self.ldap_filter + ) + logger.debug("ldap search filter: %s", query) + result = conn.search(self.ldap_base, query) + + if result and len(conn.response) == 1: + # found exactly one result + user_dn = conn.response[0]['dn'] + logger.debug('ldap search found dn: %s', user_dn) + + # unbind and reconnect, rebind with found dn + conn.unbind() + conn = ldap3.Connection( + server, + user_dn, + password, + auto_bind=True + ) + else: + # found 0 or > 1 results, abort! + logger.warn( + "ldap search returned unexpected (%d!=1) amount of results", + len(conn.response) + ) + defer.returnValue(False) + + logger.info( + "User authenticated against ldap server: %s", + conn + ) + + # check for existing account, if none exists, create one + if not (yield self.check_user_exists(user_id)): + # query user metadata for account creation + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + + if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: + query = "(&{filter}{user_filter})".format( + filter=query, + user_filter=self.ldap_filter + ) + logger.debug("ldap registration filter: %s", query) + + result = conn.search( + search_base=self.ldap_base, + search_filter=query, + attributes=[ + self.ldap_attributes['name'], + self.ldap_attributes['mail'] + ] + ) + + if len(conn.response) == 1: + attrs = conn.response[0]['attributes'] + mail = attrs[self.ldap_attributes['mail']][0] + name = attrs[self.ldap_attributes['name']][0] + + # create account + registration_handler = self.hs.get_handlers().registration_handler + user_id, access_token = ( + yield registration_handler.register(localpart=localpart) + ) + + # TODO: bind email, set displayname with data from ldap directory + + logger.info( + "ldap registration successful: %d: %s (%s, %)", + user_id, + localpart, + name, + mail + ) + else: + logger.warn( + "ldap registration failed: unexpected (%d!=1) amount of results", + len(result) + ) + defer.returnValue(False) + defer.returnValue(True) - except ldap.LDAPError, e: - logger.warn("LDAP error: %s", e) + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during ldap authentication: %s", e) defer.returnValue(False) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) @defer.inlineCallbacks - def issue_refresh_token(self, user_id): + def issue_refresh_token(self, user_id, device_id=None): refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) + yield self.store.add_refresh_token_to_user(user_id, refresh_token, + device_id) defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token(self, user_id, extra_caveats=None, + duration_in_ms=(60 * 60 * 1000)): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) + expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) @@ -529,14 +713,20 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() + def generate_delete_pusher_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = delete_pusher") + return macaroon.serialize() + def validate_short_term_login_token_and_get_user_id(self, login_token): + auth_api = self.hs.get_auth() try: macaroon = pymacaroons.Macaroon.deserialize(login_token) - auth_api = self.hs.get_auth() - auth_api.validate_macaroon(macaroon, "login", True) - return self.get_user_from_macaroon(macaroon) - except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): - raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) + user_id = auth_api.get_user_id_from_macaroon(macaroon) + auth_api.validate_macaroon(macaroon, "login", True, user_id) + return user_id + except Exception: + raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( @@ -547,23 +737,18 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - def get_user_from_macaroon(self, macaroon): - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] - raise AuthError( - self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", - errcode=Codes.UNKNOWN_TOKEN - ) - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) except_access_token_ids = [requester.access_token_id] if requester else [] - yield self.store.user_set_password_hash(user_id, password_hash) + try: + yield self.store.user_set_password_hash(user_id, password_hash) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) + raise e yield self.store.user_delete_access_tokens( user_id, except_access_token_ids ) @@ -603,7 +788,8 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) + return bcrypt.hashpw(password + self.hs.config.password_pepper, + bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -616,6 +802,7 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash (bool). """ if stored_hash: - return bcrypt.hashpw(password, stored_hash) == stored_hash + return bcrypt.hashpw(password + self.hs.config.password_pepper, + stored_hash.encode('utf-8')) == stored_hash else: return False diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..8d630c6b1a --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api import errors +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name=None): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string_with_symbols(16) + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except errors.StoreError: + attempts += 1 + + raise errors.StoreError(500, "Couldn't generate a device ID.") + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """ + Retrieve the given user's devices + + Args: + user_id (str): + Returns: + defer.Deferred: list[dict[str, X]]: info on each device + """ + + device_map = yield self.store.get_devices_by_user(user_id) + + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id) for device_id in device_map.keys()) + ) + + devices = device_map.values() + for device in devices: + _update_device_from_client_ips(device, ips) + + defer.returnValue(devices) + + @defer.inlineCallbacks + def get_device(self, user_id, device_id): + """ Retrieve the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: dict[str, X]: info on the device + Raises: + errors.NotFoundError: if the device was not found + """ + try: + device = yield self.store.get_device(user_id, device_id) + except errors.StoreError: + raise errors.NotFoundError + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id),) + ) + _update_device_from_client_ips(device, ips) + defer.returnValue(device) + + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens( + user_id, device_id=device_id, + delete_refresh_tokens=True, + ) + + yield self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=device_id + ) + + @defer.inlineCallbacks + def update_device(self, user_id, device_id, content): + """ Update the given device + + Args: + user_id (str): + device_id (str): + content (dict): body of update request + + Returns: + defer.Deferred: + """ + + try: + yield self.store.update_device( + user_id, + device_id, + new_display_name=content.get("display_name") + ) + except errors.StoreError, e: + if e.code == 404: + raise errors.NotFoundError() + else: + raise + + +def _update_device_from_client_ips(device, client_ips): + ip = client_ips.get((device["user_id"], device["device_id"]), {}) + device.update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8eeb225811..4bea7f2b19 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -33,6 +33,7 @@ class DirectoryHandler(BaseHandler): super(DirectoryHandler, self).__init__(hs) self.state = hs.get_state_handler() + self.appservice_handler = hs.get_application_service_handler() self.federation = hs.get_replication_layer() self.federation.register_query_handler( @@ -281,7 +282,7 @@ class DirectoryHandler(BaseHandler): ) if not result: # Query AS to see if it exists - as_handler = self.hs.get_handlers().appservice_handler + as_handler = self.appservice_handler result = yield as_handler.query_room_alias_exists(room_alias) defer.returnValue(result) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py new file mode 100644 index 0000000000..2c7bfd91ed --- /dev/null +++ b/synapse/handlers/e2e_keys.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import logging + +from twisted.internet import defer + +from synapse.api import errors +import synapse.types + +logger = logging.getLogger(__name__) + + +class E2eKeysHandler(object): + def __init__(self, hs): + self.store = hs.get_datastore() + self.federation = hs.get_replication_layer() + self.is_mine_id = hs.is_mine_id + self.server_name = hs.hostname + + # doesn't really work as part of the generic query API, because the + # query request requires an object POST, but we abuse the + # "query handler" interface. + self.federation.register_query_handler( + "client_keys", self.on_federation_query_client_keys + ) + + @defer.inlineCallbacks + def query_devices(self, query_body): + """ Handle a device key query from a client + + { + "device_keys": { + "<user_id>": ["<device_id>"] + } + } + -> + { + "device_keys": { + "<user_id>": { + "<device_id>": { + ... + } + } + } + } + """ + device_keys_query = query_body.get("device_keys", {}) + + # separate users by domain. + # make a map from domain to user_id to device_ids + queries_by_domain = collections.defaultdict(dict) + for user_id, device_ids in device_keys_query.items(): + user = synapse.types.UserID.from_string(user_id) + queries_by_domain[user.domain][user_id] = device_ids + + # do the queries + # TODO: do these in parallel + results = {} + for destination, destination_query in queries_by_domain.items(): + if destination == self.server_name: + res = yield self.query_local_devices(destination_query) + else: + res = yield self.federation.query_client_keys( + destination, {"device_keys": destination_query} + ) + res = res["device_keys"] + for user_id, keys in res.items(): + if user_id in destination_query: + results[user_id] = keys + + defer.returnValue((200, {"device_keys": results})) + + @defer.inlineCallbacks + def query_local_devices(self, query): + """Get E2E device keys for local users + + Args: + query (dict[string, list[string]|None): map from user_id to a list + of devices to query (None for all devices) + + Returns: + defer.Deferred: (resolves to dict[string, dict[string, dict]]): + map from user_id -> device_id -> device details + """ + local_query = [] + + result_dict = {} + for user_id, device_ids in query.items(): + if not self.is_mine_id(user_id): + logger.warning("Request for keys for non-local user %s", + user_id) + raise errors.SynapseError(400, "Not a user here") + + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + + # make sure that each queried user appears in the result dict + result_dict[user_id] = {} + + results = yield self.store.get_e2e_device_keys(local_query) + + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section + for user_id, device_keys in results.items(): + for device_id, device_info in device_keys.items(): + r = json.loads(device_info["key_json"]) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + result_dict[user_id][device_id] = r + + defer.returnValue(result_dict) + + @defer.inlineCallbacks + def on_federation_query_client_keys(self, query_body): + """ Handle a device key query from a federated server + """ + device_keys_query = query_body.get("device_keys", {}) + res = yield self.query_local_devices(device_keys_query) + defer.returnValue({"device_keys": res}) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 648a505e65..ff6bb475b5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -66,10 +66,6 @@ class FederationHandler(BaseHandler): self.hs = hs - self.distributor.observe("user_joined_room", self.user_joined_room) - - self.waiting_for_join_list = {} - self.store = hs.get_datastore() self.replication_layer = hs.get_replication_layer() self.state_handler = hs.get_state_handler() @@ -128,7 +124,7 @@ class FederationHandler(BaseHandler): try: event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) except AuthError as e: raise FederationError( @@ -253,7 +249,7 @@ class FederationHandler(BaseHandler): if ev.type != EventTypes.Member: continue try: - domain = UserID.from_string(ev.state_key).domain + domain = get_domain_from_id(ev.state_key) except: continue @@ -339,29 +335,58 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - seen_events = yield self.store.have_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - all_events = events + state_events.values() + auth_events.values() required_auth = set( - a_id for event in all_events for a_id, _ in event.auth_events + a_id + for event in events + state_events.values() + auth_events.values() + for a_id, _ in event.auth_events ) - + auth_events.update({ + e_id: event_map[e_id] for e_id in required_auth if e_id in event_map + }) missing_auth = required_auth - set(auth_events) - results = yield defer.gatherResults( - [ - self.replication_layer.get_pdu( - [dest], - event_id, - outlier=True, - timeout=10000, + failed_to_fetch = set() + + # Try and fetch any missing auth events from both DB and remote servers. + # We repeatedly do this until we stop finding new auth events. + while missing_auth - failed_to_fetch: + logger.info("Missing auth for backfill: %r", missing_auth) + ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) + auth_events.update(ret_events) + + required_auth.update( + a_id for event in ret_events.values() for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + if missing_auth - failed_to_fetch: + logger.info( + "Fetching missing auth for backfill: %r", + missing_auth - failed_to_fetch ) - for event_id in missing_auth - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results}) + + results = yield defer.gatherResults( + [ + self.replication_layer.get_pdu( + [dest], + event_id, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True + ).addErrback(unwrapFirstError) + auth_events.update({a.event_id: a for a in results}) + required_auth.update( + a_id for event in results for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + failed_to_fetch = missing_auth - set(auth_events) + + seen_events = yield self.store.have_events( + set(auth_events.keys()) | set(state_events.keys()) + ) ev_infos = [] for a in auth_events.values(): @@ -374,6 +399,7 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in a.auth_events + if a_id in auth_events } }) @@ -385,6 +411,7 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in event_map[e_id].auth_events + if a_id in auth_events } }) @@ -403,7 +430,7 @@ class FederationHandler(BaseHandler): # previous to work out the state. # TODO: We can probably do something more clever here. yield self._handle_new_event( - dest, event + dest, event, backfilled=True, ) defer.returnValue(events) @@ -639,7 +666,7 @@ class FederationHandler(BaseHandler): pass event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) with PreserveLoggingContext(): @@ -690,7 +717,9 @@ class FederationHandler(BaseHandler): logger.warn("Failed to create join %r because %s", event, e) raise e - self.auth.check(event, auth_events=context.current_state) + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_join_request` + self.auth.check(event, auth_events=context.current_state, do_sig_check=False) defer.returnValue(event) @@ -920,7 +949,9 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, auth_events=context.current_state) + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_leave_request` + self.auth.check(event, auth_events=context.current_state, do_sig_check=False) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) raise e @@ -989,14 +1020,9 @@ class FederationHandler(BaseHandler): defer.returnValue(None) @defer.inlineCallbacks - def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): + def get_state_for_pdu(self, room_id, event_id): yield run_on_reactor() - if do_auth: - in_room = yield self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - state_groups = yield self.store.get_state_groups( room_id, [event_id] ) @@ -1020,13 +1046,16 @@ class FederationHandler(BaseHandler): res = results.values() for event in res: - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) defer.returnValue(res) else: @@ -1064,16 +1093,17 @@ class FederationHandler(BaseHandler): ) if event: - # FIXME: This is a temporary work around where we occasionally - # return events slightly differently than when they were - # originally signed - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + if self.hs.is_mine_id(event.event_id): + # FIXME: This is a temporary work around where we occasionally + # return events slightly differently than when they were + # originally signed + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) if do_auth: in_room = yield self.auth.check_host_in_room( @@ -1083,6 +1113,12 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") + events = yield self._filter_events_for_server( + origin, event.room_id, [event] + ) + + event = events[0] + defer.returnValue(event) else: defer.returnValue(None) @@ -1091,15 +1127,6 @@ class FederationHandler(BaseHandler): def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) - @log_function - def user_joined_room(self, user, room_id): - waiters = self.waiting_for_join_list.get( - (user.to_string(), room_id), - [] - ) - while waiters: - waiters.pop().callback(None) - @defer.inlineCallbacks @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None, @@ -1122,11 +1149,12 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - preserve_fn(self.hs.get_pusherpool().on_new_notifications)( - event_stream_id, max_stream_id - ) + if not backfilled: + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id + ) defer.returnValue((context, event_stream_id, max_stream_id)) @@ -1158,11 +1186,19 @@ class FederationHandler(BaseHandler): ) @defer.inlineCallbacks - def _persist_auth_tree(self, auth_events, state, event): + def _persist_auth_tree(self, origin, auth_events, state, event): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event seperately. + Will attempt to fetch missing auth events. + + Args: + origin (str): Where the events came from + auth_events (list) + state (list) + event (Event) + Returns: 2-tuple of (event_stream_id, max_stream_id) from the persist_event call for `event` @@ -1175,7 +1211,7 @@ class FederationHandler(BaseHandler): event_map = { e.event_id: e - for e in auth_events + for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1184,10 +1220,29 @@ class FederationHandler(BaseHandler): create_event = e break + missing_auth_events = set() + for e in itertools.chain(auth_events, state, [event]): + for e_id, _ in e.auth_events: + if e_id not in event_map: + missing_auth_events.add(e_id) + + for e_id in missing_auth_events: + m_ev = yield self.replication_layer.get_pdu( + [origin], + e_id, + outlier=True, + timeout=10000, + ) + if m_ev and m_ev.event_id == e_id: + event_map[e_id] = m_ev + else: + logger.info("Failed to find auth event %r", e_id) + for e in itertools.chain(auth_events, state, [event]): auth_for_e = { (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] for e_id, _ in e.auth_events + if e_id in event_map } if create_event: auth_for_e[(EventTypes.Create, "")] = create_event @@ -1421,7 +1476,7 @@ class FederationHandler(BaseHandler): local_view = dict(auth_events) remote_view = dict(auth_events) remote_view.update({ - (d.type, d.state_key): d for d in different_events + (d.type, d.state_key): d for d in different_events if d }) new_state, prev_state = self.state_handler.resolve_events( diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 656ce124f9..559e5d5a71 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,7 +21,7 @@ from synapse.api.errors import ( ) from ._base import BaseHandler from synapse.util.async import run_on_reactor -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes import json import logging @@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler): hs.config.use_insecure_ssl_client_just_for_testing_do_not_use ) + def _should_trust_id_server(self, id_server): + if id_server not in self.trusted_id_servers: + if self.trust_any_id_server_just_for_testing_do_not_use: + logger.warn( + "Trusting untrustworthy ID server %r even though it isn't" + " in the trusted id list for testing because" + " 'use_insecure_ssl_client_just_for_testing_do_not_use'" + " is set in the config", + id_server, + ) + else: + return False + return True + @defer.inlineCallbacks def threepid_from_creds(self, creds): yield run_on_reactor() @@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler): else: raise SynapseError(400, "No client_secret in creds") - if id_server not in self.trusted_id_servers: - if self.trust_any_id_server_just_for_testing_do_not_use: - logger.warn( - "Trusting untrustworthy ID server %r even though it isn't" - " in the trusted id list for testing because" - " 'use_insecure_ssl_client_just_for_testing_do_not_use'" - " is set in the config", - id_server, - ) - else: - logger.warn('%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', id_server) - defer.returnValue(None) + if not self._should_trust_id_server(id_server): + logger.warn( + '%s is not a trusted ID server: rejecting 3pid ' + + 'credentials', id_server + ) + defer.returnValue(None) data = {} try: @@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): yield run_on_reactor() + if not self._should_trust_id_server(id_server): + raise SynapseError( + 400, "Untrusted ID server '%s'" % id_server, + Codes.SERVER_NOT_TRUSTED + ) + params = { 'email': email, 'client_secret': client_secret, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c41dafdef5..dc76d34a52 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -26,9 +26,9 @@ from synapse.types import ( UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id ) from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute +from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import preserve_fn from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -50,9 +50,23 @@ class MessageHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() + self.pagination_lock = ReadWriteLock() + + @defer.inlineCallbacks + def purge_history(self, room_id, event_id): + event = yield self.store.get_event(event_id) + + if event.room_id != room_id: + raise SynapseError(400, "Event is for wrong room.") + + depth = event.depth + + with (yield self.pagination_lock.write(room_id)): + yield self.store.delete_old_state(room_id, depth) + @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, event_filter=None): """Get messages in a room. Args: @@ -61,11 +75,11 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + event_filter (Filter): Filter to apply to results or None Returns: dict: Pagination API results """ user_id = requester.user.to_string() - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: room_token = pagin_config.from_token.room_key @@ -85,42 +99,48 @@ class MessageHandler(BaseHandler): source_config = pagin_config.get_source_config("room") - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) + with (yield self.pagination_lock.read(room_id)): + membership, member_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id + ) - if source_config.direction == 'b': - # if we're going backwards, we might need to backfill. This - # requires that we have a topo token. - if room_token.topological: - max_topo = room_token.topological - else: - max_topo = yield self.store.get_max_topological_token_for_stream_and_room( - room_id, room_token.stream - ) + if source_config.direction == 'b': + # if we're going backwards, we might need to backfill. This + # requires that we have a topo token. + if room_token.topological: + max_topo = room_token.topological + else: + max_topo = yield self.store.get_max_topological_token( + room_id, room_token.stream + ) + + if membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room, to save the effort of loading from the + # database. + leave_token = yield self.store.get_topological_token_for_event( + member_event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < max_topo: + source_config.from_key = str(leave_token) - if membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room, to save the effort of loading from the - # database. - leave_token = yield self.store.get_topological_token_for_event( - member_event_id + yield self.hs.get_handlers().federation_handler.maybe_backfill( + room_id, max_topo ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < max_topo: - source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, max_topo + events, next_key = yield self.store.paginate_room_events( + room_id=room_id, + from_key=source_config.from_key, + to_key=source_config.to_key, + direction=source_config.direction, + limit=source_config.limit, + event_filter=event_filter, ) - events, next_key = yield data_source.get_pagination_rows( - requester.user, source_config, room_id - ) - - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) + next_token = pagin_config.from_token.copy_and_replace( + "room_key", next_key + ) if not events: defer.returnValue({ @@ -129,6 +149,9 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) + if event_filter: + events = event_filter.filter(events) + events = yield filter_events_for_client( self.store, user_id, @@ -908,13 +931,16 @@ class MessageHandler(BaseHandler): "Failed to get destination from event %s", s.event_id ) - with PreserveLoggingContext(): - # Don't block waiting on waking up all the listeners. + @defer.inlineCallbacks + def _notify(): + yield run_on_reactor() self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) + preserve_fn(_notify)() + # If invite, remove room_state from unsigned before sending. event.unsigned.pop("invite_room_state", None) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 37f57301fb..6b70fa3817 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -50,6 +50,8 @@ timers_fired_counter = metrics.register_counter("timers_fired") federation_presence_counter = metrics.register_counter("federation_presence") bump_active_time_counter = metrics.register_counter("bump_active_time") +get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) + # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" @@ -68,6 +70,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000 # How often to resend presence to remote servers FEDERATION_PING_INTERVAL = 25 * 60 * 1000 +# How long we will wait before assuming that the syncs from an external process +# are dead. +EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 + assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER @@ -158,15 +164,26 @@ class PresenceHandler(object): self.serial_to_user = {} self._next_serial = 1 - # Keeps track of the number of *ongoing* syncs. While this is non zero - # a user will never go offline. + # Keeps track of the number of *ongoing* syncs on this process. While + # this is non zero a user will never go offline. self.user_to_num_current_syncs = {} + # Keeps track of the number of *ongoing* syncs on other processes. + # While any sync is ongoing on another process the user will never + # go offline. + # Each process has a unique identifier and an update frequency. If + # no update is received from that process within the update period then + # we assume that all the sync requests on that process have stopped. + # Stored as a dict from process_id to set of user_id, and a dict of + # process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs = {} + self.external_process_last_updated_ms = {} + # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. self.clock.call_later( - 30 * 1000, + 30, self.clock.looping_call, self._handle_timeouts, 5000, @@ -266,31 +283,48 @@ class PresenceHandler(object): """Checks the presence of users that have timed out and updates as appropriate. """ + logger.info("Handling presence timeouts") now = self.clock.time_msec() - with Measure(self.clock, "presence_handle_timeouts"): - # Fetch the list of users that *may* have timed out. Things may have - # changed since the timeout was set, so we won't necessarily have to - # take any action. - users_to_check = self.wheel_timer.fetch(now) + try: + with Measure(self.clock, "presence_handle_timeouts"): + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = set(self.wheel_timer.fetch(now)) + + # Check whether the lists of syncing processes from an external + # process have expired. + expired_process_ids = [ + process_id for process_id, last_update + in self.external_process_last_updated_ms.items() + if now - last_update > EXTERNAL_PROCESS_EXPIRY + ] + for process_id in expired_process_ids: + users_to_check.update( + self.external_process_last_updated_ms.pop(process_id, ()) + ) + self.external_process_last_update.pop(process_id) - states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) - for user_id in set(users_to_check) - ] + states = [ + self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) + for user_id in users_to_check + ] - timers_fired_counter.inc_by(len(states)) + timers_fired_counter.inc_by(len(states)) - changes = handle_timeouts( - states, - is_mine_fn=self.is_mine_id, - user_to_num_current_syncs=self.user_to_num_current_syncs, - now=now, - ) + changes = handle_timeouts( + states, + is_mine_fn=self.is_mine_id, + syncing_user_ids=self.get_currently_syncing_users(), + now=now, + ) - preserve_fn(self._update_states)(changes) + preserve_fn(self._update_states)(changes) + except: + logger.exception("Exception in _handle_timeouts loop") @defer.inlineCallbacks def bump_presence_active_time(self, user): @@ -363,6 +397,74 @@ class PresenceHandler(object): defer.returnValue(_user_syncing()) + def get_currently_syncing_users(self): + """Get the set of user ids that are currently syncing on this HS. + Returns: + set(str): A set of user_id strings. + """ + syncing_user_ids = { + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + return syncing_user_ids + + @defer.inlineCallbacks + def update_external_syncs(self, process_id, syncing_user_ids): + """Update the syncing users for an external process + + Args: + process_id(str): An identifier for the process the users are + syncing against. This allows synapse to process updates + as user start and stop syncing against a given process. + syncing_user_ids(set(str)): The set of user_ids that are + currently syncing on that server. + """ + + # Grab the previous list of user_ids that were syncing on that process + prev_syncing_user_ids = ( + self.external_process_to_current_syncs.get(process_id, set()) + ) + # Grab the current presence state for both the users that are syncing + # now and the users that were syncing before this update. + prev_states = yield self.current_state_for_users( + syncing_user_ids | prev_syncing_user_ids + ) + updates = [] + time_now_ms = self.clock.time_msec() + + # For each new user that is syncing check if we need to mark them as + # being online. + for new_user_id in syncing_user_ids - prev_syncing_user_ids: + prev_state = prev_states[new_user_id] + if prev_state.state == PresenceState.OFFLINE: + updates.append(prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=time_now_ms, + last_user_sync_ts=time_now_ms, + )) + else: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + # For each user that is still syncing or stopped syncing update the + # last sync time so that we will correctly apply the grace period when + # they stop syncing. + for old_user_id in prev_syncing_user_ids: + prev_state = prev_states[old_user_id] + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + yield self._update_states(updates) + + # Update the last updated time for the process. We expire the entries + # if we don't receive an update in the given timeframe. + self.external_process_last_updated_ms[process_id] = self.clock.time_msec() + self.external_process_to_current_syncs[process_id] = syncing_user_ids + @defer.inlineCallbacks def current_state_for_user(self, user_id): """Get the current presence state for a user. @@ -879,13 +981,13 @@ class PresenceEventSource(object): user_ids_changed = set() changed = None - if from_key and max_token - from_key < 100: - # For small deltas, its quicker to get all changes and then - # work out if we share a room or they're in our presence list + if from_key: changed = stream_change_cache.get_all_entities_changed(from_key) - # get_all_entities_changed can return None - if changed is not None: + if changed is not None and len(changed) < 500: + # For small deltas, its quicker to get all changes and then + # work out if we share a room or they're in our presence list + get_updates_counter.inc("stream") for other_user_id in changed: if other_user_id in friends: user_ids_changed.add(other_user_id) @@ -897,6 +999,8 @@ class PresenceEventSource(object): else: # Too many possible updates. Find all users we can see and check # if any of them have changed. + get_updates_counter.inc("full") + user_ids_to_check = set() for room_id in room_ids: users = yield self.store.get_users_in_room(room_id) @@ -935,15 +1039,14 @@ class PresenceEventSource(object): return self.get_new_events(user, from_key=None, include_offline=False) -def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): +def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): """Checks the presence of users that have timed out and updates as appropriate. Args: user_states(list): List of UserPresenceState's to check. is_mine_fn (fn): Function that returns if a user_id is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -954,21 +1057,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): for state in user_states: is_mine = is_mine_fn(state.user_id) - new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) + new_state = handle_timeout(state, is_mine, syncing_user_ids, now) if new_state: changes[state.user_id] = new_state return changes.values() -def handle_timeout(state, is_mine, user_to_num_current_syncs, now): +def handle_timeout(state, is_mine, syncing_user_ids, now): """Checks the presence of the user to see if any of the timers have elapsed Args: state (UserPresenceState) is_mine (bool): Whether the user is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -1002,7 +1104,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now): # If there are have been no sync for a while (and none ongoing), # set presence to offline - if not user_to_num_current_syncs.get(user_id, 0): + if user_id not in syncing_user_ids: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( state=PresenceState.OFFLINE, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index e37409170d..d9ac09078d 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer +import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID, Requester - +from synapse.types import UserID from ._base import BaseHandler -import logging - logger = logging.getLogger(__name__) @@ -36,13 +36,6 @@ class ProfileHandler(BaseHandler): "profile", self.on_profile_query ) - distributor = hs.get_distributor() - - distributor.observe("registered_user", self.registered_user) - - def registered_user(self, user): - return self.store.create_profile(user.localpart) - @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): @@ -172,7 +165,9 @@ class ProfileHandler(BaseHandler): try: # Assume the user isn't a guest because we don't let guests set # profile or avatar data. - requester = Requester(user, "", False) + # XXX why are we recreating `requester` here for each room? + # what was wrong with the `requester` we were passed? + requester = synapse.types.create_requester(user) yield handler.update_membership( requester, user, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 5883b9111e..dd75c4fecf 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,19 +14,19 @@ # limitations under the License. """Contains functions for registering clients.""" +import logging +import urllib + from twisted.internet import defer -from synapse.types import UserID +import synapse.types from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) -from ._base import BaseHandler -from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient -from synapse.util.distributor import registered_user - -import logging -import urllib +from synapse.types import UserID +from synapse.util.async import run_on_reactor +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -37,8 +37,6 @@ class RegistrationHandler(BaseHandler): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() - self.distributor = hs.get_distributor() - self.distributor.declare("registered_user") self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -55,6 +53,13 @@ class RegistrationHandler(BaseHandler): Codes.INVALID_USERNAME ) + if localpart[0] == '_': + raise SynapseError( + 400, + "User ID may not begin with _", + Codes.INVALID_USERNAME + ) + user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -93,7 +98,8 @@ class RegistrationHandler(BaseHandler): password=None, generate_token=True, guest_access_token=None, - make_guest=False + make_guest=False, + admin=False, ): """Registers a new client on the server. @@ -101,8 +107,13 @@ class RegistrationHandler(BaseHandler): localpart : The local part of the user ID to register. If None, one will be generated. password (str) : The password to assign to this user so they can - login again. This can be None which means they cannot login again - via a password (e.g. the user is an application service user). + login again. This can be None which means they cannot login again + via a password (e.g. the user is an application service user). + generate_token (bool): Whether a new access token should be + generated. Having this be True should be considered deprecated, + since it offers no means of associating a device_id with the + access_token. Instead you should call auth_handler.issue_access_token + after registration. Returns: A tuple of (user_id, access_token). Raises: @@ -140,9 +151,12 @@ class RegistrationHandler(BaseHandler): password_hash=password_hash, was_guest=was_guest, make_guest=make_guest, + create_profile_with_localpart=( + # If the user was a guest then they already have a profile + None if was_guest else user.localpart + ), + admin=admin, ) - - yield registered_user(self.distributor, user) else: # autogen a sequential user ID attempts = 0 @@ -160,7 +174,8 @@ class RegistrationHandler(BaseHandler): user_id=user_id, token=token, password_hash=password_hash, - make_guest=make_guest + make_guest=make_guest, + create_profile_with_localpart=user.localpart, ) except SynapseError: # if user id is taken, just generate another @@ -168,7 +183,6 @@ class RegistrationHandler(BaseHandler): user_id = None token = None attempts += 1 - yield registered_user(self.distributor, user) # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding @@ -195,15 +209,13 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, - token=token, password_hash="", appservice_id=service_id, + create_profile_with_localpart=user.localpart, ) - yield registered_user(self.distributor, user) - defer.returnValue((user_id, token)) + defer.returnValue(user_id) @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): @@ -248,9 +260,9 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash=None + password_hash=None, + create_profile_with_localpart=user.localpart, ) - yield registered_user(self.distributor, user) except Exception as e: yield self.store.add_access_token_to_user(user_id, token) # Ignore Registration errors @@ -359,8 +371,10 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds): - """Creates a new user or returns an access token for an existing one + def get_or_create_user(self, localpart, displayname, duration_in_ms, + password_hash=None): + """Creates a new user if the user does not exist, + else revokes all previous access tokens and generates a new one. Args: localpart : The local part of the user ID to register. If None, @@ -387,32 +401,32 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - auth_handler = self.hs.get_handlers().auth_handler - token = auth_handler.generate_short_term_login_token(user_id, duration_seconds) + token = self.auth_handler().generate_access_token( + user_id, None, duration_in_ms) if need_register: yield self.store.register( user_id=user_id, token=token, - password_hash=None + password_hash=password_hash, + create_profile_with_localpart=user.localpart, ) - - yield registered_user(self.distributor, user) else: - yield self.store.flush_user(user_id=user_id) + yield self.store.user_delete_access_tokens(user_id=user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) profile_handler = self.hs.get_handlers().profile_handler + requester = synapse.types.create_requester(user) yield profile_handler.set_displayname( - user, user, displayname + user, requester, displayname ) defer.returnValue((user_id, token)) def auth_handler(self): - return self.hs.get_handlers().auth_handler + return self.hs.get_auth_handler() @defer.inlineCallbacks def guest_access_token_for(self, medium, address, inviter_user_id): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d63b3c513..bf6b1c1535 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,7 +20,7 @@ from ._base import BaseHandler from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.api.constants import ( - EventTypes, JoinRules, RoomCreationPreset, + EventTypes, JoinRules, RoomCreationPreset, Membership, ) from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils @@ -36,6 +36,8 @@ import string logger = logging.getLogger(__name__) +REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 + id_server_scheme = "https://" @@ -343,9 +345,15 @@ class RoomCreationHandler(BaseHandler): class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) - self.response_cache = ResponseCache() + self.response_cache = ResponseCache(hs) + self.remote_list_request_cache = ResponseCache(hs) + self.remote_list_cache = {} + self.fetch_looping_call = hs.get_clock().looping_call( + self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL + ) + self.fetch_all_remote_lists() - def get_public_room_list(self): + def get_local_public_room_list(self): result = self.response_cache.get(()) if not result: result = self.response_cache.set((), self._get_public_room_list()) @@ -359,14 +367,10 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def handle_room(room_id): - # We pull each bit of state out indvidually to avoid pulling the - # full state into memory. Due to how the caching works this should - # be fairly quick, even if not originally in the cache. - def get_state(etype, state_key): - return self.state_handler.get_current_state(room_id, etype, state_key) + current_state = yield self.state_handler.get_current_state(room_id) # Double check that this is actually a public room. - join_rules_event = yield get_state(EventTypes.JoinRules, "") + join_rules_event = current_state.get((EventTypes.JoinRules, "")) if join_rules_event: join_rule = join_rules_event.content.get("join_rule", None) if join_rule and join_rule != JoinRules.PUBLIC: @@ -374,47 +378,51 @@ class RoomListHandler(BaseHandler): result = {"room_id": room_id} - joined_users = yield self.store.get_users_in_room(room_id) - if len(joined_users) == 0: + num_joined_users = len([ + 1 for _, event in current_state.items() + if event.type == EventTypes.Member + and event.membership == Membership.JOIN + ]) + if num_joined_users == 0: return - result["num_joined_members"] = len(joined_users) + result["num_joined_members"] = num_joined_users aliases = yield self.store.get_aliases_for_room(room_id) if aliases: result["aliases"] = aliases - name_event = yield get_state(EventTypes.Name, "") + name_event = yield current_state.get((EventTypes.Name, "")) if name_event: name = name_event.content.get("name", None) if name: result["name"] = name - topic_event = yield get_state(EventTypes.Topic, "") + topic_event = current_state.get((EventTypes.Topic, "")) if topic_event: topic = topic_event.content.get("topic", None) if topic: result["topic"] = topic - canonical_event = yield get_state(EventTypes.CanonicalAlias, "") + canonical_event = current_state.get((EventTypes.CanonicalAlias, "")) if canonical_event: canonical_alias = canonical_event.content.get("alias", None) if canonical_alias: result["canonical_alias"] = canonical_alias - visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "") + visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, "")) visibility = None if visibility_event: visibility = visibility_event.content.get("history_visibility", None) result["world_readable"] = visibility == "world_readable" - guest_event = yield get_state(EventTypes.GuestAccess, "") + guest_event = current_state.get((EventTypes.GuestAccess, "")) guest = None if guest_event: guest = guest_event.content.get("guest_access", None) result["guest_can_join"] = guest == "can_join" - avatar_event = yield get_state("m.room.avatar", "") + avatar_event = current_state.get(("m.room.avatar", "")) if avatar_event: avatar_url = avatar_event.content.get("url", None) if avatar_url: @@ -427,6 +435,55 @@ class RoomListHandler(BaseHandler): # FIXME (erikj): START is no longer a valid value defer.returnValue({"start": "START", "end": "END", "chunk": results}) + @defer.inlineCallbacks + def fetch_all_remote_lists(self): + deferred = self.hs.get_replication_layer().get_public_rooms( + self.hs.config.secondary_directory_servers + ) + self.remote_list_request_cache.set((), deferred) + self.remote_list_cache = yield deferred + + @defer.inlineCallbacks + def get_aggregated_public_room_list(self): + """ + Get the public room list from this server and the servers + specified in the secondary_directory_servers config option. + XXX: Pagination... + """ + # We return the results from out cache which is updated by a looping call, + # unless we're missing a cache entry, in which case wait for the result + # of the fetch if there's one in progress. If not, omit that server. + wait = False + for s in self.hs.config.secondary_directory_servers: + if s not in self.remote_list_cache: + logger.warn("No cached room list from %s: waiting for fetch", s) + wait = True + break + + if wait and self.remote_list_request_cache.get(()): + yield self.remote_list_request_cache.get(()) + + public_rooms = yield self.get_local_public_room_list() + + # keep track of which room IDs we've seen so we can de-dup + room_ids = set() + + # tag all the ones in our list with our server name. + # Also add the them to the de-deping set + for room in public_rooms['chunk']: + room["server_name"] = self.hs.hostname + room_ids.add(room["room_id"]) + + # Now add the results from federation + for server_name, server_result in self.remote_list_cache.items(): + for room in server_result["chunk"]: + if room["room_id"] not in room_ids: + room["server_name"] = server_name + public_rooms["chunk"].append(room) + room_ids.add(room["room_id"]) + + defer.returnValue(public_rooms) + class RoomContextHandler(BaseHandler): @defer.inlineCallbacks diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7e616f44fd..8cec8fc4ed 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -14,24 +14,22 @@ # limitations under the License. -from twisted.internet import defer +import logging -from ._base import BaseHandler +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from twisted.internet import defer +from unpaddedbase64 import decode_base64 -from synapse.types import UserID, RoomID, Requester +import synapse.types from synapse.api.constants import ( EventTypes, Membership, ) from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.types import UserID, RoomID from synapse.util.async import Linearizer from synapse.util.distributor import user_left_room, user_joined_room - -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes - -from unpaddedbase64 import decode_base64 - -import logging +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler): ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: - requester = Requester(target_user, None, False) + requester = synapse.types.create_requester(target_user) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9ebfccc8bf..0ee4ebe504 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute from synapse.util.logcontext import LoggingContext @@ -139,7 +138,7 @@ class SyncHandler(object): self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache() + self.response_cache = ResponseCache(hs) def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): @@ -194,195 +193,15 @@ class SyncHandler(object): Returns: A Deferred SyncResult. """ - if since_token is None or full_state: - return self.full_state_sync(sync_config, since_token) - else: - return self.incremental_sync_with_gap(sync_config, since_token) - - @defer.inlineCallbacks - def full_state_sync(self, sync_config, timeline_since_token): - """Get a sync for a client which is starting without any state. - - If a 'message_since_token' is given, only timeline events which have - happened since that token will be returned. - - Returns: - A Deferred SyncResult. - """ - now_token = yield self.event_sources.get_current_token() - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token - ) - - presence_stream = self.event_sources.sources["presence"] - # TODO (mjark): This looks wrong, shouldn't we be getting the presence - # UP to the present rather than after the present? - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( - user=sync_config.user, - pagination_config=pagination_config.get_source_config("presence"), - key=None - ) - - membership_list = ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN - ) - - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=sync_config.user.to_string(), - membership_list=membership_list - ) - - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user( - sync_config.user.to_string() - ) - ) - - account_data['m.push_rules'] = yield self.push_rules_for_user( - sync_config.user - ) - - tags_by_room = yield self.store.get_tags_for_user( - sync_config.user.to_string() - ) - - ignored_users = account_data.get( - "m.ignored_user_list", {} - ).get("ignored_users", {}).keys() - - joined = [] - invited = [] - archived = [] - - user_id = sync_config.user.to_string() - - @defer.inlineCallbacks - def _generate_room_entry(event): - if event.membership == Membership.JOIN: - room_result = yield self.full_state_sync_for_joined_room( - room_id=event.room_id, - sync_config=sync_config, - now_token=now_token, - timeline_since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - joined.append(room_result) - elif event.membership == Membership.INVITE: - if event.sender in ignored_users: - return - invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) - elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. - if not sync_config.filter_collection.include_leave: - if event.membership == Membership.LEAVE: - if user_id == event.sender: - return - - leave_token = now_token.copy_and_replace( - "room_key", "s%d" % (event.stream_ordering,) - ) - room_result = yield self.full_state_sync_for_archived_room( - sync_config=sync_config, - room_id=event.room_id, - leave_event_id=event.event_id, - leave_token=leave_token, - timeline_since_token=timeline_since_token, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - archived.append(room_result) - - yield concurrently_execute(_generate_room_entry, room_list, 10) - - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) - - presence = sync_config.filter_collection.filter_presence( - presence - ) - - defer.returnValue(SyncResult( - presence=presence, - account_data=account_data_for_user, - joined=joined, - invited=invited, - archived=archived, - next_batch=now_token, - )) - - @defer.inlineCallbacks - def full_state_sync_for_joined_room(self, room_id, sync_config, - now_token, timeline_since_token, - ephemeral_by_room, tags_by_room, - account_data_by_room): - """Sync a room for a client which is starting without any state - Returns: - A Deferred JoinedSyncResult. - """ - - batch = yield self.load_filtered_recents( - room_id, sync_config, now_token, since_token=timeline_since_token - ) - - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id, sync_config, - now_token=now_token, - since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - batch=batch, - full_state=True, - ) - - defer.returnValue(room_sync) + return self.generate_sync_result(sync_config, since_token, full_state) @defer.inlineCallbacks def push_rules_for_user(self, user): user_id = user.to_string() - rawrules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - rules = format_push_rules_for_user(user, rawrules, enabled_map) + rules = yield self.store.get_push_rules_for_user(user_id) + rules = format_push_rules_for_user(user, rules) defer.returnValue(rules) - def account_data_for_user(self, account_data): - account_data_events = [] - - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - return account_data_events - - def account_data_for_room(self, room_id, tags_by_room, account_data_by_room): - account_data_events = [] - tags = tags_by_room.get(room_id) - if tags is not None: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = account_data_by_room.get(room_id, {}) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - return account_data_events - @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): """Get the ephemeral events for each room the user is in @@ -445,258 +264,22 @@ class SyncHandler(object): defer.returnValue((now_token, ephemeral_by_room)) - def full_state_sync_for_archived_room(self, room_id, sync_config, - leave_event_id, leave_token, - timeline_since_token, tags_by_room, - account_data_by_room): - """Sync a room for a client which is starting without any state - Returns: - A Deferred ArchivedSyncResult. - """ - - return self.incremental_sync_for_archived_room( - sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room, - account_data_by_room, full_state=True, leave_token=leave_token, - ) - - @defer.inlineCallbacks - def incremental_sync_with_gap(self, sync_config, since_token): - """ Get the incremental delta needed to bring the client up to - date with the server. - Returns: - A Deferred SyncResult. - """ - now_token = yield self.event_sources.get_current_token() - - rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) - room_ids = [room.room_id for room in rooms] - - presence_source = self.event_sources.sources["presence"] - presence, presence_key = yield presence_source.get_new_events( - user=sync_config.user, - from_key=since_token.presence_key, - limit=sync_config.filter_collection.presence_limit(), - room_ids=room_ids, - is_guest=sync_config.is_guest, - ) - now_token = now_token.copy_and_replace("presence_key", presence_key) - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token, since_token - ) - - app_service = yield self.store.get_app_service_by_user_id( - sync_config.user.to_string() - ) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) - else: - rooms = yield self.store.get_rooms_for_user( - sync_config.user.to_string() - ) - joined_room_ids = set(r.room_id for r in rooms) - - user_id = sync_config.user.to_string() - - timeline_limit = sync_config.filter_collection.timeline_limit() - - tags_by_room = yield self.store.get_updated_tags( - user_id, - since_token.account_data_key, - ) - - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, - since_token.account_data_key, - ) - ) - - push_rules_changed = yield self.store.have_push_rules_changed_for_user( - user_id, int(since_token.push_rules_key) - ) - - if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( - sync_config.user - ) - - ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id, - ) - - if ignored_account_data: - ignored_users = ignored_account_data.get("ignored_users", {}).keys() - else: - ignored_users = frozenset() - - # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key - ) - - mem_change_events_by_room_id = {} - for event in rooms_changed: - mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - - newly_joined_rooms = [] - archived = [] - invited = [] - for room_id, events in mem_change_events_by_room_id.items(): - non_joins = [e for e in events if e.membership != Membership.JOIN] - has_join = len(non_joins) != len(events) - - # We want to figure out if we joined the room at some point since - # the last sync (even if we have since left). This is to make sure - # we do send down the room, and with full state, where necessary - if room_id in joined_room_ids or has_join: - old_state = yield self.get_state_at(room_id, since_token) - old_mem_ev = old_state.get((EventTypes.Member, user_id), None) - if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: - newly_joined_rooms.append(room_id) - - if room_id in joined_room_ids: - continue - - if not non_joins: - continue - - # Only bother if we're still currently invited - should_invite = non_joins[-1].membership == Membership.INVITE - if should_invite: - if event.sender not in ignored_users: - room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) - if room_sync: - invited.append(room_sync) - - # Always include leave/ban events. Just take the last one. - # TODO: How do we handle ban -> leave in same batch? - leave_events = [ - e for e in non_joins - if e.membership in (Membership.LEAVE, Membership.BAN) - ] - - if leave_events: - leave_event = leave_events[-1] - room_sync = yield self.incremental_sync_for_archived_room( - sync_config, room_id, leave_event.event_id, since_token, - tags_by_room, account_data_by_room, - full_state=room_id in newly_joined_rooms - ) - if room_sync: - archived.append(room_sync) - - # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( - room_ids=joined_room_ids, - from_key=since_token.room_key, - to_key=now_token.room_key, - limit=timeline_limit + 1, - ) - - joined = [] - # We loop through all room ids, even if there are no new events, in case - # there are non room events taht we need to notify about. - for room_id in joined_room_ids: - room_entry = room_to_events.get(room_id, None) - - if room_entry: - events, start_key = room_entry - - prev_batch_token = now_token.copy_and_replace("room_key", start_key) - - newly_joined_room = room_id in newly_joined_rooms - full_state = newly_joined_room - - batch = yield self.load_filtered_recents( - room_id, sync_config, prev_batch_token, - since_token=since_token, - recents=events, - newly_joined_room=newly_joined_room, - ) - else: - batch = TimelineBatch( - events=[], - prev_batch=since_token, - limited=False, - ) - full_state = False - - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id=room_id, - sync_config=sync_config, - since_token=since_token, - now_token=now_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - batch=batch, - full_state=full_state, - ) - if room_sync: - joined.append(room_sync) - - # For each newly joined room, we want to send down presence of - # existing users. - presence_handler = self.presence_handler - extra_presence_users = set() - for room_id in newly_joined_rooms: - users = yield self.store.get_users_in_room(event.room_id) - extra_presence_users.update(users) - - # For each new member, send down presence. - for joined_sync in joined: - it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) - for event in it: - if event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - extra_presence_users.add(event.state_key) - - states = yield presence_handler.get_states( - [u for u in extra_presence_users if u != user_id], - as_event=True, - ) - presence.extend(states) - - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) - - presence = sync_config.filter_collection.filter_presence( - presence - ) - - defer.returnValue(SyncResult( - presence=presence, - account_data=account_data_for_user, - joined=joined, - invited=invited, - archived=archived, - next_batch=now_token, - )) - @defer.inlineCallbacks - def load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None, recents=None, newly_joined_room=False): + def _load_filtered_recents(self, room_id, sync_config, now_token, + since_token=None, recents=None, newly_joined_room=False): """ Returns: a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): - filtering_factor = 2 timeline_limit = sync_config.filter_collection.timeline_limit() - load_limit = max(timeline_limit * filtering_factor, 10) - max_repeat = 5 # Only try a few times per room, otherwise - room_key = now_token.room_key - end_key = room_key if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True else: limited = False - if recents is not None: + if recents: recents = sync_config.filter_collection.filter_room_timeline(recents) recents = yield filter_events_for_client( self.store, @@ -706,6 +289,19 @@ class SyncHandler(object): else: recents = [] + if not limited: + defer.returnValue(TimelineBatch( + events=recents, + prev_batch=now_token, + limited=False + )) + + filtering_factor = 2 + load_limit = max(timeline_limit * filtering_factor, 10) + max_repeat = 5 # Only try a few times per room, otherwise + room_key = now_token.room_key + end_key = room_key + since_key = None if since_token and not newly_joined_room: since_key = since_token.room_key @@ -749,103 +345,6 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def incremental_sync_with_gap_for_room(self, room_id, sync_config, - since_token, now_token, - ephemeral_by_room, tags_by_room, - account_data_by_room, - batch, full_state=False): - state = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, now_token, - full_state=full_state - ) - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - ephemeral = sync_config.filter_collection.filter_room_ephemeral( - ephemeral_by_room.get(room_id, []) - ) - - unread_notifications = {} - room_sync = JoinedSyncResult( - room_id=room_id, - timeline=batch, - state=state, - ephemeral=ephemeral, - account_data=account_data, - unread_notifications=unread_notifications, - ) - - if room_sync: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config - ) - - if notifs is not None: - unread_notifications["notification_count"] = notifs["notify_count"] - unread_notifications["highlight_count"] = notifs["highlight_count"] - - logger.debug("Room sync: %r", room_sync) - - defer.returnValue(room_sync) - - @defer.inlineCallbacks - def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id, - since_token, tags_by_room, - account_data_by_room, full_state, - leave_token=None): - """ Get the incremental delta needed to bring the client up to date for - the archived room. - Returns: - A Deferred ArchivedSyncResult - """ - - if not leave_token: - stream_token = yield self.store.get_stream_token_for_event( - leave_event_id - ) - - leave_token = since_token.copy_and_replace("room_key", stream_token) - - if since_token and since_token.is_after(leave_token): - defer.returnValue(None) - - batch = yield self.load_filtered_recents( - room_id, sync_config, leave_token, since_token, - ) - - logger.debug("Recents %r", batch) - - state_events_delta = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, leave_token, - full_state=full_state - ) - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - room_sync = ArchivedSyncResult( - room_id=room_id, - timeline=batch, - state=state_events_delta, - account_data=account_data, - ) - - logger.debug("Room sync: %r", room_sync) - - defer.returnValue(room_sync) - - @defer.inlineCallbacks def get_state_after_event(self, event): """ Get the room state after the given event @@ -970,26 +469,6 @@ class SyncHandler(object): for e in sync_config.filter_collection.filter_room_state(state.values()) }) - def check_joined_room(self, sync_config, state_delta): - """ - Check if the user has just joined the given room (so should - be given the full state) - - Args: - sync_config(synapse.handlers.sync.SyncConfig): - state_delta(dict[(str,str), synapse.events.FrozenEvent]): the - difference in state since the last sync - - Returns: - A deferred Tuple (state_delta, limited) - """ - join_event = state_delta.get(( - EventTypes.Member, sync_config.user.to_string()), None) - if join_event is not None: - if join_event.content["membership"] == Membership.JOIN: - return True - return False - @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): with Measure(self.clock, "unread_notifs_for_room_id"): @@ -1010,6 +489,551 @@ class SyncHandler(object): # count is whatever it was last time. defer.returnValue(None) + @defer.inlineCallbacks + def generate_sync_result(self, sync_config, since_token=None, full_state=False): + """Generates a sync result. + + Args: + sync_config (SyncConfig) + since_token (StreamToken) + full_state (bool) + + Returns: + Deferred(SyncResult) + """ + + # NB: The now_token gets changed by some of the generate_sync_* methods, + # this is due to some of the underlying streams not supporting the ability + # to query up to a given point. + # Always use the `now_token` in `SyncResultBuilder` + now_token = yield self.event_sources.get_current_token() + + sync_result_builder = SyncResultBuilder( + sync_config, full_state, + since_token=since_token, + now_token=now_token, + ) + + account_data_by_room = yield self._generate_sync_entry_for_account_data( + sync_result_builder + ) + + res = yield self._generate_sync_entry_for_rooms( + sync_result_builder, account_data_by_room + ) + newly_joined_rooms, newly_joined_users = res + + yield self._generate_sync_entry_for_presence( + sync_result_builder, newly_joined_rooms, newly_joined_users + ) + + defer.returnValue(SyncResult( + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + next_batch=sync_result_builder.now_token, + )) + + @defer.inlineCallbacks + def _generate_sync_entry_for_account_data(self, sync_result_builder): + """Generates the account data portion of the sync response. Populates + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + + Returns: + Deferred(dict): A dictionary containing the per room account data. + """ + sync_config = sync_result_builder.sync_config + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + + if since_token and not sync_result_builder.full_state: + account_data, account_data_by_room = ( + yield self.store.get_updated_account_data_for_user( + user_id, + since_token.account_data_key, + ) + ) + + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + else: + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user( + sync_config.user.to_string() + ) + ) + + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + + account_data_for_user = sync_config.filter_collection.filter_account_data([ + {"type": account_data_type, "content": content} + for account_data_type, content in account_data.items() + ]) + + sync_result_builder.account_data = account_data_for_user + + defer.returnValue(account_data_by_room) + + @defer.inlineCallbacks + def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, + newly_joined_users): + """Generates the presence portion of the sync response. Populates the + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + newly_joined_rooms(list): List of rooms that the user has joined + since the last sync (or empty if an initial sync) + newly_joined_users(list): List of users that have joined rooms + since the last sync (or empty if an initial sync) + """ + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + user = sync_result_builder.sync_config.user + + presence_source = self.event_sources.sources["presence"] + + since_token = sync_result_builder.since_token + if since_token and not sync_result_builder.full_state: + presence_key = since_token.presence_key + include_offline = True + else: + presence_key = None + include_offline = False + + presence, presence_key = yield presence_source.get_new_events( + user=user, + from_key=presence_key, + is_guest=sync_config.is_guest, + include_offline=include_offline, + ) + sync_result_builder.now_token = now_token.copy_and_replace( + "presence_key", presence_key + ) + + extra_users_ids = set(newly_joined_users) + for room_id in newly_joined_rooms: + users = yield self.store.get_users_in_room(room_id) + extra_users_ids.update(users) + extra_users_ids.discard(user.to_string()) + + states = yield self.presence_handler.get_states( + extra_users_ids, + as_event=True, + ) + presence.extend(states) + + # Deduplicate the presence entries so that there's at most one per user + presence = {p["content"]["user_id"]: p for p in presence}.values() + + presence = sync_config.filter_collection.filter_presence( + presence + ) + + sync_result_builder.presence = presence + + @defer.inlineCallbacks + def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + """Generates the rooms portion of the sync response. Populates the + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + account_data_by_room(dict): Dictionary of per room account data + + Returns: + Deferred(tuple): Returns a 2-tuple of + `(newly_joined_rooms, newly_joined_users)` + """ + user_id = sync_result_builder.sync_config.user.to_string() + + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_result_builder.sync_config, + now_token=sync_result_builder.now_token, + since_token=sync_result_builder.since_token, + ) + sync_result_builder.now_token = now_token + + ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + "m.ignored_user_list", user_id=user_id, + ) + + if ignored_account_data: + ignored_users = ignored_account_data.get("ignored_users", {}).keys() + else: + ignored_users = frozenset() + + if sync_result_builder.since_token: + res = yield self._get_rooms_changed(sync_result_builder, ignored_users) + room_entries, invited, newly_joined_rooms = res + + tags_by_room = yield self.store.get_updated_tags( + user_id, + sync_result_builder.since_token.account_data_key, + ) + else: + res = yield self._get_all_rooms(sync_result_builder, ignored_users) + room_entries, invited, newly_joined_rooms = res + + tags_by_room = yield self.store.get_tags_for_user(user_id) + + def handle_room_entries(room_entry): + return self._generate_room_entry( + sync_result_builder, + ignored_users, + room_entry, + ephemeral=ephemeral_by_room.get(room_entry.room_id, []), + tags=tags_by_room.get(room_entry.room_id), + account_data=account_data_by_room.get(room_entry.room_id, {}), + always_include=sync_result_builder.full_state, + ) + + yield concurrently_execute(handle_room_entries, room_entries, 10) + + sync_result_builder.invited.extend(invited) + + # Now we want to get any newly joined users + newly_joined_users = set() + if sync_result_builder.since_token: + for joined_sync in sync_result_builder.joined: + it = itertools.chain( + joined_sync.timeline.events, joined_sync.state.values() + ) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + newly_joined_users.add(event.state_key) + + defer.returnValue((newly_joined_rooms, newly_joined_users)) + + @defer.inlineCallbacks + def _get_rooms_changed(self, sync_result_builder, ignored_users): + """Gets the the changes that have happened since the last sync. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + + Returns: + Deferred(tuple): Returns a tuple of the form: + `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)` + """ + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + assert since_token + + app_service = yield self.store.get_app_service_by_user_id(user_id) + if app_service: + rooms = yield self.store.get_app_service_rooms(app_service) + joined_room_ids = set(r.room_id for r in rooms) + else: + rooms = yield self.store.get_rooms_for_user(user_id) + joined_room_ids = set(r.room_id for r in rooms) + + # Get a list of membership change events that have happened. + rooms_changed = yield self.store.get_membership_changes_for_user( + user_id, since_token.room_key, now_token.room_key + ) + + mem_change_events_by_room_id = {} + for event in rooms_changed: + mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) + + newly_joined_rooms = [] + room_entries = [] + invited = [] + for room_id, events in mem_change_events_by_room_id.items(): + non_joins = [e for e in events if e.membership != Membership.JOIN] + has_join = len(non_joins) != len(events) + + # We want to figure out if we joined the room at some point since + # the last sync (even if we have since left). This is to make sure + # we do send down the room, and with full state, where necessary + if room_id in joined_room_ids or has_join: + old_state = yield self.get_state_at(room_id, since_token) + old_mem_ev = old_state.get((EventTypes.Member, user_id), None) + if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: + newly_joined_rooms.append(room_id) + + if room_id in joined_room_ids: + continue + + if not non_joins: + continue + + # Only bother if we're still currently invited + should_invite = non_joins[-1].membership == Membership.INVITE + if should_invite: + if event.sender not in ignored_users: + room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if room_sync: + invited.append(room_sync) + + # Always include leave/ban events. Just take the last one. + # TODO: How do we handle ban -> leave in same batch? + leave_events = [ + e for e in non_joins + if e.membership in (Membership.LEAVE, Membership.BAN) + ] + + if leave_events: + leave_event = leave_events[-1] + leave_stream_token = yield self.store.get_stream_token_for_event( + leave_event.event_id + ) + leave_token = since_token.copy_and_replace( + "room_key", leave_stream_token + ) + + if since_token and since_token.is_after(leave_token): + continue + + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="archived", + events=None, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=leave_token, + )) + + timeline_limit = sync_config.filter_collection.timeline_limit() + + # Get all events for rooms we're currently joined to. + room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_ids=joined_room_ids, + from_key=since_token.room_key, + to_key=now_token.room_key, + limit=timeline_limit + 1, + ) + + # We loop through all room ids, even if there are no new events, in case + # there are non room events taht we need to notify about. + for room_id in joined_room_ids: + room_entry = room_to_events.get(room_id, None) + + if room_entry: + events, start_key = room_entry + + prev_batch_token = now_token.copy_and_replace("room_key", start_key) + + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="joined", + events=events, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=None if room_id in newly_joined_rooms else since_token, + upto_token=prev_batch_token, + )) + else: + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="joined", + events=[], + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=since_token, + )) + + defer.returnValue((room_entries, invited, newly_joined_rooms)) + + @defer.inlineCallbacks + def _get_all_rooms(self, sync_result_builder, ignored_users): + """Returns entries for all rooms for the user. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + + Returns: + Deferred(tuple): Returns a tuple of the form: + `([RoomSyncResultBuilder], [InvitedSyncResult], [])` + """ + + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + membership_list = ( + Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN + ) + + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, + membership_list=membership_list + ) + + room_entries = [] + invited = [] + + for event in room_list: + if event.membership == Membership.JOIN: + room_entries.append(RoomSyncResultBuilder( + room_id=event.room_id, + rtype="joined", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=now_token, + )) + elif event.membership == Membership.INVITE: + if event.sender in ignored_users: + continue + invite = yield self.store.get_event(event.event_id) + invited.append(InvitedSyncResult( + room_id=event.room_id, + invite=invite, + )) + elif event.membership in (Membership.LEAVE, Membership.BAN): + # Always send down rooms we were banned or kicked from. + if not sync_config.filter_collection.include_leave: + if event.membership == Membership.LEAVE: + if user_id == event.sender: + continue + + leave_token = now_token.copy_and_replace( + "room_key", "s%d" % (event.stream_ordering,) + ) + room_entries.append(RoomSyncResultBuilder( + room_id=event.room_id, + rtype="archived", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=leave_token, + )) + + defer.returnValue((room_entries, invited, [])) + + @defer.inlineCallbacks + def _generate_room_entry(self, sync_result_builder, ignored_users, + room_builder, ephemeral, tags, account_data, + always_include=False): + """Populates the `joined` and `archived` section of `sync_result_builder` + based on the `room_builder`. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + room_builder(RoomSyncResultBuilder) + ephemeral(list): List of new ephemeral events for room + tags(list): List of *all* tags for room, or None if there has been + no change. + account_data(list): List of new account data for room + always_include(bool): Always include this room in the sync response, + even if empty. + """ + newly_joined = room_builder.newly_joined + full_state = ( + room_builder.full_state + or newly_joined + or sync_result_builder.full_state + ) + events = room_builder.events + + # We want to shortcut out as early as possible. + if not (always_include or account_data or ephemeral or full_state): + if events == [] and tags is None: + return + + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + room_id = room_builder.room_id + since_token = room_builder.since_token + upto_token = room_builder.upto_token + + batch = yield self._load_filtered_recents( + room_id, sync_config, + now_token=upto_token, + since_token=since_token, + recents=events, + newly_joined_room=newly_joined, + ) + + account_data_events = [] + if tags is not None: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + account_data = sync_config.filter_collection.filter_room_account_data( + account_data_events + ) + + ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) + + if not (always_include or batch or account_data or ephemeral or full_state): + return + + state = yield self.compute_state_delta( + room_id, batch, sync_config, since_token, now_token, + full_state=full_state + ) + + if room_builder.rtype == "joined": + unread_notifications = {} + room_sync = JoinedSyncResult( + room_id=room_id, + timeline=batch, + state=state, + ephemeral=ephemeral, + account_data=account_data_events, + unread_notifications=unread_notifications, + ) + + if room_sync or always_include: + notifs = yield self.unread_notifs_for_room_id( + room_id, sync_config + ) + + if notifs is not None: + unread_notifications["notification_count"] = notifs["notify_count"] + unread_notifications["highlight_count"] = notifs["highlight_count"] + + sync_result_builder.joined.append(room_sync) + elif room_builder.rtype == "archived": + room_sync = ArchivedSyncResult( + room_id=room_id, + timeline=batch, + state=state, + account_data=account_data, + ) + if room_sync or always_include: + sync_result_builder.archived.append(room_sync) + else: + raise Exception("Unrecognized rtype: %r", room_builder.rtype) + def _action_has_highlight(actions): for action in actions: @@ -1057,3 +1081,51 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): (e.type, e.state_key): e for e in evs } + + +class SyncResultBuilder(object): + "Used to help build up a new SyncResult for a user" + def __init__(self, sync_config, full_state, since_token, now_token): + """ + Args: + sync_config(SyncConfig) + full_state(bool): The full_state flag as specified by user + since_token(StreamToken): The token supplied by user, or None. + now_token(StreamToken): The token to sync up to. + """ + self.sync_config = sync_config + self.full_state = full_state + self.since_token = since_token + self.now_token = now_token + + self.presence = [] + self.account_data = [] + self.joined = [] + self.invited = [] + self.archived = [] + + +class RoomSyncResultBuilder(object): + """Stores information needed to create either a `JoinedSyncResult` or + `ArchivedSyncResult`. + """ + def __init__(self, room_id, rtype, events, newly_joined, full_state, + since_token, upto_token): + """ + Args: + room_id(str) + rtype(str): One of `"joined"` or `"archived"` + events(list): List of events to include in the room, (more events + may be added when generating result). + newly_joined(bool): If the user has newly joined the room + full_state(bool): Whether the full state should be sent in result + since_token(StreamToken): Earliest point to return events from, or None + upto_token(StreamToken): Latest point to return events from. + """ + self.room_id = room_id + self.rtype = rtype + self.events = events + self.newly_joined = newly_joined + self.full_state = full_state + self.since_token = since_token + self.upto_token = upto_token diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d46f05f426..5589296c09 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user")) +RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) class TypingHandler(object): @@ -38,7 +38,7 @@ class TypingHandler(object): self.store = hs.get_datastore() self.server_name = hs.config.server_name self.auth = hs.get_auth() - self.is_mine = hs.is_mine + self.is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() self.clock = hs.get_clock() @@ -67,20 +67,23 @@ class TypingHandler(object): @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): - if not self.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has started typing in %s", target_user.to_string(), room_id + "%s has started typing in %s", target_user_id, room_id ) until = self.clock.time_msec() + timeout - member = RoomMember(room_id=room_id, user=target_user) + member = RoomMember(room_id=room_id, user_id=target_user_id) was_present = member in self._member_typing_until @@ -104,25 +107,28 @@ class TypingHandler(object): yield self._push_update( room_id=room_id, - user=target_user, + user_id=target_user_id, typing=True, ) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): - if not self.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has stopped typing in %s", target_user.to_string(), room_id + "%s has stopped typing in %s", target_user_id, room_id ) - member = RoomMember(room_id=room_id, user=target_user) + member = RoomMember(room_id=room_id, user_id=target_user_id) if member in self._member_typing_timer: self.clock.cancel_call_later(self._member_typing_timer[member]) @@ -132,8 +138,9 @@ class TypingHandler(object): @defer.inlineCallbacks def user_left_room(self, user, room_id): - if self.is_mine(user): - member = RoomMember(room_id=room_id, user=user) + user_id = user.to_string() + if self.is_mine_id(user_id): + member = RoomMember(room_id=room_id, user_id=user_id) yield self._stopped_typing(member) @defer.inlineCallbacks @@ -144,7 +151,7 @@ class TypingHandler(object): yield self._push_update( room_id=member.room_id, - user=member.user, + user_id=member.user_id, typing=False, ) @@ -156,7 +163,7 @@ class TypingHandler(object): del self._member_typing_timer[member] @defer.inlineCallbacks - def _push_update(self, room_id, user, typing): + def _push_update(self, room_id, user_id, typing): domains = yield self.store.get_joined_hosts_for_room(room_id) deferreds = [] @@ -164,7 +171,7 @@ class TypingHandler(object): if domain == self.server_name: self._push_update_local( room_id=room_id, - user=user, + user_id=user_id, typing=typing ) else: @@ -173,7 +180,7 @@ class TypingHandler(object): edu_type="m.typing", content={ "room_id": room_id, - "user_id": user.to_string(), + "user_id": user_id, "typing": typing, }, )) @@ -183,23 +190,26 @@ class TypingHandler(object): @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] - user = UserID.from_string(content["user_id"]) + user_id = content["user_id"] + + # Check that the string is a valid user id + UserID.from_string(user_id) domains = yield self.store.get_joined_hosts_for_room(room_id) if self.server_name in domains: self._push_update_local( room_id=room_id, - user=user, + user_id=user_id, typing=content["typing"] ) - def _push_update_local(self, room_id, user, typing): + def _push_update_local(self, room_id, user_id, typing): room_set = self._room_typing.setdefault(room_id, set()) if typing: - room_set.add(user) + room_set.add(user_id) else: - room_set.discard(user) + room_set.discard(user_id) self._latest_room_serial += 1 self._room_serials[room_id] = self._latest_room_serial @@ -211,13 +221,14 @@ class TypingHandler(object): def get_all_typing_updates(self, last_id, current_id): # TODO: Work out a way to do this without scanning the entire state. + if last_id == current_id: + return [] + rows = [] for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps([ - u.to_string() for u in typing - ], ensure_ascii=False) + typing_bytes = json.dumps(list(typing), ensure_ascii=False) rows.append((serial, room_id, typing_bytes)) rows.sort() return rows @@ -239,7 +250,7 @@ class TypingNotificationEventSource(object): "type": "m.typing", "room_id": room_id, "content": { - "user_ids": [u.to_string() for u in typing], + "user_ids": list(typing), }, } |