diff options
Diffstat (limited to 'synapse/handlers')
36 files changed, 2575 insertions, 2876 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index dca337ec61..c29c78bd65 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -94,14 +94,15 @@ class BaseHandler(object): burst_count = self.hs.config.rc_message.burst_count allowed, time_allowed = self.ratelimiter.can_do_action( - user_id, time_now, + user_id, + time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) @defer.inlineCallbacks @@ -139,7 +140,7 @@ class BaseHandler(object): if member_event.content["membership"] not in { Membership.JOIN, - Membership.INVITE + Membership.INVITE, }: continue @@ -156,8 +157,7 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = synapse.types.create_requester( - target_user, is_guest=True) + requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() yield handler.update_membership( requester, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 7fa5d44d29..e62e6cab77 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -20,7 +20,7 @@ class AccountDataEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - def get_current_key(self, direction='f'): + def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() @defer.inlineCallbacks @@ -34,29 +34,22 @@ class AccountDataEventSource(object): tags = yield self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): - results.append({ - "type": "m.tag", - "content": {"tags": room_tags}, - "room_id": room_id, - }) + results.append( + {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} + ) account_data, room_account_data = ( yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) ) for account_data_type, content in account_data.items(): - results.append({ - "type": account_data_type, - "content": content, - }) + results.append({"type": account_data_type, "content": content}) for room_id, account_data in room_account_data.items(): for account_data_type, content in account_data.items(): - results.append({ - "type": account_data_type, - "content": content, - "room_id": room_id, - }) + results.append( + {"type": account_data_type, "content": content, "room_id": room_id} + ) defer.returnValue((results, current_stream_id)) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 5e0b92eb1c..0719da3ab7 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -49,12 +49,10 @@ class AccountValidityHandler(object): app_name = self.hs.config.email_app_name self._subject = self._account_validity.renew_email_subject % { - "app": app_name, + "app": app_name } - self._from_string = self.hs.config.email_notif_from % { - "app": app_name, - } + self._from_string = self.hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. self._subject = self._account_validity.renew_email_subject @@ -69,10 +67,7 @@ class AccountValidityHandler(object): ) # Check the renewal emails to send and send them every 30min. - self.clock.looping_call( - self.send_renewal_emails, - 30 * 60 * 1000, - ) + self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000) @defer.inlineCallbacks def send_renewal_emails(self): @@ -86,8 +81,7 @@ class AccountValidityHandler(object): if expiring_users: for user in expiring_users: yield self._send_renewal_email( - user_id=user["user_id"], - expiration_ts=user["expiration_ts_ms"], + user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) @defer.inlineCallbacks @@ -146,32 +140,33 @@ class AccountValidityHandler(object): for address in addresses: raw_to = email.utils.parseaddr(address)[1] - multipart_msg = MIMEMultipart('alternative') - multipart_msg['Subject'] = self._subject - multipart_msg['From'] = self._from_string - multipart_msg['To'] = address - multipart_msg['Date'] = email.utils.formatdate() - multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg = MIMEMultipart("alternative") + multipart_msg["Subject"] = self._subject + multipart_msg["From"] = self._from_string + multipart_msg["To"] = address + multipart_msg["Date"] = email.utils.formatdate() + multipart_msg["Message-ID"] = email.utils.make_msgid() multipart_msg.attach(text_part) multipart_msg.attach(html_part) logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable(self.sendmail( - self.hs.config.email_smtp_host, - self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security - )) - - yield self.store.set_renewal_mail_status( - user_id=user_id, - email_sent=True, - ) + yield make_deferred_yieldable( + self.sendmail( + self.hs.config.email_smtp_host, + self._raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + reactor=self.hs.get_reactor(), + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security, + ) + ) + + yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) @defer.inlineCallbacks def _get_email_addresses_for_user(self, user_id): @@ -248,9 +243,7 @@ class AccountValidityHandler(object): expiration_ts = self.clock.time_msec() + self._account_validity.period yield self.store.set_account_validity_for_user( - user_id=user_id, - expiration_ts=expiration_ts, - email_sent=email_sent, + user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) defer.returnValue(expiration_ts) diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 813777bf18..fbef2f3d38 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -15,14 +15,9 @@ import logging -import attr -from zope.interface import implementer - import twisted import twisted.internet.error from twisted.internet import defer -from twisted.python.filepath import FilePath -from twisted.python.url import URL from twisted.web import server, static from twisted.web.resource import Resource @@ -30,27 +25,6 @@ from synapse.app import check_bind_error logger = logging.getLogger(__name__) -try: - from txacme.interfaces import ICertificateStore - - @attr.s - @implementer(ICertificateStore) - class ErsatzStore(object): - """ - A store that only stores in memory. - """ - - certs = attr.ib(default=attr.Factory(dict)) - - def store(self, server_name, pem_objects): - self.certs[server_name] = [o.as_bytes() for o in pem_objects] - return defer.succeed(None) - - -except ImportError: - # txacme is missing - pass - class AcmeHandler(object): def __init__(self, hs): @@ -60,6 +34,7 @@ class AcmeHandler(object): @defer.inlineCallbacks def start_listening(self): + from synapse.handlers import acme_issuing_service # Configure logging for txacme, if you need to debug # from eliot import add_destinations @@ -67,50 +42,27 @@ class AcmeHandler(object): # # add_destinations(TwistedDestination()) - from txacme.challenges import HTTP01Responder - from txacme.service import AcmeIssuingService - from txacme.endpoint import load_or_create_client_key - from txacme.client import Client - from josepy.jwa import RS256 - - self._store = ErsatzStore() - responder = HTTP01Responder() - - self._issuer = AcmeIssuingService( - cert_store=self._store, - client_creator=( - lambda: Client.from_url( - reactor=self.reactor, - url=URL.from_text(self.hs.config.acme_url), - key=load_or_create_client_key( - FilePath(self.hs.config.config_dir_path) - ), - alg=RS256, - ) - ), - clock=self.reactor, - responders=[responder], + well_known = Resource() + + self._issuer = acme_issuing_service.create_issuing_service( + self.reactor, + acme_url=self.hs.config.acme_url, + account_key_file=self.hs.config.acme_account_key_file, + well_known_resource=well_known, ) - well_known = Resource() - well_known.putChild(b'acme-challenge', responder.resource) responder_resource = Resource() - responder_resource.putChild(b'.well-known', well_known) - responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain')) - + responder_resource.putChild(b".well-known", well_known) + responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain")) srv = server.Site(responder_resource) bind_addresses = self.hs.config.acme_bind_addresses for host in bind_addresses: logger.info( - "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port, + "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port ) try: - self.reactor.listenTCP( - self.hs.config.acme_port, - srv, - interface=host, - ) + self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host) except twisted.internet.error.CannotListenError as e: check_bind_error(e, host, bind_addresses) @@ -132,7 +84,7 @@ class AcmeHandler(object): logger.exception("Fail!") raise logger.warning("Reprovisioned %s, saving.", self._acme_domain) - cert_chain = self._store.certs[self._acme_domain] + cert_chain = self._issuer.cert_store.certs[self._acme_domain] try: with open(self.hs.config.tls_private_key_file, "wb") as private_key_file: diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py new file mode 100644 index 0000000000..e1d4224e74 --- /dev/null +++ b/synapse/handlers/acme_issuing_service.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility function to create an ACME issuing service. + +This file contains the unconditional imports on the acme and cryptography bits that we +only need (and may only have available) if we are doing ACME, so is designed to be +imported conditionally. +""" +import logging + +import attr +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from josepy import JWKRSA +from josepy.jwa import RS256 +from txacme.challenges import HTTP01Responder +from txacme.client import Client +from txacme.interfaces import ICertificateStore +from txacme.service import AcmeIssuingService +from txacme.util import generate_private_key +from zope.interface import implementer + +from twisted.internet import defer +from twisted.python.filepath import FilePath +from twisted.python.url import URL + +logger = logging.getLogger(__name__) + + +def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource): + """Create an ACME issuing service, and attach it to a web Resource + + Args: + reactor: twisted reactor + acme_url (str): URL to use to request certificates + account_key_file (str): where to store the account key + well_known_resource (twisted.web.IResource): web resource for .well-known. + we will attach a child resource for "acme-challenge". + + Returns: + AcmeIssuingService + """ + responder = HTTP01Responder() + + well_known_resource.putChild(b"acme-challenge", responder.resource) + + store = ErsatzStore() + + return AcmeIssuingService( + cert_store=store, + client_creator=( + lambda: Client.from_url( + reactor=reactor, + url=URL.from_text(acme_url), + key=load_or_create_client_key(account_key_file), + alg=RS256, + ) + ), + clock=reactor, + responders=[responder], + ) + + +@attr.s +@implementer(ICertificateStore) +class ErsatzStore(object): + """ + A store that only stores in memory. + """ + + certs = attr.ib(default=attr.Factory(dict)) + + def store(self, server_name, pem_objects): + self.certs[server_name] = [o.as_bytes() for o in pem_objects] + return defer.succeed(None) + + +def load_or_create_client_key(key_file): + """Load the ACME account key from a file, creating it if it does not exist. + + Args: + key_file (str): name of the file to use as the account key + """ + # this is based on txacme.endpoint.load_or_create_client_key, but doesn't + # hardcode the 'client.key' filename + acme_key_file = FilePath(key_file) + if acme_key_file.exists(): + logger.info("Loading ACME account key from '%s'", acme_key_file) + key = serialization.load_pem_private_key( + acme_key_file.getContent(), password=None, backend=default_backend() + ) + else: + logger.info("Saving new ACME account key to '%s'", acme_key_file) + key = generate_private_key("rsa") + acme_key_file.setContent( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + return JWKRSA(key=key) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5d629126fc..941ebfa107 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) class AdminHandler(BaseHandler): - def __init__(self, hs): super(AdminHandler, self).__init__(hs) @@ -33,23 +32,17 @@ class AdminHandler(BaseHandler): sessions = yield self.store.get_user_ip_and_agents(user) for session in sessions: - connections.append({ - "ip": session["ip"], - "last_seen": session["last_seen"], - "user_agent": session["user_agent"], - }) + connections.append( + { + "ip": session["ip"], + "last_seen": session["last_seen"], + "user_agent": session["user_agent"], + } + ) ret = { "user_id": user.to_string(), - "devices": { - "": { - "sessions": [ - { - "connections": connections, - } - ] - }, - }, + "devices": {"": {"sessions": [{"connections": connections}]}}, } defer.returnValue(ret) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 17eedf4dbf..5cc89d43f6 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -38,7 +38,6 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed class ApplicationServicesHandler(object): - def __init__(self, hs): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id @@ -101,9 +100,10 @@ class ApplicationServicesHandler(object): yield self._check_user_exists(event.state_key) if not self.started_scheduler: + def start_scheduler(): return self.scheduler.start().addErrback( - log_failure, "Application Services Failure", + log_failure, "Application Services Failure" ) run_as_background_process("as_scheduler", start_scheduler) @@ -118,10 +118,15 @@ class ApplicationServicesHandler(object): for event in events: yield handle_event(event) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True, + ) + ) yield self.store.set_appservice_last_pos(upper_bound) @@ -129,20 +134,23 @@ class ApplicationServicesHandler(object): ts = yield self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_positions.labels( - "appservice_sender").set(upper_bound) + "appservice_sender" + ).set(upper_bound) events_processed_counter.inc(len(events)) - event_processing_loop_room_count.labels( - "appservice_sender" - ).inc(len(events_by_room)) + event_processing_loop_room_count.labels("appservice_sender").inc( + len(events_by_room) + ) event_processing_loop_counter.labels("appservice_sender").inc() synapse.metrics.event_processing_lag.labels( - "appservice_sender").set(now - ts) + "appservice_sender" + ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "appservice_sender").set(ts) + "appservice_sender" + ).set(ts) finally: self.is_processing = False @@ -155,13 +163,9 @@ class ApplicationServicesHandler(object): Returns: True if this user exists on at least one application service. """ - user_query_services = yield self._get_services_for_user( - user_id=user_id - ) + user_query_services = yield self._get_services_for_user(user_id=user_id) for user_service in user_query_services: - is_known_user = yield self.appservice_api.query_user( - user_service, user_id - ) + is_known_user = yield self.appservice_api.query_user(user_service, user_id) if is_known_user: defer.returnValue(True) defer.returnValue(False) @@ -179,9 +183,7 @@ class ApplicationServicesHandler(object): room_alias_str = room_alias.to_string() services = self.store.get_app_services() alias_query_services = [ - s for s in services if ( - s.is_interested_in_alias(room_alias_str) - ) + s for s in services if (s.is_interested_in_alias(room_alias_str)) ] for alias_service in alias_query_services: is_known_alias = yield self.appservice_api.query_alias( @@ -189,22 +191,24 @@ class ApplicationServicesHandler(object): ) if is_known_alias: # the alias exists now so don't query more ASes. - result = yield self.store.get_association_from_room_alias( - room_alias - ) + result = yield self.store.get_association_from_room_alias(room_alias) defer.returnValue(result) @defer.inlineCallbacks def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) - results = yield make_deferred_yieldable(defer.DeferredList([ - run_in_background( - self.appservice_api.query_3pe, - service, kind, protocol, fields, + results = yield make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.query_3pe, service, kind, protocol, fields + ) + for service in services + ], + consumeErrors=True, ) - for service in services - ], consumeErrors=True)) + ) ret = [] for (success, result) in results: @@ -276,18 +280,12 @@ class ApplicationServicesHandler(object): def _get_services_for_user(self, user_id): services = self.store.get_app_services() - interested_list = [ - s for s in services if ( - s.is_interested_in_user(user_id) - ) - ] + interested_list = [s for s in services if (s.is_interested_in_user(user_id))] return defer.succeed(interested_list) def _get_services_for_3pn(self, protocol): services = self.store.get_app_services() - interested_list = [ - s for s in services if s.is_interested_in_protocol(protocol) - ] + interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] return defer.succeed(interested_list) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a0cf37a9f9..c8c1ed3246 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -134,13 +134,9 @@ class AuthHandler(BaseHandler): """ # build a list of supported flows - flows = [ - [login_type] for login_type in self._supported_login_types - ] + flows = [[login_type] for login_type in self._supported_login_types] - result, params, _ = yield self.check_auth( - flows, request_body, clientip, - ) + result, params, _ = yield self.check_auth(flows, request_body, clientip) # find the completed login type for login_type in self._supported_login_types: @@ -151,9 +147,7 @@ class AuthHandler(BaseHandler): break else: # this can't happen - raise Exception( - "check_auth returned True but no successful login type", - ) + raise Exception("check_auth returned True but no successful login type") # check that the UI auth matched the access token if user_id != requester.user.to_string(): @@ -215,11 +209,11 @@ class AuthHandler(BaseHandler): authdict = None sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - del clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + del clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] session = self._get_session_info(sid) if len(clientdict) > 0: @@ -232,27 +226,27 @@ class AuthHandler(BaseHandler): # on a home server. # Revisit: Assumimg the REST APIs do sensible validation, the data # isn't arbintrary. - session['clientdict'] = clientdict + session["clientdict"] = clientdict self._save_session(session) - elif 'clientdict' in session: - clientdict = session['clientdict'] + elif "clientdict" in session: + clientdict = session["clientdict"] if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session), + self._auth_dict_for_flows(flows, session) ) - if 'creds' not in session: - session['creds'] = {} - creds = session['creds'] + if "creds" not in session: + session["creds"] = {} + creds = session["creds"] # check auth type currently being presented errordict = {} - if 'type' in authdict: - login_type = authdict['type'] + if "type" in authdict: + login_type = authdict["type"] try: result = yield self._check_auth_dict( - authdict, clientip, password_servlet=password_servlet, + authdict, clientip, password_servlet=password_servlet ) if result: creds[login_type] = result @@ -281,16 +275,15 @@ class AuthHandler(BaseHandler): # and is not sensitive). logger.info( "Auth completed with creds: %r. Client dict has keys: %r", - creds, list(clientdict) + creds, + list(clientdict), ) - defer.returnValue((creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session["id"])) ret = self._auth_dict_for_flows(flows, session) - ret['completed'] = list(creds) + ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError( - ret, - ) + raise InteractiveAuthIncompleteError(ret) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -300,15 +293,13 @@ class AuthHandler(BaseHandler): """ if stagetype not in self.checkers: raise LoginError(400, "", Codes.MISSING_PARAM) - if 'session' not in authdict: + if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info( - authdict['session'] - ) - if 'creds' not in sess: - sess['creds'] = {} - creds = sess['creds'] + sess = self._get_session_info(authdict["session"]) + if "creds" not in sess: + sess["creds"] = {} + creds = sess["creds"] result = yield self.checkers[stagetype](authdict, clientip) if result: @@ -329,10 +320,10 @@ class AuthHandler(BaseHandler): not send a session ID, returns None. """ sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] return sid def set_session_data(self, session_id, key, value): @@ -347,7 +338,7 @@ class AuthHandler(BaseHandler): value (any): The data to store """ sess = self._get_session_info(session_id) - sess.setdefault('serverdict', {})[key] = value + sess.setdefault("serverdict", {})[key] = value self._save_session(sess) def get_session_data(self, session_id, key, default=None): @@ -360,7 +351,7 @@ class AuthHandler(BaseHandler): default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) - return sess.setdefault('serverdict', {}).get(key, default) + return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks def _check_auth_dict(self, authdict, clientip, password_servlet=False): @@ -378,15 +369,13 @@ class AuthHandler(BaseHandler): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - login_type = authdict['type'] + login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: # XXX: Temporary workaround for having Synapse handle password resets # See AuthHandler.check_auth for further details res = yield checker( - authdict, - clientip=clientip, - password_servlet=password_servlet, + authdict, clientip=clientip, password_servlet=password_servlet ) defer.returnValue(res) @@ -408,13 +397,11 @@ class AuthHandler(BaseHandler): # Client tried to provide captcha but didn't give the parameter: # bad request. raise LoginError( - 400, "Captcha response is required", - errcode=Codes.CAPTCHA_NEEDED + 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED ) logger.info( - "Submitting recaptcha response %s with remoteip %s", - user_response, clientip + "Submitting recaptcha response %s with remoteip %s", user_response, clientip ) # TODO: get this from the homeserver rather than creating a new one for @@ -424,34 +411,34 @@ class AuthHandler(BaseHandler): resp_body = yield client.post_urlencoded_get_json( self.hs.config.recaptcha_siteverify_api, args={ - 'secret': self.hs.config.recaptcha_private_key, - 'response': user_response, - 'remoteip': clientip, - } + "secret": self.hs.config.recaptcha_private_key, + "response": user_response, + "remoteip": clientip, + }, ) except PartialDownloadError as pde: # Twisted is silly data = pde.response resp_body = json.loads(data) - if 'success' in resp_body: + if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly # intend the CAPTCHA to be presented by whatever client the # user is using, we just care that they have completed a CAPTCHA. logger.info( "%s reCAPTCHA from hostname %s", - "Successful" if resp_body['success'] else "Failed", - resp_body.get('hostname') + "Successful" if resp_body["success"] else "Failed", + resp_body.get("hostname"), ) - if resp_body['success']: + if resp_body["success"]: defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) def _check_email_identity(self, authdict, **kwargs): - return self._check_threepid('email', authdict, **kwargs) + return self._check_threepid("email", authdict, **kwargs) def _check_msisdn(self, authdict, **kwargs): - return self._check_threepid('msisdn', authdict) + return self._check_threepid("msisdn", authdict) def _check_dummy_auth(self, authdict, **kwargs): return defer.succeed(True) @@ -461,10 +448,10 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs): - if 'threepid_creds' not in authdict: + if "threepid_creds" not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - threepid_creds = authdict['threepid_creds'] + threepid_creds = authdict["threepid_creds"] identity_handler = self.hs.get_handlers().identity_handler @@ -482,31 +469,36 @@ class AuthHandler(BaseHandler): validated=True, ) - threepid = { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], - } if row else None + threepid = ( + { + "medium": row["medium"], + "address": row["address"], + "validated_at": row["validated_at"], + } + if row + else None + ) if row: # Valid threepid returned, delete from the db yield self.store.delete_threepid_session(threepid_creds["sid"]) else: - raise SynapseError(400, "Password resets are not enabled on this homeserver") + raise SynapseError( + 400, "Password resets are not enabled on this homeserver" + ) if not threepid: raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - if threepid['medium'] != medium: + if threepid["medium"] != medium: raise LoginError( 401, - "Expecting threepid of type '%s', got '%s'" % ( - medium, threepid['medium'], - ), - errcode=Codes.UNAUTHORIZED + "Expecting threepid of type '%s', got '%s'" + % (medium, threepid["medium"]), + errcode=Codes.UNAUTHORIZED, ) - threepid['threepid_creds'] = authdict['threepid_creds'] + threepid["threepid_creds"] = authdict["threepid_creds"] defer.returnValue(threepid) @@ -520,13 +512,14 @@ class AuthHandler(BaseHandler): "version": self.hs.config.user_consent_version, "en": { "name": self.hs.config.user_consent_policy_name, - "url": "%s_matrix/consent?v=%s" % ( + "url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), }, - }, - }, + } + } } def _auth_dict_for_flows(self, flows, session): @@ -547,9 +540,9 @@ class AuthHandler(BaseHandler): params[stage] = get_params[stage]() return { - "session": session['id'], + "session": session["id"], "flows": [{"stages": f} for f in public_flows], - "params": params + "params": params, } def _get_session_info(self, session_id): @@ -560,9 +553,7 @@ class AuthHandler(BaseHandler): # create a new session while session_id is None or session_id in self.sessions: session_id = stringutils.random_string(24) - self.sessions[session_id] = { - "id": session_id, - } + self.sessions[session_id] = {"id": session_id} return self.sessions[session_id] @@ -652,7 +643,8 @@ class AuthHandler(BaseHandler): logger.warn( "Attempted to login as %s but it matches more than one user " "inexactly: %r", - user_id, user_infos.keys() + user_id, + user_infos.keys(), ) defer.returnValue(result) @@ -690,12 +682,10 @@ class AuthHandler(BaseHandler): user is too high too proceed. """ - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: - qualified_user_id = UserID( - username, self.hs.hostname - ).to_string() + qualified_user_id = UserID(username, self.hs.hostname).to_string() self.ratelimit_login_per_account(qualified_user_id) @@ -713,17 +703,15 @@ class AuthHandler(BaseHandler): raise SynapseError(400, "Missing parameter: password") for provider in self.password_providers: - if (hasattr(provider, "check_password") - and login_type == LoginType.PASSWORD): + if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password( - qualified_user_id, password, - ) + is_valid = yield provider.check_password(qualified_user_id, password) if is_valid: defer.returnValue((qualified_user_id, None)) - if (not hasattr(provider, "get_supported_login_types") - or not hasattr(provider, "check_auth")): + if not hasattr(provider, "get_supported_login_types") or not hasattr( + provider, "check_auth" + ): # this password provider doesn't understand custom login types continue @@ -744,25 +732,22 @@ class AuthHandler(BaseHandler): login_dict[f] = login_submission[f] if missing_fields: raise SynapseError( - 400, "Missing parameters for login type %s: %s" % ( - login_type, - missing_fields, - ), + 400, + "Missing parameters for login type %s: %s" + % (login_type, missing_fields), ) - result = yield provider.check_auth( - username, login_type, login_dict, - ) + result = yield provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) defer.returnValue(result) - if login_type == LoginType.PASSWORD: + if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True canonical_user_id = yield self._check_local_password( - qualified_user_id, password, + qualified_user_id, password ) if canonical_user_id: @@ -773,7 +758,8 @@ class AuthHandler(BaseHandler): # unknown username or invalid password. self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), time_now_s=self._clock.time(), + qualified_user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=True, @@ -781,10 +767,7 @@ class AuthHandler(BaseHandler): # We raise a 403 here, but note that if we're doing user-interactive # login, it turns all LoginErrors into a 401 anyway. - raise LoginError( - 403, "Invalid password", - errcode=Codes.FORBIDDEN - ) + raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks def check_password_provider_3pid(self, medium, address, password): @@ -810,9 +793,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth( - medium, address, password, - ) + result = yield provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -853,8 +834,7 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token, - device_id) + yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) @defer.inlineCallbacks @@ -896,12 +876,13 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: yield self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"], ) + str(user_info["user"]), (user_info["token_id"],) ) @defer.inlineCallbacks - def delete_access_tokens_for_user(self, user_id, except_token_id=None, - device_id=None): + def delete_access_tokens_for_user( + self, user_id, except_token_id=None, device_id=None + ): """Invalidate access tokens belonging to a user Args: @@ -915,7 +896,7 @@ class AuthHandler(BaseHandler): Deferred """ tokens_and_devices = yield self.store.user_delete_access_tokens( - user_id, except_token_id=except_token_id, device_id=device_id, + user_id, except_token_id=except_token_id, device_id=device_id ) # see if any of our auth providers want to know about this @@ -923,14 +904,12 @@ class AuthHandler(BaseHandler): if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: yield provider.on_logged_out( - user_id=user_id, - device_id=device_id, - access_token=token, + user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens yield self.hs.get_pusherpool().remove_pushers_by_access_token( - user_id, (token_id for _, token_id, _ in tokens_and_devices), + user_id, (token_id for _, token_id, _ in tokens_and_devices) ) @defer.inlineCallbacks @@ -944,12 +923,11 @@ class AuthHandler(BaseHandler): # of specific types of threepid (and fixes the fact that checking # for the presence of an email address during password reset was # case sensitive). - if medium == 'email': + if medium == "email": address = address.lower() yield self.store.user_add_threepid( - user_id, medium, address, validated_at, - self.hs.get_clock().time_msec() + user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) @defer.inlineCallbacks @@ -973,22 +951,15 @@ class AuthHandler(BaseHandler): """ # 'Canonicalise' email addresses as per above - if medium == 'email': + if medium == "email": address = address.lower() identity_handler = self.hs.get_handlers().identity_handler result = yield identity_handler.try_unbind_threepid( - user_id, - { - 'medium': medium, - 'address': address, - 'id_server': id_server, - }, + user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid( - user_id, medium, address, - ) + yield self.store.user_delete_threepid(user_id, medium, address) defer.returnValue(result) def _save_session(self, session): @@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler): Returns: Deferred(unicode): Hashed password. """ + def _do_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.hashpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), bcrypt.gensalt(self.bcrypt_rounds), - ).decode('ascii') + ).decode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) @@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler): Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ + def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), - stored_hash + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + stored_hash, ) if stored_hash: if not isinstance(stored_hash, bytes): - stored_hash = stored_hash.encode('ascii') + stored_hash = stored_hash.encode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: @@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler): for this user is too high too proceed. """ self._failed_attempts_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=False, ) self._account_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, update=True, @@ -1083,9 +1058,9 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat("type = access") # Include a nonce, to make sure that each login gets a different # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) + macaroon.add_first_party_caveat( + "nonce = %s" % (stringutils.random_string_with_symbols(16),) + ) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() @@ -1116,7 +1091,8 @@ class MacaroonGenerator(object): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 7378b56c1d..e8f9da6098 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) class DeactivateAccountHandler(BaseHandler): """Handler which deals with deactivating user accounts.""" + def __init__(self, hs): super(DeactivateAccountHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() @@ -78,9 +79,9 @@ class DeactivateAccountHandler(BaseHandler): result = yield self._identity_handler.try_unbind_threepid( user_id, { - 'medium': threepid['medium'], - 'address': threepid['address'], - 'id_server': id_server, + "medium": threepid["medium"], + "address": threepid["address"], + "id_server": id_server, }, ) identity_server_supports_unbinding &= result @@ -89,7 +90,7 @@ class DeactivateAccountHandler(BaseHandler): logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") yield self.store.user_delete_threepid( - user_id, threepid['medium'], threepid['address'], + user_id, threepid["medium"], threepid["address"] ) # delete any devices belonging to the user, which will also @@ -183,5 +184,6 @@ class DeactivateAccountHandler(BaseHandler): except Exception: logger.exception( "Failed to part user %r from room %r: ignoring and continuing", - user_id, room_id, + user_id, + room_id, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b398848079..99e8413092 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -58,9 +58,7 @@ class DeviceWorkerHandler(BaseHandler): device_map = yield self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id=None - ) + ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -85,9 +83,7 @@ class DeviceWorkerHandler(BaseHandler): device = yield self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id, - ) + ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) defer.returnValue(device) @@ -105,22 +101,24 @@ class DeviceWorkerHandler(BaseHandler): room_ids = yield self.store.get_rooms_for_user(user_id) - # First we check if any devices have changed - changed = yield self.store.get_user_whose_devices_changed( - from_token.device_list_key + # First we check if any devices have changed for users that we share + # rooms with. + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + changed = yield self.store.get_users_whose_devices_changed( + from_token.device_list_key, users_who_share_room ) # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) member_events = yield self.store.get_membership_changes_for_user( - user_id, from_token.room_key, now_room_key, + user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) - stream_ordering = RoomStreamToken.parse_stream_token( - from_token.room_key - ).stream + stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream possibly_changed = set(changed) possibly_left = set() @@ -194,10 +192,6 @@ class DeviceWorkerHandler(BaseHandler): break if possibly_changed or possibly_left: - users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id - ) - # Take the intersection of the users whose devices may have changed # and those that actually still share a room with the user possibly_joined = possibly_changed & users_who_share_room @@ -206,10 +200,9 @@ class DeviceWorkerHandler(BaseHandler): possibly_joined = [] possibly_left = [] - defer.returnValue({ - "changed": list(possibly_joined), - "left": list(possibly_left), - }) + defer.returnValue( + {"changed": list(possibly_joined), "left": list(possibly_left)} + ) class DeviceHandler(DeviceWorkerHandler): @@ -223,17 +216,18 @@ class DeviceHandler(DeviceWorkerHandler): federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler( - "m.device_list_update", self._edu_updater.incoming_device_list_update, + "m.device_list_update", self._edu_updater.incoming_device_list_update ) federation_registry.register_query_handler( - "user_devices", self.on_federation_query_user_devices, + "user_devices", self.on_federation_query_user_devices ) hs.get_distributor().observe("user_left_room", self.user_left_room) @defer.inlineCallbacks - def check_device_registered(self, user_id, device_id, - initial_device_display_name=None): + def check_device_registered( + self, user_id, device_id, initial_device_display_name=None + ): """ If the given device has not been registered, register it with the supplied display name. @@ -297,12 +291,10 @@ class DeviceHandler(DeviceWorkerHandler): raise yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id, + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( - user_id=user_id, device_id=device_id - ) + yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) yield self.notify_device_update(user_id, [device_id]) @@ -349,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler): # considered as part of a critical path. for device_id in device_ids: yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id, + user_id, device_id=device_id ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id @@ -372,9 +364,7 @@ class DeviceHandler(DeviceWorkerHandler): try: yield self.store.update_device( - user_id, - device_id, - new_display_name=content.get("display_name") + user_id, device_id, new_display_name=content.get("display_name") ) yield self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: @@ -404,29 +394,26 @@ class DeviceHandler(DeviceWorkerHandler): for device_id in device_ids: logger.debug( - "Notifying about update %r/%r, ID: %r", user_id, device_id, - position, + "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) room_ids = yield self.store.get_rooms_for_user(user_id) - yield self.notifier.on_new_event( - "device_list_key", position, rooms=room_ids, - ) + yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids) if hosts: - logger.info("Sending device list update notif for %r to: %r", user_id, hosts) + logger.info( + "Sending device list update notif for %r to: %r", user_id, hosts + ) for host in hosts: self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - defer.returnValue({ - "user_id": user_id, - "stream_id": stream_id, - "devices": devices, - }) + defer.returnValue( + {"user_id": user_id, "stream_id": stream_id, "devices": devices} + ) @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -440,10 +427,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) class DeviceListEduUpdater(object): @@ -481,13 +465,15 @@ class DeviceListEduUpdater(object): device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints prev_ids = edu_content.pop("prev_id", []) - prev_ids = [str(p) for p in prev_ids] # They may come as ints + prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: # TODO: Raise? logger.warning( "Got device list update edu for %r/%r from %r", - user_id, device_id, origin, + user_id, + device_id, + origin, ) return @@ -497,13 +483,12 @@ class DeviceListEduUpdater(object): # probably won't get any further updates. logger.warning( "Got device list update edu for %r/%r, but don't share a room", - user_id, device_id, + user_id, + device_id, ) return - logger.debug( - "Received device list update for %r/%r", user_id, device_id, - ) + logger.debug("Received device list update for %r/%r", user_id, device_id) self._pending_updates.setdefault(user_id, []).append( (device_id, stream_id, prev_ids, edu_content) @@ -525,7 +510,10 @@ class DeviceListEduUpdater(object): for device_id, stream_id, prev_ids, content in pending_updates: logger.debug( "Handling update %r/%r, ID: %r, prev: %r ", - user_id, device_id, stream_id, prev_ids, + user_id, + device_id, + stream_id, + prev_ids, ) # Given a list of updates we check if we need to resync. This @@ -540,13 +528,13 @@ class DeviceListEduUpdater(object): try: result = yield self.federation.query_user_devices(origin, user_id) except ( - NotRetryingDestination, RequestSendFailed, HttpResponseException, + NotRetryingDestination, + RequestSendFailed, + HttpResponseException, ): # TODO: Remember that we are now out of sync and try again # later - logger.warn( - "Failed to handle device list update for %s", user_id, - ) + logger.warn("Failed to handle device list update for %s", user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -582,18 +570,21 @@ class DeviceListEduUpdater(object): if len(devices) > 1000: logger.warn( "Ignoring device list snapshot for %s as it has >1K devs (%d)", - user_id, len(devices) + user_id, + len(devices), ) devices = [] for device in devices: logger.debug( "Handling resync update %r/%r, ID: %r", - user_id, device["device_id"], stream_id, + user_id, + device["device_id"], + stream_id, ) yield self.store.update_remote_device_list_cache( - user_id, devices, stream_id, + user_id, devices, stream_id ) device_ids = [device["device_id"] for device in devices] yield self.device_handler.notify_device_update(user_id, device_ids) @@ -606,7 +597,7 @@ class DeviceListEduUpdater(object): # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: yield self.store.update_remote_device_list_cache_entry( - user_id, device_id, content, stream_id, + user_id, device_id, content, stream_id ) yield self.device_handler.notify_device_update( @@ -624,14 +615,9 @@ class DeviceListEduUpdater(object): """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote( - user_id - ) + extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) - logger.debug( - "Current extremity for %r: %r", - user_id, extremity, - ) + logger.debug("Current extremity for %r: %r", user_id, extremity) stream_id_in_updates = set() # stream_ids in updates list for _, stream_id, prev_ids, _ in updates: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 2e2e5261de..e1ebb6346c 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -25,7 +25,6 @@ logger = logging.getLogger(__name__) class DeviceMessageHandler(object): - def __init__(self, hs): """ Args: @@ -47,15 +46,15 @@ class DeviceMessageHandler(object): if origin != get_domain_from_id(sender_user_id): logger.warn( "Dropping device message from %r with spoofed sender %r", - origin, sender_user_id + origin, + sender_user_id, ) message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") messages_by_device = { diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a12f9508d8..42d5b3db30 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -36,7 +36,6 @@ logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): - def __init__(self, hs): super(DirectoryHandler, self).__init__(hs) @@ -77,15 +76,19 @@ class DirectoryHandler(BaseHandler): raise SynapseError(400, "Failed to get server list") yield self.store.create_room_alias_association( - room_alias, - room_id, - servers, - creator=creator, + room_alias, room_id, servers, creator=creator ) @defer.inlineCallbacks - def create_association(self, requester, room_alias, room_id, servers=None, - send_event=True, check_membership=True): + def create_association( + self, + requester, + room_alias, + room_id, + servers=None, + send_event=True, + check_membership=True, + ): """Attempt to create a new alias Args: @@ -115,49 +118,40 @@ class DirectoryHandler(BaseHandler): if service: if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( - 400, "This application service has not reserved" - " this kind of alias.", errcode=Codes.EXCLUSIVE + 400, + "This application service has not reserved" " this kind of alias.", + errcode=Codes.EXCLUSIVE, ) else: if self.require_membership and check_membership: rooms_for_user = yield self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( - 403, - "You must be in the room to create an alias for it", + 403, "You must be in the room to create an alias for it" ) if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): - raise AuthError( - 403, "This user is not permitted to create this alias", - ) + raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( - user_id, room_id, room_alias.to_string(), + user_id, room_id, room_alias.to_string() ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? - raise SynapseError( - 403, "Not allowed to create alias", - ) + raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) + can_create = yield self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) yield self._create_association(room_alias, room_id, servers, creator=user_id) if send_event: - yield self.send_room_alias_update_event( - requester, - room_id - ) + yield self.send_room_alias_update_event(requester, room_id) @defer.inlineCallbacks def delete_association(self, requester, room_alias, send_event=True): @@ -194,34 +188,24 @@ class DirectoryHandler(BaseHandler): raise if not can_delete: - raise AuthError( - 403, "You don't have permission to delete the alias.", - ) + raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) + can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) room_id = yield self._delete_association(room_alias) try: if send_event: - yield self.send_room_alias_update_event( - requester, - room_id - ) + yield self.send_room_alias_update_event(requester, room_id) yield self._update_canonical_alias( - requester, - requester.user.to_string(), - room_id, - room_alias, + requester, requester.user.to_string(), room_id, room_alias ) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -234,7 +218,7 @@ class DirectoryHandler(BaseHandler): raise SynapseError( 400, "This application service has not reserved this kind of alias", - errcode=Codes.EXCLUSIVE + errcode=Codes.EXCLUSIVE, ) yield self._delete_association(room_alias) @@ -251,9 +235,7 @@ class DirectoryHandler(BaseHandler): def get_association(self, room_alias): room_id = None if self.hs.is_mine(room_alias): - result = yield self.get_association_from_room_alias( - room_alias - ) + result = yield self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id @@ -263,9 +245,7 @@ class DirectoryHandler(BaseHandler): result = yield self.federation.make_query( destination=room_alias.domain, query_type="directory", - args={ - "room_alias": room_alias.to_string(), - }, + args={"room_alias": room_alias.to_string()}, retry_on_dns_fail=False, ignore_backoff=True, ) @@ -284,7 +264,7 @@ class DirectoryHandler(BaseHandler): raise SynapseError( 404, "Room alias %s not found" % (room_alias.to_string(),), - Codes.NOT_FOUND + Codes.NOT_FOUND, ) users = yield self.state.get_current_users_in_room(room_id) @@ -293,41 +273,28 @@ class DirectoryHandler(BaseHandler): # If this server is in the list of servers, return it first. if self.server_name in servers: - servers = ( - [self.server_name] + - [s for s in servers if s != self.server_name] - ) + servers = [self.server_name] + [s for s in servers if s != self.server_name] else: servers = list(servers) - defer.returnValue({ - "room_id": room_id, - "servers": servers, - }) + defer.returnValue({"room_id": room_id, "servers": servers}) return @defer.inlineCallbacks def on_directory_query(self, args): room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): - raise SynapseError( - 400, "Room Alias is not hosted on this Home Server" - ) + raise SynapseError(400, "Room Alias is not hosted on this Home Server") - result = yield self.get_association_from_room_alias( - room_alias - ) + result = yield self.get_association_from_room_alias(room_alias) if result is not None: - defer.returnValue({ - "room_id": result.room_id, - "servers": result.servers, - }) + defer.returnValue({"room_id": result.room_id, "servers": result.servers}) else: raise SynapseError( 404, "Room alias %r not found" % (room_alias.to_string(),), - Codes.NOT_FOUND + Codes.NOT_FOUND, ) @defer.inlineCallbacks @@ -343,7 +310,7 @@ class DirectoryHandler(BaseHandler): "sender": requester.user.to_string(), "content": {"aliases": aliases}, }, - ratelimit=False + ratelimit=False, ) @defer.inlineCallbacks @@ -365,14 +332,12 @@ class DirectoryHandler(BaseHandler): "sender": user_id, "content": {}, }, - ratelimit=False + ratelimit=False, ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): - result = yield self.store.get_association_from_room_alias( - room_alias - ) + result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists as_handler = self.appservice_handler @@ -421,8 +386,7 @@ class DirectoryHandler(BaseHandler): if not self.spam_checker.user_may_publish_room(user_id, room_id): raise AuthError( - 403, - "This user is not permitted to publish rooms to the room list" + 403, "This user is not permitted to publish rooms to the room list" ) if requester.is_guest: @@ -434,8 +398,7 @@ class DirectoryHandler(BaseHandler): if visibility == "public" and not self.enable_room_list_search: # The room list has been disabled. raise AuthError( - 403, - "This user is not permitted to publish rooms to the room list" + 403, "This user is not permitted to publish rooms to the room list" ) room = yield self.store.get_room(room_id) @@ -452,20 +415,19 @@ class DirectoryHandler(BaseHandler): room_aliases.append(canonical_alias) if not self.config.is_publishing_room_allowed( - user_id, room_id, room_aliases, + user_id, room_id, room_aliases ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? - raise SynapseError( - 403, "Not allowed to publish room", - ) + raise SynapseError(403, "Not allowed to publish room") yield self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks - def edit_published_appservice_room_list(self, appservice_id, network_id, - room_id, visibility): + def edit_published_appservice_room_list( + self, appservice_id, network_id, room_id, visibility + ): """Add or remove a room from the appservice/network specific public room list. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9dc46aa15f..807900fe52 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -99,9 +99,7 @@ class E2eKeysHandler(object): query_list.append((user_id, None)) user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache( - query_list - ) + yield self.store.get_user_devices_from_cache(query_list) ) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) @@ -126,9 +124,7 @@ class E2eKeysHandler(object): destination_query = remote_queries_not_in_cache[destination] try: remote_result = yield self.federation.query_client_keys( - destination, - {"device_keys": destination_query}, - timeout=timeout + destination, {"device_keys": destination_query}, timeout=timeout ) for user_id, keys in remote_result["device_keys"].items(): @@ -138,14 +134,17 @@ class E2eKeysHandler(object): except Exception as e: failures[destination] = _exception_to_failure(e) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(do_remote_query, destination) - for destination in remote_queries_not_in_cache - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(do_remote_query, destination) + for destination in remote_queries_not_in_cache + ], + consumeErrors=True, + ) + ) - defer.returnValue({ - "device_keys": results, "failures": failures, - }) + defer.returnValue({"device_keys": results, "failures": failures}) @defer.inlineCallbacks def query_local_devices(self, query): @@ -165,8 +164,7 @@ class E2eKeysHandler(object): for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") if not device_ids: @@ -231,9 +229,7 @@ class E2eKeysHandler(object): device_keys = remote_queries[destination] try: remote_result = yield self.federation.claim_client_keys( - destination, - {"one_time_keys": device_keys}, - timeout=timeout + destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: @@ -241,25 +237,29 @@ class E2eKeysHandler(object): except Exception as e: failures[destination] = _exception_to_failure(e) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(claim_client_keys, destination) - for destination in remote_queries - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(claim_client_keys, destination) + for destination in remote_queries + ], + consumeErrors=True, + ) + ) logger.info( "Claimed one-time-keys: %s", - ",".join(( - "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) - )), + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + ) + ), ) - defer.returnValue({ - "one_time_keys": json_result, - "failures": failures - }) + defer.returnValue({"one_time_keys": json_result, "failures": failures}) @defer.inlineCallbacks def upload_keys_for_user(self, user_id, device_id, keys): @@ -270,11 +270,13 @@ class E2eKeysHandler(object): if device_keys: logger.info( "Updating device_keys for device %r for user %s at %d", - device_id, user_id, time_now + device_id, + user_id, + time_now, ) # TODO: Sign the JSON with the server key changed = yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, device_keys, + user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed @@ -283,7 +285,7 @@ class E2eKeysHandler(object): one_time_keys = keys.get("one_time_keys", None) if one_time_keys: yield self._upload_one_time_keys_for_user( - user_id, device_id, time_now, one_time_keys, + user_id, device_id, time_now, one_time_keys ) # the device should have been registered already, but it may have been @@ -298,20 +300,22 @@ class E2eKeysHandler(object): defer.returnValue({"one_time_key_counts": result}) @defer.inlineCallbacks - def _upload_one_time_keys_for_user(self, user_id, device_id, time_now, - one_time_keys): + def _upload_one_time_keys_for_user( + self, user_id, device_id, time_now, one_time_keys + ): logger.info( "Adding one_time_keys %r for device %r for user %r at %d", - one_time_keys.keys(), device_id, user_id, time_now, + one_time_keys.keys(), + device_id, + user_id, + time_now, ) # make a list of (alg, id, key) tuples key_list = [] for key_id, key_obj in one_time_keys.items(): algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, key_obj - )) + key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. existing_key_map = yield self.store.get_e2e_one_time_keys( @@ -325,42 +329,35 @@ class E2eKeysHandler(object): if not _one_time_keys_match(ex_json, key): raise SynapseError( 400, - ("One time key %s:%s already exists. " - "Old key: %s; new key: %r") % - (algorithm, key_id, ex_json, key) + ( + "One time key %s:%s already exists. " + "Old key: %s; new key: %r" + ) + % (algorithm, key_id, ex_json, key), ) else: - new_keys.append(( - algorithm, key_id, encode_canonical_json(key).decode('ascii'))) + new_keys.append( + (algorithm, key_id, encode_canonical_json(key).decode("ascii")) + ) - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, new_keys - ) + yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) def _exception_to_failure(e): if isinstance(e, CodeMessageException): - return { - "status": e.code, "message": str(e), - } + return {"status": e.code, "message": str(e)} if isinstance(e, NotRetryingDestination): - return { - "status": 503, "message": "Not ready for retry", - } + return {"status": 503, "message": "Not ready for retry"} if isinstance(e, FederationDeniedError): - return { - "status": 403, "message": "Federation Denied", - } + return {"status": 403, "message": "Federation Denied"} # include ConnectionRefused and other errors # # Note that some Exceptions (notably twisted's ResponseFailed etc) don't # give a string for e.message, which json then fails to serialize. - return { - "status": 503, "message": str(e), - } + return {"status": 503, "message": str(e)} def _one_time_keys_match(old_key_json, new_key): diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 7bc174070e..ebd807bca6 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -152,14 +152,14 @@ class E2eRoomKeysHandler(object): else: raise - if version_info['version'] != version: + if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: version_info = yield self.store.get_e2e_room_keys_version_info( - user_id, version, + user_id, version ) # if we get this far, the version must exist - raise RoomKeysVersionError(current_version=version_info['version']) + raise RoomKeysVersionError(current_version=version_info["version"]) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -168,8 +168,8 @@ class E2eRoomKeysHandler(object): # go through the room_keys. # XXX: this should/could be done concurrently, given we're in a lock. - for room_id, room in iteritems(room_keys['rooms']): - for session_id, session in iteritems(room['sessions']): + for room_id, room in iteritems(room_keys["rooms"]): + for session_id, session in iteritems(room["sessions"]): yield self._upload_room_key( user_id, version, room_id, session_id, session ) @@ -223,14 +223,14 @@ class E2eRoomKeysHandler(object): # spelt out with if/elifs rather than nested boolean expressions # purely for legibility. - if room_key['is_verified'] and not current_room_key['is_verified']: + if room_key["is_verified"] and not current_room_key["is_verified"]: return True elif ( - room_key['first_message_index'] < - current_room_key['first_message_index'] + room_key["first_message_index"] + < current_room_key["first_message_index"] ): return True - elif room_key['forwarded_count'] < current_room_key['forwarded_count']: + elif room_key["forwarded_count"] < current_room_key["forwarded_count"]: return True else: return False @@ -328,16 +328,10 @@ class E2eRoomKeysHandler(object): A deferred of an empty dict. """ if "version" not in version_info: - raise SynapseError( - 400, - "Missing version in body", - Codes.MISSING_PARAM - ) + raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM) if version_info["version"] != version: raise SynapseError( - 400, - "Version in body does not match", - Codes.INVALID_PARAM + 400, "Version in body does not match", Codes.INVALID_PARAM ) with (yield self._upload_linearizer.queue(user_id)): try: @@ -350,12 +344,10 @@ class E2eRoomKeysHandler(object): else: raise if old_info["algorithm"] != version_info["algorithm"]: - raise SynapseError( - 400, - "Algorithm does not match", - Codes.INVALID_PARAM - ) + raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version(user_id, version, version_info) + yield self.store.update_e2e_room_keys_version( + user_id, version, version_info + ) defer.returnValue({}) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index eb525070cf..5836d3c639 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -31,7 +31,6 @@ logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): - def __init__(self, hs): super(EventStreamHandler, self).__init__(hs) @@ -53,9 +52,17 @@ class EventStreamHandler(BaseHandler): @defer.inlineCallbacks @log_function - def get_stream(self, auth_user_id, pagin_config, timeout=0, - as_client_event=True, affect_presence=True, - only_keys=None, room_id=None, is_guest=False): + def get_stream( + self, + auth_user_id, + pagin_config, + timeout=0, + as_client_event=True, + affect_presence=True, + only_keys=None, + room_id=None, + is_guest=False, + ): """Fetches the events stream for a given user. If `only_keys` is not None, events from keys will be sent down. @@ -73,7 +80,7 @@ class EventStreamHandler(BaseHandler): presence_handler = self.hs.get_presence_handler() context = yield presence_handler.user_syncing( - auth_user_id, affect_presence=affect_presence, + auth_user_id, affect_presence=affect_presence ) with context: if timeout: @@ -85,9 +92,12 @@ class EventStreamHandler(BaseHandler): timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) events, tokens = yield self.notifier.get_events_for( - auth_user, pagin_config, timeout, + auth_user, + pagin_config, + timeout, only_keys=only_keys, - is_guest=is_guest, explicit_room_id=room_id + is_guest=is_guest, + explicit_room_id=room_id, ) # When the user joins a new room, or another user joins a currently @@ -102,17 +112,15 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room(event.room_id) - states = yield presence_handler.get_states( - users, - as_event=True, + users = yield self.state.get_current_users_in_room( + event.room_id ) + states = yield presence_handler.get_states(users, as_event=True) to_add.extend(states) else: ev = yield presence_handler.get_state( - UserID.from_string(event.state_key), - as_event=True, + UserID.from_string(event.state_key), as_event=True ) to_add.append(ev) @@ -121,7 +129,9 @@ class EventStreamHandler(BaseHandler): time_now = self.clock.time_msec() chunks = yield self._event_serializer.serialize_events( - events, time_now, as_client_event=as_client_event, + events, + time_now, + as_client_event=as_client_event, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. bundle_aggregations=False, @@ -137,7 +147,6 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): - @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -164,16 +173,10 @@ class EventHandler(BaseHandler): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, - user.to_string(), - [event], - is_peeking=is_peeking + self.store, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: - raise AuthError( - 403, - "You don't have permission to access that event." - ) + raise AuthError(403, "You don't have permission to access that event.") defer.returnValue(event) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2b173a3f86..e7db501d98 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -82,7 +82,7 @@ def shortstr(iterable, maxitems=5): items = list(itertools.islice(iterable, maxitems + 1)) if len(items) <= maxitems: return str(items) - return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]" + return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" class FederationHandler(BaseHandler): @@ -115,14 +115,14 @@ class FederationHandler(BaseHandler): self.config = hs.config self.http_client = hs.get_simple_http_client() - self._send_events_to_master = ( - ReplicationFederationSendEventsRestServlet.make_client(hs) + self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( + hs ) - self._notify_user_membership_change = ( - ReplicationUserJoinedLeftRoomRestServlet.make_client(hs) + self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( + hs ) - self._clean_room_for_join_client = ( - ReplicationCleanRoomRestServlet.make_client(hs) + self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( + hs ) # When joining a room we need to queue any events for that room up @@ -132,9 +132,7 @@ class FederationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() @defer.inlineCallbacks - def on_receive_pdu( - self, origin, pdu, sent_to_us_directly=False, - ): + def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -151,26 +149,19 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info( - "[%s %s] handling received PDU: %s", - room_id, event_id, pdu, - ) + logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) # We reprocess pdus when we have seen them only as outliers existing = yield self.store.get_event( - event_id, - allow_none=True, - allow_rejected=True, + event_id, allow_none=True, allow_rejected=True ) # FIXME: Currently we fetch an event again when we already have it # if it has been marked as an outlier. - already_seen = ( - existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) + already_seen = existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() ) if already_seen: logger.debug("[%s %s]: Already seen pdu", room_id, event_id) @@ -182,20 +173,19 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id) - raise FederationError( - "ERROR", - err.code, - err.msg, - affected=pdu.event_id, + logger.warn( + "[%s %s] Received event failed sanity checks", room_id, event_id ) + raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) # If we are currently in the process of joining this room, then we # queue up events for later processing. if room_id in self.room_queues: logger.info( "[%s %s] Queuing PDU from %s for now: join in progress", - room_id, event_id, origin, + room_id, + event_id, + origin, ) self.room_queues[room_id].append((pdu, origin)) return @@ -206,14 +196,13 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room( - room_id, - self.server_name - ) + is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", - room_id, event_id, origin, + room_id, + event_id, + origin, ) defer.returnValue(None) @@ -223,14 +212,9 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context( - pdu.room_id - ) + min_depth = yield self.get_min_depth_for_context(pdu.room_id) - logger.debug( - "[%s %s] min_depth: %d", - room_id, event_id, min_depth, - ) + logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) seen = yield self.store.have_seen_events(prevs) @@ -248,12 +232,17 @@ class FederationHandler(BaseHandler): # at a time. logger.info( "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", - room_id, event_id, len(missing_prevs), shortstr(missing_prevs), + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), ) with (yield self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", - room_id, event_id, len(missing_prevs), + room_id, + event_id, + len(missing_prevs), ) yield self._get_missing_events_for_pdu( @@ -267,12 +256,16 @@ class FederationHandler(BaseHandler): if not prevs - seen: logger.info( "[%s %s] Found all missing prev_events", - room_id, event_id, + room_id, + event_id, ) elif missing_prevs: logger.info( "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, event_id, len(missing_prevs), shortstr(missing_prevs), + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), ) if prevs - seen: @@ -303,7 +296,10 @@ class FederationHandler(BaseHandler): if sent_to_us_directly: logger.warn( "[%s %s] Rejecting: failed to fetch %d prev events: %s", - room_id, event_id, len(prevs - seen), shortstr(prevs - seen) + room_id, + event_id, + len(prevs - seen), + shortstr(prevs - seen), ) raise FederationError( "ERROR", @@ -318,9 +314,7 @@ class FederationHandler(BaseHandler): # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. auth_chains = set() - event_map = { - event_id: pdu, - } + event_map = {event_id: pdu} try: # Get the state of the events we know about ours = yield self.store.get_state_groups_ids(room_id, seen) @@ -337,7 +331,9 @@ class FederationHandler(BaseHandler): for p in prevs - seen: logger.info( "[%s %s] Requesting state at missing prev_event %s", - room_id, event_id, p, + room_id, + event_id, + p, ) room_version = yield self.store.get_room_version(room_id) @@ -348,19 +344,19 @@ class FederationHandler(BaseHandler): # by the get_pdu_cache in federation_client. remote_state, got_auth_chain = ( yield self.federation_client.get_state_for_room( - origin, room_id, p, + origin, room_id, p ) ) # we want the state *after* p; get_state_for_room returns the # state *before* p. remote_event = yield self.federation_client.get_pdu( - [origin], p, room_version, outlier=True, + [origin], p, room_version, outlier=True ) if remote_event is None: raise Exception( - "Unable to get missing prev_event %s" % (p, ) + "Unable to get missing prev_event %s" % (p,) ) if remote_event.is_state(): @@ -380,7 +376,9 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x state_map = yield resolve_events_with_store( - room_version, state_maps, event_map, + room_version, + state_maps, + event_map, state_res_store=StateResolutionStore(self.store), ) @@ -396,15 +394,15 @@ class FederationHandler(BaseHandler): ) event_map.update(evs) - state = [ - event_map[e] for e in six.itervalues(state_map) - ] + state = [event_map[e] for e in six.itervalues(state_map)] auth_chain = list(auth_chains) except Exception: logger.warn( "[%s %s] Error attempting to resolve state at missing " "prev_events", - room_id, event_id, exc_info=True, + room_id, + event_id, + exc_info=True, ) raise FederationError( "ERROR", @@ -414,10 +412,7 @@ class FederationHandler(BaseHandler): ) yield self._process_received_pdu( - origin, - pdu, - state=state, - auth_chain=auth_chain, + origin, pdu, state=state, auth_chain=auth_chain ) @defer.inlineCallbacks @@ -447,7 +442,10 @@ class FederationHandler(BaseHandler): logger.info( "[%s %s]: Requesting missing events between %s and %s", - room_id, event_id, shortstr(latest), event_id, + room_id, + event_id, + shortstr(latest), + event_id, ) # XXX: we set timeout to 10s to help workaround @@ -498,19 +496,29 @@ class FederationHandler(BaseHandler): # # All that said: Let's try increasing the timout to 60s and see what happens. - missing_events = yield self.federation_client.get_missing_events( - origin, - room_id, - earliest_events_ids=list(latest), - latest_events=[pdu], - limit=10, - min_depth=min_depth, - timeout=60000, - ) + try: + missing_events = yield self.federation_client.get_missing_events( + origin, + room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + timeout=60000, + ) + except RequestSendFailed as e: + # We failed to get the missing events, but since we need to handle + # the case of `get_missing_events` not returning the necessary + # events anyway, it is safe to simply log the error and continue. + logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) + return logger.info( "[%s %s]: Got %d prev_events: %s", - room_id, event_id, len(missing_events), shortstr(missing_events), + room_id, + event_id, + len(missing_events), + shortstr(missing_events), ) # We want to sort these by depth so we process them and @@ -520,20 +528,20 @@ class FederationHandler(BaseHandler): for ev in missing_events: logger.info( "[%s %s] Handling received prev_event %s", - room_id, event_id, ev.event_id, + room_id, + event_id, + ev.event_id, ) with logcontext.nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu( - origin, - ev, - sent_to_us_directly=False, - ) + yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warn( "[%s %s] Received prev_event %s failed history check.", - room_id, event_id, ev.event_id, + room_id, + event_id, + ev.event_id, ) else: raise @@ -546,10 +554,7 @@ class FederationHandler(BaseHandler): room_id = event.room_id event_id = event.event_id - logger.debug( - "[%s %s] Processing event: %s", - room_id, event_id, event, - ) + logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) event_ids = set() if state: @@ -571,43 +576,32 @@ class FederationHandler(BaseHandler): e.internal_metadata.outlier = True auth_ids = e.auth_event_ids() auth = { - (e.type, e.state_key): e for e in auth_chain + (e.type, e.state_key): e + for e in auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } - event_infos.append({ - "event": e, - "auth_events": auth, - }) + event_infos.append({"event": e, "auth_events": auth}) seen_ids.add(e.event_id) logger.info( "[%s %s] persisting newly-received auth/state events %s", - room_id, event_id, [e["event"].event_id for e in event_infos] + room_id, + event_id, + [e["event"].event_id for e in event_infos], ) yield self._handle_new_events(origin, event_infos) try: - context = yield self._handle_new_event( - origin, - event, - state=state, - ) + context = yield self._handle_new_event(origin, event, state=state) except AuthError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) + raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) room = yield self.store.get_room(room_id) if not room: try: yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False, + room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: logger.exception("Failed to store room.") @@ -621,12 +615,10 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) - prev_state_id = prev_state_ids.get( - (event.type, event.state_key) - ) + prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: prev_state = yield self.store.get_event( - prev_state_id, allow_none=True, + prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: newly_joined = False @@ -657,10 +649,7 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(room_id) events = yield self.federation_client.backfill( - dest, - room_id, - limit=limit, - extremities=extremities, + dest, room_id, limit=limit, extremities=extremities ) # ideally we'd sanity check the events here for excess prev_events etc, @@ -687,16 +676,9 @@ class FederationHandler(BaseHandler): event_ids = set(e.event_id for e in events) - edges = [ - ev.event_id - for ev in events - if set(ev.prev_event_ids()) - event_ids - ] + edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] - logger.info( - "backfill: Got %d events with %d edges", - len(events), len(edges), - ) + logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) # For each edge get the current state. @@ -705,9 +687,7 @@ class FederationHandler(BaseHandler): events_to_state = {} for e_id in edges: state, auth = yield self.federation_client.get_state_for_room( - destination=dest, - room_id=room_id, - event_id=e_id + destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) @@ -716,12 +696,14 @@ class FederationHandler(BaseHandler): required_auth = set( a_id - for event in events + list(state_events.values()) + list(auth_events.values()) + for event in events + + list(state_events.values()) + + list(auth_events.values()) for a_id in event.auth_event_ids() ) - auth_events.update({ - e_id: event_map[e_id] for e_id in required_auth if e_id in event_map - }) + auth_events.update( + {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} + ) missing_auth = required_auth - set(auth_events) failed_to_fetch = set() @@ -740,27 +722,30 @@ class FederationHandler(BaseHandler): if missing_auth - failed_to_fetch: logger.info( "Fetching missing auth for backfill: %r", - missing_auth - failed_to_fetch + missing_auth - failed_to_fetch, ) - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background( - self.federation_client.get_pdu, - [dest], - event_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - failed_to_fetch - ], - consumeErrors=True - )).addErrback(unwrapFirstError) + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background( + self.federation_client.get_pdu, + [dest], + event_id, + room_version=room_version, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) auth_events.update({a.event_id: a for a in results if a}) required_auth.update( a_id - for event in results if event + for event in results + if event for a_id in event.auth_event_ids() ) missing_auth = required_auth - set(auth_events) @@ -792,15 +777,19 @@ class FederationHandler(BaseHandler): continue a.internal_metadata.outlier = True - ev_infos.append({ - "event": a, - "auth_events": { - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id in a.auth_event_ids() - if a_id in auth_events + ev_infos.append( + { + "event": a, + "auth_events": { + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in a.auth_event_ids() + if a_id in auth_events + }, } - }) + ) # Step 1b: persist the events in the chunk we fetched state for (i.e. # the backwards extremities) as non-outliers. @@ -808,23 +797,24 @@ class FederationHandler(BaseHandler): # For paranoia we ensure that these events are marked as # non-outliers ev = event_map[e_id] - assert(not ev.internal_metadata.is_outlier()) - - ev_infos.append({ - "event": ev, - "state": events_to_state[e_id], - "auth_events": { - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id in ev.auth_event_ids() - if a_id in auth_events + assert not ev.internal_metadata.is_outlier() + + ev_infos.append( + { + "event": ev, + "state": events_to_state[e_id], + "auth_events": { + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in ev.auth_event_ids() + if a_id in auth_events + }, } - }) + ) - yield self._handle_new_events( - dest, ev_infos, - backfilled=True, - ) + yield self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -835,14 +825,12 @@ class FederationHandler(BaseHandler): # For paranoia we ensure that these events are marked as # non-outliers - assert(not event.internal_metadata.is_outlier()) + assert not event.internal_metadata.is_outlier() # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event( - dest, event, backfilled=True, - ) + yield self._handle_new_event(dest, event, backfilled=True) defer.returnValue(events) @@ -851,9 +839,7 @@ class FederationHandler(BaseHandler): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room( - room_id - ) + extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -885,31 +871,27 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events( - list(extremities), - ) + forward_events = yield self.store.get_successor_events(list(extremities)) extremities_events = yield self.store.get_events( - forward_events, - check_redacted=False, - get_prev_content=False, + forward_events, check_redacted=False, get_prev_content=False ) # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, self.server_name, list(extremities_events.values()), - redact=False, check_history_visibility_only=True, + self.store, + self.server_name, + list(extremities_events.values()), + redact=False, + check_history_visibility_only=True, ) if not filtered_extremities: defer.returnValue(False) # Check if we reached a point where we should start backfilling. - sorted_extremeties_tuple = sorted( - extremities.items(), - key=lambda e: -int(e[1]) - ) + sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] # We don't want to specify too many extremities as it causes the backfill @@ -918,8 +900,7 @@ class FederationHandler(BaseHandler): if current_depth > max_depth: logger.debug( - "Not backfilling as we don't need to. %d < %d", - max_depth, current_depth, + "Not backfilling as we don't need to. %d < %d", max_depth, current_depth ) return @@ -944,8 +925,7 @@ class FederationHandler(BaseHandler): joined_users = [ (state_key, int(event.depth)) for (e_type, state_key), event in iteritems(state) - if e_type == EventTypes.Member - and event.membership == Membership.JOIN + if e_type == EventTypes.Member and event.membership == Membership.JOIN ] joined_domains = {} @@ -965,8 +945,7 @@ class FederationHandler(BaseHandler): curr_domains = get_domains_from_state(curr_state) likely_domains = [ - domain for domain, depth in curr_domains - if domain != self.server_name + domain for domain, depth in curr_domains if domain != self.server_name ] @defer.inlineCallbacks @@ -975,28 +954,20 @@ class FederationHandler(BaseHandler): for dom in domains: try: yield self.backfill( - dom, room_id, - limit=100, - extremities=extremities, + dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the # appropriate stuff. # TODO: We can probably do something more intelligent here. defer.returnValue(True) except SynapseError as e: - logger.info( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.info("Failed to backfill from %s because %s", dom, e) continue except CodeMessageException as e: if 400 <= e.code < 500: raise - logger.info( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.info("Failed to backfill from %s because %s", dom, e) continue except NotRetryingDestination as e: logger.info(str(e)) @@ -1005,10 +976,7 @@ class FederationHandler(BaseHandler): logger.info(e) continue except Exception as e: - logger.exception( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.exception("Failed to backfill from %s because %s", dom, e) continue defer.returnValue(False) @@ -1029,10 +997,11 @@ class FederationHandler(BaseHandler): resolve = logcontext.preserve_fn( self.state_handler.resolve_state_groups_for_events ) - states = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [resolve(room_id, [e]) for e in event_ids], - consumeErrors=True, - )) + states = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [resolve(room_id, [e]) for e in event_ids], consumeErrors=True + ) + ) # dict[str, dict[tuple, str]], a map from event_id to state map of # event_ids. @@ -1040,23 +1009,23 @@ class FederationHandler(BaseHandler): state_map = yield self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], - get_prev_content=False + get_prev_content=False, ) states = { key: { k: state_map[e_id] for k, e_id in iteritems(state_dict) if e_id in state_map - } for key, state_dict in iteritems(states) + } + for key, state_dict in iteritems(states) } for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill([ - dom for dom, _ in likely_domains - if dom not in tried_domains - ]) + success = yield try_backfill( + [dom for dom, _ in likely_domains if dom not in tried_domains] + ) if success: defer.returnValue(True) @@ -1081,20 +1050,20 @@ class FederationHandler(BaseHandler): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn("Rejecting event %s which has %i prev_events", - ev.event_id, len(ev.prev_event_ids())) - raise SynapseError( - http_client.BAD_REQUEST, - "Too many prev_events", + logger.warn( + "Rejecting event %s which has %i prev_events", + ev.event_id, + len(ev.prev_event_ids()), ) + raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn("Rejecting event %s which has %i auth_events", - ev.event_id, len(ev.auth_event_ids())) - raise SynapseError( - http_client.BAD_REQUEST, - "Too many auth_events", + logger.warn( + "Rejecting event %s which has %i auth_events", + ev.event_id, + len(ev.auth_event_ids()), ) + raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") @defer.inlineCallbacks def send_invite(self, target_host, event): @@ -1106,7 +1075,7 @@ class FederationHandler(BaseHandler): destination=target_host, room_id=event.room_id, event_id=event.event_id, - pdu=event + pdu=event, ) defer.returnValue(pdu) @@ -1115,8 +1084,7 @@ class FederationHandler(BaseHandler): def on_event_auth(self, event_id): event = yield self.store.get_event(event_id) auth = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], - include_given=True + [auth_id for auth_id in event.auth_event_ids()], include_given=True ) defer.returnValue([e for e in auth]) @@ -1142,15 +1110,13 @@ class FederationHandler(BaseHandler): joinee, "join", content, - params={ - "ver": KNOWN_ROOM_VERSIONS, - }, + params={"ver": KNOWN_ROOM_VERSIONS}, ) # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. - assert (room_id not in self.room_queues) + assert room_id not in self.room_queues self.room_queues[room_id] = [] @@ -1167,7 +1133,7 @@ class FederationHandler(BaseHandler): except ValueError: pass ret = yield self.federation_client.send_join( - target_hosts, event, event_format_version, + target_hosts, event, event_format_version ) origin = ret["origin"] @@ -1186,17 +1152,13 @@ class FederationHandler(BaseHandler): try: yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False + room_id=room_id, room_creator_user_id="", is_public=False ) except Exception: # FIXME pass - yield self._persist_auth_tree( - origin, auth_chain, state, event - ) + yield self._persist_auth_tree(origin, auth_chain, state, event) logger.debug("Finished joining %s to %s", joinee, room_id) finally: @@ -1223,14 +1185,18 @@ class FederationHandler(BaseHandler): """ for p, origin in room_queue: try: - logger.info("Processing queued PDU %s which was received " - "while we were joining %s", p.event_id, p.room_id) + logger.info( + "Processing queued PDU %s which was received " + "while we were joining %s", + p.event_id, + p.room_id, + ) with logcontext.nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warn( - "Error handling queued PDU %s from %s: %s", - p.event_id, origin, e) + "Error handling queued PDU %s from %s: %s", p.event_id, origin, e + ) @defer.inlineCallbacks @log_function @@ -1251,30 +1217,30 @@ class FederationHandler(BaseHandler): "room_id": room_id, "sender": user_id, "state_key": user_id, - } + }, ) try: event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) except AuthError as e: logger.warn("Failed to create join %r because %s", event, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Creation of join %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` yield self.auth.check_from_context( - room_version, event, context, do_sig_check=False, + room_version, event, context, do_sig_check=False ) defer.returnValue(event) @@ -1309,17 +1275,15 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context = yield self._handle_new_event( - origin, event - ) + context = yield self._handle_new_event(origin, event) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Sending of join %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) logger.debug( @@ -1340,10 +1304,7 @@ class FederationHandler(BaseHandler): state = yield self.store.get_events(list(prev_state_ids.values())) - defer.returnValue({ - "state": list(state.values()), - "auth_chain": auth_chain, - }) + defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain}) @defer.inlineCallbacks def on_invite_request(self, origin, pdu): @@ -1364,7 +1325,7 @@ class FederationHandler(BaseHandler): raise SynapseError(403, "This server does not accept room invites") if not self.spam_checker.user_may_invite( - event.sender, event.state_key, event.room_id, + event.sender, event.state_key, event.room_id ): raise SynapseError( 403, "This user is not permitted to send invites to this server/user" @@ -1376,26 +1337,23 @@ class FederationHandler(BaseHandler): sender_domain = get_domain_from_id(event.sender) if sender_domain != origin: - raise SynapseError(400, "The invite event was not from the server sending it") + raise SynapseError( + 400, "The invite event was not from the server sending it" + ) if not self.is_mine_id(event.state_key): raise SynapseError(400, "The invite event must be for this server") # block any attempts to invite the server notices mxid if event.state_key == self._server_notices_mxid: - raise SynapseError( - http_client.FORBIDDEN, - "Cannot invite this user", - ) + raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True event.signatures.update( compute_event_signature( - event.get_pdu_json(), - self.hs.hostname, - self.hs.config.signing_key[0] + event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0] ) ) @@ -1407,10 +1365,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" + target_hosts, room_id, user_id, "leave" ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1425,10 +1380,7 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.federation_client.send_leave( - target_hosts, - event - ) + yield self.federation_client.send_leave(target_hosts, event) context = yield self.state_handler.compute_event_context(event) yield self.persist_events_and_notify([(event, context)]) @@ -1436,25 +1388,21 @@ class FederationHandler(BaseHandler): defer.returnValue(event) @defer.inlineCallbacks - def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, - content={}, params=None): + def _make_and_verify_event( + self, target_hosts, room_id, user_id, membership, content={}, params=None + ): origin, event, format_ver = yield self.federation_client.make_membership_event( - target_hosts, - room_id, - user_id, - membership, - content, - params=params, + target_hosts, room_id, user_id, membership, content, params=params ) logger.debug("Got response to make_%s: %s", membership, event) # We should assert some things. # FIXME: Do this in a nicer way - assert(event.type == EventTypes.Member) - assert(event.user_id == user_id) - assert(event.state_key == user_id) - assert(event.room_id == room_id) + assert event.type == EventTypes.Member + assert event.user_id == user_id + assert event.state_key == user_id + assert event.room_id == room_id defer.returnValue((origin, event, format_ver)) @defer.inlineCallbacks @@ -1473,27 +1421,27 @@ class FederationHandler(BaseHandler): "room_id": room_id, "sender": user_id, "state_key": user_id, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.warning("Creation of leave %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` yield self.auth.check_from_context( - room_version, event, context, do_sig_check=False, + room_version, event, context, do_sig_check=False ) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) @@ -1515,17 +1463,15 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event( - origin, event - ) + context = yield self._handle_new_event(origin, event) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Sending of leave %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) logger.debug( @@ -1542,18 +1488,14 @@ class FederationHandler(BaseHandler): """ event = yield self.store.get_event( - event_id, allow_none=False, check_room_id=room_id, + event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups( - room_id, [event_id] - ) + state_groups = yield self.store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() - results = { - (e.type, e.state_key): e for e in state - } + results = {(e.type, e.state_key): e for e in state} if event.is_state(): # Get previous state @@ -1575,12 +1517,10 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ event = yield self.store.get_event( - event_id, allow_none=False, check_room_id=room_id, + event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids( - room_id, [event_id] - ) + state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1606,11 +1546,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield self.store.get_backfill_events( - room_id, - pdu_list, - limit - ) + events = yield self.store.get_backfill_events(room_id, pdu_list, limit) events = yield filter_events_for_server(self.store, origin, events) @@ -1634,22 +1570,15 @@ class FederationHandler(BaseHandler): AuthError if the server is not currently in the room """ event = yield self.store.get_event( - event_id, - allow_none=True, - allow_rejected=True, + event_id, allow_none=True, allow_rejected=True ) if event: - in_room = yield self.auth.check_host_in_room( - event.room_id, - origin - ) + in_room = yield self.auth.check_host_in_room(event.room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server( - self.store, origin, [event], - ) + events = yield filter_events_for_server(self.store, origin, [event]) event = events[0] defer.returnValue(event) else: @@ -1659,13 +1588,11 @@ class FederationHandler(BaseHandler): return self.store.get_min_depth(context) @defer.inlineCallbacks - def _handle_new_event(self, origin, event, state=None, auth_events=None, - backfilled=False): + def _handle_new_event( + self, origin, event, state=None, auth_events=None, backfilled=False + ): context = yield self._prep_event( - origin, event, - state=state, - auth_events=auth_events, - backfilled=backfilled, + origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we @@ -1678,15 +1605,13 @@ class FederationHandler(BaseHandler): ) yield self.persist_events_and_notify( - [(event, context)], - backfilled=backfilled, + [(event, context)], backfilled=backfilled ) success = True finally: if not success: logcontext.run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, + self.store.remove_push_actions_from_staging, event.event_id ) defer.returnValue(context) @@ -1714,12 +1639,15 @@ class FederationHandler(BaseHandler): ) defer.returnValue(res) - contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background(prep, ev_info) - for ev_info in event_infos - ], consumeErrors=True, - )) + contexts = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background(prep, ev_info) + for ev_info in event_infos + ], + consumeErrors=True, + ) + ) yield self.persist_events_and_notify( [ @@ -1754,8 +1682,7 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id] = ctx event_map = { - e.event_id: e - for e in itertools.chain(auth_events, state, [event]) + e.event_id: e for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1770,7 +1697,7 @@ class FederationHandler(BaseHandler): raise SynapseError(400, "No create event in state") room_version = create_event.content.get( - "room_version", RoomVersions.V1.identifier, + "room_version", RoomVersions.V1.identifier ) missing_auth_events = set() @@ -1781,11 +1708,7 @@ class FederationHandler(BaseHandler): for e_id in missing_auth_events: m_ev = yield self.federation_client.get_pdu( - [origin], - e_id, - room_version=room_version, - outlier=True, - timeout=10000, + [origin], e_id, room_version=room_version, outlier=True, timeout=10000 ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -1810,10 +1733,7 @@ class FederationHandler(BaseHandler): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn( - "Rejecting %s because %s", - e.event_id, err.msg - ) + logger.warn("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise @@ -1823,16 +1743,14 @@ class FederationHandler(BaseHandler): [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ], + ] ) new_event_context = yield self.state_handler.compute_event_context( event, old_state=state ) - yield self.persist_events_and_notify( - [(event, new_event_context)], - ) + yield self.persist_events_and_notify([(event, new_event_context)]) @defer.inlineCallbacks def _prep_event(self, origin, event, state, auth_events, backfilled): @@ -1848,40 +1766,30 @@ class FederationHandler(BaseHandler): Returns: Deferred, which resolves to synapse.events.snapshot.EventContext """ - context = yield self.state_handler.compute_event_context( - event, old_state=state, - ) + context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_event_ids(): if len(event.prev_event_ids()) == 1 and event.depth < 5: c = yield self.store.get_event( - event.prev_event_ids()[0], - allow_none=True, + event.prev_event_ids()[0], allow_none=True ) if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c try: - yield self.do_auth( - origin, event, context, auth_events=auth_events - ) + yield self.do_auth(origin, event, context, auth_events=auth_events) except AuthError as e: - logger.warn( - "[%s %s] Rejecting: %s", - event.room_id, event.event_id, e.msg - ) + logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) context.rejected = RejectedReason.AUTH_ERROR @@ -1912,9 +1820,7 @@ class FederationHandler(BaseHandler): # "soft-fail" the event. do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() if do_soft_fail_check: - extrem_ids = yield self.store.get_latest_event_ids_in_room( - event.room_id, - ) + extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids) prev_event_ids = set(event.prev_event_ids()) @@ -1942,31 +1848,31 @@ class FederationHandler(BaseHandler): # like bans, especially with state res v2. state_sets = yield self.store.get_state_groups( - event.room_id, extrem_ids, + event.room_id, extrem_ids ) state_sets = list(state_sets.values()) state_sets.append(state) current_state_ids = yield self.state_handler.resolve_events( - room_version, state_sets, event, + room_version, state_sets, event ) current_state_ids = { k: e.event_id for k, e in iteritems(current_state_ids) } else: current_state_ids = yield self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids, + event.room_id, latest_event_ids=extrem_ids ) logger.debug( "Doing soft-fail check for %s: state %s", - event.event_id, current_state_ids, + event.event_id, + current_state_ids, ) # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ - e for k, e in iteritems(current_state_ids) - if k in auth_types + e for k, e in iteritems(current_state_ids) if k in auth_types ] current_auth_events = yield self.store.get_events(current_state_ids) @@ -1977,19 +1883,14 @@ class FederationHandler(BaseHandler): try: self.auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: - logger.warn( - "Soft-failing %r because %s", - event, e, - ) + logger.warn("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @defer.inlineCallbacks - def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects, - missing): - in_room = yield self.auth.check_host_in_room( - room_id, - origin - ) + def on_query_auth( + self, origin, event_id, room_id, remote_auth_chain, rejects, missing + ): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2007,28 +1908,23 @@ class FederationHandler(BaseHandler): # Now get the current auth_chain for the event. local_auth_chain = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], - include_given=True + [auth_id for auth_id in event.auth_event_ids()], include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell # everyone. - ret = yield self.construct_auth_difference( - local_auth_chain, remote_auth_chain - ) + ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain) logger.debug("on_query_auth returning: %s", ret) defer.returnValue(ret) @defer.inlineCallbacks - def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit): - in_room = yield self.auth.check_host_in_room( - room_id, - origin - ) + def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2042,7 +1938,7 @@ class FederationHandler(BaseHandler): ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events, + self.store, origin, missing_events ) defer.returnValue(missing_events) @@ -2130,25 +2026,17 @@ class FederationHandler(BaseHandler): if missing_auth: # TODO: can we use store.have_seen_events here instead? - have_events = yield self.store.get_seen_events_with_rejections( - missing_auth - ) + have_events = yield self.store.get_seen_events_with_rejections(missing_auth) logger.debug("Got events %s from store", have_events) missing_auth.difference_update(have_events.keys()) else: have_events = {} - have_events.update({ - e.event_id: "" - for e in auth_events.values() - }) + have_events.update({e.event_id: "" for e in auth_events.values()}) if missing_auth: # If we don't have all the auth events, we need to get them. - logger.info( - "auth_events contains unknown events: %s", - missing_auth, - ) + logger.info("auth_events contains unknown events: %s", missing_auth) try: try: remote_auth_chain = yield self.federation_client.get_event_auth( @@ -2174,18 +2062,16 @@ class FederationHandler(BaseHandler): try: auth_ids = e.auth_event_ids() auth = { - (e.type, e.state_key): e for e in remote_auth_chain + (e.type, e.state_key): e + for e in remote_auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } e.internal_metadata.outlier = True logger.debug( - "do_auth %s missing_auth: %s", - event.event_id, e.event_id - ) - yield self._handle_new_event( - origin, e, auth_events=auth + "do_auth %s missing_auth: %s", event.event_id, e.event_id ) + yield self._handle_new_event(origin, e, auth_events=auth) if e.event_id in event_auth_events: auth_events[(e.type, e.state_key)] = e @@ -2221,35 +2107,36 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(event.room_id) different_events = yield logcontext.make_deferred_yieldable( - defer.gatherResults([ - logcontext.run_in_background( - self.store.get_event, - d, - allow_none=True, - allow_rejected=False, - ) - for d in different_auth - if d in have_events and not have_events[d] - ], consumeErrors=True) + defer.gatherResults( + [ + logcontext.run_in_background( + self.store.get_event, d, allow_none=True, allow_rejected=False + ) + for d in different_auth + if d in have_events and not have_events[d] + ], + consumeErrors=True, + ) ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) remote_view = dict(auth_events) - remote_view.update({ - (d.type, d.state_key): d for d in different_events if d - }) + remote_view.update( + {(d.type, d.state_key): d for d in different_events if d} + ) new_state = yield self.state_handler.resolve_events( room_version, [list(local_view.values()), list(remote_view.values())], - event + event, ) logger.info( "After state res: updating auth_events with new state %s", { - (d.type, d.state_key): d.event_id for d in new_state.values() + (d.type, d.state_key): d.event_id + for d in new_state.values() if auth_events.get((d.type, d.state_key)) != d }, ) @@ -2261,7 +2148,7 @@ class FederationHandler(BaseHandler): ) yield self._update_context_for_auth_events( - event, context, auth_events, event_key, + event, context, auth_events, event_key ) if not different_auth: @@ -2295,21 +2182,14 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) # 1. Get what we think is the auth chain. - auth_ids = yield self.auth.compute_auth_events( - event, prev_state_ids - ) - local_auth_chain = yield self.store.get_auth_chain( - auth_ids, include_given=True - ) + auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids) + local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True) try: # 2. Get remote difference. try: result = yield self.federation_client.query_auth( - origin, - event.room_id, - event.event_id, - local_auth_chain, + origin, event.room_id, event.event_id, local_auth_chain ) except RequestSendFailed as e: # The other side isn't around or doesn't implement the @@ -2334,19 +2214,15 @@ class FederationHandler(BaseHandler): auth = { (e.type, e.state_key): e for e in result["auth_chain"] - if e.event_id in auth_ids - or event.type == EventTypes.Create + if e.event_id in auth_ids or event.type == EventTypes.Create } ev.internal_metadata.outlier = True logger.debug( - "do_auth %s different_auth: %s", - event.event_id, e.event_id + "do_auth %s different_auth: %s", event.event_id, e.event_id ) - yield self._handle_new_event( - origin, ev, auth_events=auth - ) + yield self._handle_new_event(origin, ev, auth_events=auth) if ev.event_id in event_auth_events: auth_events[(ev.type, ev.state_key)] = ev @@ -2361,12 +2237,11 @@ class FederationHandler(BaseHandler): # TODO. yield self._update_context_for_auth_events( - event, context, auth_events, event_key, + event, context, auth_events, event_key ) @defer.inlineCallbacks - def _update_context_for_auth_events(self, event, context, auth_events, - event_key): + def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, storing the changes as a new state group. @@ -2383,8 +2258,7 @@ class FederationHandler(BaseHandler): this will not be included in the current_state in the context. """ state_updates = { - k: a.event_id for k, a in iteritems(auth_events) - if k != event_key + k: a.event_id for k, a in iteritems(auth_events) if k != event_key } current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = dict(current_state_ids) @@ -2394,9 +2268,7 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = dict(prev_state_ids) - prev_state_ids.update({ - k: a.event_id for k, a in iteritems(auth_events) - }) + prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) # create a new state group as a delta from the existing one. prev_group = context.state_group @@ -2545,30 +2417,23 @@ class FederationHandler(BaseHandler): logger.debug("construct_auth_difference returning") - defer.returnValue({ - "auth_chain": local_auth, - "rejects": { - e.event_id: { - "reason": reason_map[e.event_id], - "proof": None, - } - for e in base_remote_rejected - }, - "missing": [e.event_id for e in missing_locals], - }) + defer.returnValue( + { + "auth_chain": local_auth, + "rejects": { + e.event_id: {"reason": reason_map[e.event_id], "proof": None} + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + } + ) @defer.inlineCallbacks @log_function def exchange_third_party_invite( - self, - sender_user_id, - target_user_id, - room_id, - signed, + self, sender_user_id, target_user_id, room_id, signed ): - third_party_invite = { - "signed": signed, - } + third_party_invite = {"signed": signed} event_dict = { "type": EventTypes.Member, @@ -2591,7 +2456,7 @@ class FederationHandler(BaseHandler): ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info( @@ -2599,7 +2464,7 @@ class FederationHandler(BaseHandler): event, ) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) event, context = yield self.add_display_name_to_third_party_invite( @@ -2624,9 +2489,7 @@ class FederationHandler(BaseHandler): else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.federation_client.forward_third_party_invite( - destinations, - room_id, - event_dict, + destinations, room_id, event_dict ) @defer.inlineCallbacks @@ -2647,19 +2510,18 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(room_version, event_dict) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.warning( - "Exchange of threepid invite %s forbidden by third-party rules", - event, + "Exchange of threepid invite %s forbidden by third-party rules", event ) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) event, context = yield self.add_display_name_to_third_party_invite( @@ -2681,11 +2543,12 @@ class FederationHandler(BaseHandler): yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks - def add_display_name_to_third_party_invite(self, room_version, event_dict, - event, context): + def add_display_name_to_third_party_invite( + self, room_version, event_dict, event, context + ): key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) original_invite = None prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -2699,8 +2562,7 @@ class FederationHandler(BaseHandler): event_dict["content"]["third_party_invite"]["display_name"] = display_name else: logger.info( - "Could not find invite event for third_party_invite: %r", - event_dict + "Could not find invite event for third_party_invite: %r", event_dict ) # We don't discard here as this is not the appropriate place to do # auth checks. If we need the invite and don't have it then the @@ -2709,7 +2571,7 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) EventValidator().validate_new(event) defer.returnValue((event, context)) @@ -2733,9 +2595,7 @@ class FederationHandler(BaseHandler): token = signed["token"] prev_state_ids = yield context.get_prev_state_ids(self.store) - invite_event_id = prev_state_ids.get( - (EventTypes.ThirdPartyInvite, token,) - ) + invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None if invite_event_id: @@ -2744,25 +2604,59 @@ class FederationHandler(BaseHandler): if not invite_event: raise AuthError(403, "Could not find invite") + logger.debug("Checking auth on event %r", event.content) + last_exception = None + # for each public key in the 3pid invite event for public_key_object in self.hs.get_auth().get_public_keys(invite_event): try: + # for each sig on the third_party_invite block of the actual invite for server, signature_block in signed["signatures"].items(): for key_name, encoded_signature in signature_block.items(): if not key_name.startswith("ed25519:"): continue - public_key = public_key_object["public_key"] - verify_key = decode_verify_key_bytes( + logger.debug( + "Attempting to verify sig with key %s from %r " + "against pubkey %r", key_name, - decode_base64(public_key) + server, + public_key_object, ) - verify_signed_json(signed, server, verify_key) - if "key_validity_url" in public_key_object: - yield self._check_key_revocation( - public_key, - public_key_object["key_validity_url"] + + try: + public_key = public_key_object["public_key"] + verify_key = decode_verify_key_bytes( + key_name, decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + logger.debug( + "Successfully verified sig with key %s from %r " + "against pubkey %r", + key_name, + server, + public_key_object, ) + except Exception: + logger.info( + "Failed to verify sig with key %s from %r " + "against pubkey %r", + key_name, + server, + public_key_object, + ) + raise + try: + if "key_validity_url" in public_key_object: + yield self._check_key_revocation( + public_key, public_key_object["key_validity_url"] + ) + except Exception: + logger.info( + "Failed to query key_validity_url %s", + public_key_object["key_validity_url"], + ) + raise return except Exception as e: last_exception = e @@ -2783,15 +2677,9 @@ class FederationHandler(BaseHandler): for revocation. """ try: - response = yield self.http_client.get_json( - url, - {"public_key": public_key} - ) + response = yield self.http_client.get_json(url, {"public_key": public_key}) except Exception: - raise SynapseError( - 502, - "Third party certificate could not be checked" - ) + raise SynapseError(502, "Third party certificate could not be checked") if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") @@ -2812,12 +2700,11 @@ class FederationHandler(BaseHandler): yield self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, - backfilled=backfilled + backfilled=backfilled, ) else: max_stream_id = yield self.store.persist_events( - event_and_contexts, - backfilled=backfilled, + event_and_contexts, backfilled=backfilled ) if not backfilled: # Never notify for backfilled events @@ -2851,13 +2738,10 @@ class FederationHandler(BaseHandler): event_stream_id = event.internal_metadata.stream_ordering self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users + event, event_stream_id, max_stream_id, extra_users=extra_users ) - return self.pusher_pool.on_new_notifications( - event_stream_id, max_stream_id, - ) + return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _clean_room_for_join(self, room_id): """Called to clean up any data in DB for a given room, ready for the @@ -2876,9 +2760,7 @@ class FederationHandler(BaseHandler): """ if self.config.worker_app: return self._notify_user_membership_change( - room_id=room_id, - user_id=user.to_string(), - change="joined", + room_id=room_id, user_id=user.to_string(), change="joined" ) else: return user_joined_room(self.distributor, user, room_id) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index f60ace02e8..7da63bb643 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -30,6 +30,7 @@ def _create_rerouter(func_name): """Returns a function that looks at the group id and calls the function on federation or the local group server if the group is local """ + def f(self, group_id, *args, **kwargs): if self.is_mine_id(group_id): return getattr(self.groups_server_handler, func_name)( @@ -58,6 +59,7 @@ def _create_rerouter(func_name): d.addErrback(http_response_errback) d.addErrback(request_failed_errback) return d + return f @@ -125,7 +127,7 @@ class GroupsLocalHandler(object): ) else: res = yield self.transport_client.get_group_summary( - get_domain_from_id(group_id), group_id, requester_user_id, + get_domain_from_id(group_id), group_id, requester_user_id ) group_server_name = get_domain_from_id(group_id) @@ -182,7 +184,7 @@ class GroupsLocalHandler(object): content["user_profile"] = yield self.profile_handler.get_profile(user_id) res = yield self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -195,16 +197,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=True, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue(res) @@ -221,7 +222,7 @@ class GroupsLocalHandler(object): group_server_name = get_domain_from_id(group_id) res = yield self.transport_client.get_users_in_group( - get_domain_from_id(group_id), group_id, requester_user_id, + get_domain_from_id(group_id), group_id, requester_user_id ) chunk = res["chunk"] @@ -250,9 +251,7 @@ class GroupsLocalHandler(object): """Request to join a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.join_group( - group_id, user_id, content - ) + yield self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -260,7 +259,7 @@ class GroupsLocalHandler(object): content["attestation"] = local_attestation res = yield self.transport_client.join_group( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -276,16 +275,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=False, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue({}) @@ -294,9 +292,7 @@ class GroupsLocalHandler(object): """Accept an invite to a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.accept_invite( - group_id, user_id, content - ) + yield self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -304,7 +300,7 @@ class GroupsLocalHandler(object): content["attestation"] = local_attestation res = yield self.transport_client.accept_group_invite( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -320,16 +316,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=False, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue({}) @@ -337,17 +332,17 @@ class GroupsLocalHandler(object): def invite(self, group_id, user_id, requester_user_id, config): """Invite a user to a group """ - content = { - "requester_user_id": requester_user_id, - "config": config, - } + content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): res = yield self.groups_server_handler.invite_to_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) else: res = yield self.transport_client.invite_to_group( - get_domain_from_id(group_id), group_id, user_id, requester_user_id, + get_domain_from_id(group_id), + group_id, + user_id, + requester_user_id, content, ) @@ -370,13 +365,12 @@ class GroupsLocalHandler(object): local_profile["avatar_url"] = content["profile"]["avatar_url"] token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="invite", content={"profile": local_profile, "inviter": content["inviter"]}, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) try: user_profile = yield self.profile_handler.get_profile(user_id) except Exception as e: @@ -391,25 +385,25 @@ class GroupsLocalHandler(object): """ if user_id == requester_user_id: token = yield self.store.register_user_group_membership( - group_id, user_id, - membership="leave", - ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], + group_id, user_id, membership="leave" ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) # TODO: Should probably remember that we tried to leave so that we can # retry if the group server is currently down. if self.is_mine_id(group_id): res = yield self.groups_server_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id res = yield self.transport_client.remove_user_from_group( - get_domain_from_id(group_id), group_id, requester_user_id, - user_id, content, + get_domain_from_id(group_id), + group_id, + requester_user_id, + user_id, + content, ) defer.returnValue(res) @@ -420,12 +414,9 @@ class GroupsLocalHandler(object): """ # TODO: Check if user in group token = yield self.store.register_user_group_membership( - group_id, user_id, - membership="leave", - ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], + group_id, user_id, membership="leave" ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) @defer.inlineCallbacks def get_joined_groups(self, user_id): @@ -445,7 +436,7 @@ class GroupsLocalHandler(object): defer.returnValue({"groups": result}) else: bulk_result = yield self.transport_client.bulk_get_publicised_groups( - get_domain_from_id(user_id), [user_id], + get_domain_from_id(user_id), [user_id] ) result = bulk_result.get("users", {}).get(user_id) # TODO: Verify attestations @@ -460,9 +451,7 @@ class GroupsLocalHandler(object): if self.hs.is_mine_id(user_id): local_users.add(user_id) else: - destinations.setdefault( - get_domain_from_id(user_id), set() - ).add(user_id) + destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id) if not proxy and destinations: raise SynapseError(400, "Some user_ids are not local") @@ -472,16 +461,14 @@ class GroupsLocalHandler(object): for destination, dest_user_ids in iteritems(destinations): try: r = yield self.transport_client.bulk_get_publicised_groups( - destination, list(dest_user_ids), + destination, list(dest_user_ids) ) results.update(r["users"]) except Exception: failed_results.extend(dest_user_ids) for uid in local_users: - results[uid] = yield self.store.get_publicised_groups_for_user( - uid - ) + results[uid] = yield self.store.get_publicised_groups_for_user(uid) # Check AS associated groups for this user - this depends on the # RegExps in the AS registration file (under `users`) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 04caf65793..c82b1933f2 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -36,7 +36,6 @@ logger = logging.getLogger(__name__) class IdentityHandler(BaseHandler): - def __init__(self, hs): super(IdentityHandler, self).__init__(hs) @@ -64,40 +63,38 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def threepid_from_creds(self, creds): - if 'id_server' in creds: - id_server = creds['id_server'] - elif 'idServer' in creds: - id_server = creds['idServer'] + if "id_server" in creds: + id_server = creds["id_server"] + elif "idServer" in creds: + id_server = creds["idServer"] else: raise SynapseError(400, "No id_server in creds") - if 'client_secret' in creds: - client_secret = creds['client_secret'] - elif 'clientSecret' in creds: - client_secret = creds['clientSecret'] + if "client_secret" in creds: + client_secret = creds["client_secret"] + elif "clientSecret" in creds: + client_secret = creds["clientSecret"] else: raise SynapseError(400, "No client_secret in creds") if not self._should_trust_id_server(id_server): logger.warn( - '%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', id_server + "%s is not a trusted ID server: rejecting 3pid " + "credentials", + id_server, ) defer.returnValue(None) try: data = yield self.http_client.get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/3pid/getValidated3pid" - ), - {'sid': creds['sid'], 'client_secret': client_secret} + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"), + {"sid": creds["sid"], "client_secret": client_secret}, ) except HttpResponseException as e: logger.info("getValidated3pid failed with Matrix error: %r", e) raise e.to_synapse_error() - if 'medium' in data: + if "medium" in data: defer.returnValue(data) defer.returnValue(None) @@ -106,30 +103,24 @@ class IdentityHandler(BaseHandler): logger.debug("binding threepid %r to %s", creds, mxid) data = None - if 'id_server' in creds: - id_server = creds['id_server'] - elif 'idServer' in creds: - id_server = creds['idServer'] + if "id_server" in creds: + id_server = creds["id_server"] + elif "idServer" in creds: + id_server = creds["idServer"] else: raise SynapseError(400, "No id_server in creds") - if 'client_secret' in creds: - client_secret = creds['client_secret'] - elif 'clientSecret' in creds: - client_secret = creds['clientSecret'] + if "client_secret" in creds: + client_secret = creds["client_secret"] + elif "clientSecret" in creds: + client_secret = creds["clientSecret"] else: raise SynapseError(400, "No client_secret in creds") try: data = yield self.http_client.post_urlencoded_get_json( - "https://%s%s" % ( - id_server, "/_matrix/identity/api/v1/3pid/bind" - ), - { - 'sid': creds['sid'], - 'client_secret': client_secret, - 'mxid': mxid, - } + "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"), + {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid}, ) logger.debug("bound threepid %r to %s", creds, mxid) @@ -165,9 +156,7 @@ class IdentityHandler(BaseHandler): id_servers = [threepid["id_server"]] else: id_servers = yield self.store.get_id_servers_user_bound( - user_id=mxid, - medium=threepid["medium"], - address=threepid["address"], + user_id=mxid, medium=threepid["medium"], address=threepid["address"] ) # We don't know where to unbind, so we don't have a choice but to return @@ -177,7 +166,7 @@ class IdentityHandler(BaseHandler): changed = True for id_server in id_servers: changed &= yield self.try_unbind_threepid_with_id_server( - mxid, threepid, id_server, + mxid, threepid, id_server ) defer.returnValue(changed) @@ -201,10 +190,7 @@ class IdentityHandler(BaseHandler): url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) content = { "mxid": mxid, - "threepid": { - "medium": threepid["medium"], - "address": threepid["address"], - }, + "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, } # we abuse the federation http client to sign the request, but we have to send it @@ -212,25 +198,19 @@ class IdentityHandler(BaseHandler): # 'browser-like' HTTPS. auth_headers = self.federation_http_client.build_auth_headers( destination=None, - method='POST', - url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'), + method="POST", + url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"), content=content, destination_is=id_server, ) - headers = { - b"Authorization": auth_headers, - } + headers = {b"Authorization": auth_headers} try: - yield self.http_client.post_json_get_json( - url, - content, - headers, - ) + yield self.http_client.post_json_get_json(url, content, headers) changed = True except HttpResponseException as e: changed = False - if e.code in (400, 404, 501,): + if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) logger.warn("Received %d response while unbinding threepid", e.code) else: @@ -248,35 +228,27 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestEmailToken( - self, - id_server, - email, - client_secret, - send_attempt, - next_link=None, + self, id_server, email, client_secret, send_attempt, next_link=None ): if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) params = { - 'email': email, - 'client_secret': client_secret, - 'send_attempt': send_attempt, + "email": email, + "client_secret": client_secret, + "send_attempt": send_attempt, } if next_link: - params.update({'next_link': next_link}) + params.update({"next_link": next_link}) try: data = yield self.http_client.post_json_get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/validate/email/requestToken" - ), - params + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"), + params, ) defer.returnValue(data) except HttpResponseException as e: @@ -285,30 +257,26 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestMsisdnToken( - self, id_server, country, phone_number, - client_secret, send_attempt, **kwargs + self, id_server, country, phone_number, client_secret, send_attempt, **kwargs ): if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) params = { - 'country': country, - 'phone_number': phone_number, - 'client_secret': client_secret, - 'send_attempt': send_attempt, + "country": country, + "phone_number": phone_number, + "client_secret": client_secret, + "send_attempt": send_attempt, } params.update(kwargs) try: data = yield self.http_client.post_json_get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/validate/msisdn/requestToken" - ), - params + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"), + params, ) defer.returnValue(data) except HttpResponseException as e: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index aaee5db0b7..a1fe9d116f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -44,8 +44,13 @@ class InitialSyncHandler(BaseHandler): self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): + def snapshot_all_rooms( + self, + user_id=None, + pagin_config=None, + as_client_event=True, + include_archived=False, + ): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -77,13 +82,22 @@ class InitialSyncHandler(BaseHandler): if result is not None: return result - return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - )) + return self.snapshot_cache.set( + now_ms, + key, + self._snapshot_all_rooms( + user_id, pagin_config, as_client_event, include_archived + ), + ) @defer.inlineCallbacks - def _snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): + def _snapshot_all_rooms( + self, + user_id=None, + pagin_config=None, + as_client_event=True, + include_archived=False, + ): memberships = [Membership.INVITE, Membership.JOIN] if include_archived: @@ -128,8 +142,7 @@ class InitialSyncHandler(BaseHandler): "room_id": event.room_id, "membership": event.membership, "visibility": ( - "public" if event.room_id in public_room_ids - else "private" + "public" if event.room_id in public_room_ids else "private" ), } @@ -139,7 +152,7 @@ class InitialSyncHandler(BaseHandler): invite_event = yield self.store.get_event(event.event_id) d["invite"] = yield self._event_serializer.serialize_event( - invite_event, time_now, as_client_event, + invite_event, time_now, as_client_event ) rooms_ret.append(d) @@ -151,14 +164,12 @@ class InitialSyncHandler(BaseHandler): if event.membership == Membership.JOIN: room_end_token = now_token.room_key deferred_room_state = run_in_background( - self.state_handler.get_current_state, - event.room_id, + self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, - [event.event_id], + self.store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -178,9 +189,7 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client( - self.store, user_id, messages - ) + messages = yield filter_events_for_client(self.store, user_id, messages) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -189,8 +198,7 @@ class InitialSyncHandler(BaseHandler): d["messages"] = { "chunk": ( yield self._event_serializer.serialize_events( - messages, time_now=time_now, - as_client_event=as_client_event, + messages, time_now=time_now, as_client_event=as_client_event ) ), "start": start_token.to_string(), @@ -200,23 +208,21 @@ class InitialSyncHandler(BaseHandler): d["state"] = yield self._event_serializer.serialize_events( current_state.values(), time_now=time_now, - as_client_event=as_client_event + as_client_event=as_client_event, ) account_data_events = [] tags = tags_by_room.get(event.room_id) if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append( + {"type": "m.tag", "content": {"tags": tags}} + ) account_data = account_data_by_room.get(event.room_id, {}) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append( + {"type": account_data_type, "content": content} + ) d["account_data"] = account_data_events except Exception: @@ -226,10 +232,7 @@ class InitialSyncHandler(BaseHandler): account_data_events = [] for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) now = self.clock.time_msec() @@ -274,7 +277,7 @@ class InitialSyncHandler(BaseHandler): user_id = requester.user.to_string() membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id, + room_id, user_id ) is_peeking = member_event_id is None @@ -290,28 +293,21 @@ class InitialSyncHandler(BaseHandler): account_data_events = [] tags = yield self.store.get_tags_for_room(user_id, room_id) if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) account_data = yield self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) result["account_data"] = account_data_events defer.returnValue(result) @defer.inlineCallbacks - def _room_initial_sync_parted(self, user_id, room_id, pagin_config, - membership, member_event_id, is_peeking): - room_state = yield self.store.get_state_for_events( - [member_event_id], - ) + def _room_initial_sync_parted( + self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking + ): + room_state = yield self.store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -319,14 +315,10 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event( - member_event_id - ) + stream_token = yield self.store.get_stream_token_for_event(member_event_id) messages, token = yield self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=stream_token + room_id, limit=limit, end_token=stream_token ) messages = yield filter_events_for_client( @@ -338,34 +330,39 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() - defer.returnValue({ - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": (yield self._event_serializer.serialize_events( - messages, time_now, - )), - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": (yield self._event_serializer.serialize_events( - room_state.values(), time_now, - )), - "presence": [], - "receipts": [], - }) + defer.returnValue( + { + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": ( + yield self._event_serializer.serialize_events( + messages, time_now + ) + ), + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": ( + yield self._event_serializer.serialize_events( + room_state.values(), time_now + ) + ), + "presence": [], + "receipts": [], + } + ) @defer.inlineCallbacks - def _room_initial_sync_joined(self, user_id, room_id, pagin_config, - membership, is_peeking): - current_state = yield self.state.get_current_state( - room_id=room_id, - ) + def _room_initial_sync_joined( + self, user_id, room_id, pagin_config, membership, is_peeking + ): + current_state = yield self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() state = yield self._event_serializer.serialize_events( - current_state.values(), time_now, + current_state.values(), time_now ) now_token = yield self.hs.get_event_sources().get_current_token() @@ -375,7 +372,8 @@ class InitialSyncHandler(BaseHandler): limit = 10 room_members = [ - m for m in current_state.values() + m + for m in current_state.values() if m.type == EventTypes.Member and m.content["membership"] == Membership.JOIN ] @@ -389,8 +387,7 @@ class InitialSyncHandler(BaseHandler): defer.returnValue([]) states = yield presence_handler.get_states( - [m.user_id for m in room_members], - as_event=True, + [m.user_id for m in room_members], as_event=True ) defer.returnValue(states) @@ -398,8 +395,7 @@ class InitialSyncHandler(BaseHandler): @defer.inlineCallbacks def get_receipts(): receipts = yield self.store.get_linearized_receipts_for_room( - room_id, - to_key=now_token.receipt_key, + room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] @@ -415,14 +411,14 @@ class InitialSyncHandler(BaseHandler): room_id, limit=limit, end_token=now_token.room_key, - ) + ), ], consumeErrors=True, - ).addErrback(unwrapFirstError), + ).addErrback(unwrapFirstError) ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking, + self.store, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) @@ -433,9 +429,9 @@ class InitialSyncHandler(BaseHandler): ret = { "room_id": room_id, "messages": { - "chunk": (yield self._event_serializer.serialize_events( - messages, time_now, - )), + "chunk": ( + yield self._event_serializer.serialize_events(messages, time_now) + ), "start": start_token.to_string(), "end": end_token.to_string(), }, @@ -464,8 +460,8 @@ class InitialSyncHandler(BaseHandler): room_id, EventTypes.RoomHistoryVisibility, "" ) if ( - visibility and - visibility.content["history_visibility"] == "world_readable" + visibility + and visibility.content["history_visibility"] == "world_readable" ): defer.returnValue((Membership.JOIN, None)) return diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 11650dc80c..683da6bf32 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -34,9 +34,10 @@ from synapse.api.errors import ( from synapse.api.room_versions import RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events.validator import EventValidator +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.state import StateFilter -from synapse.types import RoomAlias, UserID +from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background @@ -60,8 +61,9 @@ class MessageHandler(object): self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks - def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key="", is_guest=False): + def get_room_data( + self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False + ): """ Get data from a room. Args: @@ -76,9 +78,7 @@ class MessageHandler(object): ) if membership == Membership.JOIN: - data = yield self.state.get_current_state( - room_id, event_type, state_key - ) + data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( @@ -90,8 +90,12 @@ class MessageHandler(object): @defer.inlineCallbacks def get_state_events( - self, user_id, room_id, state_filter=StateFilter.all(), - at_token=None, is_guest=False, + self, + user_id, + room_id, + state_filter=StateFilter.all(), + at_token=None, + is_guest=False, ): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has @@ -123,50 +127,48 @@ class MessageHandler(object): # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) last_events, _ = yield self.store.get_recent_events_for_room( - room_id, end_token=at_token.room_key, limit=1, + room_id, end_token=at_token.room_key, limit=1 ) if not last_events: - raise NotFoundError("Can't find event for token %s" % (at_token, )) + raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events, + self.store, user_id, last_events ) event = last_events[0] if visible_events: room_state = yield self.store.get_state_for_events( - [event.event_id], state_filter=state_filter, + [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] else: raise AuthError( 403, - "User %s not allowed to view events in room %s at token %s" % ( - user_id, room_id, at_token, - ) + "User %s not allowed to view events in room %s at token %s" + % (user_id, room_id, at_token), ) else: membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable( - room_id, user_id, - ) + yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( - room_id, state_filter=state_filter, + room_id, state_filter=state_filter ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [membership_event_id], state_filter=state_filter, + [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] now = self.clock.time_msec() events = yield self._event_serializer.serialize_events( - room_state.values(), now, + room_state.values(), + now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. bundle_aggregations=False, @@ -210,13 +212,15 @@ class MessageHandler(object): # Loop fell through, AS has no interested users in room raise AuthError(403, "Appservice not in room") - defer.returnValue({ - user_id: { - "avatar_url": profile.avatar_url, - "display_name": profile.display_name, + defer.returnValue( + { + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, + } + for user_id, profile in iteritems(users_with_profile) } - for user_id, profile in iteritems(users_with_profile) - }) + ) class EventCreationHandler(object): @@ -261,9 +265,28 @@ class EventCreationHandler(object): if self._block_events_without_consent_error: self._consent_uri_builder = ConsentURIBuilder(self.config) + if ( + not self.config.worker_app + and self.config.cleanup_extremities_with_dummy_events + ): + self.clock.looping_call( + lambda: run_as_background_process( + "send_dummy_events_to_fill_extremities", + self._send_dummy_events_to_fill_extremities, + ), + 5 * 60 * 1000, + ) + @defer.inlineCallbacks - def create_event(self, requester, event_dict, token_id=None, txn_id=None, - prev_events_and_hashes=None, require_consent=True): + def create_event( + self, + requester, + event_dict, + token_id=None, + txn_id=None, + prev_events_and_hashes=None, + require_consent=True, + ): """ Given a dict from a client, create a new event. @@ -323,8 +346,7 @@ class EventCreationHandler(object): content["avatar_url"] = yield profile.get_avatar_url(target) except Exception as e: logger.info( - "Failed to get profile information for %r: %s", - target, e + "Failed to get profile information for %r: %s", target, e ) is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) @@ -360,16 +382,17 @@ class EventCreationHandler(object): prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event or prev_event.membership != Membership.JOIN: logger.warning( - ("Attempt to send `m.room.aliases` in room %s by user %s but" - " membership is %s"), + ( + "Attempt to send `m.room.aliases` in room %s by user %s but" + " membership is %s" + ), event.room_id, event.sender, prev_event.membership if prev_event else None, ) raise AuthError( - 403, - "You must be in the room to create an alias for it", + 403, "You must be in the room to create an alias for it" ) self.validator.validate_new(event) @@ -436,8 +459,8 @@ class EventCreationHandler(object): # exempt the system notices user if ( - self.config.server_notices_mxid is not None and - user_id == self.config.server_notices_mxid + self.config.server_notices_mxid is not None + and user_id == self.config.server_notices_mxid ): return @@ -450,15 +473,10 @@ class EventCreationHandler(object): return consent_uri = self._consent_uri_builder.build_user_consent_uri( - requester.user.localpart, - ) - msg = self._block_events_without_consent_error % { - 'consent_uri': consent_uri, - } - raise ConsentNotGivenError( - msg=msg, - consent_uri=consent_uri, + requester.user.localpart ) + msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} + raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) @defer.inlineCallbacks def send_nonmember_event(self, requester, event, context, ratelimit=True): @@ -473,8 +491,7 @@ class EventCreationHandler(object): """ if event.type == EventTypes.Member: raise SynapseError( - 500, - "Tried to send member event through non-member codepath" + 500, "Tried to send member event through non-member codepath" ) user = UserID.from_string(event.sender) @@ -486,15 +503,13 @@ class EventCreationHandler(object): if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", - event.event_id, prev_state.event_id, + event.event_id, + prev_state.event_id, ) defer.returnValue(prev_state) yield self.handle_new_client_event( - requester=requester, - event=event, - context=context, - ratelimit=ratelimit, + requester=requester, event=event, context=context, ratelimit=ratelimit ) @defer.inlineCallbacks @@ -520,11 +535,7 @@ class EventCreationHandler(object): @defer.inlineCallbacks def create_and_send_nonmember_event( - self, - requester, - event_dict, - ratelimit=True, - txn_id=None + self, requester, event_dict, ratelimit=True, txn_id=None ): """ Creates an event, then sends it. @@ -539,32 +550,25 @@ class EventCreationHandler(object): # taking longer. with (yield self.limiter.queue(event_dict["room_id"])): event, context = yield self.create_event( - requester, - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id + requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) spam_error = self.spam_checker.check_event_for_spam(event) if spam_error: if not isinstance(spam_error, string_types): spam_error = "Spam is not permitted here" - raise SynapseError( - 403, spam_error, Codes.FORBIDDEN - ) + raise SynapseError(403, spam_error, Codes.FORBIDDEN) yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=ratelimit, + requester, event, context, ratelimit=ratelimit ) defer.returnValue(event) @measure_func("create_new_client_event") @defer.inlineCallbacks - def create_new_client_event(self, builder, requester=None, - prev_events_and_hashes=None): + def create_new_client_event( + self, builder, requester=None, prev_events_and_hashes=None + ): """Create a new event for a local client Args: @@ -584,22 +588,21 @@ class EventCreationHandler(object): """ if prev_events_and_hashes is not None: - assert len(prev_events_and_hashes) <= 10, \ - "Attempting to create an event with %i prev_events" % ( - len(prev_events_and_hashes), + assert len(prev_events_and_hashes) <= 10, ( + "Attempting to create an event with %i prev_events" + % (len(prev_events_and_hashes),) ) else: - prev_events_and_hashes = \ - yield self.store.get_prev_events_for_room(builder.room_id) + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + builder.room_id + ) prev_events = [ (event_id, prev_hashes) for event_id, prev_hashes, _ in prev_events_and_hashes ] - event = yield builder.build( - prev_event_ids=[p for p, _ in prev_events], - ) + event = yield builder.build(prev_event_ids=[p for p, _ in prev_events]) context = yield self.state.compute_event_context(event) if requester: context.app_service = requester.app_service @@ -615,29 +618,19 @@ class EventCreationHandler(object): aggregation_key = relation["key"] already_exists = yield self.store.has_user_annotated_event( - relates_to, event.type, aggregation_key, event.sender, + relates_to, event.type, aggregation_key, event.sender ) if already_exists: raise SynapseError(400, "Can't send same reaction twice") - logger.debug( - "Created event %s", - event.event_id, - ) + logger.debug("Created event %s", event.event_id) - defer.returnValue( - (event, context,) - ) + defer.returnValue((event, context)) @measure_func("handle_new_client_event") @defer.inlineCallbacks def handle_new_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[], + self, requester, event, context, ratelimit=True, extra_users=[] ): """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -653,19 +646,20 @@ class EventCreationHandler(object): extra_users (list(UserID)): Any extra users to notify about event """ - if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""): - room_version = event.content.get( - "room_version", RoomVersions.V1.identifier - ) + if event.is_state() and (event.type, event.state_key) == ( + EventTypes.Create, + "", + ): + room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: room_version = yield self.store.get_room_version(event.room_id) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) try: @@ -682,9 +676,7 @@ class EventCreationHandler(object): logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event( - event, context - ) + yield self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. @@ -705,11 +697,7 @@ class EventCreationHandler(object): return yield self.persist_and_notify_client_event( - requester, - event, - context, - ratelimit=ratelimit, - extra_users=extra_users, + requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True @@ -718,18 +706,12 @@ class EventCreationHandler(object): # Ensure that we actually remove the entries in the push actions # staging area, if we calculated them. run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, + self.store.remove_push_actions_from_staging, event.event_id ) @defer.inlineCallbacks def persist_and_notify_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[], + self, requester, event, context, ratelimit=True, extra_users=[] ): """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -754,20 +736,16 @@ class EventCreationHandler(object): if mapping["room_id"] != event.room_id: raise SynapseError( 400, - "Room alias %s does not point to the room" % ( - room_alias_str, - ) + "Room alias %s does not point to the room" % (room_alias_str,), ) federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: + def is_inviter_member_event(e): - return ( - e.type == EventTypes.Member and - e.sender == event.sender - ) + return e.type == EventTypes.Member and e.sender == event.sender current_state_ids = yield context.get_current_state_ids(self.store) @@ -797,26 +775,21 @@ class EventCreationHandler(object): # to get them to sign the event. returned_invite = yield federation_handler.send_invite( - invitee.domain, - event, + invitee.domain, event ) event.unsigned.pop("room_state", None) # TODO: Make sure the signatures actually are correct. - event.signatures.update( - returned_invite.signatures - ) + event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version = yield self.store.get_room_version(event.room_id) if self.auth.check_redaction(room_version, event, auth_events=auth_events): original_event = yield self.store.get_event( @@ -824,13 +797,10 @@ class EventCreationHandler(object): check_redacted=False, get_prev_content=False, allow_rejected=False, - allow_none=False + allow_none=False, ) if event.user_id != original_event.user_id: - raise AuthError( - 403, - "You don't have permission to redact events" - ) + raise AuthError(403, "You don't have permission to redact events") # We've already checked. event.internal_metadata.recheck_redaction = False @@ -838,24 +808,18 @@ class EventCreationHandler(object): if event.type == EventTypes.Create: prev_state_ids = yield context.get_prev_state_ids(self.store) if prev_state_ids: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) + raise AuthError(403, "Changing the room create event is forbidden") (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context ) - yield self.pusher_pool.on_new_notifications( - event_stream_id, max_stream_id, - ) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users + event, event_stream_id, max_stream_id, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -874,3 +838,54 @@ class EventCreationHandler(object): yield presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") + + @defer.inlineCallbacks + def _send_dummy_events_to_fill_extremities(self): + """Background task to send dummy events into rooms that have a large + number of extremities + """ + + room_ids = yield self.store.get_rooms_with_many_extremities( + min_count=10, limit=5 + ) + + for room_id in room_ids: + # For each room we need to find a joined member we can use to send + # the dummy event with. + + prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + + latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) + + members = yield self.state.get_current_users_in_room( + room_id, latest_event_ids=latest_event_ids + ) + + user_id = None + for member in members: + if self.hs.is_mine_id(member): + user_id = member + break + + if not user_id: + # We don't have a joined user. + # TODO: We should do something here to stop the room from + # appearing next time. + continue + + requester = create_requester(user_id) + + event, context = yield self.create_event( + requester, + { + "type": "org.matrix.dummy_event", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) + + event.internal_metadata.proactively_send = False + + yield self.send_nonmember_event(requester, event, context, ratelimit=False) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8f811e24fe..76ee97ddd3 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -55,9 +55,7 @@ class PurgeStatus(object): self.status = PurgeStatus.STATUS_ACTIVE def asdict(self): - return { - "status": PurgeStatus.STATUS_TEXT[self.status] - } + return {"status": PurgeStatus.STATUS_TEXT[self.status]} class PaginationHandler(object): @@ -79,8 +77,7 @@ class PaginationHandler(object): self._purges_by_id = {} self._event_serializer = hs.get_event_client_serializer() - def start_purge_history(self, room_id, token, - delete_local_events=False): + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. Args: @@ -95,8 +92,7 @@ class PaginationHandler(object): """ if room_id in self._purges_in_progress_by_room: raise SynapseError( - 400, - "History purge already in progress for %s" % (room_id, ), + 400, "History purge already in progress for %s" % (room_id,) ) purge_id = random_string(16) @@ -107,14 +103,12 @@ class PaginationHandler(object): self._purges_by_id[purge_id] = PurgeStatus() run_in_background( - self._purge_history, - purge_id, room_id, token, delete_local_events, + self._purge_history, purge_id, room_id, token, delete_local_events ) return purge_id @defer.inlineCallbacks - def _purge_history(self, purge_id, room_id, token, - delete_local_events): + def _purge_history(self, purge_id, room_id, token, delete_local_events): """Carry out a history purge on a room. Args: @@ -130,16 +124,13 @@ class PaginationHandler(object): self._purges_in_progress_by_room.add(room_id) try: with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history( - room_id, token, delete_local_events, - ) + yield self.store.purge_history(room_id, token, delete_local_events) logger.info("[purge] complete") self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE except Exception: f = Failure() logger.error( - "[purge] failed", - exc_info=(f.type, f.value, f.getTracebackObject()), + "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED finally: @@ -148,6 +139,7 @@ class PaginationHandler(object): # remove the purge from the list 24 hours after it completes def clear_purge(): del self._purges_by_id[purge_id] + self.hs.get_reactor().callLater(24 * 3600, clear_purge) def get_purge_status(self, purge_id): @@ -162,8 +154,14 @@ class PaginationHandler(object): return self._purges_by_id.get(purge_id) @defer.inlineCallbacks - def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True, event_filter=None): + def get_messages( + self, + requester, + room_id=None, + pagin_config=None, + as_client_event=True, + event_filter=None, + ): """Get messages in a room. Args: @@ -182,9 +180,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token_for_room( - room_id=room_id - ) + yield self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key @@ -201,7 +197,7 @@ class PaginationHandler(object): room_id, user_id ) - if source_config.direction == 'b': + if source_config.direction == "b": # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: @@ -235,27 +231,24 @@ class PaginationHandler(object): event_filter=event_filter, ) - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) + next_token = pagin_config.from_token.copy_and_replace("room_key", next_key) if events: if event_filter: events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, - user_id, - events, - is_peeking=(member_event_id is None), + self.store, user_id, events, is_peeking=(member_event_id is None) ) if not events: - defer.returnValue({ - "chunk": [], - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), - }) + defer.returnValue( + { + "chunk": [], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + } + ) state = None if event_filter and event_filter.lazy_load_members() and len(events) > 0: @@ -263,12 +256,11 @@ class PaginationHandler(object): # FIXME: we also care about invite targets etc. state_filter = StateFilter.from_types( - (EventTypes.Member, event.sender) - for event in events + (EventTypes.Member, event.sender) for event in events ) state_ids = yield self.store.get_state_ids_for_event( - events[0].event_id, state_filter=state_filter, + events[0].event_id, state_filter=state_filter ) if state_ids: @@ -280,8 +272,7 @@ class PaginationHandler(object): chunk = { "chunk": ( yield self._event_serializer.serialize_events( - events, time_now, - as_client_event=as_client_event, + events, time_now, as_client_event=as_client_event ) ), "start": pagin_config.from_token.to_string(), @@ -291,8 +282,7 @@ class PaginationHandler(object): if state: chunk["state"] = ( yield self._event_serializer.serialize_events( - state, time_now, - as_client_event=as_client_event, + state, time_now, as_client_event=as_client_event ) ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 557fb5f83d..c80dc2eba0 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -50,16 +50,20 @@ logger = logging.getLogger(__name__) notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "") federation_presence_out_counter = Counter( - "synapse_handler_presence_federation_presence_out", "") + "synapse_handler_presence_federation_presence_out", "" +) presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "") timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "") -federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "") +federation_presence_counter = Counter( + "synapse_handler_presence_federation_presence", "" +) bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "") get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"]) notify_reason_counter = Counter( - "synapse_handler_presence_notify_reason", "", ["reason"]) + "synapse_handler_presence_notify_reason", "", ["reason"] +) state_transition_counter = Counter( "synapse_handler_presence_state_transition", "", ["from", "to"] ) @@ -90,7 +94,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): - def __init__(self, hs): """ @@ -110,31 +113,26 @@ class PresenceHandler(object): federation_registry = hs.get_federation_registry() - federation_registry.register_edu_handler( - "m.presence", self.incoming_presence - ) + federation_registry.register_edu_handler("m.presence", self.incoming_presence) active_presence = self.store.take_presence_startup_info() # A dictionary of the current state of users. This is prefilled with # non-offline presence from the DB. We should fetch from the DB if # we can't find a users presence in here. - self.user_to_current_state = { - state.user_id: state - for state in active_presence - } + self.user_to_current_state = {state.user_id: state for state in active_presence} LaterGauge( - "synapse_handlers_presence_user_to_current_state_size", "", [], - lambda: len(self.user_to_current_state) + "synapse_handlers_presence_user_to_current_state_size", + "", + [], + lambda: len(self.user_to_current_state), ) now = self.clock.time_msec() for state in active_presence: self.wheel_timer.insert( - now=now, - obj=state.user_id, - then=state.last_active_ts + IDLE_TIMER, + now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) self.wheel_timer.insert( now=now, @@ -193,27 +191,21 @@ class PresenceHandler(object): "handle_presence_timeouts", self._handle_timeouts ) - self.clock.call_later( - 30, - self.clock.looping_call, - run_timeout_handler, - 5000, - ) + self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000) def run_persister(): return run_as_background_process( "persist_presence_changes", self._persist_unpersisted_changes ) - self.clock.call_later( - 60, - self.clock.looping_call, - run_persister, - 60 * 1000, - ) + self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000) - LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [], - lambda: len(self.wheel_timer)) + LaterGauge( + "synapse_handlers_presence_wheel_timer_size", + "", + [], + lambda: len(self.wheel_timer), + ) # Used to handle sending of presence to newly joined users/servers if hs.config.use_presence: @@ -241,15 +233,17 @@ class PresenceHandler(object): logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", - len(self.user_to_current_state) + len(self.user_to_current_state), ) if self.unpersisted_users_changes: - yield self.store.update_presence([ - self.user_to_current_state[user_id] - for user_id in self.unpersisted_users_changes - ]) + yield self.store.update_presence( + [ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ] + ) logger.info("Finished _on_shutdown") @defer.inlineCallbacks @@ -261,13 +255,10 @@ class PresenceHandler(object): self.unpersisted_users_changes = set() if unpersisted: - logger.info( - "Persisting %d upersisted presence updates", len(unpersisted) + logger.info("Persisting %d upersisted presence updates", len(unpersisted)) + yield self.store.update_presence( + [self.user_to_current_state[user_id] for user_id in unpersisted] ) - yield self.store.update_presence([ - self.user_to_current_state[user_id] - for user_id in unpersisted - ]) @defer.inlineCallbacks def _update_states(self, new_states): @@ -303,10 +294,11 @@ class PresenceHandler(object): ) new_state, should_notify, should_ping = handle_update( - prev_state, new_state, + prev_state, + new_state, is_mine=self.is_mine_id(user_id), wheel_timer=self.wheel_timer, - now=now + now=now, ) self.user_to_current_state[user_id] = new_state @@ -328,7 +320,8 @@ class PresenceHandler(object): self.unpersisted_users_changes -= set(to_notify.keys()) to_federation_ping = { - user_id: state for user_id, state in to_federation_ping.items() + user_id: state + for user_id, state in to_federation_ping.items() if user_id not in to_notify } if to_federation_ping: @@ -351,8 +344,8 @@ class PresenceHandler(object): # Check whether the lists of syncing processes from an external # process have expired. expired_process_ids = [ - process_id for process_id, last_update - in self.external_process_last_updated_ms.items() + process_id + for process_id, last_update in self.external_process_last_updated_ms.items() if now - last_update > EXTERNAL_PROCESS_EXPIRY ] for process_id in expired_process_ids: @@ -362,9 +355,7 @@ class PresenceHandler(object): self.external_process_last_update.pop(process_id) states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) + self.user_to_current_state.get(user_id, UserPresenceState.default(user_id)) for user_id in users_to_check ] @@ -394,9 +385,7 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) - new_fields = { - "last_active_ts": self.clock.time_msec(), - } + new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE @@ -430,15 +419,23 @@ class PresenceHandler(object): if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states([prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=self.clock.time_msec(), - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), + ) + ] + ) else: - yield self._update_states([prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec() + ) + ] + ) @defer.inlineCallbacks def _end(): @@ -446,9 +443,13 @@ class PresenceHandler(object): self.user_to_num_current_syncs[user_id] -= 1 prev_state = yield self.current_state_for_user(user_id) - yield self._update_states([prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec() + ) + ] + ) except Exception: logger.exception("Error updating presence after sync") @@ -469,7 +470,8 @@ class PresenceHandler(object): """ if self.hs.config.use_presence: syncing_user_ids = { - user_id for user_id, count in self.user_to_num_current_syncs.items() + user_id + for user_id, count in self.user_to_num_current_syncs.items() if count } for user_ids in self.external_process_to_current_syncs.values(): @@ -479,7 +481,9 @@ class PresenceHandler(object): return set() @defer.inlineCallbacks - def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): + def update_external_syncs_row( + self, process_id, user_id, is_syncing, sync_time_msec + ): """Update the syncing users for an external process as a delta. Args: @@ -500,20 +504,22 @@ class PresenceHandler(object): updates = [] if is_syncing and user_id not in process_presence: if prev_state.state == PresenceState.OFFLINE: - updates.append(prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=sync_time_msec, - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=sync_time_msec, + last_user_sync_ts=sync_time_msec, + ) + ) else: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) process_presence.add(user_id) elif user_id in process_presence: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) if not is_syncing: process_presence.discard(user_id) @@ -537,12 +543,12 @@ class PresenceHandler(object): prev_states = yield self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states([ - prev_state.copy_and_replace( - last_user_sync_ts=time_now_ms, - ) - for prev_state in itervalues(prev_states) - ]) + yield self._update_states( + [ + prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) + for prev_state in itervalues(prev_states) + ] + ) self.external_process_last_updated_ms.pop(process_id, None) @defer.inlineCallbacks @@ -574,8 +580,7 @@ class PresenceHandler(object): missing = [user_id for user_id, state in iteritems(states) if not state] if missing: new = { - user_id: UserPresenceState.default(user_id) - for user_id in missing + user_id: UserPresenceState.default(user_id) for user_id in missing } states.update(new) self.user_to_current_state.update(new) @@ -593,8 +598,10 @@ class PresenceHandler(object): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states] + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states], ) self._push_to_remotes(states) @@ -605,8 +612,10 @@ class PresenceHandler(object): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states] + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states], ) def _push_to_remotes(self, states): @@ -631,15 +640,15 @@ class PresenceHandler(object): user_id = push.get("user_id", None) if not user_id: logger.info( - "Got presence update from %r with no 'user_id': %r", - origin, push, + "Got presence update from %r with no 'user_id': %r", origin, push ) continue if get_domain_from_id(user_id) != origin: logger.info( "Got presence update from %r with bad 'user_id': %r", - origin, user_id, + origin, + user_id, ) continue @@ -647,14 +656,12 @@ class PresenceHandler(object): if not presence_state: logger.info( "Got presence update from %r with no 'presence_state': %r", - origin, push, + origin, + push, ) continue - new_fields = { - "state": presence_state, - "last_federation_update_ts": now, - } + new_fields = {"state": presence_state, "last_federation_update_ts": now} last_active_ago = push.get("last_active_ago", None) if last_active_ago is not None: @@ -672,10 +679,7 @@ class PresenceHandler(object): @defer.inlineCallbacks def get_state(self, target_user, as_event=False): - results = yield self.get_states( - [target_user.to_string()], - as_event=as_event, - ) + results = yield self.get_states([target_user.to_string()], as_event=as_event) defer.returnValue(results[0]) @@ -699,13 +703,15 @@ class PresenceHandler(object): now = self.clock.time_msec() if as_event: - defer.returnValue([ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ]) + defer.returnValue( + [ + { + "type": "m.presence", + "content": format_user_presence_state(state, now), + } + for state in updates + ] + ) else: defer.returnValue(updates) @@ -717,7 +723,9 @@ class PresenceHandler(object): presence = state["presence"] valid_presence = ( - PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, ) if presence not in valid_presence: raise SynapseError(400, "Invalid presence state") @@ -726,9 +734,7 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) - new_fields = { - "state": presence - } + new_fields = {"state": presence} if not ignore_status_msg: msg = status_msg if presence != PresenceState.OFFLINE else None @@ -877,8 +883,7 @@ class PresenceHandler(object): hosts = set(host for host in hosts if host != self.server_name) self.federation.send_presence_to_destinations( - states=[state], - destinations=hosts, + states=[state], destinations=hosts ) else: # A remote user has joined the room, so we need to: @@ -904,7 +909,8 @@ class PresenceHandler(object): # default state. now = self.clock.time_msec() states = [ - state for state in states.values() + state + for state in states.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None @@ -912,8 +918,7 @@ class PresenceHandler(object): if states: self.federation.send_presence_to_destinations( - states=states, - destinations=[get_domain_from_id(user_id)], + states=states, destinations=[get_domain_from_id(user_id)] ) @@ -937,7 +942,10 @@ def should_notify(old_state, new_state): notify_reason_counter.labels("current_active_change").inc() return True - if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + if ( + new_state.last_active_ts - old_state.last_active_ts + > LAST_ACTIVE_GRANULARITY + ): # Only notify about last active bumps if we're not currently acive if not new_state.currently_active: notify_reason_counter.labels("last_active_change_online").inc() @@ -958,9 +966,7 @@ def format_user_presence_state(state, now, include_user_id=True): The "user_id" is optional so that this function can be used to format presence updates for client /sync responses and for federation /send requests. """ - content = { - "presence": state.state, - } + content = {"presence": state.state} if include_user_id: content["user_id"] = state.user_id if state.last_active_ts: @@ -986,8 +992,15 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function - def get_new_events(self, user, from_key, room_ids=None, include_offline=True, - explicit_room_id=None, **kwargs): + def get_new_events( + self, + user, + from_key, + room_ids=None, + include_offline=True, + explicit_room_id=None, + **kwargs + ): # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1004,11 +1017,28 @@ class PresenceEventSource(object): if from_key is not None: from_key = int(from_key) + max_token = self.store.get_current_presence_token() + if from_key == max_token: + # This is necessary as due to the way stream ID generators work + # we may get updates that have a stream ID greater than the max + # token (e.g. max_token is N but stream generator may return + # results for N+2, due to N+1 not having finished being + # persisted yet). + # + # This is usually fine, as it just means that we may send down + # some presence updates multiple times. However, we need to be + # careful that the sync stream either actually does make some + # progress or doesn't return, otherwise clients will end up + # tight looping calling /sync due to it immediately returning + # the same token repeatedly. + # + # Hence this guard where we just return nothing so that the sync + # doesn't return. C.f. #5503. + defer.returnValue(([], max_token)) + presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - max_token = self.store.get_current_presence_token() - users_interested_in = yield self._get_interested_in(user, explicit_room_id) user_ids_changed = set() @@ -1030,7 +1060,7 @@ class PresenceEventSource(object): if from_key: user_ids_changed = stream_change_cache.get_entities_changed( - users_interested_in, from_key, + users_interested_in, from_key ) else: user_ids_changed = users_interested_in @@ -1040,10 +1070,16 @@ class PresenceEventSource(object): if include_offline: defer.returnValue((list(updates.values()), max_token)) else: - defer.returnValue(([ - s for s in itervalues(updates) - if s.state != PresenceState.OFFLINE - ], max_token)) + defer.returnValue( + ( + [ + s + for s in itervalues(updates) + if s.state != PresenceState.OFFLINE + ], + max_token, + ) + ) def get_current_key(self): return self.store.get_current_presence_token() @@ -1061,13 +1097,13 @@ class PresenceEventSource(object): users_interested_in.add(user_id) # So that we receive our own presence users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id, on_invalidate=cache_context.invalidate, + user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: user_ids = yield self.store.get_users_in_room( - explicit_room_id, on_invalidate=cache_context.invalidate, + explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1123,9 +1159,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): if now - state.last_active_ts > IDLE_TIMER: # Currently online, but last activity ages ago so auto # idle - state = state.copy_and_replace( - state=PresenceState.UNAVAILABLE, - ) + state = state.copy_and_replace(state=PresenceState.UNAVAILABLE) changed = True elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: # So that we send down a notification that we've @@ -1145,8 +1179,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) if now - sync_or_active > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( - state=PresenceState.OFFLINE, - status_msg=None, + state=PresenceState.OFFLINE, status_msg=None ) changed = True else: @@ -1155,10 +1188,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: # The other side seems to have disappeared. - state = state.copy_and_replace( - state=PresenceState.OFFLINE, - status_msg=None, - ) + state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None) changed = True return state if changed else None @@ -1193,21 +1223,17 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): if new_state.state == PresenceState.ONLINE: # Idle timer wheel_timer.insert( - now=now, - obj=user_id, - then=new_state.last_active_ts + IDLE_TIMER + now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER ) active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY - new_state = new_state.copy_and_replace( - currently_active=active, - ) + new_state = new_state.copy_and_replace(currently_active=active) if active: wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY, ) if new_state.state != PresenceState.OFFLINE: @@ -1215,29 +1241,25 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, ) last_federate = new_state.last_federation_update_ts if now - last_federate > FEDERATION_PING_INTERVAL: # Been a while since we've poked remote servers - new_state = new_state.copy_and_replace( - last_federation_update_ts=now, - ) + new_state = new_state.copy_and_replace(last_federation_update_ts=now) federation_ping = True else: wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT, ) # Check whether the change was something worth notifying about if should_notify(prev_state, new_state): - new_state = new_state.copy_and_replace( - last_federation_update_ts=now, - ) + new_state = new_state.copy_and_replace(last_federation_update_ts=now) persist_and_notify = True return new_state, persist_and_notify, federation_ping diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 3e04233394..d8462b75ec 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -73,18 +73,13 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({ - "displayname": displayname, - "avatar_url": avatar_url, - }) + defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) else: try: result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": user_id, - }, + args={"user_id": user_id}, ignore_backoff=True, ) defer.returnValue(result) @@ -113,10 +108,7 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({ - "displayname": displayname, - "avatar_url": avatar_url, - }) + defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) else: profile = yield self.store.get_from_remote_profile_cache(user_id) defer.returnValue(profile or {}) @@ -139,10 +131,7 @@ class BaseProfileHandler(BaseHandler): result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": target_user.to_string(), - "field": "displayname", - }, + args={"user_id": target_user.to_string(), "field": "displayname"}, ignore_backoff=True, ) except RequestSendFailed as e: @@ -170,15 +159,13 @@ class BaseProfileHandler(BaseHandler): if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( - 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ), + 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - if new_displayname == '': + if new_displayname == "": new_displayname = None - yield self.store.set_profile_displayname( - target_user.localpart, new_displayname - ) + yield self.store.set_profile_displayname(target_user.localpart, new_displayname) if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(target_user.localpart) @@ -205,10 +192,7 @@ class BaseProfileHandler(BaseHandler): result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": target_user.to_string(), - "field": "avatar_url", - }, + args={"user_id": target_user.to_string(), "field": "avatar_url"}, ignore_backoff=True, ) except RequestSendFailed as e: @@ -230,12 +214,10 @@ class BaseProfileHandler(BaseHandler): if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( - 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ), + 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) - yield self.store.set_profile_avatar_url( - target_user.localpart, new_avatar_url - ) + yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(target_user.localpart) @@ -278,9 +260,7 @@ class BaseProfileHandler(BaseHandler): yield self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user( - target_user.to_string(), - ) + room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() @@ -296,8 +276,7 @@ class BaseProfileHandler(BaseHandler): ) except Exception as e: logger.warn( - "Failed to update join event for room %s - %s", - room_id, str(e) + "Failed to update join event for room %s - %s", room_id, str(e) ) @defer.inlineCallbacks @@ -325,11 +304,9 @@ class BaseProfileHandler(BaseHandler): return try: - requester_rooms = yield self.store.get_rooms_for_user( - requester.to_string() - ) + requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) target_user_rooms = yield self.store.get_rooms_for_user( - target_user.to_string(), + target_user.to_string() ) # Check if the room lists have no elements in common. @@ -353,12 +330,12 @@ class MasterProfileHandler(BaseProfileHandler): assert hs.config.worker_app is None self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS ) def _start_update_remote_profile_cache(self): return run_as_background_process( - "Update remote profile", self._update_remote_profile_cache, + "Update remote profile", self._update_remote_profile_cache ) @defer.inlineCallbacks @@ -372,7 +349,7 @@ class MasterProfileHandler(BaseProfileHandler): for user_id, displayname, avatar_url in entries: is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( - user_id, + user_id ) if not is_subscribed: yield self.store.maybe_delete_remote_profile_cache(user_id) @@ -382,9 +359,7 @@ class MasterProfileHandler(BaseProfileHandler): profile = yield self.federation.make_query( destination=get_domain_from_id(user_id), query_type="profile", - args={ - "user_id": user_id, - }, + args={"user_id": user_id}, ignore_backoff=True, ) except Exception: @@ -399,6 +374,4 @@ class MasterProfileHandler(BaseProfileHandler): new_avatar = profile.get("avatar_url") # We always hit update to update the last_check timestamp - yield self.store.update_remote_profile_cache( - user_id, new_name, new_avatar - ) + yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 32108568c6..3e4d8c93a4 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -43,7 +43,7 @@ class ReadMarkerHandler(BaseHandler): with (yield self.read_marker_linearizer.queue((room_id, user_id))): existing_read_marker = yield self.store.get_account_data_for_room_and_type( - user_id, room_id, "m.fully_read", + user_id, room_id, "m.fully_read" ) should_update = True @@ -51,14 +51,11 @@ class ReadMarkerHandler(BaseHandler): if existing_read_marker: # Only update if the new marker is ahead in the stream should_update = yield self.store.is_event_after( - event_id, - existing_read_marker['event_id'] + event_id, existing_read_marker["event_id"] ) if should_update: - content = { - "event_id": event_id - } + content = {"event_id": event_id} max_id = yield self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 274d2946ad..a85dd8cdee 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -88,19 +88,16 @@ class ReceiptsHandler(BaseHandler): affected_room_ids = list(set([r.room_id for r in receipts])) - self.notifier.on_new_event( - "receipt_key", max_batch_id, rooms=affected_room_ids - ) + self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids, + min_batch_id, max_batch_id, affected_room_ids ) defer.returnValue(True) @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, - event_id): + def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -109,9 +106,7 @@ class ReceiptsHandler(BaseHandler): receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], - data={ - "ts": int(self.clock.time_msec()), - }, + data={"ts": int(self.clock.time_msec())}, ) is_new = yield self._handle_new_receipts([receipt]) @@ -125,8 +120,7 @@ class ReceiptsHandler(BaseHandler): """Gets all receipts for a room, upto the given key. """ result = yield self.store.get_linearized_receipts_for_room( - room_id, - to_key=to_key, + room_id, to_key=to_key ) if not result: @@ -148,14 +142,12 @@ class ReceiptEventSource(object): defer.returnValue(([], to_key)) events = yield self.store.get_linearized_receipts_for_rooms( - room_ids, - from_key=from_key, - to_key=to_key, + room_ids, from_key=from_key, to_key=to_key ) defer.returnValue((events, to_key)) - def get_current_key(self, direction='f'): + def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() @defer.inlineCallbacks @@ -169,9 +161,7 @@ class ReceiptEventSource(object): room_ids = yield self.store.get_rooms_for_user(user.to_string()) events = yield self.store.get_linearized_receipts_for_rooms( - room_ids, - from_key=from_key, - to_key=to_key, + room_ids, from_key=from_key, to_key=to_key ) defer.returnValue((events, to_key)) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 9a388ea013..e487b90c08 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -47,7 +47,6 @@ logger = logging.getLogger(__name__) class RegistrationHandler(BaseHandler): - def __init__(self, hs): """ @@ -69,44 +68,37 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._generate_user_id_linearizer = Linearizer( - name="_generate_user_id_linearizer", + name="_generate_user_id_linearizer" ) self._server_notices_mxid = hs.config.server_notices_mxid if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) - self._register_device_client = ( - RegisterDeviceReplicationServlet.make_client(hs) + self._register_device_client = RegisterDeviceReplicationServlet.make_client( + hs ) - self._post_registration_client = ( - ReplicationPostRegisterActionsServlet.make_client(hs) + self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( + hs ) else: self.device_handler = hs.get_device_handler() self.pusher_pool = hs.get_pusherpool() @defer.inlineCallbacks - def check_username(self, localpart, guest_access_token=None, - assigned_user_id=None): + def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", - Codes.INVALID_USERNAME + Codes.INVALID_USERNAME, ) if not localpart: - raise SynapseError( - 400, - "User ID cannot be empty", - Codes.INVALID_USERNAME - ) + raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME) - if localpart[0] == '_': + if localpart[0] == "_": raise SynapseError( - 400, - "User ID may not begin with _", - Codes.INVALID_USERNAME + 400, "User ID may not begin with _", Codes.INVALID_USERNAME ) user = UserID(localpart, self.hs.hostname) @@ -126,19 +118,15 @@ class RegistrationHandler(BaseHandler): if len(user_id) > MAX_USERID_LENGTH: raise SynapseError( 400, - "User ID may not be longer than %s characters" % ( - MAX_USERID_LENGTH, - ), - Codes.INVALID_USERNAME + "User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,), + Codes.INVALID_USERNAME, ) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: raise SynapseError( - 400, - "User ID already taken.", - errcode=Codes.USER_IN_USE, + 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) user_data = yield self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: @@ -203,8 +191,7 @@ class RegistrationHandler(BaseHandler): try: int(localpart) raise RegistrationError( - 400, - "Numeric user IDs are reserved for guest users." + 400, "Numeric user IDs are reserved for guest users." ) except ValueError: pass @@ -283,9 +270,7 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - yield self._register_email_threepid( - user_id, threepid_dict, None, False, - ) + yield self._register_email_threepid(user_id, threepid_dict, None, False) defer.returnValue((user_id, token)) @@ -318,8 +303,8 @@ class RegistrationHandler(BaseHandler): room_alias = RoomAlias.from_string(r) if self.hs.hostname != room_alias.domain: logger.warning( - 'Cannot create room alias %s, ' - 'it does not match server domain', + "Cannot create room alias %s, " + "it does not match server domain", r, ) else: @@ -332,7 +317,7 @@ class RegistrationHandler(BaseHandler): fake_requester, config={ "preset": "public_chat", - "room_alias_name": room_alias_localpart + "room_alias_name": room_alias_localpart, }, ratelimit=False, ) @@ -364,8 +349,9 @@ class RegistrationHandler(BaseHandler): raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): raise SynapseError( - 400, "Invalid user localpart for this application service.", - errcode=Codes.EXCLUSIVE + 400, + "Invalid user localpart for this application service.", + errcode=Codes.EXCLUSIVE, ) service_id = service.id if service.is_exclusive_user(user_id) else None @@ -391,17 +377,15 @@ class RegistrationHandler(BaseHandler): """ captcha_response = yield self._validate_captcha( - ip, - private_key, - challenge, - response + ip, private_key, challenge, response ) if not captcha_response["valid"]: - logger.info("Invalid captcha entered from %s. Error: %s", - ip, captcha_response["error_url"]) - raise InvalidCaptchaError( - error_url=captcha_response["error_url"] + logger.info( + "Invalid captcha entered from %s. Error: %s", + ip, + captcha_response["error_url"], ) + raise InvalidCaptchaError(error_url=captcha_response["error_url"]) else: logger.info("Valid captcha entered from %s", ip) @@ -414,8 +398,11 @@ class RegistrationHandler(BaseHandler): """ for c in threepidCreds: - logger.info("validating threepidcred sid %s on id server %s", - c['sid'], c['idServer']) + logger.info( + "validating threepidcred sid %s on id server %s", + c["sid"], + c["idServer"], + ) try: threepid = yield self.identity_handler.threepid_from_creds(c) except Exception: @@ -424,13 +411,14 @@ class RegistrationHandler(BaseHandler): if not threepid: raise RegistrationError(400, "Couldn't validate 3pid") - logger.info("got threepid with medium '%s' and address '%s'", - threepid['medium'], threepid['address']) + logger.info( + "got threepid with medium '%s' and address '%s'", + threepid["medium"], + threepid["address"], + ) - if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']): - raise RegistrationError( - 403, "Third party identifier is not allowed" - ) + if not check_3pid_allowed(self.hs, threepid["medium"], threepid["address"]): + raise RegistrationError(403, "Third party identifier is not allowed") @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): @@ -449,23 +437,23 @@ class RegistrationHandler(BaseHandler): if self._server_notices_mxid is not None: if user_id == self._server_notices_mxid: raise SynapseError( - 400, "This user ID is reserved.", - errcode=Codes.EXCLUSIVE + 400, "This user ID is reserved.", errcode=Codes.EXCLUSIVE ) # valid user IDs must not clash with any user ID namespaces claimed by # application services. services = self.store.get_app_services() interested_services = [ - s for s in services - if s.is_interested_in_user(user_id) - and s != allowed_appservice + s + for s in services + if s.is_interested_in_user(user_id) and s != allowed_appservice ] for service in interested_services: if service.is_exclusive_user(user_id): raise SynapseError( - 400, "This user ID is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This user ID is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) @defer.inlineCallbacks @@ -491,14 +479,13 @@ class RegistrationHandler(BaseHandler): dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. """ - response = yield self._submit_captcha(ip_addr, private_key, challenge, - response) + response = yield self._submit_captcha(ip_addr, private_key, challenge, response) # parse Google's response. Lovely format.. - lines = response.split('\n') + lines = response.split("\n") json = { - "valid": lines[0] == 'true', - "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + - "error=%s" % lines[1] + "valid": lines[0] == "true", + "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + + "error=%s" % lines[1], } defer.returnValue(json) @@ -510,17 +497,16 @@ class RegistrationHandler(BaseHandler): data = yield self.captcha_client.post_urlencoded_get_raw( "http://www.recaptcha.net:80/recaptcha/api/verify", args={ - 'privatekey': private_key, - 'remoteip': ip_addr, - 'challenge': challenge, - 'response': response - } + "privatekey": private_key, + "remoteip": ip_addr, + "challenge": challenge, + "response": response, + }, ) defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, - password_hash=None): + def get_or_create_user(self, requester, localpart, displayname, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -565,7 +551,7 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) yield self.profile_handler.set_displayname( - user, requester, displayname, by_admin=True, + user, requester, displayname, by_admin=True ) defer.returnValue((user_id, token)) @@ -587,15 +573,12 @@ class RegistrationHandler(BaseHandler): """ access_token = yield self.store.get_3pid_guest_access_token(medium, address) if access_token: - user_info = yield self.auth.get_user_by_access_token( - access_token - ) + user_info = yield self.auth.get_user_by_access_token(access_token) defer.returnValue((user_info["user"].to_string(), access_token)) user_id, access_token = yield self.register( - generate_token=True, - make_guest=True + generate_token=True, make_guest=True ) access_token = yield self.store.save_or_get_3pid_guest_access_token( medium, address, access_token, inviter_user_id @@ -616,9 +599,9 @@ class RegistrationHandler(BaseHandler): ) room_id = room_id.to_string() else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) yield room_member_handler.update_membership( requester=requester, @@ -629,10 +612,19 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) - def register_with_store(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, - user_type=None, address=None): + def register_with_store( + self, + user_id, + token=None, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + address=None, + ): """Register user in the datastore. Args: @@ -661,14 +653,15 @@ class RegistrationHandler(BaseHandler): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - address, time_now_s=time_now, + address, + time_now_s=time_now, rate_hz=self.hs.config.rc_registration.per_second, burst_count=self.hs.config.rc_registration.burst_count, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) if self.hs.config.worker_app: @@ -698,8 +691,7 @@ class RegistrationHandler(BaseHandler): ) @defer.inlineCallbacks - def register_device(self, user_id, device_id, initial_display_name, - is_guest=False): + def register_device(self, user_id, device_id, initial_display_name, is_guest=False): """Register a device for a user and generate an access token. Args: @@ -732,14 +724,15 @@ class RegistrationHandler(BaseHandler): ) else: access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, + user_id, device_id=device_id ) defer.returnValue((device_id, access_token)) @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token, - bind_email, bind_msisdn): + def post_registration_actions( + self, user_id, auth_result, access_token, bind_email, bind_msisdn + ): """A user has completed registration Args: @@ -773,20 +766,15 @@ class RegistrationHandler(BaseHandler): yield self.store.upsert_monthly_active_user(user_id) yield self._register_email_threepid( - user_id, threepid, access_token, - bind_email, + user_id, threepid, access_token, bind_email ) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid( - user_id, threepid, bind_msisdn, - ) + yield self._register_msisdn_threepid(user_id, threepid, bind_msisdn) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented( - user_id, self.hs.config.user_consent_version, - ) + yield self._on_user_consented(user_id, self.hs.config.user_consent_version) @defer.inlineCallbacks def _on_user_consented(self, user_id, consent_version): @@ -798,9 +786,7 @@ class RegistrationHandler(BaseHandler): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version( - user_id, consent_version, - ) + yield self.store.user_set_consent_version(user_id, consent_version) yield self.post_consent_actions(user_id) @defer.inlineCallbacks @@ -824,33 +810,30 @@ class RegistrationHandler(BaseHandler): Returns: defer.Deferred: """ - reqd = ('medium', 'address', 'validated_at') + reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): # This will only happen if the ID server returns a malformed response logger.info("Can't add incomplete 3pid") return yield self._auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) # And we add an email pusher for them by default, but only # if email notifications are enabled (so people don't start # getting mail spam where they weren't before if email # notifs are set up on a home server) - if (self.hs.config.email_enable_notifs and - self.hs.config.email_notif_for_new_users - and token): + if ( + self.hs.config.email_enable_notifs + and self.hs.config.email_notif_for_new_users + and token + ): # Pull the ID of the access token back out of the db # It would really make more sense for this to be passed # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token( - token - ) + user_tuple = yield self.store.get_user_by_access_token(token) token_id = user_tuple["token_id"] yield self.pusher_pool.add_pusher( @@ -867,11 +850,9 @@ class RegistrationHandler(BaseHandler): if bind_email: logger.info("bind_email specified: binding") - logger.debug("Binding emails %s to %s" % ( - threepid, user_id - )) + logger.debug("Binding emails %s to %s" % (threepid, user_id)) yield self.identity_handler.bind_threepid( - threepid['threepid_creds'], user_id + threepid["threepid_creds"], user_id ) else: logger.info("bind_email not specified: not binding email") @@ -894,7 +875,7 @@ class RegistrationHandler(BaseHandler): defer.Deferred: """ try: - assert_params_in_dict(threepid, ['medium', 'address', 'validated_at']) + assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) except SynapseError as ex: if ex.errcode == Codes.MISSING_PARAM: # This will only happen if the ID server returns a malformed response @@ -903,17 +884,14 @@ class RegistrationHandler(BaseHandler): raise yield self._auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) if bind_msisdn: logger.info("bind_msisdn specified: binding") logger.debug("Binding msisdn %s to %s", threepid, user_id) yield self.identity_handler.bind_threepid( - threepid['threepid_creds'], user_id + threepid["threepid_creds"], user_id ) else: logger.info("bind_msisdn not specified: not binding msisdn") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 74793bab33..db3f8cb76b 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -32,6 +32,7 @@ from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils from synapse.util.async_helpers import Linearizer +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -40,6 +41,8 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" +FIVE_MINUTES_IN_MS = 5 * 60 * 1000 + class RoomCreationHandler(BaseHandler): @@ -75,6 +78,12 @@ class RoomCreationHandler(BaseHandler): # linearizer to stop two upgrades happening at once self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") + # If a user tries to update the same room multiple times in quick + # succession, only process the first attempt and return its result to + # subsequent requests + self._upgrade_response_cache = ResponseCache( + hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS + ) self._server_notices_mxid = hs.config.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() @@ -95,70 +104,100 @@ class RoomCreationHandler(BaseHandler): user_id = requester.user.to_string() - with (yield self._upgrade_linearizer.queue(old_room_id)): - # start by allocating a new room id - r = yield self.store.get_room(old_room_id) - if r is None: - raise NotFoundError("Unknown room id %s" % (old_room_id,)) - new_room_id = yield self._generate_room_id( - creator_id=user_id, is_public=r["is_public"], - ) + # Check if this room is already being upgraded by another person + for key in self._upgrade_response_cache.pending_result_cache: + if key[0] == old_room_id and key[1] != user_id: + # Two different people are trying to upgrade the same room. + # Send the second an error. + # + # Note that this of course only gets caught if both users are + # on the same homeserver. + raise SynapseError( + 400, "An upgrade for this room is currently in progress" + ) - logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) + # Upgrade the room + # + # If this user has sent multiple upgrade requests for the same room + # and one of them is not complete yet, cache the response and + # return it to all subsequent requests + ret = yield self._upgrade_response_cache.wrap( + (old_room_id, user_id), + self._upgrade_room, + requester, + old_room_id, + new_version, # args for _upgrade_room + ) + defer.returnValue(ret) - # we create and auth the tombstone event before properly creating the new - # room, to check our user has perms in the old room. - tombstone_event, tombstone_context = ( - yield self.event_creation_handler.create_event( - requester, { - "type": EventTypes.Tombstone, - "state_key": "", - "room_id": old_room_id, - "sender": user_id, - "content": { - "body": "This room has been replaced", - "replacement_room": new_room_id, - } - }, - token_id=requester.access_token_id, - ) - ) - old_room_version = yield self.store.get_room_version(old_room_id) - yield self.auth.check_from_context( - old_room_version, tombstone_event, tombstone_context, - ) + @defer.inlineCallbacks + def _upgrade_room(self, requester, old_room_id, new_version): + user_id = requester.user.to_string() + + # start by allocating a new room id + r = yield self.store.get_room(old_room_id) + if r is None: + raise NotFoundError("Unknown room id %s" % (old_room_id,)) + new_room_id = yield self._generate_room_id( + creator_id=user_id, is_public=r["is_public"] + ) - yield self.clone_existing_room( + logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) + + # we create and auth the tombstone event before properly creating the new + # room, to check our user has perms in the old room. + tombstone_event, tombstone_context = ( + yield self.event_creation_handler.create_event( requester, - old_room_id=old_room_id, - new_room_id=new_room_id, - new_room_version=new_version, - tombstone_event_id=tombstone_event.event_id, + { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, + }, + }, + token_id=requester.access_token_id, ) + ) + old_room_version = yield self.store.get_room_version(old_room_id) + yield self.auth.check_from_context( + old_room_version, tombstone_event, tombstone_context + ) - # now send the tombstone - yield self.event_creation_handler.send_nonmember_event( - requester, tombstone_event, tombstone_context, - ) + yield self.clone_existing_room( + requester, + old_room_id=old_room_id, + new_room_id=new_room_id, + new_room_version=new_version, + tombstone_event_id=tombstone_event.event_id, + ) - old_room_state = yield tombstone_context.get_current_state_ids(self.store) + # now send the tombstone + yield self.event_creation_handler.send_nonmember_event( + requester, tombstone_event, tombstone_context + ) - # update any aliases - yield self._move_aliases_to_new_room( - requester, old_room_id, new_room_id, old_room_state, - ) + old_room_state = yield tombstone_context.get_current_state_ids(self.store) - # and finally, shut down the PLs in the old room, and update them in the new - # room. - yield self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state, - ) + # update any aliases + yield self._move_aliases_to_new_room( + requester, old_room_id, new_room_id, old_room_state + ) + + # and finally, shut down the PLs in the old room, and update them in the new + # room. + yield self._update_upgraded_room_pls( + requester, old_room_id, new_room_id, old_room_state + ) - defer.returnValue(new_room_id) + defer.returnValue(new_room_id) @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state, + self, requester, old_room_id, new_room_id, old_room_state ): """Send updated power levels in both rooms after an upgrade @@ -176,7 +215,7 @@ class RoomCreationHandler(BaseHandler): if old_room_pl_event_id is None: logger.warning( "Not supported: upgrading a room with no PL event. Not setting PLs " - "in old room.", + "in old room." ) return @@ -197,45 +236,48 @@ class RoomCreationHandler(BaseHandler): if current < restricted_level: logger.info( "Setting level for %s in %s to %i (was %i)", - v, old_room_id, restricted_level, current, + v, + old_room_id, + restricted_level, + current, ) pl_content[v] = restricted_level updated = True else: - logger.info( - "Not setting level for %s (already %i)", - v, current, - ) + logger.info("Not setting level for %s (already %i)", v, current) if updated: try: yield self.event_creation_handler.create_and_send_nonmember_event( - requester, { + requester, + { "type": EventTypes.PowerLevels, - "state_key": '', + "state_key": "", "room_id": old_room_id, "sender": requester.user.to_string(), "content": pl_content, - }, ratelimit=False, + }, + ratelimit=False, ) except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) logger.info("Setting correct PLs in new room") yield self.event_creation_handler.create_and_send_nonmember_event( - requester, { + requester, + { "type": EventTypes.PowerLevels, - "state_key": '', + "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), "content": old_room_pl_state.content, - }, ratelimit=False, + }, + ratelimit=False, ) @defer.inlineCallbacks def clone_existing_room( - self, requester, old_room_id, new_room_id, new_room_version, - tombstone_event_id, + self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id ): """Populate a new room based on an old room @@ -257,10 +299,7 @@ class RoomCreationHandler(BaseHandler): creation_content = { "room_version": new_room_version, - "predecessor": { - "room_id": old_room_id, - "event_id": tombstone_event_id, - } + "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, } # Check if old room was non-federatable @@ -289,7 +328,7 @@ class RoomCreationHandler(BaseHandler): ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types(types_to_copy), + old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) @@ -302,11 +341,9 @@ class RoomCreationHandler(BaseHandler): yield self._send_events_for_new_room( requester, new_room_id, - # we expect to override all the presets with initial_state, so this is # somewhat arbitrary. preset_config=RoomCreationPreset.PRIVATE_CHAT, - invite_list=[], initial_state=initial_state, creation_content=creation_content, @@ -314,20 +351,22 @@ class RoomCreationHandler(BaseHandler): # Transfer membership events old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]), + old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent old_room_member_state_events = yield self.store.get_events( - old_room_member_state_ids.values(), + old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): # Only transfer ban events - if ("membership" in old_event.content and - old_event.content["membership"] == "ban"): + if ( + "membership" in old_event.content + and old_event.content["membership"] == "ban" + ): yield self.room_member_handler.update_membership( requester, - UserID.from_string(old_event['state_key']), + UserID.from_string(old_event["state_key"]), new_room_id, "ban", ratelimit=False, @@ -339,7 +378,7 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def _move_aliases_to_new_room( - self, requester, old_room_id, new_room_id, old_room_state, + self, requester, old_room_id, new_room_id, old_room_state ): directory_handler = self.hs.get_handlers().directory_handler @@ -370,14 +409,11 @@ class RoomCreationHandler(BaseHandler): alias = RoomAlias.from_string(alias_str) try: yield directory_handler.delete_association( - requester, alias, send_event=False, + requester, alias, send_event=False ) removed_aliases.append(alias_str) except SynapseError as e: - logger.warning( - "Unable to remove alias %s from old room: %s", - alias, e, - ) + logger.warning("Unable to remove alias %s from old room: %s", alias, e) # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest # of this. @@ -393,30 +429,26 @@ class RoomCreationHandler(BaseHandler): # as when you remove an alias from the directory normally - it just means that # the aliases event gets out of sync with the directory # (cf https://github.com/vector-im/riot-web/issues/2369) - yield directory_handler.send_room_alias_update_event( - requester, old_room_id, - ) + yield directory_handler.send_room_alias_update_event(requester, old_room_id) except AuthError as e: - logger.warning( - "Failed to send updated alias event on old room: %s", e, - ) + logger.warning("Failed to send updated alias event on old room: %s", e) # we can now add any aliases we successfully removed to the new room. for alias in removed_aliases: try: yield directory_handler.create_association( - requester, RoomAlias.from_string(alias), - new_room_id, servers=(self.hs.hostname, ), - send_event=False, check_membership=False, + requester, + RoomAlias.from_string(alias), + new_room_id, + servers=(self.hs.hostname,), + send_event=False, + check_membership=False, ) logger.info("Moved alias %s to new room", alias) except SynapseError as e: # I'm not really expecting this to happen, but it could if the spam # checking module decides it shouldn't, or similar. - logger.error( - "Error adding alias %s to new room: %s", - alias, e, - ) + logger.error("Error adding alias %s to new room: %s", alias, e) try: if canonical_alias and (canonical_alias in removed_aliases): @@ -427,24 +459,19 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), - "content": {"alias": canonical_alias, }, + "content": {"alias": canonical_alias}, }, - ratelimit=False + ratelimit=False, ) - yield directory_handler.send_room_alias_update_event( - requester, new_room_id, - ) + yield directory_handler.send_room_alias_update_event(requester, new_room_id) except SynapseError as e: # again I'm not really expecting this to fail, but if it does, I'd rather # we returned the new room to the client at this point. - logger.error( - "Unable to send updated alias events in new room: %s", e, - ) + logger.error("Unable to send updated alias events in new room: %s", e) @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, - creator_join_profile=None): + def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): """ Creates a new room. Args: @@ -474,25 +501,23 @@ class RoomCreationHandler(BaseHandler): yield self.auth.check_auth_blocking(user_id) - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. yield self.third_party_event_rules.on_create_room( - requester, - config, - is_requester_admin=is_requester_admin, + requester, config, is_requester_admin=is_requester_admin ) if not is_requester_admin and not self.spam_checker.user_may_create_room( - user_id, + user_id ): raise SynapseError(403, "You are not permitted to create rooms") @@ -500,16 +525,11 @@ class RoomCreationHandler(BaseHandler): yield self.ratelimit(requester) room_version = config.get( - "room_version", - self.config.default_room_version.identifier, + "room_version", self.config.default_room_version.identifier ) if not isinstance(room_version, string_types): - raise SynapseError( - 400, - "room_version must be a string", - Codes.BAD_JSON, - ) + raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( @@ -523,20 +543,11 @@ class RoomCreationHandler(BaseHandler): if wchar in config["room_alias_name"]: raise SynapseError(400, "Invalid characters in room alias") - room_alias = RoomAlias( - config["room_alias_name"], - self.hs.hostname, - ) - mapping = yield self.store.get_association_from_room_alias( - room_alias - ) + room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) + mapping = yield self.store.get_association_from_room_alias(room_alias) if mapping: - raise SynapseError( - 400, - "Room alias already taken", - Codes.ROOM_IN_USE - ) + raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) else: room_alias = None @@ -547,9 +558,7 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy( - requester, - ) + yield self.event_creation_handler.assert_accepted_privacy_policy(requester) invite_3pid_list = config.get("invite_3pid", []) @@ -573,7 +582,7 @@ class RoomCreationHandler(BaseHandler): "preset", RoomCreationPreset.PRIVATE_CHAT if visibility == "private" - else RoomCreationPreset.PUBLIC_CHAT + else RoomCreationPreset.PUBLIC_CHAT, ) raw_initial_state = config.get("initial_state", []) @@ -610,7 +619,8 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "content": {"name": name}, }, - ratelimit=False) + ratelimit=False, + ) if "topic" in config: topic = config["topic"] @@ -623,7 +633,8 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "content": {"topic": topic}, }, - ratelimit=False) + ratelimit=False, + ) for invitee in invite_list: content = {} @@ -658,30 +669,25 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - yield directory_handler.send_room_alias_update_event( - requester, room_id - ) + yield directory_handler.send_room_alias_update_event(requester, room_id) defer.returnValue(result) @defer.inlineCallbacks def _send_events_for_new_room( - self, - creator, # A Requester object. - room_id, - preset_config, - invite_list, - initial_state, - creation_content, - room_alias=None, - power_level_content_override=None, - creator_join_profile=None, + self, + creator, # A Requester object. + room_id, + preset_config, + invite_list, + initial_state, + creation_content, + room_alias=None, + power_level_content_override=None, + creator_join_profile=None, ): def create(etype, content, **kwargs): - e = { - "type": etype, - "content": content, - } + e = {"type": etype, "content": content} e.update(event_keys) e.update(kwargs) @@ -693,26 +699,17 @@ class RoomCreationHandler(BaseHandler): event = create(etype, content, **kwargs) logger.info("Sending %s in new room", etype) yield self.event_creation_handler.create_and_send_nonmember_event( - creator, - event, - ratelimit=False + creator, event, ratelimit=False ) config = RoomCreationHandler.PRESETS_DICT[preset_config] creator_id = creator.user.to_string() - event_keys = { - "room_id": room_id, - "sender": creator_id, - "state_key": "", - } + event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send( - etype=EventTypes.Create, - content=creation_content, - ) + yield send(etype=EventTypes.Create, content=creation_content) logger.info("Sending %s in new room", EventTypes.Member) yield self.room_member_handler.update_membership( @@ -726,17 +723,12 @@ class RoomCreationHandler(BaseHandler): # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. - pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None) + pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send( - etype=EventTypes.PowerLevels, - content=pl_content, - ) + yield send(etype=EventTypes.PowerLevels, content=pl_content) else: power_level_content = { - "users": { - creator_id: 100, - }, + "users": {creator_id: 100}, "users_default": 0, "events": { EventTypes.Name: 50, @@ -760,42 +752,33 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - yield send( - etype=EventTypes.PowerLevels, - content=power_level_content, - ) + yield send(etype=EventTypes.PowerLevels, content=power_level_content) - if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: + if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: yield send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) - if (EventTypes.JoinRules, '') not in initial_state: + if (EventTypes.JoinRules, "") not in initial_state: yield send( - etype=EventTypes.JoinRules, - content={"join_rule": config["join_rules"]}, + etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) - if (EventTypes.RoomHistoryVisibility, '') not in initial_state: + if (EventTypes.RoomHistoryVisibility, "") not in initial_state: yield send( etype=EventTypes.RoomHistoryVisibility, - content={"history_visibility": config["history_visibility"]} + content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: - if (EventTypes.GuestAccess, '') not in initial_state: + if (EventTypes.GuestAccess, "") not in initial_state: yield send( - etype=EventTypes.GuestAccess, - content={"guest_access": "can_join"} + etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send( - etype=etype, - state_key=state_key, - content=content, - ) + yield send(etype=etype, state_key=state_key, content=content) @defer.inlineCallbacks def _generate_room_id(self, creator_id, is_public): @@ -805,12 +788,9 @@ class RoomCreationHandler(BaseHandler): while attempts < 5: try: random_string = stringutils.random_string(18) - gen_room_id = RoomID( - random_string, - self.hs.hostname, - ).to_string() + gen_room_id = RoomID(random_string, self.hs.hostname).to_string() if isinstance(gen_room_id, bytes): - gen_room_id = gen_room_id.decode('utf-8') + gen_room_id = gen_room_id.decode("utf-8") yield self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, @@ -844,7 +824,7 @@ class RoomContextHandler(object): Returns: dict, or None if the event isn't found """ - before_limit = math.floor(limit / 2.) + before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit users = yield self.store.get_users_in_room(room_id) @@ -852,24 +832,19 @@ class RoomContextHandler(object): def filter_evts(events): return filter_events_for_client( - self.store, - user.to_string(), - events, - is_peeking=is_peeking + self.store, user.to_string(), events, is_peeking=is_peeking ) - event = yield self.store.get_event(event_id, get_prev_content=True, - allow_none=True) + event = yield self.store.get_event( + event_id, get_prev_content=True, allow_none=True + ) if not event: defer.returnValue(None) return - filtered = yield(filter_evts([event])) + filtered = yield (filter_evts([event])) if not filtered: - raise AuthError( - 403, - "You don't have permission to access that event." - ) + raise AuthError(403, "You don't have permission to access that event.") results = yield self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter @@ -901,7 +876,7 @@ class RoomContextHandler(object): # https://github.com/matrix-org/matrix-doc/issues/687 state = yield self.store.get_state_for_events( - [last_event_id], state_filter=state_filter, + [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) @@ -913,9 +888,7 @@ class RoomContextHandler(object): "room_key", results["start"] ).to_string() - results["end"] = token.copy_and_replace( - "room_key", results["end"] - ).to_string() + results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() defer.returnValue(results) @@ -926,13 +899,7 @@ class RoomEventSource(object): @defer.inlineCallbacks def get_new_events( - self, - user, - from_key, - limit, - room_ids, - is_guest, - explicit_room_id=None, + self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None ): # We just ignore the key for now. @@ -943,9 +910,7 @@ class RoomEventSource(object): logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = self.store.get_app_service_by_user_id( - user.to_string() - ) + app_service = self.store.get_app_service_by_user_id(user.to_string()) if app_service: # We no longer support AS users using /sync directly. # See https://github.com/matrix-org/matrix-doc/issues/1144 @@ -960,7 +925,7 @@ class RoomEventSource(object): from_key=from_key, to_key=to_key, limit=limit or 10, - order='ASC', + order="ASC", ) events = list(room_events) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 617d1c9ef8..aae696a7e8 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -46,13 +46,18 @@ class RoomListHandler(BaseHandler): super(RoomListHandler, self).__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache(hs, "room_list") - self.remote_response_cache = ResponseCache(hs, "remote_room_list", - timeout_ms=30 * 1000) + self.remote_response_cache = ResponseCache( + hs, "remote_room_list", timeout_ms=30 * 1000 + ) - def get_local_public_room_list(self, limit=None, since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False): + def get_local_public_room_list( + self, + limit=None, + since_token=None, + search_filter=None, + network_tuple=EMPTY_THIRD_PARTY_ID, + from_federation=False, + ): """Generate a local public room list. There are multiple different lists: the main one plus one per third @@ -68,14 +73,14 @@ class RoomListHandler(BaseHandler): Setting to None returns all public rooms across all lists. """ if not self.enable_room_list_search: - return defer.succeed({ - "chunk": [], - "total_room_count_estimate": 0, - }) + return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) logger.info( "Getting public room list: limit=%r, since=%r, search=%r, network=%r", - limit, since_token, bool(search_filter), network_tuple, + limit, + since_token, + bool(search_filter), + network_tuple, ) if search_filter: @@ -88,24 +93,33 @@ class RoomListHandler(BaseHandler): # solution at some point timeout = self.clock.time() + 60 return self._get_public_room_list( - limit, since_token, search_filter, - network_tuple=network_tuple, timeout=timeout, + limit, + since_token, + search_filter, + network_tuple=network_tuple, + timeout=timeout, ) key = (limit, since_token, network_tuple) return self.response_cache.wrap( key, self._get_public_room_list, - limit, since_token, - network_tuple=network_tuple, from_federation=from_federation, + limit, + since_token, + network_tuple=network_tuple, + from_federation=from_federation, ) @defer.inlineCallbacks - def _get_public_room_list(self, limit=None, since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False, - timeout=None,): + def _get_public_room_list( + self, + limit=None, + since_token=None, + search_filter=None, + network_tuple=EMPTY_THIRD_PARTY_ID, + from_federation=False, + timeout=None, + ): """Generate a public room list. Args: limit (int|None): Maximum amount of rooms to return. @@ -135,15 +149,14 @@ class RoomListHandler(BaseHandler): current_public_id = yield self.store.get_current_public_room_stream_id() public_room_stream_id = since_token.public_room_stream_id newly_visible, newly_unpublished = yield self.store.get_public_room_changes( - public_room_stream_id, current_public_id, - network_tuple=network_tuple, + public_room_stream_id, current_public_id, network_tuple=network_tuple ) else: stream_token = yield self.store.get_room_max_stream_ordering() public_room_stream_id = yield self.store.get_current_public_room_stream_id() room_ids = yield self.store.get_public_room_ids_at_stream_id( - public_room_stream_id, network_tuple=network_tuple, + public_room_stream_id, network_tuple=network_tuple ) # We want to return rooms in a particular order: the number of joined @@ -168,7 +181,7 @@ class RoomListHandler(BaseHandler): return joined_users = yield self.state_handler.get_current_users_in_room( - room_id, latest_event_ids, + room_id, latest_event_ids ) num_joined_users = len(joined_users) @@ -180,8 +193,9 @@ class RoomListHandler(BaseHandler): # We want larger rooms to be first, hence negating num_joined_users rooms_to_order_value[room_id] = (-num_joined_users, room_id) - logger.info("Getting ordering for %i rooms since %s", - len(room_ids), stream_token) + logger.info( + "Getting ordering for %i rooms since %s", len(room_ids), stream_token + ) yield concurrently_execute(get_order_for_room, room_ids, 10) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) @@ -193,7 +207,8 @@ class RoomListHandler(BaseHandler): # Filter out rooms that we don't want to return rooms_to_scan = [ - r for r in sorted_rooms + r + for r in sorted_rooms if r not in newly_unpublished and rooms_to_num_joined[r] > 0 ] @@ -204,13 +219,12 @@ class RoomListHandler(BaseHandler): # `since_token.current_limit` is the index of the last room we # sent down, so we exclude it and everything before/after it. if since_token.direction_is_forward: - rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:] + rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :] else: - rooms_to_scan = rooms_to_scan[:since_token.current_limit] + rooms_to_scan = rooms_to_scan[: since_token.current_limit] rooms_to_scan.reverse() - logger.info("After sorting and filtering, %i rooms remain", - len(rooms_to_scan)) + logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan)) # _append_room_entry_to_chunk will append to chunk but will stop if # len(chunk) > limit @@ -237,15 +251,19 @@ class RoomListHandler(BaseHandler): if timeout and self.clock.time() > timeout: raise Exception("Timed out searching room directory") - batch = rooms_to_scan[i:i + step] + batch = rooms_to_scan[i : i + step] logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( lambda r: self._append_room_entry_to_chunk( - r, rooms_to_num_joined[r], - chunk, limit, search_filter, + r, + rooms_to_num_joined[r], + chunk, + limit, + search_filter, from_federation=from_federation, ), - batch, 5, + batch, + 5, ) logger.info("Now %i rooms in result", len(chunk)) if len(chunk) >= limit + 1: @@ -273,10 +291,7 @@ class RoomListHandler(BaseHandler): new_limit = sorted_rooms.index(last_room_id) - results = { - "chunk": chunk, - "total_room_count_estimate": total_room_count, - } + results = {"chunk": chunk, "total_room_count_estimate": total_room_count} if since_token: results["new_rooms"] = bool(newly_visible) @@ -313,8 +328,15 @@ class RoomListHandler(BaseHandler): defer.returnValue(results) @defer.inlineCallbacks - def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit, - search_filter, from_federation=False): + def _append_room_entry_to_chunk( + self, + room_id, + num_joined_users, + chunk, + limit, + search_filter, + from_federation=False, + ): """Generate the entry for a room in the public room list and append it to the `chunk` if it matches the search filter @@ -345,8 +367,14 @@ class RoomListHandler(BaseHandler): chunk.append(result) @cachedInlineCallbacks(num_args=1, cache_context=True) - def generate_room_entry(self, room_id, num_joined_users, cache_context, - with_alias=True, allow_private=False): + def generate_room_entry( + self, + room_id, + num_joined_users, + cache_context, + with_alias=True, + allow_private=False, + ): """Returns the entry for a room Args: @@ -360,33 +388,31 @@ class RoomListHandler(BaseHandler): Deferred[dict|None]: Returns a room entry as a dictionary, or None if this room was determined not to be shown publicly. """ - result = { - "room_id": room_id, - "num_joined_members": num_joined_users, - } + result = {"room_id": room_id, "num_joined_members": num_joined_users} current_state_ids = yield self.store.get_current_state_ids( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) - event_map = yield self.store.get_events([ - event_id for key, event_id in iteritems(current_state_ids) - if key[0] in ( - EventTypes.Create, - EventTypes.JoinRules, - EventTypes.Name, - EventTypes.Topic, - EventTypes.CanonicalAlias, - EventTypes.RoomHistoryVisibility, - EventTypes.GuestAccess, - "m.room.avatar", - ) - ]) + event_map = yield self.store.get_events( + [ + event_id + for key, event_id in iteritems(current_state_ids) + if key[0] + in ( + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.Name, + EventTypes.Topic, + EventTypes.CanonicalAlias, + EventTypes.RoomHistoryVisibility, + EventTypes.GuestAccess, + "m.room.avatar", + ) + ] + ) - current_state = { - (ev.type, ev.state_key): ev - for ev in event_map.values() - } + current_state = {(ev.type, ev.state_key): ev for ev in event_map.values()} # Double check that this is actually a public room. @@ -446,14 +472,17 @@ class RoomListHandler(BaseHandler): defer.returnValue(result) @defer.inlineCallbacks - def get_remote_public_room_list(self, server_name, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None,): + def get_remote_public_room_list( + self, + server_name, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): if not self.enable_room_list_search: - defer.returnValue({ - "chunk": [], - "total_room_count_estimate": 0, - }) + defer.returnValue({"chunk": [], "total_room_count_estimate": 0}) if search_filter: # We currently don't support searching across federation, so we have @@ -462,52 +491,75 @@ class RoomListHandler(BaseHandler): since_token = None res = yield self._get_remote_list_cached( - server_name, limit=limit, since_token=since_token, + server_name, + limit=limit, + since_token=since_token, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) if search_filter: - res = {"chunk": [ - entry - for entry in list(res.get("chunk", [])) - if _matches_room_entry(entry, search_filter) - ]} + res = { + "chunk": [ + entry + for entry in list(res.get("chunk", [])) + if _matches_room_entry(entry, search_filter) + ] + } defer.returnValue(res) - def _get_remote_list_cached(self, server_name, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None,): + def _get_remote_list_cached( + self, + server_name, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search return repl_layer.get_public_rooms( - server_name, limit=limit, since_token=since_token, - search_filter=search_filter, include_all_networks=include_all_networks, + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) key = ( - server_name, limit, since_token, include_all_networks, + server_name, + limit, + since_token, + include_all_networks, third_party_instance_id, ) return self.remote_response_cache.wrap( key, repl_layer.get_public_rooms, - server_name, limit=limit, since_token=since_token, + server_name, + limit=limit, + since_token=since_token, search_filter=search_filter, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) -class RoomListNextBatch(namedtuple("RoomListNextBatch", ( - "stream_ordering", # stream_ordering of the first public room list - "public_room_stream_id", # public room stream id for first public room list - "current_limit", # The number of previous rooms returned - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch -))): +class RoomListNextBatch( + namedtuple( + "RoomListNextBatch", + ( + "stream_ordering", # stream_ordering of the first public room list + "public_room_stream_id", # public room stream id for first public room list + "current_limit", # The number of previous rooms returned + "direction_is_forward", # Bool if this is a next_batch, false if prev_batch + ), + ) +): KEY_DICT = { "stream_ordering": "s", @@ -527,21 +579,19 @@ class RoomListNextBatch(namedtuple("RoomListNextBatch", ( decoded = msgpack.loads(decode_base64(token), raw=False) else: decoded = msgpack.loads(decode_base64(token)) - return RoomListNextBatch(**{ - cls.REVERSE_KEY_DICT[key]: val - for key, val in decoded.items() - }) + return RoomListNextBatch( + **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} + ) def to_token(self): - return encode_base64(msgpack.dumps({ - self.KEY_DICT[key]: val - for key, val in self._asdict().items() - })) + return encode_base64( + msgpack.dumps( + {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + ) + ) def copy_and_replace(self, **kwds): - return self._replace( - **kwds - ) + return self._replace(**kwds) def _matches_room_entry(room_entry, search_filter): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a65832efed..cef520392f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -170,7 +170,11 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _local_membership_update( - self, requester, target, room_id, membership, + self, + requester, + target, + room_id, + membership, prev_events_and_hashes, txn_id=None, ratelimit=True, @@ -194,7 +198,6 @@ class RoomMemberHandler(object): "room_id": room_id, "sender": requester.user.to_string(), "state_key": user_id, - # For backwards compatibility: "membership": membership, }, @@ -206,26 +209,19 @@ class RoomMemberHandler(object): # Check if this event matches the previous membership event for the user. duplicate = yield self.event_creation_handler.deduplicate_state_event( - event, context, + event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. defer.returnValue(duplicate) yield self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target], - ratelimit=ratelimit, + requester, event, context, extra_users=[target], ratelimit=ratelimit ) prev_state_ids = yield context.get_prev_state_ids(self.store) - prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, user_id), - None - ) + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has actually joined the @@ -247,11 +243,11 @@ class RoomMemberHandler(object): if predecessor: # It is an upgraded room. Copy over old tags self.copy_room_tags_and_direct_to_room( - predecessor["room_id"], room_id, user_id, + predecessor["room_id"], room_id, user_id ) # Move over old push rules self.store.move_push_rules_from_room_to_room_for_user( - predecessor["room_id"], room_id, user_id, + predecessor["room_id"], room_id, user_id ) elif event.membership == Membership.LEAVE: if prev_member_event_id: @@ -262,12 +258,7 @@ class RoomMemberHandler(object): defer.returnValue(event) @defer.inlineCallbacks - def copy_room_tags_and_direct_to_room( - self, - old_room_id, - new_room_id, - user_id, - ): + def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id): """Copies the tags and direct room state from one room to another. Args: @@ -279,9 +270,7 @@ class RoomMemberHandler(object): Deferred[None] """ # Retrieve user account data for predecessor room - user_account_data, _ = yield self.store.get_account_data_for_user( - user_id, - ) + user_account_data, _ = yield self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable direct_rooms = user_account_data.get("m.direct", {}) @@ -295,34 +284,30 @@ class RoomMemberHandler(object): # Save back to user's m.direct account data yield self.store.add_account_data_for_user( - user_id, "m.direct", direct_rooms, + user_id, "m.direct", direct_rooms ) break # Copy room tags if applicable - room_tags = yield self.store.get_tags_for_room( - user_id, old_room_id, - ) + room_tags = yield self.store.get_tags_for_room(user_id, old_room_id) # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - yield self.store.add_tag_to_room( - user_id, new_room_id, tag, tag_content - ) + yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) @defer.inlineCallbacks def update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + require_consent=True, ): key = (room_id,) @@ -344,17 +329,17 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + require_consent=True, ): content_specified = bool(content) if content is None: @@ -388,7 +373,7 @@ class RoomMemberHandler(object): if not remote_room_hosts: remote_room_hosts = [] - if effective_membership_state not in ("leave", "ban",): + if effective_membership_state not in ("leave", "ban"): is_blocked = yield self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -396,22 +381,19 @@ class RoomMemberHandler(object): if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: - raise SynapseError( - http_client.FORBIDDEN, - "Cannot invite this user", - ) + raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") block_invite = False - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to send invites is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -422,25 +404,19 @@ class RoomMemberHandler(object): block_invite = True if not self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id, + requester.user.to_string(), target.to_string(), room_id ): logger.info("Blocking invite due to spam checker") block_invite = True if block_invite: - raise SynapseError( - 403, "Invites have been disabled on this server", - ) + raise SynapseError(403, "Invites have been disabled on this server") - prev_events_and_hashes = yield self.store.get_prev_events_for_room( - room_id, - ) - latest_event_ids = ( - event_id for (event_id, _, _) in prev_events_and_hashes - ) + prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) current_state_ids = yield self.state_handler.get_current_state_ids( - room_id, latest_event_ids=latest_event_ids, + room_id, latest_event_ids=latest_event_ids ) # TODO: Refactor into dictionary of explicitly allowed transitions @@ -455,13 +431,13 @@ class RoomMemberHandler(object): 403, "Cannot unban user who was not banned" " (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE + errcode=Codes.BAD_STATE, ) if old_membership == "ban" and action != "unban": raise SynapseError( 403, "Cannot %s user who was banned" % (action,), - errcode=Codes.BAD_STATE + errcode=Codes.BAD_STATE, ) if old_state: @@ -477,8 +453,8 @@ class RoomMemberHandler(object): # we don't allow people to reject invites to the server notice # room, but they can leave it once they are joined. if ( - old_membership == Membership.INVITE and - effective_membership_state == Membership.LEAVE + old_membership == Membership.INVITE + and effective_membership_state == Membership.LEAVE ): is_blocked = yield self._is_server_notice_room(room_id) if is_blocked: @@ -539,7 +515,7 @@ class RoomMemberHandler(object): # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( - requester, remote_room_hosts, room_id, target, + requester, remote_room_hosts, room_id, target ) defer.returnValue(res) @@ -558,12 +534,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def send_membership_event( - self, - requester, - event, - context, - remote_room_hosts=None, - ratelimit=True, + self, requester, event, context, remote_room_hosts=None, ratelimit=True ): """ Change the membership status of a user in a room. @@ -589,16 +560,15 @@ class RoomMemberHandler(object): if requester is not None: sender = UserID.from_string(event.sender) - assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % - (sender, requester.user) - ) + assert ( + sender == requester.user + ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = types.create_requester(target_user) prev_event = yield self.event_creation_handler.deduplicate_state_event( - event, context, + event, context ) if prev_event is not None: return @@ -618,16 +588,11 @@ class RoomMemberHandler(object): raise SynapseError(403, "This room has been blocked on this server") yield self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, + requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, event.state_key), - None + (EventTypes.Member, event.state_key), None ) if event.membership == Membership.JOIN: @@ -697,31 +662,20 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( - user_id=user_id, - room_id=room_id, + user_id=user_id, room_id=room_id ) if invite: defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def do_3pid_invite( - self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id + self, room_id, inviter, medium, address, id_server, requester, txn_id ): if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( - 403, "Invites have been disabled on this server", - Codes.FORBIDDEN, + 403, "Invites have been disabled on this server", Codes.FORBIDDEN ) # We need to rate limit *before* we send out any 3PID invites, so we @@ -729,35 +683,24 @@ class RoomMemberHandler(object): yield self.base_handler.ratelimit(requester) can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( - medium, address, room_id, + medium, address, room_id ) if not can_invite: raise SynapseError( - 403, "This third-party identifier can not be invited in this room", + 403, + "This third-party identifier can not be invited in this room", Codes.FORBIDDEN, ) - invitee = yield self._lookup_3pid( - id_server, medium, address - ) + invitee = yield self._lookup_3pid(id_server, medium, address) if invitee: yield self.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - txn_id=txn_id, + requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: yield self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id + requester, id_server, medium, address, room_id, inviter, txn_id=txn_id ) @defer.inlineCallbacks @@ -775,15 +718,12 @@ class RoomMemberHandler(object): """ if not self._enable_lookup: raise SynapseError( - 403, "Looking up third-party identifiers is denied from this server", + 403, "Looking up third-party identifiers is denied from this server" ) try: data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), - { - "medium": medium, - "address": address, - } + "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), + {"medium": medium, "address": address}, ) if "mxid" in data: @@ -802,29 +742,25 @@ class RoomMemberHandler(object): raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): key_data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" % - (id_server_scheme, server_hostname, key_name,), + "%s%s/_matrix/identity/api/v1/pubkey/%s" + % (id_server_scheme, server_hostname, key_name) ) if "public_key" not in key_data: - raise AuthError(401, "No public key named %s from %s" % - (key_name, server_hostname,)) + raise AuthError( + 401, "No public key named %s from %s" % (key_name, server_hostname) + ) verify_signed_json( data, server_hostname, - decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + decode_verify_key_bytes( + key_name, decode_base64(key_data["public_key"]) + ), ) return @defer.inlineCallbacks def _make_and_store_3pid_invite( - self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id + self, requester, id_server, medium, address, room_id, user, txn_id ): room_state = yield self.state_handler.get_current_state(room_id) @@ -872,7 +808,7 @@ class RoomMemberHandler(object): room_join_rules=room_join_rules, room_name=room_name, inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url + inviter_avatar_url=inviter_avatar_url, ) ) @@ -883,7 +819,6 @@ class RoomMemberHandler(object): "content": { "display_name": display_name, "public_keys": public_keys, - # For backwards compatibility: "key_validity_url": fallback_public_key["key_validity_url"], "public_key": fallback_public_key["public_key"], @@ -892,24 +827,25 @@ class RoomMemberHandler(object): "sender": user.to_string(), "state_key": token, }, + ratelimit=False, txn_id=txn_id, ) @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( - self, - requester, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url + self, + requester, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url, ): """ Asks an identity server for a third party invite. @@ -941,7 +877,8 @@ class RoomMemberHandler(object): """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( - id_server_scheme, id_server, + id_server_scheme, + id_server, ) invite_config = { @@ -965,14 +902,15 @@ class RoomMemberHandler(object): inviter_user_id=inviter_user_id, ) - invite_config.update({ - "guest_access_token": guest_access_token, - "guest_user_id": guest_user_id, - }) + invite_config.update( + { + "guest_access_token": guest_access_token, + "guest_user_id": guest_user_id, + } + ) data = yield self.simple_http_client.post_urlencoded_get_json( - is_url, - invite_config + is_url, invite_config ) # TODO: Check for success token = data["token"] @@ -980,9 +918,8 @@ class RoomMemberHandler(object): if "public_key" in data: fallback_public_key = { "public_key": data["public_key"], - "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ), + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" + % (id_server_scheme, id_server), } else: fallback_public_key = public_keys[0] @@ -1102,10 +1039,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user.to_string(), - content, + remote_room_hosts, room_id, user.to_string(), content ) yield self._user_joined_room(user, room_id) @@ -1145,9 +1079,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): fed_handler = self.federation_handler try: ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - target.to_string(), + remote_room_hosts, room_id, target.to_string() ) defer.returnValue(ret) except Exception as e: @@ -1159,9 +1091,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite( - target.to_string(), room_id - ) + yield self.store.locally_reject_invite(target.to_string(), room_id) defer.returnValue({}) def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): @@ -1185,18 +1115,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): user_id = user.to_string() member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership is not None and membership not in [ - Membership.LEAVE, Membership.BAN + Membership.LEAVE, + Membership.BAN, ]: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) + raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) if membership: yield self.store.forget(user_id, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index acc6eb8099..da501f38c0 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -71,18 +71,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler): """Implements RoomMemberHandler._user_joined_room """ return self._notify_change_client( - user_id=target.to_string(), - room_id=room_id, - change="joined", + user_id=target.to_string(), room_id=room_id, change="joined" ) def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ return self._notify_change_client( - user_id=target.to_string(), - room_id=room_id, - change="left", + user_id=target.to_string(), room_id=room_id, change="left" ) def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9bba74d6c9..ddc4430d03 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -32,7 +32,6 @@ logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): - def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() @@ -93,7 +92,7 @@ class SearchHandler(BaseHandler): batch_token = None if batch: try: - b = decode_base64(batch).decode('ascii') + b = decode_base64(batch).decode("ascii") batch_group, batch_group_key, batch_token = b.split("\n") assert batch_group is not None @@ -104,7 +103,9 @@ class SearchHandler(BaseHandler): logger.info( "Search batch properties: %r, %r, %r", - batch_group, batch_group_key, batch_token, + batch_group, + batch_group_key, + batch_token, ) logger.info("Search content: %s", content) @@ -116,9 +117,9 @@ class SearchHandler(BaseHandler): search_term = room_cat["search_term"] # Which "keys" to search over in FTS query - keys = room_cat.get("keys", [ - "content.body", "content.name", "content.topic", - ]) + keys = room_cat.get( + "keys", ["content.body", "content.name", "content.topic"] + ) # Filter to apply to results filter_dict = room_cat.get("filter", {}) @@ -130,9 +131,7 @@ class SearchHandler(BaseHandler): include_state = room_cat.get("include_state", False) # Include context around each event? - event_context = room_cat.get( - "event_context", None - ) + event_context = room_cat.get("event_context", None) # Group results together? May allow clients to paginate within a # group @@ -140,12 +139,8 @@ class SearchHandler(BaseHandler): group_keys = [g["key"] for g in group_by] if event_context is not None: - before_limit = int(event_context.get( - "before_limit", 5 - )) - after_limit = int(event_context.get( - "after_limit", 5 - )) + before_limit = int(event_context.get("before_limit", 5)) + after_limit = int(event_context.get("after_limit", 5)) # Return the historic display name and avatar for the senders # of the events? @@ -159,7 +154,8 @@ class SearchHandler(BaseHandler): if set(group_keys) - {"room_id", "sender"}: raise SynapseError( 400, - "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},) + "Invalid group by keys: %r" + % (set(group_keys) - {"room_id", "sender"},), ) search_filter = Filter(filter_dict) @@ -190,15 +186,13 @@ class SearchHandler(BaseHandler): room_ids.intersection_update({batch_group_key}) if not room_ids: - defer.returnValue({ - "search_categories": { - "room_events": { - "results": [], - "count": 0, - "highlights": [], + defer.returnValue( + { + "search_categories": { + "room_events": {"results": [], "count": 0, "highlights": []} } } - }) + ) rank_map = {} # event_id -> rank of event allowed_events = [] @@ -213,9 +207,7 @@ class SearchHandler(BaseHandler): count = None if order_by == "rank": - search_result = yield self.store.search_msgs( - room_ids, search_term, keys - ) + search_result = yield self.store.search_msgs(room_ids, search_term, keys) count = search_result["count"] @@ -235,19 +227,17 @@ class SearchHandler(BaseHandler): ) events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[:search_filter.limit()] + allowed_events = events[: search_filter.limit()] for e in allowed_events: - rm = room_groups.setdefault(e.room_id, { - "results": [], - "order": rank_map[e.event_id], - }) + rm = room_groups.setdefault( + e.room_id, {"results": [], "order": rank_map[e.event_id]} + ) rm["results"].append(e.event_id) - s = sender_group.setdefault(e.sender, { - "results": [], - "order": rank_map[e.event_id], - }) + s = sender_group.setdefault( + e.sender, {"results": [], "order": rank_map[e.event_id]} + ) s["results"].append(e.event_id) elif order_by == "recent": @@ -262,7 +252,10 @@ class SearchHandler(BaseHandler): while len(room_events) < search_filter.limit() and i < 5: i += 1 search_result = yield self.store.search_rooms( - room_ids, search_term, keys, search_filter.limit() * 2, + room_ids, + search_term, + keys, + search_filter.limit() * 2, pagination_token=pagination_token, ) @@ -277,16 +270,14 @@ class SearchHandler(BaseHandler): rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([ - r["event"] for r in results - ]) + filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( self.store, user.to_string(), filtered_events ) room_events.extend(events) - room_events = room_events[:search_filter.limit()] + room_events = room_events[: search_filter.limit()] if len(results) < search_filter.limit() * 2: pagination_token = None @@ -295,9 +286,7 @@ class SearchHandler(BaseHandler): pagination_token = results[-1]["pagination_token"] for event in room_events: - group = room_groups.setdefault(event.room_id, { - "results": [], - }) + group = room_groups.setdefault(event.room_id, {"results": []}) group["results"].append(event.event_id) if room_events and len(room_events) >= search_filter.limit(): @@ -309,18 +298,23 @@ class SearchHandler(BaseHandler): # it returns more from the same group (if applicable) rather # than reverting to searching all results again. if batch_group and batch_group_key: - global_next_batch = encode_base64(("%s\n%s\n%s" % ( - batch_group, batch_group_key, pagination_token - )).encode('ascii')) + global_next_batch = encode_base64( + ( + "%s\n%s\n%s" + % (batch_group, batch_group_key, pagination_token) + ).encode("ascii") + ) else: - global_next_batch = encode_base64(("%s\n%s\n%s" % ( - "all", "", pagination_token - )).encode('ascii')) + global_next_batch = encode_base64( + ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") + ) for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64(("%s\n%s\n%s" % ( - "room_id", room_id, pagination_token - )).encode('ascii')) + group["next_batch"] = encode_base64( + ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( + "ascii" + ) + ) allowed_events.extend(room_events) @@ -338,12 +332,13 @@ class SearchHandler(BaseHandler): contexts = {} for event in allowed_events: res = yield self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit, + event.room_id, event.event_id, before_limit, after_limit ) logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), len(res["events_after"]), + len(res["events_before"]), + len(res["events_after"]), ) res["events_before"] = yield filter_events_for_client( @@ -403,12 +398,12 @@ class SearchHandler(BaseHandler): for context in contexts.values(): context["events_before"] = ( yield self._event_serializer.serialize_events( - context["events_before"], time_now, + context["events_before"], time_now ) ) context["events_after"] = ( yield self._event_serializer.serialize_events( - context["events_after"], time_now, + context["events_after"], time_now ) ) @@ -426,11 +421,15 @@ class SearchHandler(BaseHandler): results = [] for e in allowed_events: - results.append({ - "rank": rank_map[e.event_id], - "result": (yield self._event_serializer.serialize_event(e, time_now)), - "context": contexts.get(e.event_id, {}), - }) + results.append( + { + "rank": rank_map[e.event_id], + "result": ( + yield self._event_serializer.serialize_event(e, time_now) + ), + "context": contexts.get(e.event_id, {}), + } + ) rooms_cat_res = { "results": results, @@ -442,7 +441,7 @@ class SearchHandler(BaseHandler): s = {} for room_id, state in state_results.items(): s[room_id] = yield self._event_serializer.serialize_events( - state, time_now, + state, time_now ) rooms_cat_res["state"] = s @@ -456,8 +455,4 @@ class SearchHandler(BaseHandler): if global_next_batch: rooms_cat_res["next_batch"] = global_next_batch - defer.returnValue({ - "search_categories": { - "room_events": rooms_cat_res - } - }) + defer.returnValue({"search_categories": {"room_events": rooms_cat_res}}) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 7ecdede4dc..d90c9e0108 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" + def __init__(self, hs): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() @@ -32,6 +33,9 @@ class SetPasswordHandler(BaseHandler): @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): + if not self.hs.config.password_localdb_enabled: + raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) + password_hash = yield self._auth_handler.hash(newpassword) except_device_id = requester.device_id if requester else None @@ -47,11 +51,11 @@ class SetPasswordHandler(BaseHandler): # we want to log out all of the user's other sessions. First delete # all his other devices. yield self._device_handler.delete_all_devices_for_user( - user_id, except_device_id=except_device_id, + user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). yield self._auth_handler.delete_access_tokens_for_user( - user_id, except_token_id=except_access_token_id, + user_id, except_token_id=except_access_token_id ) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index b268bbcb2c..6b364befd5 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -21,7 +21,6 @@ logger = logging.getLogger(__name__) class StateDeltasHandler(object): - def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 7ad16c8566..a0ee8db988 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -156,7 +156,7 @@ class StatsHandler(StateDeltasHandler): prev_event_content = {} if prev_event_id is not None: prev_event = yield self.store.get_event( - prev_event_id, allow_none=True, + prev_event_id, allow_none=True ) if prev_event: prev_event_content = prev_event.content diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 62fda0c664..a3f550554f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -64,20 +64,14 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 -SyncConfig = collections.namedtuple("SyncConfig", [ - "user", - "filter_collection", - "is_guest", - "request_key", - "device_id", -]) - - -class TimelineBatch(collections.namedtuple("TimelineBatch", [ - "prev_batch", - "events", - "limited", -])): +SyncConfig = collections.namedtuple( + "SyncConfig", ["user", "filter_collection", "is_guest", "request_key", "device_id"] +) + + +class TimelineBatch( + collections.namedtuple("TimelineBatch", ["prev_batch", "events", "limited"]) +): __slots__ = [] def __nonzero__(self): @@ -85,18 +79,24 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [ to tell if room needs to be part of the sync result. """ return bool(self.events) + __bool__ = __nonzero__ # python3 -class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "ephemeral", - "account_data", - "unread_notifications", - "summary", -])): +class JoinedSyncResult( + collections.namedtuple( + "JoinedSyncResult", + [ + "room_id", # str + "timeline", # TimelineBatch + "state", # dict[(str, str), FrozenEvent] + "ephemeral", + "account_data", + "unread_notifications", + "summary", + ], + ) +): __slots__ = [] def __nonzero__(self): @@ -111,77 +111,93 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) + __bool__ = __nonzero__ # python3 -class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "account_data", -])): +class ArchivedSyncResult( + collections.namedtuple( + "ArchivedSyncResult", + [ + "room_id", # str + "timeline", # TimelineBatch + "state", # dict[(str, str), FrozenEvent] + "account_data", + ], + ) +): __slots__ = [] def __nonzero__(self): """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ - return bool( - self.timeline - or self.state - or self.account_data - ) + return bool(self.timeline or self.state or self.account_data) + __bool__ = __nonzero__ # python3 -class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ - "room_id", # str - "invite", # FrozenEvent: the invite event -])): +class InvitedSyncResult( + collections.namedtuple( + "InvitedSyncResult", + ["room_id", "invite"], # str # FrozenEvent: the invite event + ) +): __slots__ = [] def __nonzero__(self): """Invited rooms should always be reported to the client""" return True + __bool__ = __nonzero__ # python3 -class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ - "join", - "invite", - "leave", -])): +class GroupsSyncResult( + collections.namedtuple("GroupsSyncResult", ["join", "invite", "leave"]) +): __slots__ = [] def __nonzero__(self): return bool(self.join or self.invite or self.leave) + __bool__ = __nonzero__ # python3 -class DeviceLists(collections.namedtuple("DeviceLists", [ - "changed", # list of user_ids whose devices may have changed - "left", # list of user_ids whose devices we no longer track -])): +class DeviceLists( + collections.namedtuple( + "DeviceLists", + [ + "changed", # list of user_ids whose devices may have changed + "left", # list of user_ids whose devices we no longer track + ], + ) +): __slots__ = [] def __nonzero__(self): return bool(self.changed or self.left) + __bool__ = __nonzero__ # python3 -class SyncResult(collections.namedtuple("SyncResult", [ - "next_batch", # Token for the next sync - "presence", # List of presence events for the user. - "account_data", # List of account_data events for the user. - "joined", # JoinedSyncResult for each joined room. - "invited", # InvitedSyncResult for each invited room. - "archived", # ArchivedSyncResult for each archived room. - "to_device", # List of direct messages for the device. - "device_lists", # List of user_ids whose devices have changed - "device_one_time_keys_count", # Dict of algorithm to count for one time keys - # for this device - "groups", -])): +class SyncResult( + collections.namedtuple( + "SyncResult", + [ + "next_batch", # Token for the next sync + "presence", # List of presence events for the user. + "account_data", # List of account_data events for the user. + "joined", # JoinedSyncResult for each joined room. + "invited", # InvitedSyncResult for each invited room. + "archived", # ArchivedSyncResult for each archived room. + "to_device", # List of direct messages for the device. + "device_lists", # List of user_ids whose devices have changed + "device_one_time_keys_count", # Dict of algorithm to count for one time keys + # for this device + "groups", + ], + ) +): __slots__ = [] def __nonzero__(self): @@ -190,20 +206,20 @@ class SyncResult(collections.namedtuple("SyncResult", [ events. """ return bool( - self.presence or - self.joined or - self.invited or - self.archived or - self.account_data or - self.to_device or - self.device_lists or - self.groups + self.presence + or self.joined + or self.invited + or self.archived + or self.account_data + or self.to_device + or self.device_lists + or self.groups ) + __bool__ = __nonzero__ # python3 class SyncHandler(object): - def __init__(self, hs): self.hs_config = hs.config self.store = hs.get_datastore() @@ -217,13 +233,16 @@ class SyncHandler(object): # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( - "lazy_loaded_members_cache", self.clock, - max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, + "lazy_loaded_members_cache", + self.clock, + max_len=0, + expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) @defer.inlineCallbacks - def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, - full_state=False): + def wait_for_sync_for_user( + self, sync_config, since_token=None, timeout=0, full_state=False + ): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. @@ -239,13 +258,15 @@ class SyncHandler(object): res = yield self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, - sync_config, since_token, timeout, full_state, + sync_config, + since_token, + timeout, + full_state, ) defer.returnValue(res) @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, - full_state): + def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): if since_token is None: sync_type = "initial_sync" elif full_state: @@ -261,14 +282,17 @@ class SyncHandler(object): # we are going to return immediately, so don't bother calling # notifier.wait_for_events. result = yield self.current_sync_for_user( - sync_config, since_token, full_state=full_state, + sync_config, since_token, full_state=full_state ) else: + def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) result = yield self.notifier.wait_for_events( - sync_config.user.to_string(), timeout, current_sync_callback, + sync_config.user.to_string(), + timeout, + current_sync_callback, from_token=since_token, ) @@ -281,8 +305,7 @@ class SyncHandler(object): defer.returnValue(result) - def current_sync_for_user(self, sync_config, since_token=None, - full_state=False): + def current_sync_for_user(self, sync_config, since_token=None, full_state=False): """Get the sync for client needed to match what the server has now. Returns: A Deferred SyncResult. @@ -334,8 +357,7 @@ class SyncHandler(object): # result returned by the event source is poor form (it might cache # the object) room_id = event["room_id"] - event_copy = {k: v for (k, v) in iteritems(event) - if k != "room_id"} + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) receipt_key = since_token.receipt_key if since_token else "0" @@ -353,22 +375,30 @@ class SyncHandler(object): for event in receipts: room_id = event["room_id"] # exclude room id, as above - event_copy = {k: v for (k, v) in iteritems(event) - if k != "room_id"} + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) defer.returnValue((now_token, ephemeral_by_room)) @defer.inlineCallbacks - def _load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None, recents=None, newly_joined_room=False): + def _load_filtered_recents( + self, + room_id, + sync_config, + now_token, + since_token=None, + recents=None, + newly_joined_room=False, + ): """ Returns: a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() - block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() + block_all_timeline = ( + sync_config.filter_collection.blocks_all_room_timeline() + ) if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True @@ -396,11 +426,9 @@ class SyncHandler(object): recents = [] if not limited or block_all_timeline: - defer.returnValue(TimelineBatch( - events=recents, - prev_batch=now_token, - limited=False - )) + defer.returnValue( + TimelineBatch(events=recents, prev_batch=now_token, limited=False) + ) filtering_factor = 2 load_limit = max(timeline_limit * filtering_factor, 10) @@ -427,9 +455,7 @@ class SyncHandler(object): ) else: events, end_key = yield self.store.get_recent_events_for_room( - room_id, - limit=load_limit + 1, - end_token=end_key, + room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( events @@ -462,15 +488,15 @@ class SyncHandler(object): recents = recents[-timeline_limit:] room_key = recents[0].internal_metadata.before - prev_batch_token = now_token.copy_and_replace( - "room_key", room_key - ) + prev_batch_token = now_token.copy_and_replace("room_key", room_key) - defer.returnValue(TimelineBatch( - events=recents, - prev_batch=prev_batch_token, - limited=limited or newly_joined_room - )) + defer.returnValue( + TimelineBatch( + events=recents, + prev_batch=prev_batch_token, + limited=limited or newly_joined_room, + ) + ) @defer.inlineCallbacks def get_state_after_event(self, event, state_filter=StateFilter.all()): @@ -486,7 +512,7 @@ class SyncHandler(object): A Deferred map from ((type, state_key)->Event) """ state_ids = yield self.store.get_state_ids_for_event( - event.event_id, state_filter=state_filter, + event.event_id, state_filter=state_filter ) if event.is_state(): state_ids = state_ids.copy() @@ -511,13 +537,13 @@ class SyncHandler(object): # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) last_events, _ = yield self.store.get_recent_events_for_room( - room_id, end_token=stream_position.room_key, limit=1, + room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] state = yield self.get_state_after_event( - last_event, state_filter=state_filter, + last_event, state_filter=state_filter ) else: @@ -549,7 +575,7 @@ class SyncHandler(object): # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 last_events, _ = yield self.store.get_recent_event_ids_for_room( - room_id, end_token=now_token.room_key, limit=1, + room_id, end_token=now_token.room_key, limit=1 ) if not last_events: @@ -559,28 +585,25 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( last_event.event_id, - state_filter=StateFilter.from_types([ - (EventTypes.Name, ''), - (EventTypes.CanonicalAlias, ''), - ]), + state_filter=StateFilter.from_types( + [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] + ), ) # this is heavily cached, thus: fast. details = yield self.store.get_room_summary(room_id) - name_id = state_ids.get((EventTypes.Name, '')) - canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) + name_id = state_ids.get((EventTypes.Name, "")) + canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) summary = {} empty_ms = MemberSummary([], 0) # TODO: only send these when they change. - summary["m.joined_member_count"] = ( - details.get(Membership.JOIN, empty_ms).count - ) - summary["m.invited_member_count"] = ( - details.get(Membership.INVITE, empty_ms).count - ) + summary["m.joined_member_count"] = details.get(Membership.JOIN, empty_ms).count + summary["m.invited_member_count"] = details.get( + Membership.INVITE, empty_ms + ).count # if the room has a name or canonical_alias set, we can skip # calculating heroes. Empty strings are falsey, so we check @@ -592,7 +615,7 @@ class SyncHandler(object): if canonical_alias_id: canonical_alias = yield self.store.get_event( - canonical_alias_id, allow_none=True, + canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): defer.returnValue(summary) @@ -600,26 +623,14 @@ class SyncHandler(object): me = sync_config.user.to_string() joined_user_ids = [ - r[0] - for r in details.get(Membership.JOIN, empty_ms).members - if r[0] != me + r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me ] invited_user_ids = [ - r[0] - for r in details.get(Membership.INVITE, empty_ms).members - if r[0] != me + r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me ] - gone_user_ids = ( - [ - r[0] - for r in details.get(Membership.LEAVE, empty_ms).members - if r[0] != me - ] + [ - r[0] - for r in details.get(Membership.BAN, empty_ms).members - if r[0] != me - ] - ) + gone_user_ids = [ + r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me + ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] # FIXME: only build up a member_ids list for our heroes member_ids = {} @@ -627,20 +638,18 @@ class SyncHandler(object): Membership.JOIN, Membership.INVITE, Membership.LEAVE, - Membership.BAN + Membership.BAN, ): for user_id, event_id in details.get(membership, empty_ms).members: member_ids[user_id] = event_id # FIXME: order by stream ordering rather than as returned by SQL - if (joined_user_ids or invited_user_ids): - summary['m.heroes'] = sorted( + if joined_user_ids or invited_user_ids: + summary["m.heroes"] = sorted( [user_id for user_id in (joined_user_ids + invited_user_ids)] )[0:5] else: - summary['m.heroes'] = sorted( - [user_id for user_id in gone_user_ids] - )[0:5] + summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] if not sync_config.filter_collection.lazy_load_members(): defer.returnValue(summary) @@ -652,8 +661,7 @@ class SyncHandler(object): # track which members the client should already know about via LL: # Ones which are already in state... existing_members = set( - user_id for (typ, user_id) in state.keys() - if typ == EventTypes.Member + user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member ) # ...or ones which are in the timeline... @@ -664,10 +672,10 @@ class SyncHandler(object): # ...and then ensure any missing ones get included in state. missing_hero_event_ids = [ member_ids[hero_id] - for hero_id in summary['m.heroes'] + for hero_id in summary["m.heroes"] if ( - cache.get(hero_id) != member_ids[hero_id] and - hero_id not in existing_members + cache.get(hero_id) != member_ids[hero_id] + and hero_id not in existing_members ) ] @@ -691,8 +699,9 @@ class SyncHandler(object): return cache @defer.inlineCallbacks - def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, - full_state): + def compute_state_delta( + self, room_id, batch, sync_config, since_token, now_token, full_state + ): """ Works out the difference in state between the start of the timeline and the previous sync. @@ -745,23 +754,23 @@ class SyncHandler(object): timeline_state = { (event.type, event.state_key): event.event_id - for event in batch.events if event.is_state() + for event in batch.events + if event.is_state() } if full_state: if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter, + batch.events[-1].event_id, state_filter=state_filter ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter, + batch.events[0].event_id, state_filter=state_filter ) else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token, - state_filter=state_filter, + room_id, stream_position=now_token, state_filter=state_filter ) state_ids = current_state_ids @@ -775,7 +784,7 @@ class SyncHandler(object): ) elif batch.limited: state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter, + batch.events[0].event_id, state_filter=state_filter ) # for now, we disable LL for gappy syncs - see @@ -793,12 +802,11 @@ class SyncHandler(object): state_filter = StateFilter.all() state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token, - state_filter=state_filter, + room_id, stream_position=since_token, state_filter=state_filter ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter, + batch.events[-1].event_id, state_filter=state_filter ) state_ids = _calculate_state( @@ -854,8 +862,7 @@ class SyncHandler(object): # add any member IDs we are about to send into our LruCache for t, event_id in itertools.chain( - state_ids.items(), - timeline_state.items(), + state_ids.items(), timeline_state.items() ): if t[0] == EventTypes.Member: cache.set(t[1], event_id) @@ -864,10 +871,14 @@ class SyncHandler(object): if state_ids: state = yield self.store.get_events(list(state_ids.values())) - defer.returnValue({ - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state(list(state.values())) - }) + defer.returnValue( + { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + list(state.values()) + ) + } + ) @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): @@ -875,7 +886,7 @@ class SyncHandler(object): last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read" + receipt_type="m.read", ) notifs = [] @@ -909,7 +920,9 @@ class SyncHandler(object): logger.info( "Calculating sync response for %r between %s and %s", - sync_config.user, since_token, now_token, + sync_config.user, + since_token, + now_token, ) user_id = sync_config.user.to_string() @@ -920,11 +933,12 @@ class SyncHandler(object): raise NotImplementedError() else: joined_room_ids = yield self.get_rooms_for_user_at( - user_id, now_token.room_stream_id, + user_id, now_token.room_stream_id ) sync_result_builder = SyncResultBuilder( - sync_config, full_state, + sync_config, + full_state, since_token=since_token, now_token=now_token, joined_room_ids=joined_room_ids, @@ -941,8 +955,7 @@ class SyncHandler(object): _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( - since_token is None and - sync_config.filter_collection.blocks_all_presence() + since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: yield self._generate_sync_entry_for_presence( @@ -973,22 +986,23 @@ class SyncHandler(object): room_id = joined_room.room_id if room_id in newly_joined_rooms: issue4422_logger.debug( - "Sync result for newly joined room %s: %r", - room_id, joined_room, + "Sync result for newly joined room %s: %r", room_id, joined_room ) - defer.returnValue(SyncResult( - presence=sync_result_builder.presence, - account_data=sync_result_builder.account_data, - joined=sync_result_builder.joined, - invited=sync_result_builder.invited, - archived=sync_result_builder.archived, - to_device=sync_result_builder.to_device, - device_lists=device_lists, - groups=sync_result_builder.groups, - device_one_time_keys_count=one_time_key_counts, - next_batch=sync_result_builder.now_token, - )) + defer.returnValue( + SyncResult( + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, + device_lists=device_lists, + groups=sync_result_builder.groups, + device_one_time_keys_count=one_time_key_counts, + next_batch=sync_result_builder.now_token, + ) + ) @measure_func("_generate_sync_entry_for_groups") @defer.inlineCallbacks @@ -999,11 +1013,11 @@ class SyncHandler(object): if since_token and since_token.groups_key: results = yield self.store.get_groups_changes_for_user( - user_id, since_token.groups_key, now_token.groups_key, + user_id, since_token.groups_key, now_token.groups_key ) else: results = yield self.store.get_all_groups_for_user( - user_id, now_token.groups_key, + user_id, now_token.groups_key ) invited = {} @@ -1031,58 +1045,90 @@ class SyncHandler(object): left[group_id] = content["content"] sync_result_builder.groups = GroupsSyncResult( - join=joined, - invite=invited, - leave=left, + join=joined, invite=invited, leave=left ) @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks - def _generate_sync_entry_for_device_list(self, sync_result_builder, - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, newly_left_users): + def _generate_sync_entry_for_device_list( + self, + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, + ): + """Generate the DeviceLists section of sync + + Args: + sync_result_builder (SyncResultBuilder) + newly_joined_rooms (set[str]): Set of rooms user has joined since + previous sync + newly_joined_or_invited_users (set[str]): Set of users that have + joined or been invited to a room since previous sync. + newly_left_rooms (set[str]): Set of rooms user has left since + previous sync + newly_left_users (set[str]): Set of users that have left a room + we're in since previous sync + + Returns: + Deferred[DeviceLists] + """ + user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token + # We're going to mutate these fields, so lets copy them rather than + # assume they won't get used later. + newly_joined_or_invited_users = set(newly_joined_or_invited_users) + newly_left_users = set(newly_left_users) + if since_token and since_token.device_list_key: - changed = yield self.store.get_user_whose_devices_changed( - since_token.device_list_key + # We want to figure out what user IDs the client should refetch + # device keys for, and which users we aren't going to track changes + # for anymore. + # + # For the first step we check: + # a. if any users we share a room with have updated their devices, + # and + # b. we also check if we've joined any new rooms, or if a user has + # joined a room we're in. + # + # For the second step we just find any users we no longer share a + # room with by looking at all users that have left a room plus users + # that were in a room we've left. + + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) + + # Step 1a, check for changes in devices of users we share a room with + users_that_have_changed = yield self.store.get_users_whose_devices_changed( + since_token.device_list_key, users_who_share_room ) - # TODO: Be more clever than this, i.e. remove users who we already - # share a room with? + # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: joined_users = yield self.state.get_current_users_in_room(room_id) newly_joined_or_invited_users.update(joined_users) + # TODO: Check that these users are actually new, i.e. either they + # weren't in the previous sync *or* they left and rejoined. + users_that_have_changed.update(newly_joined_or_invited_users) + + # Now find users that we no longer track for room_id in newly_left_rooms: left_users = yield self.state.get_current_users_in_room(room_id) newly_left_users.update(left_users) - # TODO: Check that these users are actually new, i.e. either they - # weren't in the previous sync *or* they left and rejoined. - changed.update(newly_joined_or_invited_users) - - if not changed and not newly_left_users: - defer.returnValue(DeviceLists( - changed=[], - left=newly_left_users, - )) + # Remove any users that we still share a room with. + newly_left_users -= users_who_share_room - users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id + defer.returnValue( + DeviceLists(changed=users_that_have_changed, left=newly_left_users) ) - - defer.returnValue(DeviceLists( - changed=users_who_share_room & changed, - left=set(newly_left_users) - users_who_share_room, - )) else: - defer.returnValue(DeviceLists( - changed=[], - left=[], - )) + defer.returnValue(DeviceLists(changed=[], left=[])) @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): @@ -1109,8 +1155,9 @@ class SyncHandler(object): deleted = yield self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) - logger.debug("Deleted %d to-device messages up to %d", - deleted, since_stream_id) + logger.debug( + "Deleted %d to-device messages up to %d", deleted, since_stream_id + ) messages, stream_id = yield self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key @@ -1118,7 +1165,10 @@ class SyncHandler(object): logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", - len(messages), since_stream_id, stream_id, now_token.to_device_key + len(messages), + since_stream_id, + stream_id, + now_token.to_device_key, ) sync_result_builder.now_token = now_token.copy_and_replace( "to_device_key", stream_id @@ -1145,8 +1195,7 @@ class SyncHandler(object): if since_token and not sync_result_builder.full_state: account_data, account_data_by_room = ( yield self.store.get_updated_account_data_for_user( - user_id, - since_token.account_data_key, + user_id, since_token.account_data_key ) ) @@ -1160,27 +1209,28 @@ class SyncHandler(object): ) else: account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user( - sync_config.user.to_string() - ) + yield self.store.get_account_data_for_user(sync_config.user.to_string()) ) - account_data['m.push_rules'] = yield self.push_rules_for_user( + account_data["m.push_rules"] = yield self.push_rules_for_user( sync_config.user ) - account_data_for_user = sync_config.filter_collection.filter_account_data([ - {"type": account_data_type, "content": content} - for account_data_type, content in account_data.items() - ]) + account_data_for_user = sync_config.filter_collection.filter_account_data( + [ + {"type": account_data_type, "content": content} + for account_data_type, content in account_data.items() + ] + ) sync_result_builder.account_data = account_data_for_user defer.returnValue(account_data_by_room) @defer.inlineCallbacks - def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, - newly_joined_or_invited_users): + def _generate_sync_entry_for_presence( + self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users + ): """Generates the presence portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1223,17 +1273,13 @@ class SyncHandler(object): extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states( - extra_users_ids, - ) + states = yield self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user presence = list({p.user_id: p for p in presence}.values()) - presence = sync_config.filter_collection.filter_presence( - presence - ) + presence = sync_config.filter_collection.filter_presence(presence) sync_result_builder.presence = presence @@ -1253,8 +1299,8 @@ class SyncHandler(object): """ user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - sync_result_builder.since_token is None and - sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + sync_result_builder.since_token is None + and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) if block_all_room_ephemeral: @@ -1275,15 +1321,14 @@ class SyncHandler(object): have_changed = yield self._have_rooms_changed(sync_result_builder) if not have_changed: tags_by_room = yield self.store.get_updated_tags( - user_id, - since_token.account_data_key, + user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") defer.returnValue(([], [], [], [])) ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id, + "m.ignored_user_list", user_id=user_id ) if ignored_account_data: @@ -1296,7 +1341,7 @@ class SyncHandler(object): room_entries, invited, newly_joined_rooms, newly_left_rooms = res tags_by_room = yield self.store.get_updated_tags( - user_id, since_token.account_data_key, + user_id, since_token.account_data_key ) else: res = yield self._get_all_rooms(sync_result_builder, ignored_users) @@ -1331,8 +1376,8 @@ class SyncHandler(object): for event in it: if event.type == EventTypes.Member: if ( - event.membership == Membership.JOIN or - event.membership == Membership.INVITE + event.membership == Membership.JOIN + or event.membership == Membership.INVITE ): newly_joined_or_invited_users.add(event.state_key) else: @@ -1343,12 +1388,14 @@ class SyncHandler(object): newly_left_users -= newly_joined_or_invited_users - defer.returnValue(( - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, - newly_left_users, - )) + defer.returnValue( + ( + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, + ) + ) @defer.inlineCallbacks def _have_rooms_changed(self, sync_result_builder): @@ -1454,7 +1501,9 @@ class SyncHandler(object): prev_membership = old_mem_ev.membership issue4422_logger.debug( "Previous membership for room %s with join: %s (event %s)", - room_id, prev_membership, old_mem_ev_id, + room_id, + prev_membership, + old_mem_ev_id, ) if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: @@ -1476,8 +1525,7 @@ class SyncHandler(object): if not old_state_ids: old_state_ids = yield self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( - (EventTypes.Member, user_id), - None, + (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: @@ -1498,7 +1546,8 @@ class SyncHandler(object): # Always include leave/ban events. Just take the last one. # TODO: How do we handle ban -> leave in same batch? leave_events = [ - e for e in non_joins + e + for e in non_joins if e.membership in (Membership.LEAVE, Membership.BAN) ] @@ -1526,15 +1575,17 @@ class SyncHandler(object): else: batch_events = None - room_entries.append(RoomSyncResultBuilder( - room_id=room_id, - rtype="archived", - events=batch_events, - newly_joined=room_id in newly_joined_rooms, - full_state=False, - since_token=since_token, - upto_token=leave_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=room_id, + rtype="archived", + events=batch_events, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=leave_token, + ) + ) timeline_limit = sync_config.filter_collection.timeline_limit() @@ -1581,7 +1632,8 @@ class SyncHandler(object): # debugging for https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( "RoomSyncResultBuilder events for newly joined room %s: %r", - room_id, entry.events, + room_id, + entry.events, ) room_entries.append(entry) @@ -1606,12 +1658,14 @@ class SyncHandler(object): sync_config = sync_result_builder.sync_config membership_list = ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + Membership.BAN, ) room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, - membership_list=membership_list + user_id=user_id, membership_list=membership_list ) room_entries = [] @@ -1619,23 +1673,22 @@ class SyncHandler(object): for event in room_list: if event.membership == Membership.JOIN: - room_entries.append(RoomSyncResultBuilder( - room_id=event.room_id, - rtype="joined", - events=None, - newly_joined=False, - full_state=True, - since_token=since_token, - upto_token=now_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=event.room_id, + rtype="joined", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=now_token, + ) + ) elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) + invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. if not sync_config.filter_collection.include_leave: @@ -1646,22 +1699,31 @@ class SyncHandler(object): leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) - room_entries.append(RoomSyncResultBuilder( - room_id=event.room_id, - rtype="archived", - events=None, - newly_joined=False, - full_state=True, - since_token=since_token, - upto_token=leave_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=event.room_id, + rtype="archived", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=leave_token, + ) + ) defer.returnValue((room_entries, invited, [])) @defer.inlineCallbacks - def _generate_room_entry(self, sync_result_builder, ignored_users, - room_builder, ephemeral, tags, account_data, - always_include=False): + def _generate_room_entry( + self, + sync_result_builder, + ignored_users, + room_builder, + ephemeral, + tags, + account_data, + always_include=False, + ): """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. @@ -1678,9 +1740,7 @@ class SyncHandler(object): """ newly_joined = room_builder.newly_joined full_state = ( - room_builder.full_state - or newly_joined - or sync_result_builder.full_state + room_builder.full_state or newly_joined or sync_result_builder.full_state ) events = room_builder.events @@ -1697,7 +1757,8 @@ class SyncHandler(object): upto_token = room_builder.upto_token batch = yield self._load_filtered_recents( - room_id, sync_config, + room_id, + sync_config, now_token=upto_token, since_token=since_token, recents=events, @@ -1708,7 +1769,8 @@ class SyncHandler(object): # debug for https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( "Timeline events after filtering in newly-joined room %s: %r", - room_id, batch, + room_id, + batch, ) # When we join the room (or the client requests full_state), we should @@ -1726,16 +1788,10 @@ class SyncHandler(object): account_data_events = [] if tags is not None: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) account_data_events = sync_config.filter_collection.filter_room_account_data( account_data_events @@ -1743,16 +1799,13 @@ class SyncHandler(object): ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) - if not (always_include - or batch - or account_data_events - or ephemeral - or full_state): + if not ( + always_include or batch or account_data_events or ephemeral or full_state + ): return state = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, now_token, - full_state=full_state + room_id, batch, sync_config, since_token, now_token, full_state=full_state ) summary = {} @@ -1760,22 +1813,19 @@ class SyncHandler(object): # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form # the name itself). - if ( - sync_config.filter_collection.lazy_load_members() and - ( - # we recalulate the summary: - # if there are membership changes in the timeline, or - # if membership has changed during a gappy sync, or - # if this is an initial sync. - any(ev.type == EventTypes.Member for ev in batch.events) or - ( - # XXX: this may include false positives in the form of LL - # members which have snuck into state - batch.limited and - any(t == EventTypes.Member for (t, k) in state) - ) or - since_token is None + if sync_config.filter_collection.lazy_load_members() and ( + # we recalulate the summary: + # if there are membership changes in the timeline, or + # if membership has changed during a gappy sync, or + # if this is an initial sync. + any(ev.type == EventTypes.Member for ev in batch.events) + or ( + # XXX: this may include false positives in the form of LL + # members which have snuck into state + batch.limited + and any(t == EventTypes.Member for (t, k) in state) ) + or since_token is None ): summary = yield self.compute_summary( room_id, sync_config, batch, state, now_token @@ -1794,9 +1844,7 @@ class SyncHandler(object): ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config - ) + notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1807,11 +1855,8 @@ class SyncHandler(object): if batch.limited and since_token: user_id = sync_result_builder.sync_config.user.to_string() logger.info( - "Incremental gappy sync of %s for user %s with %d state events" % ( - room_id, - user_id, - len(state), - ) + "Incremental gappy sync of %s for user %s with %d state events" + % (room_id, user_id, len(state)) ) elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( @@ -1841,9 +1886,7 @@ class SyncHandler(object): Deferred[frozenset[str]]: Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering( - user_id, - ) + joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1862,11 +1905,9 @@ class SyncHandler(object): logger.info("User joined room after current token: %s", room_id) extrems = yield self.store.get_forward_extremeties_for_room( - room_id, stream_ordering, - ) - users_in_room = yield self.state.get_current_users_in_room( - room_id, extrems, + room_id, stream_ordering ) + users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) @@ -1886,7 +1927,7 @@ def _action_has_highlight(actions): def _calculate_state( - timeline_contains, timeline_start, previous, current, lazy_load_members, + timeline_contains, timeline_start, previous, current, lazy_load_members ): """Works out what state to include in a sync response. @@ -1930,15 +1971,12 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in iteritems(timeline_start) - if t[0] == EventTypes.Member + e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member ) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - return { - event_id_to_key[e]: e for e in state_ids - } + return {event_id_to_key[e]: e for e in state_ids} class SyncResultBuilder(object): @@ -1961,8 +1999,10 @@ class SyncResultBuilder(object): groups (GroupsSyncResult|None) to_device (list) """ - def __init__(self, sync_config, full_state, since_token, now_token, - joined_room_ids): + + def __init__( + self, sync_config, full_state, since_token, now_token, joined_room_ids + ): """ Args: sync_config (SyncConfig) @@ -1991,8 +2031,10 @@ class RoomSyncResultBuilder(object): """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. """ - def __init__(self, room_id, rtype, events, newly_joined, full_state, - since_token, upto_token): + + def __init__( + self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token + ): """ Args: room_id(str) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 972662eb48..f8062c8671 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -68,13 +68,10 @@ class TypingHandler(object): # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( - "TypingStreamChangeCache", self._latest_room_serial, + "TypingStreamChangeCache", self._latest_room_serial ) - self.clock.looping_call( - self._handle_timeouts, - 5000, - ) + self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self): """ @@ -108,19 +105,11 @@ class TypingHandler(object): if self.hs.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_in_background( - self._push_remote, - member=member, - typing=True - ) + run_in_background(self._push_remote, member=member, typing=True) # Add a paranoia timer to ensure that we always have a timer for # each person typing. - self.wheel_timer.insert( - now=now, - obj=member, - then=now + 60 * 1000, - ) + self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) @@ -138,9 +127,7 @@ class TypingHandler(object): yield self.auth.check_joined_room(room_id, target_user_id) - logger.debug( - "%s has started typing in %s", target_user_id, room_id - ) + logger.debug("%s has started typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) @@ -149,20 +136,13 @@ class TypingHandler(object): now = self.clock.time_msec() self._member_typing_until[member] = now + timeout - self.wheel_timer.insert( - now=now, - obj=member, - then=now + timeout, - ) + self.wheel_timer.insert(now=now, obj=member, then=now + timeout) if was_present: # No point sending another notification defer.returnValue(None) - self._push_update( - member=member, - typing=True, - ) + self._push_update(member=member, typing=True) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): @@ -177,9 +157,7 @@ class TypingHandler(object): yield self.auth.check_joined_room(room_id, target_user_id) - logger.debug( - "%s has stopped typing in %s", target_user_id, room_id - ) + logger.debug("%s has stopped typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) @@ -200,20 +178,14 @@ class TypingHandler(object): self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) - self._push_update( - member=member, - typing=False, - ) + self._push_update(member=member, typing=False) def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_in_background(self._push_remote, member, typing) - self._push_update_local( - member=member, - typing=typing - ) + self._push_update_local(member=member, typing=typing) @defer.inlineCallbacks def _push_remote(self, member, typing): @@ -223,9 +195,7 @@ class TypingHandler(object): now = self.clock.time_msec() self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_PING_INTERVAL, + now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) for domain in set(get_domain_from_id(u) for u in users): @@ -256,8 +226,7 @@ class TypingHandler(object): if user.domain != origin: logger.info( - "Got typing update from %r with bad 'user_id': %r", - origin, user_id, + "Got typing update from %r with bad 'user_id': %r", origin, user_id ) return @@ -268,15 +237,8 @@ class TypingHandler(object): logger.info("Got typing update from %s: %r", user_id, content) now = self.clock.time_msec() self._member_typing_until[member] = now + FEDERATION_TIMEOUT - self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_TIMEOUT, - ) - self._push_update_local( - member=member, - typing=content["typing"] - ) + self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) + self._push_update_local(member=member, typing=content["typing"]) def _push_update_local(self, member, typing): room_set = self._room_typing.setdefault(member.room_id, set()) @@ -288,7 +250,7 @@ class TypingHandler(object): self._latest_room_serial += 1 self._room_serials[member.room_id] = self._latest_room_serial self._typing_stream_change_cache.entity_has_changed( - member.room_id, self._latest_room_serial, + member.room_id, self._latest_room_serial ) self.notifier.on_new_event( @@ -300,7 +262,7 @@ class TypingHandler(object): return [] changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( - last_id, + last_id ) if changed_rooms is None: @@ -334,9 +296,7 @@ class TypingNotificationEventSource(object): return { "type": "m.typing", "room_id": room_id, - "content": { - "user_ids": list(typing), - }, + "content": {"user_ids": list(typing)}, } def get_new_events(self, from_key, room_ids, **kwargs): |