diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index dca337ec61..c29c78bd65 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -94,14 +94,15 @@ class BaseHandler(object):
burst_count = self.hs.config.rc_message.burst_count
allowed, time_allowed = self.ratelimiter.can_do_action(
- user_id, time_now,
+ user_id,
+ time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now))
)
@defer.inlineCallbacks
@@ -139,7 +140,7 @@ class BaseHandler(object):
if member_event.content["membership"] not in {
Membership.JOIN,
- Membership.INVITE
+ Membership.INVITE,
}:
continue
@@ -156,8 +157,7 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = synapse.types.create_requester(
- target_user, is_guest=True)
+ requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
yield handler.update_membership(
requester,
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 7fa5d44d29..e62e6cab77 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -20,7 +20,7 @@ class AccountDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- def get_current_key(self, direction='f'):
+ def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks
@@ -34,29 +34,22 @@ class AccountDataEventSource(object):
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
- results.append({
- "type": "m.tag",
- "content": {"tags": room_tags},
- "room_id": room_id,
- })
+ results.append(
+ {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
+ )
account_data, room_account_data = (
yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
)
for account_data_type, content in account_data.items():
- results.append({
- "type": account_data_type,
- "content": content,
- })
+ results.append({"type": account_data_type, "content": content})
for room_id, account_data in room_account_data.items():
for account_data_type, content in account_data.items():
- results.append({
- "type": account_data_type,
- "content": content,
- "room_id": room_id,
- })
+ results.append(
+ {"type": account_data_type, "content": content, "room_id": room_id}
+ )
defer.returnValue((results, current_stream_id))
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 261446517d..0719da3ab7 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -49,12 +49,10 @@ class AccountValidityHandler(object):
app_name = self.hs.config.email_app_name
self._subject = self._account_validity.renew_email_subject % {
- "app": app_name,
+ "app": app_name
}
- self._from_string = self.hs.config.email_notif_from % {
- "app": app_name,
- }
+ self._from_string = self.hs.config.email_notif_from % {"app": app_name}
except Exception:
# If substitution failed, fall back to the bare strings.
self._subject = self._account_validity.renew_email_subject
@@ -69,10 +67,7 @@ class AccountValidityHandler(object):
)
# Check the renewal emails to send and send them every 30min.
- self.clock.looping_call(
- self.send_renewal_emails,
- 30 * 60 * 1000,
- )
+ self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000)
@defer.inlineCallbacks
def send_renewal_emails(self):
@@ -86,8 +81,7 @@ class AccountValidityHandler(object):
if expiring_users:
for user in expiring_users:
yield self._send_renewal_email(
- user_id=user["user_id"],
- expiration_ts=user["expiration_ts_ms"],
+ user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
@defer.inlineCallbacks
@@ -110,6 +104,9 @@ class AccountValidityHandler(object):
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
# account manually.
+ # We don't need to do a specific check to make sure the account isn't
+ # deactivated, as a deactivated account isn't supposed to have any
+ # email address attached to it.
if not addresses:
return
@@ -143,32 +140,33 @@ class AccountValidityHandler(object):
for address in addresses:
raw_to = email.utils.parseaddr(address)[1]
- multipart_msg = MIMEMultipart('alternative')
- multipart_msg['Subject'] = self._subject
- multipart_msg['From'] = self._from_string
- multipart_msg['To'] = address
- multipart_msg['Date'] = email.utils.formatdate()
- multipart_msg['Message-ID'] = email.utils.make_msgid()
+ multipart_msg = MIMEMultipart("alternative")
+ multipart_msg["Subject"] = self._subject
+ multipart_msg["From"] = self._from_string
+ multipart_msg["To"] = address
+ multipart_msg["Date"] = email.utils.formatdate()
+ multipart_msg["Message-ID"] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
logger.info("Sending renewal email to %s", address)
- yield make_deferred_yieldable(self.sendmail(
- self.hs.config.email_smtp_host,
- self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
- reactor=self.hs.get_reactor(),
- port=self.hs.config.email_smtp_port,
- requireAuthentication=self.hs.config.email_smtp_user is not None,
- username=self.hs.config.email_smtp_user,
- password=self.hs.config.email_smtp_pass,
- requireTransportSecurity=self.hs.config.require_transport_security
- ))
-
- yield self.store.set_renewal_mail_status(
- user_id=user_id,
- email_sent=True,
- )
+ yield make_deferred_yieldable(
+ self.sendmail(
+ self.hs.config.email_smtp_host,
+ self._raw_from,
+ raw_to,
+ multipart_msg.as_string().encode("utf8"),
+ reactor=self.hs.get_reactor(),
+ port=self.hs.config.email_smtp_port,
+ requireAuthentication=self.hs.config.email_smtp_user is not None,
+ username=self.hs.config.email_smtp_user,
+ password=self.hs.config.email_smtp_pass,
+ requireTransportSecurity=self.hs.config.require_transport_security,
+ )
+ )
+
+ yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
@defer.inlineCallbacks
def _get_email_addresses_for_user(self, user_id):
@@ -245,9 +243,7 @@ class AccountValidityHandler(object):
expiration_ts = self.clock.time_msec() + self._account_validity.period
yield self.store.set_account_validity_for_user(
- user_id=user_id,
- expiration_ts=expiration_ts,
- email_sent=email_sent,
+ user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
defer.returnValue(expiration_ts)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 813777bf18..fbef2f3d38 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -15,14 +15,9 @@
import logging
-import attr
-from zope.interface import implementer
-
import twisted
import twisted.internet.error
from twisted.internet import defer
-from twisted.python.filepath import FilePath
-from twisted.python.url import URL
from twisted.web import server, static
from twisted.web.resource import Resource
@@ -30,27 +25,6 @@ from synapse.app import check_bind_error
logger = logging.getLogger(__name__)
-try:
- from txacme.interfaces import ICertificateStore
-
- @attr.s
- @implementer(ICertificateStore)
- class ErsatzStore(object):
- """
- A store that only stores in memory.
- """
-
- certs = attr.ib(default=attr.Factory(dict))
-
- def store(self, server_name, pem_objects):
- self.certs[server_name] = [o.as_bytes() for o in pem_objects]
- return defer.succeed(None)
-
-
-except ImportError:
- # txacme is missing
- pass
-
class AcmeHandler(object):
def __init__(self, hs):
@@ -60,6 +34,7 @@ class AcmeHandler(object):
@defer.inlineCallbacks
def start_listening(self):
+ from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
# from eliot import add_destinations
@@ -67,50 +42,27 @@ class AcmeHandler(object):
#
# add_destinations(TwistedDestination())
- from txacme.challenges import HTTP01Responder
- from txacme.service import AcmeIssuingService
- from txacme.endpoint import load_or_create_client_key
- from txacme.client import Client
- from josepy.jwa import RS256
-
- self._store = ErsatzStore()
- responder = HTTP01Responder()
-
- self._issuer = AcmeIssuingService(
- cert_store=self._store,
- client_creator=(
- lambda: Client.from_url(
- reactor=self.reactor,
- url=URL.from_text(self.hs.config.acme_url),
- key=load_or_create_client_key(
- FilePath(self.hs.config.config_dir_path)
- ),
- alg=RS256,
- )
- ),
- clock=self.reactor,
- responders=[responder],
+ well_known = Resource()
+
+ self._issuer = acme_issuing_service.create_issuing_service(
+ self.reactor,
+ acme_url=self.hs.config.acme_url,
+ account_key_file=self.hs.config.acme_account_key_file,
+ well_known_resource=well_known,
)
- well_known = Resource()
- well_known.putChild(b'acme-challenge', responder.resource)
responder_resource = Resource()
- responder_resource.putChild(b'.well-known', well_known)
- responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
-
+ responder_resource.putChild(b".well-known", well_known)
+ responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain"))
srv = server.Site(responder_resource)
bind_addresses = self.hs.config.acme_bind_addresses
for host in bind_addresses:
logger.info(
- "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port,
+ "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
- self.reactor.listenTCP(
- self.hs.config.acme_port,
- srv,
- interface=host,
- )
+ self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
@@ -132,7 +84,7 @@ class AcmeHandler(object):
logger.exception("Fail!")
raise
logger.warning("Reprovisioned %s, saving.", self._acme_domain)
- cert_chain = self._store.certs[self._acme_domain]
+ cert_chain = self._issuer.cert_store.certs[self._acme_domain]
try:
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
new file mode 100644
index 0000000000..e1d4224e74
--- /dev/null
+++ b/synapse/handlers/acme_issuing_service.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utility function to create an ACME issuing service.
+
+This file contains the unconditional imports on the acme and cryptography bits that we
+only need (and may only have available) if we are doing ACME, so is designed to be
+imported conditionally.
+"""
+import logging
+
+import attr
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from josepy import JWKRSA
+from josepy.jwa import RS256
+from txacme.challenges import HTTP01Responder
+from txacme.client import Client
+from txacme.interfaces import ICertificateStore
+from txacme.service import AcmeIssuingService
+from txacme.util import generate_private_key
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.python.filepath import FilePath
+from twisted.python.url import URL
+
+logger = logging.getLogger(__name__)
+
+
+def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+ """Create an ACME issuing service, and attach it to a web Resource
+
+ Args:
+ reactor: twisted reactor
+ acme_url (str): URL to use to request certificates
+ account_key_file (str): where to store the account key
+ well_known_resource (twisted.web.IResource): web resource for .well-known.
+ we will attach a child resource for "acme-challenge".
+
+ Returns:
+ AcmeIssuingService
+ """
+ responder = HTTP01Responder()
+
+ well_known_resource.putChild(b"acme-challenge", responder.resource)
+
+ store = ErsatzStore()
+
+ return AcmeIssuingService(
+ cert_store=store,
+ client_creator=(
+ lambda: Client.from_url(
+ reactor=reactor,
+ url=URL.from_text(acme_url),
+ key=load_or_create_client_key(account_key_file),
+ alg=RS256,
+ )
+ ),
+ clock=reactor,
+ responders=[responder],
+ )
+
+
+@attr.s
+@implementer(ICertificateStore)
+class ErsatzStore(object):
+ """
+ A store that only stores in memory.
+ """
+
+ certs = attr.ib(default=attr.Factory(dict))
+
+ def store(self, server_name, pem_objects):
+ self.certs[server_name] = [o.as_bytes() for o in pem_objects]
+ return defer.succeed(None)
+
+
+def load_or_create_client_key(key_file):
+ """Load the ACME account key from a file, creating it if it does not exist.
+
+ Args:
+ key_file (str): name of the file to use as the account key
+ """
+ # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
+ # hardcode the 'client.key' filename
+ acme_key_file = FilePath(key_file)
+ if acme_key_file.exists():
+ logger.info("Loading ACME account key from '%s'", acme_key_file)
+ key = serialization.load_pem_private_key(
+ acme_key_file.getContent(), password=None, backend=default_backend()
+ )
+ else:
+ logger.info("Saving new ACME account key to '%s'", acme_key_file)
+ key = generate_private_key("rsa")
+ acme_key_file.setContent(
+ key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+ )
+ return JWKRSA(key=key)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 5d629126fc..941ebfa107 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
-
def __init__(self, hs):
super(AdminHandler, self).__init__(hs)
@@ -33,23 +32,17 @@ class AdminHandler(BaseHandler):
sessions = yield self.store.get_user_ip_and_agents(user)
for session in sessions:
- connections.append({
- "ip": session["ip"],
- "last_seen": session["last_seen"],
- "user_agent": session["user_agent"],
- })
+ connections.append(
+ {
+ "ip": session["ip"],
+ "last_seen": session["last_seen"],
+ "user_agent": session["user_agent"],
+ }
+ )
ret = {
"user_id": user.to_string(),
- "devices": {
- "": {
- "sessions": [
- {
- "connections": connections,
- }
- ]
- },
- },
+ "devices": {"": {"sessions": [{"connections": connections}]}},
}
defer.returnValue(ret)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 17eedf4dbf..5cc89d43f6 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -38,7 +38,6 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed
class ApplicationServicesHandler(object):
-
def __init__(self, hs):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
@@ -101,9 +100,10 @@ class ApplicationServicesHandler(object):
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
+
def start_scheduler():
return self.scheduler.start().addErrback(
- log_failure, "Application Services Failure",
+ log_failure, "Application Services Failure"
)
run_as_background_process("as_scheduler", start_scheduler)
@@ -118,10 +118,15 @@ class ApplicationServicesHandler(object):
for event in events:
yield handle_event(event)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
- ], consumeErrors=True))
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ],
+ consumeErrors=True,
+ )
+ )
yield self.store.set_appservice_last_pos(upper_bound)
@@ -129,20 +134,23 @@ class ApplicationServicesHandler(object):
ts = yield self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_positions.labels(
- "appservice_sender").set(upper_bound)
+ "appservice_sender"
+ ).set(upper_bound)
events_processed_counter.inc(len(events))
- event_processing_loop_room_count.labels(
- "appservice_sender"
- ).inc(len(events_by_room))
+ event_processing_loop_room_count.labels("appservice_sender").inc(
+ len(events_by_room)
+ )
event_processing_loop_counter.labels("appservice_sender").inc()
synapse.metrics.event_processing_lag.labels(
- "appservice_sender").set(now - ts)
+ "appservice_sender"
+ ).set(now - ts)
synapse.metrics.event_processing_last_ts.labels(
- "appservice_sender").set(ts)
+ "appservice_sender"
+ ).set(ts)
finally:
self.is_processing = False
@@ -155,13 +163,9 @@ class ApplicationServicesHandler(object):
Returns:
True if this user exists on at least one application service.
"""
- user_query_services = yield self._get_services_for_user(
- user_id=user_id
- )
+ user_query_services = yield self._get_services_for_user(user_id=user_id)
for user_service in user_query_services:
- is_known_user = yield self.appservice_api.query_user(
- user_service, user_id
- )
+ is_known_user = yield self.appservice_api.query_user(user_service, user_id)
if is_known_user:
defer.returnValue(True)
defer.returnValue(False)
@@ -179,9 +183,7 @@ class ApplicationServicesHandler(object):
room_alias_str = room_alias.to_string()
services = self.store.get_app_services()
alias_query_services = [
- s for s in services if (
- s.is_interested_in_alias(room_alias_str)
- )
+ s for s in services if (s.is_interested_in_alias(room_alias_str))
]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
@@ -189,22 +191,24 @@ class ApplicationServicesHandler(object):
)
if is_known_alias:
# the alias exists now so don't query more ASes.
- result = yield self.store.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.store.get_association_from_room_alias(room_alias)
defer.returnValue(result)
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
- results = yield make_deferred_yieldable(defer.DeferredList([
- run_in_background(
- self.appservice_api.query_3pe,
- service, kind, protocol, fields,
+ results = yield make_deferred_yieldable(
+ defer.DeferredList(
+ [
+ run_in_background(
+ self.appservice_api.query_3pe, service, kind, protocol, fields
+ )
+ for service in services
+ ],
+ consumeErrors=True,
)
- for service in services
- ], consumeErrors=True))
+ )
ret = []
for (success, result) in results:
@@ -276,18 +280,12 @@ class ApplicationServicesHandler(object):
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
- interested_list = [
- s for s in services if (
- s.is_interested_in_user(user_id)
- )
- ]
+ interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
return defer.succeed(interested_list)
def _get_services_for_3pn(self, protocol):
services = self.store.get_app_services()
- interested_list = [
- s for s in services if s.is_interested_in_protocol(protocol)
- ]
+ interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
return defer.succeed(interested_list)
@defer.inlineCallbacks
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a0cf37a9f9..97b21c4093 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -134,13 +134,9 @@ class AuthHandler(BaseHandler):
"""
# build a list of supported flows
- flows = [
- [login_type] for login_type in self._supported_login_types
- ]
+ flows = [[login_type] for login_type in self._supported_login_types]
- result, params, _ = yield self.check_auth(
- flows, request_body, clientip,
- )
+ result, params, _ = yield self.check_auth(flows, request_body, clientip)
# find the completed login type
for login_type in self._supported_login_types:
@@ -151,9 +147,7 @@ class AuthHandler(BaseHandler):
break
else:
# this can't happen
- raise Exception(
- "check_auth returned True but no successful login type",
- )
+ raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
@@ -215,11 +209,11 @@ class AuthHandler(BaseHandler):
authdict = None
sid = None
- if clientdict and 'auth' in clientdict:
- authdict = clientdict['auth']
- del clientdict['auth']
- if 'session' in authdict:
- sid = authdict['session']
+ if clientdict and "auth" in clientdict:
+ authdict = clientdict["auth"]
+ del clientdict["auth"]
+ if "session" in authdict:
+ sid = authdict["session"]
session = self._get_session_info(sid)
if len(clientdict) > 0:
@@ -232,27 +226,27 @@ class AuthHandler(BaseHandler):
# on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
- session['clientdict'] = clientdict
+ session["clientdict"] = clientdict
self._save_session(session)
- elif 'clientdict' in session:
- clientdict = session['clientdict']
+ elif "clientdict" in session:
+ clientdict = session["clientdict"]
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session),
+ self._auth_dict_for_flows(flows, session)
)
- if 'creds' not in session:
- session['creds'] = {}
- creds = session['creds']
+ if "creds" not in session:
+ session["creds"] = {}
+ creds = session["creds"]
# check auth type currently being presented
errordict = {}
- if 'type' in authdict:
- login_type = authdict['type']
+ if "type" in authdict:
+ login_type = authdict["type"]
try:
result = yield self._check_auth_dict(
- authdict, clientip, password_servlet=password_servlet,
+ authdict, clientip, password_servlet=password_servlet
)
if result:
creds[login_type] = result
@@ -281,16 +275,15 @@ class AuthHandler(BaseHandler):
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
- creds, list(clientdict)
+ creds,
+ list(clientdict),
)
- defer.returnValue((creds, clientdict, session['id']))
+ defer.returnValue((creds, clientdict, session["id"]))
ret = self._auth_dict_for_flows(flows, session)
- ret['completed'] = list(creds)
+ ret["completed"] = list(creds)
ret.update(errordict)
- raise InteractiveAuthIncompleteError(
- ret,
- )
+ raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@@ -300,15 +293,13 @@ class AuthHandler(BaseHandler):
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
- if 'session' not in authdict:
+ if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
- sess = self._get_session_info(
- authdict['session']
- )
- if 'creds' not in sess:
- sess['creds'] = {}
- creds = sess['creds']
+ sess = self._get_session_info(authdict["session"])
+ if "creds" not in sess:
+ sess["creds"] = {}
+ creds = sess["creds"]
result = yield self.checkers[stagetype](authdict, clientip)
if result:
@@ -329,10 +320,10 @@ class AuthHandler(BaseHandler):
not send a session ID, returns None.
"""
sid = None
- if clientdict and 'auth' in clientdict:
- authdict = clientdict['auth']
- if 'session' in authdict:
- sid = authdict['session']
+ if clientdict and "auth" in clientdict:
+ authdict = clientdict["auth"]
+ if "session" in authdict:
+ sid = authdict["session"]
return sid
def set_session_data(self, session_id, key, value):
@@ -347,7 +338,7 @@ class AuthHandler(BaseHandler):
value (any): The data to store
"""
sess = self._get_session_info(session_id)
- sess.setdefault('serverdict', {})[key] = value
+ sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
def get_session_data(self, session_id, key, default=None):
@@ -360,7 +351,7 @@ class AuthHandler(BaseHandler):
default (any): Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
- return sess.setdefault('serverdict', {}).get(key, default)
+ return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
def _check_auth_dict(self, authdict, clientip, password_servlet=False):
@@ -378,15 +369,13 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
- login_type = authdict['type']
+ login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
# XXX: Temporary workaround for having Synapse handle password resets
# See AuthHandler.check_auth for further details
res = yield checker(
- authdict,
- clientip=clientip,
- password_servlet=password_servlet,
+ authdict, clientip=clientip, password_servlet=password_servlet
)
defer.returnValue(res)
@@ -408,13 +397,11 @@ class AuthHandler(BaseHandler):
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
- 400, "Captcha response is required",
- errcode=Codes.CAPTCHA_NEEDED
+ 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
- "Submitting recaptcha response %s with remoteip %s",
- user_response, clientip
+ "Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
@@ -424,34 +411,34 @@ class AuthHandler(BaseHandler):
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
- 'secret': self.hs.config.recaptcha_private_key,
- 'response': user_response,
- 'remoteip': clientip,
- }
+ "secret": self.hs.config.recaptcha_private_key,
+ "response": user_response,
+ "remoteip": clientip,
+ },
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
- if 'success' in resp_body:
+ if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
- "Successful" if resp_body['success'] else "Failed",
- resp_body.get('hostname')
+ "Successful" if resp_body["success"] else "Failed",
+ resp_body.get("hostname"),
)
- if resp_body['success']:
+ if resp_body["success"]:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
def _check_email_identity(self, authdict, **kwargs):
- return self._check_threepid('email', authdict, **kwargs)
+ return self._check_threepid("email", authdict, **kwargs)
def _check_msisdn(self, authdict, **kwargs):
- return self._check_threepid('msisdn', authdict)
+ return self._check_threepid("msisdn", authdict)
def _check_dummy_auth(self, authdict, **kwargs):
return defer.succeed(True)
@@ -461,10 +448,10 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
- if 'threepid_creds' not in authdict:
+ if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
- threepid_creds = authdict['threepid_creds']
+ threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_handlers().identity_handler
@@ -482,31 +469,36 @@ class AuthHandler(BaseHandler):
validated=True,
)
- threepid = {
- "medium": row["medium"],
- "address": row["address"],
- "validated_at": row["validated_at"],
- } if row else None
+ threepid = (
+ {
+ "medium": row["medium"],
+ "address": row["address"],
+ "validated_at": row["validated_at"],
+ }
+ if row
+ else None
+ )
if row:
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
else:
- raise SynapseError(400, "Password resets are not enabled on this homeserver")
+ raise SynapseError(
+ 400, "Password resets are not enabled on this homeserver"
+ )
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
- if threepid['medium'] != medium:
+ if threepid["medium"] != medium:
raise LoginError(
401,
- "Expecting threepid of type '%s', got '%s'" % (
- medium, threepid['medium'],
- ),
- errcode=Codes.UNAUTHORIZED
+ "Expecting threepid of type '%s', got '%s'"
+ % (medium, threepid["medium"]),
+ errcode=Codes.UNAUTHORIZED,
)
- threepid['threepid_creds'] = authdict['threepid_creds']
+ threepid["threepid_creds"] = authdict["threepid_creds"]
defer.returnValue(threepid)
@@ -520,13 +512,14 @@ class AuthHandler(BaseHandler):
"version": self.hs.config.user_consent_version,
"en": {
"name": self.hs.config.user_consent_policy_name,
- "url": "%s_matrix/consent?v=%s" % (
+ "url": "%s_matrix/consent?v=%s"
+ % (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
},
- },
- },
+ }
+ }
}
def _auth_dict_for_flows(self, flows, session):
@@ -547,9 +540,9 @@ class AuthHandler(BaseHandler):
params[stage] = get_params[stage]()
return {
- "session": session['id'],
+ "session": session["id"],
"flows": [{"stages": f} for f in public_flows],
- "params": params
+ "params": params,
}
def _get_session_info(self, session_id):
@@ -560,9 +553,7 @@ class AuthHandler(BaseHandler):
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
- self.sessions[session_id] = {
- "id": session_id,
- }
+ self.sessions[session_id] = {"id": session_id}
return self.sessions[session_id]
@@ -652,7 +643,8 @@ class AuthHandler(BaseHandler):
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
- user_id, user_infos.keys()
+ user_id,
+ user_infos.keys(),
)
defer.returnValue(result)
@@ -690,12 +682,10 @@ class AuthHandler(BaseHandler):
user is too high too proceed.
"""
- if username.startswith('@'):
+ if username.startswith("@"):
qualified_user_id = username
else:
- qualified_user_id = UserID(
- username, self.hs.hostname
- ).to_string()
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
self.ratelimit_login_per_account(qualified_user_id)
@@ -713,17 +703,15 @@ class AuthHandler(BaseHandler):
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
- if (hasattr(provider, "check_password")
- and login_type == LoginType.PASSWORD):
+ if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
- is_valid = yield provider.check_password(
- qualified_user_id, password,
- )
+ is_valid = yield provider.check_password(qualified_user_id, password)
if is_valid:
defer.returnValue((qualified_user_id, None))
- if (not hasattr(provider, "get_supported_login_types")
- or not hasattr(provider, "check_auth")):
+ if not hasattr(provider, "get_supported_login_types") or not hasattr(
+ provider, "check_auth"
+ ):
# this password provider doesn't understand custom login types
continue
@@ -744,15 +732,12 @@ class AuthHandler(BaseHandler):
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
- 400, "Missing parameters for login type %s: %s" % (
- login_type,
- missing_fields,
- ),
+ 400,
+ "Missing parameters for login type %s: %s"
+ % (login_type, missing_fields),
)
- result = yield provider.check_auth(
- username, login_type, login_dict,
- )
+ result = yield provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
@@ -762,7 +747,7 @@ class AuthHandler(BaseHandler):
known_login_type = True
canonical_user_id = yield self._check_local_password(
- qualified_user_id, password,
+ qualified_user_id, password
)
if canonical_user_id:
@@ -773,7 +758,8 @@ class AuthHandler(BaseHandler):
# unknown username or invalid password.
self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), time_now_s=self._clock.time(),
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
@@ -781,10 +767,7 @@ class AuthHandler(BaseHandler):
# We raise a 403 here, but note that if we're doing user-interactive
# login, it turns all LoginErrors into a 401 anyway.
- raise LoginError(
- 403, "Invalid password",
- errcode=Codes.FORBIDDEN
- )
+ raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password):
@@ -810,9 +793,7 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
- result = yield provider.check_3pid_auth(
- medium, address, password,
- )
+ result = yield provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
@@ -853,8 +834,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None):
access_token = self.macaroon_gen.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(user_id, access_token,
- device_id)
+ yield self.store.add_access_token_to_user(user_id, access_token, device_id)
defer.returnValue(access_token)
@defer.inlineCallbacks
@@ -896,12 +876,13 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
- str(user_info["user"]), (user_info["token_id"], )
+ str(user_info["user"]), (user_info["token_id"],)
)
@defer.inlineCallbacks
- def delete_access_tokens_for_user(self, user_id, except_token_id=None,
- device_id=None):
+ def delete_access_tokens_for_user(
+ self, user_id, except_token_id=None, device_id=None
+ ):
"""Invalidate access tokens belonging to a user
Args:
@@ -915,7 +896,7 @@ class AuthHandler(BaseHandler):
Deferred
"""
tokens_and_devices = yield self.store.user_delete_access_tokens(
- user_id, except_token_id=except_token_id, device_id=device_id,
+ user_id, except_token_id=except_token_id, device_id=device_id
)
# see if any of our auth providers want to know about this
@@ -923,14 +904,12 @@ class AuthHandler(BaseHandler):
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out(
- user_id=user_id,
- device_id=device_id,
- access_token=token,
+ user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
- user_id, (token_id for _, token_id, _ in tokens_and_devices),
+ user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
@defer.inlineCallbacks
@@ -944,12 +923,11 @@ class AuthHandler(BaseHandler):
# of specific types of threepid (and fixes the fact that checking
# for the presence of an email address during password reset was
# case sensitive).
- if medium == 'email':
+ if medium == "email":
address = address.lower()
yield self.store.user_add_threepid(
- user_id, medium, address, validated_at,
- self.hs.get_clock().time_msec()
+ user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
@defer.inlineCallbacks
@@ -973,22 +951,15 @@ class AuthHandler(BaseHandler):
"""
# 'Canonicalise' email addresses as per above
- if medium == 'email':
+ if medium == "email":
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
result = yield identity_handler.try_unbind_threepid(
- user_id,
- {
- 'medium': medium,
- 'address': address,
- 'id_server': id_server,
- },
+ user_id, {"medium": medium, "address": address, "id_server": id_server}
)
- yield self.store.user_delete_threepid(
- user_id, medium, address,
- )
+ yield self.store.user_delete_threepid(user_id, medium, address)
defer.returnValue(result)
def _save_session(self, session):
@@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler):
Returns:
Deferred(unicode): Hashed password.
"""
+
def _do_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
- pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
+ pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
- ).decode('ascii')
+ ).decode("ascii")
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
@@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler):
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
+
def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
- pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
- stored_hash
+ pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+ stored_hash,
)
if stored_hash:
if not isinstance(stored_hash, bytes):
- stored_hash = stored_hash.encode('ascii')
+ stored_hash = stored_hash.encode("ascii")
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
@@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler):
for this user is too high too proceed.
"""
self._failed_attempts_ratelimiter.ratelimit(
- user_id.lower(), time_now_s=self._clock.time(),
+ user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
self._account_ratelimiter.ratelimit(
- user_id.lower(), time_now_s=self._clock.time(),
+ user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
update=True,
@@ -1083,9 +1058,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
- macaroon.add_first_party_caveat("nonce = %s" % (
- stringutils.random_string_with_symbols(16),
- ))
+ macaroon.add_first_party_caveat(
+ "nonce = %s" % (stringutils.random_string_with_symbols(16),)
+ )
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
@@ -1116,7 +1091,8 @@ class MacaroonGenerator(object):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
- key=self.hs.config.macaroon_secret_key)
+ key=self.hs.config.macaroon_secret_key,
+ )
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 6a91f7698e..e8f9da6098 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
+
def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
@@ -42,6 +44,8 @@ class DeactivateAccountHandler(BaseHandler):
# it left off (if it has work left to do).
hs.get_reactor().callWhenRunning(self._start_user_parting)
+ self._account_validity_enabled = hs.config.account_validity.enabled
+
@defer.inlineCallbacks
def deactivate_account(self, user_id, erase_data, id_server=None):
"""Deactivate a user's account
@@ -75,9 +79,9 @@ class DeactivateAccountHandler(BaseHandler):
result = yield self._identity_handler.try_unbind_threepid(
user_id,
{
- 'medium': threepid['medium'],
- 'address': threepid['address'],
- 'id_server': id_server,
+ "medium": threepid["medium"],
+ "address": threepid["address"],
+ "id_server": id_server,
},
)
identity_server_supports_unbinding &= result
@@ -86,7 +90,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.exception("Failed to remove threepid from ID server")
raise SynapseError(400, "Failed to remove threepid from ID server")
yield self.store.user_delete_threepid(
- user_id, threepid['medium'], threepid['address'],
+ user_id, threepid["medium"], threepid["address"]
)
# delete any devices belonging to the user, which will also
@@ -114,6 +118,13 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running)
self._start_user_parting()
+ # Remove all information on the user from the account_validity table.
+ if self._account_validity_enabled:
+ yield self.store.delete_account_validity_for_user(user_id)
+
+ # Mark the user as deactivated.
+ yield self.store.set_user_deactivated_status(user_id, True)
+
defer.returnValue(identity_server_supports_unbinding)
def _start_user_parting(self):
@@ -173,5 +184,6 @@ class DeactivateAccountHandler(BaseHandler):
except Exception:
logger.exception(
"Failed to part user %r from room %r: ignoring and continuing",
- user_id, room_id,
+ user_id,
+ room_id,
)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b398848079..f59d0479b5 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -58,9 +58,7 @@ class DeviceWorkerHandler(BaseHandler):
device_map = yield self.store.get_devices_by_user(user_id)
- ips = yield self.store.get_last_client_ip_by_device(
- user_id, device_id=None
- )
+ ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values())
for device in devices:
@@ -85,9 +83,7 @@ class DeviceWorkerHandler(BaseHandler):
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
- ips = yield self.store.get_last_client_ip_by_device(
- user_id, device_id,
- )
+ ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@@ -114,13 +110,11 @@ class DeviceWorkerHandler(BaseHandler):
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
- user_id, from_token.room_key, now_room_key,
+ user_id, from_token.room_key, now_room_key
)
rooms_changed.update(event.room_id for event in member_events)
- stream_ordering = RoomStreamToken.parse_stream_token(
- from_token.room_key
- ).stream
+ stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
possibly_changed = set(changed)
possibly_left = set()
@@ -206,10 +200,9 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = []
possibly_left = []
- defer.returnValue({
- "changed": list(possibly_joined),
- "left": list(possibly_left),
- })
+ defer.returnValue(
+ {"changed": list(possibly_joined), "left": list(possibly_left)}
+ )
class DeviceHandler(DeviceWorkerHandler):
@@ -223,17 +216,18 @@ class DeviceHandler(DeviceWorkerHandler):
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
- "m.device_list_update", self._edu_updater.incoming_device_list_update,
+ "m.device_list_update", self._edu_updater.incoming_device_list_update
)
federation_registry.register_query_handler(
- "user_devices", self.on_federation_query_user_devices,
+ "user_devices", self.on_federation_query_user_devices
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks
- def check_device_registered(self, user_id, device_id,
- initial_device_display_name=None):
+ def check_device_registered(
+ self, user_id, device_id, initial_device_display_name=None
+ ):
"""
If the given device has not been registered, register it with the
supplied display name.
@@ -297,12 +291,10 @@ class DeviceHandler(DeviceWorkerHandler):
raise
yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id,
+ user_id, device_id=device_id
)
- yield self.store.delete_e2e_keys_by_device(
- user_id=user_id, device_id=device_id
- )
+ yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
yield self.notify_device_update(user_id, [device_id])
@@ -349,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler):
# considered as part of a critical path.
for device_id in device_ids:
yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id,
+ user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
@@ -372,9 +364,7 @@ class DeviceHandler(DeviceWorkerHandler):
try:
yield self.store.update_device(
- user_id,
- device_id,
- new_display_name=content.get("display_name")
+ user_id, device_id, new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError as e:
@@ -404,29 +394,26 @@ class DeviceHandler(DeviceWorkerHandler):
for device_id in device_ids:
logger.debug(
- "Notifying about update %r/%r, ID: %r", user_id, device_id,
- position,
+ "Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
room_ids = yield self.store.get_rooms_for_user(user_id)
- yield self.notifier.on_new_event(
- "device_list_key", position, rooms=room_ids,
- )
+ yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
if hosts:
- logger.info("Sending device list update notif for %r to: %r", user_id, hosts)
+ logger.info(
+ "Sending device list update notif for %r to: %r", user_id, hosts
+ )
for host in hosts:
self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- defer.returnValue({
- "user_id": user_id,
- "stream_id": stream_id,
- "devices": devices,
- })
+ defer.returnValue(
+ {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+ )
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -440,10 +427,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({
- "last_seen_ts": ip.get("last_seen"),
- "last_seen_ip": ip.get("ip"),
- })
+ device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
class DeviceListEduUpdater(object):
@@ -481,13 +465,15 @@ class DeviceListEduUpdater(object):
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
prev_ids = edu_content.pop("prev_id", [])
- prev_ids = [str(p) for p in prev_ids] # They may come as ints
+ prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning(
"Got device list update edu for %r/%r from %r",
- user_id, device_id, origin,
+ user_id,
+ device_id,
+ origin,
)
return
@@ -497,13 +483,12 @@ class DeviceListEduUpdater(object):
# probably won't get any further updates.
logger.warning(
"Got device list update edu for %r/%r, but don't share a room",
- user_id, device_id,
+ user_id,
+ device_id,
)
return
- logger.debug(
- "Received device list update for %r/%r", user_id, device_id,
- )
+ logger.debug("Received device list update for %r/%r", user_id, device_id)
self._pending_updates.setdefault(user_id, []).append(
(device_id, stream_id, prev_ids, edu_content)
@@ -525,7 +510,10 @@ class DeviceListEduUpdater(object):
for device_id, stream_id, prev_ids, content in pending_updates:
logger.debug(
"Handling update %r/%r, ID: %r, prev: %r ",
- user_id, device_id, stream_id, prev_ids,
+ user_id,
+ device_id,
+ stream_id,
+ prev_ids,
)
# Given a list of updates we check if we need to resync. This
@@ -540,13 +528,13 @@ class DeviceListEduUpdater(object):
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (
- NotRetryingDestination, RequestSendFailed, HttpResponseException,
+ NotRetryingDestination,
+ RequestSendFailed,
+ HttpResponseException,
):
# TODO: Remember that we are now out of sync and try again
# later
- logger.warn(
- "Failed to handle device list update for %s", user_id,
- )
+ logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
@@ -582,18 +570,21 @@ class DeviceListEduUpdater(object):
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
- user_id, len(devices)
+ user_id,
+ len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
- user_id, device["device_id"], stream_id,
+ user_id,
+ device["device_id"],
+ stream_id,
)
yield self.store.update_remote_device_list_cache(
- user_id, devices, stream_id,
+ user_id, devices, stream_id
)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
@@ -606,7 +597,7 @@ class DeviceListEduUpdater(object):
# change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
- user_id, device_id, content, stream_id,
+ user_id, device_id, content, stream_id
)
yield self.device_handler.notify_device_update(
@@ -624,14 +615,9 @@ class DeviceListEduUpdater(object):
"""
seen_updates = self._seen_updates.get(user_id, set())
- extremity = yield self.store.get_device_list_last_stream_id_for_remote(
- user_id
- )
+ extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
- logger.debug(
- "Current extremity for %r: %r",
- user_id, extremity,
- )
+ logger.debug("Current extremity for %r: %r", user_id, extremity)
stream_id_in_updates = set() # stream_ids in updates list
for _, stream_id, prev_ids, _ in updates:
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 2e2e5261de..e1ebb6346c 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -25,7 +25,6 @@ logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
-
def __init__(self, hs):
"""
Args:
@@ -47,15 +46,15 @@ class DeviceMessageHandler(object):
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
- origin, sender_user_id
+ origin,
+ sender_user_id,
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s",
- user_id)
+ logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
messages_by_device = {
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a12f9508d8..42d5b3db30 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
class DirectoryHandler(BaseHandler):
-
def __init__(self, hs):
super(DirectoryHandler, self).__init__(hs)
@@ -77,15 +76,19 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(400, "Failed to get server list")
yield self.store.create_room_alias_association(
- room_alias,
- room_id,
- servers,
- creator=creator,
+ room_alias, room_id, servers, creator=creator
)
@defer.inlineCallbacks
- def create_association(self, requester, room_alias, room_id, servers=None,
- send_event=True, check_membership=True):
+ def create_association(
+ self,
+ requester,
+ room_alias,
+ room_id,
+ servers=None,
+ send_event=True,
+ check_membership=True,
+ ):
"""Attempt to create a new alias
Args:
@@ -115,49 +118,40 @@ class DirectoryHandler(BaseHandler):
if service:
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
- 400, "This application service has not reserved"
- " this kind of alias.", errcode=Codes.EXCLUSIVE
+ 400,
+ "This application service has not reserved" " this kind of alias.",
+ errcode=Codes.EXCLUSIVE,
)
else:
if self.require_membership and check_membership:
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user:
raise AuthError(
- 403,
- "You must be in the room to create an alias for it",
+ 403, "You must be in the room to create an alias for it"
)
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
- raise AuthError(
- 403, "This user is not permitted to create this alias",
- )
+ raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
- user_id, room_id, room_alias.to_string(),
+ user_id, room_id, room_alias.to_string()
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
- raise SynapseError(
- 403, "Not allowed to create alias",
- )
+ raise SynapseError(403, "Not allowed to create alias")
- can_create = yield self.can_modify_alias(
- room_alias,
- user_id=user_id
- )
+ can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
- 400, "This alias is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This alias is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
yield self._create_association(room_alias, room_id, servers, creator=user_id)
if send_event:
- yield self.send_room_alias_update_event(
- requester,
- room_id
- )
+ yield self.send_room_alias_update_event(requester, room_id)
@defer.inlineCallbacks
def delete_association(self, requester, room_alias, send_event=True):
@@ -194,34 +188,24 @@ class DirectoryHandler(BaseHandler):
raise
if not can_delete:
- raise AuthError(
- 403, "You don't have permission to delete the alias.",
- )
+ raise AuthError(403, "You don't have permission to delete the alias.")
- can_delete = yield self.can_modify_alias(
- room_alias,
- user_id=user_id
- )
+ can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
- 400, "This alias is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This alias is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
room_id = yield self._delete_association(room_alias)
try:
if send_event:
- yield self.send_room_alias_update_event(
- requester,
- room_id
- )
+ yield self.send_room_alias_update_event(requester, room_id)
yield self._update_canonical_alias(
- requester,
- requester.user.to_string(),
- room_id,
- room_alias,
+ requester, requester.user.to_string(), room_id, room_alias
)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@@ -234,7 +218,7 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(
400,
"This application service has not reserved this kind of alias",
- errcode=Codes.EXCLUSIVE
+ errcode=Codes.EXCLUSIVE,
)
yield self._delete_association(room_alias)
@@ -251,9 +235,7 @@ class DirectoryHandler(BaseHandler):
def get_association(self, room_alias):
room_id = None
if self.hs.is_mine(room_alias):
- result = yield self.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
@@ -263,9 +245,7 @@ class DirectoryHandler(BaseHandler):
result = yield self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
- args={
- "room_alias": room_alias.to_string(),
- },
+ args={"room_alias": room_alias.to_string()},
retry_on_dns_fail=False,
ignore_backoff=True,
)
@@ -284,7 +264,7 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(
404,
"Room alias %s not found" % (room_alias.to_string(),),
- Codes.NOT_FOUND
+ Codes.NOT_FOUND,
)
users = yield self.state.get_current_users_in_room(room_id)
@@ -293,41 +273,28 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first.
if self.server_name in servers:
- servers = (
- [self.server_name] +
- [s for s in servers if s != self.server_name]
- )
+ servers = [self.server_name] + [s for s in servers if s != self.server_name]
else:
servers = list(servers)
- defer.returnValue({
- "room_id": room_id,
- "servers": servers,
- })
+ defer.returnValue({"room_id": room_id, "servers": servers})
return
@defer.inlineCallbacks
def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
- raise SynapseError(
- 400, "Room Alias is not hosted on this Home Server"
- )
+ raise SynapseError(400, "Room Alias is not hosted on this Home Server")
- result = yield self.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.get_association_from_room_alias(room_alias)
if result is not None:
- defer.returnValue({
- "room_id": result.room_id,
- "servers": result.servers,
- })
+ defer.returnValue({"room_id": result.room_id, "servers": result.servers})
else:
raise SynapseError(
404,
"Room alias %r not found" % (room_alias.to_string(),),
- Codes.NOT_FOUND
+ Codes.NOT_FOUND,
)
@defer.inlineCallbacks
@@ -343,7 +310,7 @@ class DirectoryHandler(BaseHandler):
"sender": requester.user.to_string(),
"content": {"aliases": aliases},
},
- ratelimit=False
+ ratelimit=False,
)
@defer.inlineCallbacks
@@ -365,14 +332,12 @@ class DirectoryHandler(BaseHandler):
"sender": user_id,
"content": {},
},
- ratelimit=False
+ ratelimit=False,
)
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
- result = yield self.store.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
as_handler = self.appservice_handler
@@ -421,8 +386,7 @@ class DirectoryHandler(BaseHandler):
if not self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
- 403,
- "This user is not permitted to publish rooms to the room list"
+ 403, "This user is not permitted to publish rooms to the room list"
)
if requester.is_guest:
@@ -434,8 +398,7 @@ class DirectoryHandler(BaseHandler):
if visibility == "public" and not self.enable_room_list_search:
# The room list has been disabled.
raise AuthError(
- 403,
- "This user is not permitted to publish rooms to the room list"
+ 403, "This user is not permitted to publish rooms to the room list"
)
room = yield self.store.get_room(room_id)
@@ -452,20 +415,19 @@ class DirectoryHandler(BaseHandler):
room_aliases.append(canonical_alias)
if not self.config.is_publishing_room_allowed(
- user_id, room_id, room_aliases,
+ user_id, room_id, room_aliases
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
- raise SynapseError(
- 403, "Not allowed to publish room",
- )
+ raise SynapseError(403, "Not allowed to publish room")
yield self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
- def edit_published_appservice_room_list(self, appservice_id, network_id,
- room_id, visibility):
+ def edit_published_appservice_room_list(
+ self, appservice_id, network_id, room_id, visibility
+ ):
"""Add or remove a room from the appservice/network specific public
room list.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9dc46aa15f..807900fe52 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -99,9 +99,7 @@ class E2eKeysHandler(object):
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
- yield self.store.get_user_devices_from_cache(
- query_list
- )
+ yield self.store.get_user_devices_from_cache(query_list)
)
for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
@@ -126,9 +124,7 @@ class E2eKeysHandler(object):
destination_query = remote_queries_not_in_cache[destination]
try:
remote_result = yield self.federation.query_client_keys(
- destination,
- {"device_keys": destination_query},
- timeout=timeout
+ destination, {"device_keys": destination_query}, timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
@@ -138,14 +134,17 @@ class E2eKeysHandler(object):
except Exception as e:
failures[destination] = _exception_to_failure(e)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(do_remote_query, destination)
- for destination in remote_queries_not_in_cache
- ], consumeErrors=True))
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(do_remote_query, destination)
+ for destination in remote_queries_not_in_cache
+ ],
+ consumeErrors=True,
+ )
+ )
- defer.returnValue({
- "device_keys": results, "failures": failures,
- })
+ defer.returnValue({"device_keys": results, "failures": failures})
@defer.inlineCallbacks
def query_local_devices(self, query):
@@ -165,8 +164,7 @@ class E2eKeysHandler(object):
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s",
- user_id)
+ logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
if not device_ids:
@@ -231,9 +229,7 @@ class E2eKeysHandler(object):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
- destination,
- {"one_time_keys": device_keys},
- timeout=timeout
+ destination, {"one_time_keys": device_keys}, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
@@ -241,25 +237,29 @@ class E2eKeysHandler(object):
except Exception as e:
failures[destination] = _exception_to_failure(e)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(claim_client_keys, destination)
- for destination in remote_queries
- ], consumeErrors=True))
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(claim_client_keys, destination)
+ for destination in remote_queries
+ ],
+ consumeErrors=True,
+ )
+ )
logger.info(
"Claimed one-time-keys: %s",
- ",".join((
- "%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
- )),
+ ",".join(
+ (
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
+ )
+ ),
)
- defer.returnValue({
- "one_time_keys": json_result,
- "failures": failures
- })
+ defer.returnValue({"one_time_keys": json_result, "failures": failures})
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
@@ -270,11 +270,13 @@ class E2eKeysHandler(object):
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
- device_id, user_id, time_now
+ device_id,
+ user_id,
+ time_now,
)
# TODO: Sign the JSON with the server key
changed = yield self.store.set_e2e_device_keys(
- user_id, device_id, time_now, device_keys,
+ user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
@@ -283,7 +285,7 @@ class E2eKeysHandler(object):
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
yield self._upload_one_time_keys_for_user(
- user_id, device_id, time_now, one_time_keys,
+ user_id, device_id, time_now, one_time_keys
)
# the device should have been registered already, but it may have been
@@ -298,20 +300,22 @@ class E2eKeysHandler(object):
defer.returnValue({"one_time_key_counts": result})
@defer.inlineCallbacks
- def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
- one_time_keys):
+ def _upload_one_time_keys_for_user(
+ self, user_id, device_id, time_now, one_time_keys
+ ):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
- one_time_keys.keys(), device_id, user_id, time_now,
+ one_time_keys.keys(),
+ device_id,
+ user_id,
+ time_now,
)
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
- key_list.append((
- algorithm, key_id, key_obj
- ))
+ key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
existing_key_map = yield self.store.get_e2e_one_time_keys(
@@ -325,42 +329,35 @@ class E2eKeysHandler(object):
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
- ("One time key %s:%s already exists. "
- "Old key: %s; new key: %r") %
- (algorithm, key_id, ex_json, key)
+ (
+ "One time key %s:%s already exists. "
+ "Old key: %s; new key: %r"
+ )
+ % (algorithm, key_id, ex_json, key),
)
else:
- new_keys.append((
- algorithm, key_id, encode_canonical_json(key).decode('ascii')))
+ new_keys.append(
+ (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
+ )
- yield self.store.add_e2e_one_time_keys(
- user_id, device_id, time_now, new_keys
- )
+ yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
def _exception_to_failure(e):
if isinstance(e, CodeMessageException):
- return {
- "status": e.code, "message": str(e),
- }
+ return {"status": e.code, "message": str(e)}
if isinstance(e, NotRetryingDestination):
- return {
- "status": 503, "message": "Not ready for retry",
- }
+ return {"status": 503, "message": "Not ready for retry"}
if isinstance(e, FederationDeniedError):
- return {
- "status": 403, "message": "Federation Denied",
- }
+ return {"status": 403, "message": "Federation Denied"}
# include ConnectionRefused and other errors
#
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize.
- return {
- "status": 503, "message": str(e),
- }
+ return {"status": 503, "message": str(e)}
def _one_time_keys_match(old_key_json, new_key):
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 7bc174070e..ebd807bca6 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -152,14 +152,14 @@ class E2eRoomKeysHandler(object):
else:
raise
- if version_info['version'] != version:
+ if version_info["version"] != version:
# Check that the version we're trying to upload actually exists
try:
version_info = yield self.store.get_e2e_room_keys_version_info(
- user_id, version,
+ user_id, version
)
# if we get this far, the version must exist
- raise RoomKeysVersionError(current_version=version_info['version'])
+ raise RoomKeysVersionError(current_version=version_info["version"])
except StoreError as e:
if e.code == 404:
raise NotFoundError("Version '%s' not found" % (version,))
@@ -168,8 +168,8 @@ class E2eRoomKeysHandler(object):
# go through the room_keys.
# XXX: this should/could be done concurrently, given we're in a lock.
- for room_id, room in iteritems(room_keys['rooms']):
- for session_id, session in iteritems(room['sessions']):
+ for room_id, room in iteritems(room_keys["rooms"]):
+ for session_id, session in iteritems(room["sessions"]):
yield self._upload_room_key(
user_id, version, room_id, session_id, session
)
@@ -223,14 +223,14 @@ class E2eRoomKeysHandler(object):
# spelt out with if/elifs rather than nested boolean expressions
# purely for legibility.
- if room_key['is_verified'] and not current_room_key['is_verified']:
+ if room_key["is_verified"] and not current_room_key["is_verified"]:
return True
elif (
- room_key['first_message_index'] <
- current_room_key['first_message_index']
+ room_key["first_message_index"]
+ < current_room_key["first_message_index"]
):
return True
- elif room_key['forwarded_count'] < current_room_key['forwarded_count']:
+ elif room_key["forwarded_count"] < current_room_key["forwarded_count"]:
return True
else:
return False
@@ -328,16 +328,10 @@ class E2eRoomKeysHandler(object):
A deferred of an empty dict.
"""
if "version" not in version_info:
- raise SynapseError(
- 400,
- "Missing version in body",
- Codes.MISSING_PARAM
- )
+ raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM)
if version_info["version"] != version:
raise SynapseError(
- 400,
- "Version in body does not match",
- Codes.INVALID_PARAM
+ 400, "Version in body does not match", Codes.INVALID_PARAM
)
with (yield self._upload_linearizer.queue(user_id)):
try:
@@ -350,12 +344,10 @@ class E2eRoomKeysHandler(object):
else:
raise
if old_info["algorithm"] != version_info["algorithm"]:
- raise SynapseError(
- 400,
- "Algorithm does not match",
- Codes.INVALID_PARAM
- )
+ raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
- yield self.store.update_e2e_room_keys_version(user_id, version, version_info)
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, version_info
+ )
defer.returnValue({})
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index eb525070cf..5836d3c639 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -31,7 +31,6 @@ logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler):
-
def __init__(self, hs):
super(EventStreamHandler, self).__init__(hs)
@@ -53,9 +52,17 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def get_stream(self, auth_user_id, pagin_config, timeout=0,
- as_client_event=True, affect_presence=True,
- only_keys=None, room_id=None, is_guest=False):
+ def get_stream(
+ self,
+ auth_user_id,
+ pagin_config,
+ timeout=0,
+ as_client_event=True,
+ affect_presence=True,
+ only_keys=None,
+ room_id=None,
+ is_guest=False,
+ ):
"""Fetches the events stream for a given user.
If `only_keys` is not None, events from keys will be sent down.
@@ -73,7 +80,7 @@ class EventStreamHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
context = yield presence_handler.user_syncing(
- auth_user_id, affect_presence=affect_presence,
+ auth_user_id, affect_presence=affect_presence
)
with context:
if timeout:
@@ -85,9 +92,12 @@ class EventStreamHandler(BaseHandler):
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for(
- auth_user, pagin_config, timeout,
+ auth_user,
+ pagin_config,
+ timeout,
only_keys=only_keys,
- is_guest=is_guest, explicit_room_id=room_id
+ is_guest=is_guest,
+ explicit_room_id=room_id,
)
# When the user joins a new room, or another user joins a currently
@@ -102,17 +112,15 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
- users = yield self.state.get_current_users_in_room(event.room_id)
- states = yield presence_handler.get_states(
- users,
- as_event=True,
+ users = yield self.state.get_current_users_in_room(
+ event.room_id
)
+ states = yield presence_handler.get_states(users, as_event=True)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
- UserID.from_string(event.state_key),
- as_event=True,
+ UserID.from_string(event.state_key), as_event=True
)
to_add.append(ev)
@@ -121,7 +129,9 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
chunks = yield self._event_serializer.serialize_events(
- events, time_now, as_client_event=as_client_event,
+ events,
+ time_now,
+ as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
@@ -137,7 +147,6 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
-
@defer.inlineCallbacks
def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
@@ -164,16 +173,10 @@ class EventHandler(BaseHandler):
is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
- self.store,
- user.to_string(),
- [event],
- is_peeking=is_peeking
+ self.store, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
- raise AuthError(
- 403,
- "You don't have permission to access that event."
- )
+ raise AuthError(403, "You don't have permission to access that event.")
defer.returnValue(event)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ac5ca79143..02d397c498 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,6 +34,7 @@ from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.api.errors import (
AuthError,
CodeMessageException,
+ Codes,
FederationDeniedError,
FederationError,
RequestSendFailed,
@@ -80,7 +82,7 @@ def shortstr(iterable, maxitems=5):
items = list(itertools.islice(iterable, maxitems + 1))
if len(items) <= maxitems:
return str(items)
- return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]"
+ return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
class FederationHandler(BaseHandler):
@@ -113,24 +115,24 @@ class FederationHandler(BaseHandler):
self.config = hs.config
self.http_client = hs.get_simple_http_client()
- self._send_events_to_master = (
- ReplicationFederationSendEventsRestServlet.make_client(hs)
+ self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
+ hs
)
- self._notify_user_membership_change = (
- ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
+ self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
+ hs
)
- self._clean_room_for_join_client = (
- ReplicationCleanRoomRestServlet.make_client(hs)
+ self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
+ hs
)
# When joining a room we need to queue any events for that room up
self.room_queues = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
+ self.third_party_event_rules = hs.get_third_party_event_rules()
+
@defer.inlineCallbacks
- def on_receive_pdu(
- self, origin, pdu, sent_to_us_directly=False,
- ):
+ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -147,26 +149,19 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info(
- "[%s %s] handling received PDU: %s",
- room_id, event_id, pdu,
- )
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = yield self.store.get_event(
- event_id,
- allow_none=True,
- allow_rejected=True,
+ event_id, allow_none=True, allow_rejected=True
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
- already_seen = (
- existing and (
- not existing.internal_metadata.is_outlier()
- or pdu.internal_metadata.is_outlier()
- )
+ already_seen = existing and (
+ not existing.internal_metadata.is_outlier()
+ or pdu.internal_metadata.is_outlier()
)
if already_seen:
logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
@@ -178,20 +173,19 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id)
- raise FederationError(
- "ERROR",
- err.code,
- err.msg,
- affected=pdu.event_id,
+ logger.warn(
+ "[%s %s] Received event failed sanity checks", room_id, event_id
)
+ raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
"[%s %s] Queuing PDU from %s for now: join in progress",
- room_id, event_id, origin,
+ room_id,
+ event_id,
+ origin,
)
self.room_queues[room_id].append((pdu, origin))
return
@@ -202,14 +196,13 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(
- room_id,
- self.server_name
- )
+ is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
- room_id, event_id, origin,
+ room_id,
+ event_id,
+ origin,
)
defer.returnValue(None)
@@ -219,14 +212,9 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(
- pdu.room_id
- )
+ min_depth = yield self.get_min_depth_for_context(pdu.room_id)
- logger.debug(
- "[%s %s] min_depth: %d",
- room_id, event_id, min_depth,
- )
+ logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs)
@@ -244,12 +232,17 @@ class FederationHandler(BaseHandler):
# at a time.
logger.info(
"[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
- room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
- room_id, event_id, len(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
)
yield self._get_missing_events_for_pdu(
@@ -263,12 +256,16 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
"[%s %s] Found all missing prev_events",
- room_id, event_id,
+ room_id,
+ event_id,
)
elif missing_prevs:
logger.info(
"[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
)
if prevs - seen:
@@ -299,7 +296,10 @@ class FederationHandler(BaseHandler):
if sent_to_us_directly:
logger.warn(
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
- room_id, event_id, len(prevs - seen), shortstr(prevs - seen)
+ room_id,
+ event_id,
+ len(prevs - seen),
+ shortstr(prevs - seen),
)
raise FederationError(
"ERROR",
@@ -314,9 +314,7 @@ class FederationHandler(BaseHandler):
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
- event_map = {
- event_id: pdu,
- }
+ event_map = {event_id: pdu}
try:
# Get the state of the events we know about
ours = yield self.store.get_state_groups_ids(room_id, seen)
@@ -333,7 +331,9 @@ class FederationHandler(BaseHandler):
for p in prevs - seen:
logger.info(
"[%s %s] Requesting state at missing prev_event %s",
- room_id, event_id, p,
+ room_id,
+ event_id,
+ p,
)
room_version = yield self.store.get_room_version(room_id)
@@ -344,19 +344,19 @@ class FederationHandler(BaseHandler):
# by the get_pdu_cache in federation_client.
remote_state, got_auth_chain = (
yield self.federation_client.get_state_for_room(
- origin, room_id, p,
+ origin, room_id, p
)
)
# we want the state *after* p; get_state_for_room returns the
# state *before* p.
remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True,
+ [origin], p, room_version, outlier=True
)
if remote_event is None:
raise Exception(
- "Unable to get missing prev_event %s" % (p, )
+ "Unable to get missing prev_event %s" % (p,)
)
if remote_event.is_state():
@@ -376,7 +376,9 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x
state_map = yield resolve_events_with_store(
- room_version, state_maps, event_map,
+ room_version,
+ state_maps,
+ event_map,
state_res_store=StateResolutionStore(self.store),
)
@@ -392,15 +394,15 @@ class FederationHandler(BaseHandler):
)
event_map.update(evs)
- state = [
- event_map[e] for e in six.itervalues(state_map)
- ]
+ state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception:
logger.warn(
"[%s %s] Error attempting to resolve state at missing "
"prev_events",
- room_id, event_id, exc_info=True,
+ room_id,
+ event_id,
+ exc_info=True,
)
raise FederationError(
"ERROR",
@@ -410,10 +412,7 @@ class FederationHandler(BaseHandler):
)
yield self._process_received_pdu(
- origin,
- pdu,
- state=state,
- auth_chain=auth_chain,
+ origin, pdu, state=state, auth_chain=auth_chain
)
@defer.inlineCallbacks
@@ -443,7 +442,10 @@ class FederationHandler(BaseHandler):
logger.info(
"[%s %s]: Requesting missing events between %s and %s",
- room_id, event_id, shortstr(latest), event_id,
+ room_id,
+ event_id,
+ shortstr(latest),
+ event_id,
)
# XXX: we set timeout to 10s to help workaround
@@ -494,19 +496,29 @@ class FederationHandler(BaseHandler):
#
# All that said: Let's try increasing the timout to 60s and see what happens.
- missing_events = yield self.federation_client.get_missing_events(
- origin,
- room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- timeout=60000,
- )
+ try:
+ missing_events = yield self.federation_client.get_missing_events(
+ origin,
+ room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ timeout=60000,
+ )
+ except RequestSendFailed as e:
+ # We failed to get the missing events, but since we need to handle
+ # the case of `get_missing_events` not returning the necessary
+ # events anyway, it is safe to simply log the error and continue.
+ logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e)
+ return
logger.info(
"[%s %s]: Got %d prev_events: %s",
- room_id, event_id, len(missing_events), shortstr(missing_events),
+ room_id,
+ event_id,
+ len(missing_events),
+ shortstr(missing_events),
)
# We want to sort these by depth so we process them and
@@ -516,20 +528,20 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
"[%s %s] Handling received prev_event %s",
- room_id, event_id, ev.event_id,
+ room_id,
+ event_id,
+ ev.event_id,
)
with logcontext.nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(
- origin,
- ev,
- sent_to_us_directly=False,
- )
+ yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warn(
"[%s %s] Received prev_event %s failed history check.",
- room_id, event_id, ev.event_id,
+ room_id,
+ event_id,
+ ev.event_id,
)
else:
raise
@@ -542,10 +554,7 @@ class FederationHandler(BaseHandler):
room_id = event.room_id
event_id = event.event_id
- logger.debug(
- "[%s %s] Processing event: %s",
- room_id, event_id, event,
- )
+ logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
event_ids = set()
if state:
@@ -567,43 +576,32 @@ class FederationHandler(BaseHandler):
e.internal_metadata.outlier = True
auth_ids = e.auth_event_ids()
auth = {
- (e.type, e.state_key): e for e in auth_chain
+ (e.type, e.state_key): e
+ for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
- event_infos.append({
- "event": e,
- "auth_events": auth,
- })
+ event_infos.append({"event": e, "auth_events": auth})
seen_ids.add(e.event_id)
logger.info(
"[%s %s] persisting newly-received auth/state events %s",
- room_id, event_id, [e["event"].event_id for e in event_infos]
+ room_id,
+ event_id,
+ [e["event"].event_id for e in event_infos],
)
yield self._handle_new_events(origin, event_infos)
try:
- context = yield self._handle_new_event(
- origin,
- event,
- state=state,
- )
+ context = yield self._handle_new_event(origin, event, state=state)
except AuthError as e:
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
+ raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
room = yield self.store.get_room(room_id)
if not room:
try:
yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False,
+ room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
logger.exception("Failed to store room.")
@@ -617,12 +615,10 @@ class FederationHandler(BaseHandler):
prev_state_ids = yield context.get_prev_state_ids(self.store)
- prev_state_id = prev_state_ids.get(
- (event.type, event.state_key)
- )
+ prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
prev_state = yield self.store.get_event(
- prev_state_id, allow_none=True,
+ prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
newly_joined = False
@@ -653,10 +649,7 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(room_id)
events = yield self.federation_client.backfill(
- dest,
- room_id,
- limit=limit,
- extremities=extremities,
+ dest, room_id, limit=limit, extremities=extremities
)
# ideally we'd sanity check the events here for excess prev_events etc,
@@ -683,16 +676,9 @@ class FederationHandler(BaseHandler):
event_ids = set(e.event_id for e in events)
- edges = [
- ev.event_id
- for ev in events
- if set(ev.prev_event_ids()) - event_ids
- ]
+ edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
- logger.info(
- "backfill: Got %d events with %d edges",
- len(events), len(edges),
- )
+ logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
# For each edge get the current state.
@@ -701,9 +687,7 @@ class FederationHandler(BaseHandler):
events_to_state = {}
for e_id in edges:
state, auth = yield self.federation_client.get_state_for_room(
- destination=dest,
- room_id=room_id,
- event_id=e_id
+ destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
@@ -712,12 +696,14 @@ class FederationHandler(BaseHandler):
required_auth = set(
a_id
- for event in events + list(state_events.values()) + list(auth_events.values())
+ for event in events
+ + list(state_events.values())
+ + list(auth_events.values())
for a_id in event.auth_event_ids()
)
- auth_events.update({
- e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
- })
+ auth_events.update(
+ {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
+ )
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
@@ -736,27 +722,30 @@ class FederationHandler(BaseHandler):
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
- missing_auth - failed_to_fetch
+ missing_auth - failed_to_fetch,
)
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(
- self.federation_client.get_pdu,
- [dest],
- event_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
- )
- for event_id in missing_auth - failed_to_fetch
- ],
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ results = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ logcontext.run_in_background(
+ self.federation_client.get_pdu,
+ [dest],
+ event_id,
+ room_version=room_version,
+ outlier=True,
+ timeout=10000,
+ )
+ for event_id in missing_auth - failed_to_fetch
+ ],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
- for event in results if event
+ for event in results
+ if event
for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
@@ -788,15 +777,19 @@ class FederationHandler(BaseHandler):
continue
a.internal_metadata.outlier = True
- ev_infos.append({
- "event": a,
- "auth_events": {
- (auth_events[a_id].type, auth_events[a_id].state_key):
- auth_events[a_id]
- for a_id in a.auth_event_ids()
- if a_id in auth_events
+ ev_infos.append(
+ {
+ "event": a,
+ "auth_events": {
+ (
+ auth_events[a_id].type,
+ auth_events[a_id].state_key,
+ ): auth_events[a_id]
+ for a_id in a.auth_event_ids()
+ if a_id in auth_events
+ },
}
- })
+ )
# Step 1b: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities) as non-outliers.
@@ -804,23 +797,24 @@ class FederationHandler(BaseHandler):
# For paranoia we ensure that these events are marked as
# non-outliers
ev = event_map[e_id]
- assert(not ev.internal_metadata.is_outlier())
-
- ev_infos.append({
- "event": ev,
- "state": events_to_state[e_id],
- "auth_events": {
- (auth_events[a_id].type, auth_events[a_id].state_key):
- auth_events[a_id]
- for a_id in ev.auth_event_ids()
- if a_id in auth_events
+ assert not ev.internal_metadata.is_outlier()
+
+ ev_infos.append(
+ {
+ "event": ev,
+ "state": events_to_state[e_id],
+ "auth_events": {
+ (
+ auth_events[a_id].type,
+ auth_events[a_id].state_key,
+ ): auth_events[a_id]
+ for a_id in ev.auth_event_ids()
+ if a_id in auth_events
+ },
}
- })
+ )
- yield self._handle_new_events(
- dest, ev_infos,
- backfilled=True,
- )
+ yield self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -831,14 +825,12 @@ class FederationHandler(BaseHandler):
# For paranoia we ensure that these events are marked as
# non-outliers
- assert(not event.internal_metadata.is_outlier())
+ assert not event.internal_metadata.is_outlier()
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(
- dest, event, backfilled=True,
- )
+ yield self._handle_new_event(dest, event, backfilled=True)
defer.returnValue(events)
@@ -847,9 +839,7 @@ class FederationHandler(BaseHandler):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(
- room_id
- )
+ extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -881,31 +871,27 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(
- list(extremities),
- )
+ forward_events = yield self.store.get_successor_events(list(extremities))
extremities_events = yield self.store.get_events(
- forward_events,
- check_redacted=False,
- get_prev_content=False,
+ forward_events, check_redacted=False, get_prev_content=False
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = yield filter_events_for_server(
- self.store, self.server_name, list(extremities_events.values()),
- redact=False, check_history_visibility_only=True,
+ self.store,
+ self.server_name,
+ list(extremities_events.values()),
+ redact=False,
+ check_history_visibility_only=True,
)
if not filtered_extremities:
defer.returnValue(False)
# Check if we reached a point where we should start backfilling.
- sorted_extremeties_tuple = sorted(
- extremities.items(),
- key=lambda e: -int(e[1])
- )
+ sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
@@ -914,8 +900,7 @@ class FederationHandler(BaseHandler):
if current_depth > max_depth:
logger.debug(
- "Not backfilling as we don't need to. %d < %d",
- max_depth, current_depth,
+ "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
)
return
@@ -940,8 +925,7 @@ class FederationHandler(BaseHandler):
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in iteritems(state)
- if e_type == EventTypes.Member
- and event.membership == Membership.JOIN
+ if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
joined_domains = {}
@@ -961,8 +945,7 @@ class FederationHandler(BaseHandler):
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
- domain for domain, depth in curr_domains
- if domain != self.server_name
+ domain for domain, depth in curr_domains if domain != self.server_name
]
@defer.inlineCallbacks
@@ -971,28 +954,20 @@ class FederationHandler(BaseHandler):
for dom in domains:
try:
yield self.backfill(
- dom, room_id,
- limit=100,
- extremities=extremities,
+ dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
defer.returnValue(True)
except SynapseError as e:
- logger.info(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.info("Failed to backfill from %s because %s", dom, e)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
- logger.info(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.info("Failed to backfill from %s because %s", dom, e)
continue
except NotRetryingDestination as e:
logger.info(str(e))
@@ -1001,10 +976,7 @@ class FederationHandler(BaseHandler):
logger.info(e)
continue
except Exception as e:
- logger.exception(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.exception("Failed to backfill from %s because %s", dom, e)
continue
defer.returnValue(False)
@@ -1025,10 +997,11 @@ class FederationHandler(BaseHandler):
resolve = logcontext.preserve_fn(
self.state_handler.resolve_state_groups_for_events
)
- states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [resolve(room_id, [e]) for e in event_ids],
- consumeErrors=True,
- ))
+ states = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
+ )
+ )
# dict[str, dict[tuple, str]], a map from event_id to state map of
# event_ids.
@@ -1036,23 +1009,23 @@ class FederationHandler(BaseHandler):
state_map = yield self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
- get_prev_content=False
+ get_prev_content=False,
)
states = {
key: {
k: state_map[e_id]
for k, e_id in iteritems(state_dict)
if e_id in state_map
- } for key, state_dict in iteritems(states)
+ }
+ for key, state_dict in iteritems(states)
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill([
- dom for dom, _ in likely_domains
- if dom not in tried_domains
- ])
+ success = yield try_backfill(
+ [dom for dom, _ in likely_domains if dom not in tried_domains]
+ )
if success:
defer.returnValue(True)
@@ -1077,20 +1050,20 @@ class FederationHandler(BaseHandler):
SynapseError if the event does not pass muster
"""
if len(ev.prev_event_ids()) > 20:
- logger.warn("Rejecting event %s which has %i prev_events",
- ev.event_id, len(ev.prev_event_ids()))
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Too many prev_events",
+ logger.warn(
+ "Rejecting event %s which has %i prev_events",
+ ev.event_id,
+ len(ev.prev_event_ids()),
)
+ raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
if len(ev.auth_event_ids()) > 10:
- logger.warn("Rejecting event %s which has %i auth_events",
- ev.event_id, len(ev.auth_event_ids()))
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Too many auth_events",
+ logger.warn(
+ "Rejecting event %s which has %i auth_events",
+ ev.event_id,
+ len(ev.auth_event_ids()),
)
+ raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events")
@defer.inlineCallbacks
def send_invite(self, target_host, event):
@@ -1102,7 +1075,7 @@ class FederationHandler(BaseHandler):
destination=target_host,
room_id=event.room_id,
event_id=event.event_id,
- pdu=event
+ pdu=event,
)
defer.returnValue(pdu)
@@ -1111,8 +1084,7 @@ class FederationHandler(BaseHandler):
def on_event_auth(self, event_id):
event = yield self.store.get_event(event_id)
auth = yield self.store.get_auth_chain(
- [auth_id for auth_id in event.auth_event_ids()],
- include_given=True
+ [auth_id for auth_id in event.auth_event_ids()], include_given=True
)
defer.returnValue([e for e in auth])
@@ -1138,15 +1110,13 @@ class FederationHandler(BaseHandler):
joinee,
"join",
content,
- params={
- "ver": KNOWN_ROOM_VERSIONS,
- },
+ params={"ver": KNOWN_ROOM_VERSIONS},
)
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
- assert (room_id not in self.room_queues)
+ assert room_id not in self.room_queues
self.room_queues[room_id] = []
@@ -1163,7 +1133,7 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
ret = yield self.federation_client.send_join(
- target_hosts, event, event_format_version,
+ target_hosts, event, event_format_version
)
origin = ret["origin"]
@@ -1182,17 +1152,13 @@ class FederationHandler(BaseHandler):
try:
yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False
+ room_id=room_id, room_creator_user_id="", is_public=False
)
except Exception:
# FIXME
pass
- yield self._persist_auth_tree(
- origin, auth_chain, state, event
- )
+ yield self._persist_auth_tree(origin, auth_chain, state, event)
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
@@ -1219,14 +1185,18 @@ class FederationHandler(BaseHandler):
"""
for p, origin in room_queue:
try:
- logger.info("Processing queued PDU %s which was received "
- "while we were joining %s", p.event_id, p.room_id)
+ logger.info(
+ "Processing queued PDU %s which was received "
+ "while we were joining %s",
+ p.event_id,
+ p.room_id,
+ )
with logcontext.nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warn(
- "Error handling queued PDU %s from %s: %s",
- p.event_id, origin, e)
+ "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
+ )
@defer.inlineCallbacks
@log_function
@@ -1247,21 +1217,30 @@ class FederationHandler(BaseHandler):
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
- }
+ },
)
try:
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
except AuthError as e:
logger.warn("Failed to create join %r because %s", event, e)
raise e
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Creation of join %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
yield self.auth.check_from_context(
- room_version, event, context, do_sig_check=False,
+ room_version, event, context, do_sig_check=False
)
defer.returnValue(event)
@@ -1296,9 +1275,16 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin
- context = yield self._handle_new_event(
- origin, event
+ context = yield self._handle_new_event(origin, event)
+
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
)
+ if not event_allowed:
+ logger.info("Sending of join %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -1318,10 +1304,7 @@ class FederationHandler(BaseHandler):
state = yield self.store.get_events(list(prev_state_ids.values()))
- defer.returnValue({
- "state": list(state.values()),
- "auth_chain": auth_chain,
- })
+ defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain})
@defer.inlineCallbacks
def on_invite_request(self, origin, pdu):
@@ -1342,7 +1325,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(403, "This server does not accept room invites")
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id,
+ event.sender, event.state_key, event.room_id
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1354,26 +1337,23 @@ class FederationHandler(BaseHandler):
sender_domain = get_domain_from_id(event.sender)
if sender_domain != origin:
- raise SynapseError(400, "The invite event was not from the server sending it")
+ raise SynapseError(
+ 400, "The invite event was not from the server sending it"
+ )
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
# block any attempts to invite the server notices mxid
if event.state_key == self._server_notices_mxid:
- raise SynapseError(
- http_client.FORBIDDEN,
- "Cannot invite this user",
- )
+ raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
event.signatures.update(
compute_event_signature(
- event.get_pdu_json(),
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0]
)
)
@@ -1385,10 +1365,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts,
- room_id,
- user_id,
- "leave"
+ target_hosts, room_id, user_id, "leave"
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -1403,10 +1380,7 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
- yield self.federation_client.send_leave(
- target_hosts,
- event
- )
+ yield self.federation_client.send_leave(target_hosts, event)
context = yield self.state_handler.compute_event_context(event)
yield self.persist_events_and_notify([(event, context)])
@@ -1414,25 +1388,21 @@ class FederationHandler(BaseHandler):
defer.returnValue(event)
@defer.inlineCallbacks
- def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
- content={}, params=None):
+ def _make_and_verify_event(
+ self, target_hosts, room_id, user_id, membership, content={}, params=None
+ ):
origin, event, format_ver = yield self.federation_client.make_membership_event(
- target_hosts,
- room_id,
- user_id,
- membership,
- content,
- params=params,
+ target_hosts, room_id, user_id, membership, content, params=params
)
logger.debug("Got response to make_%s: %s", membership, event)
# We should assert some things.
# FIXME: Do this in a nicer way
- assert(event.type == EventTypes.Member)
- assert(event.user_id == user_id)
- assert(event.state_key == user_id)
- assert(event.room_id == room_id)
+ assert event.type == EventTypes.Member
+ assert event.user_id == user_id
+ assert event.state_key == user_id
+ assert event.room_id == room_id
defer.returnValue((origin, event, format_ver))
@defer.inlineCallbacks
@@ -1451,18 +1421,27 @@ class FederationHandler(BaseHandler):
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.warning("Creation of leave %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
yield self.auth.check_from_context(
- room_version, event, context, do_sig_check=False,
+ room_version, event, context, do_sig_check=False
)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
@@ -1484,9 +1463,16 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
- yield self._handle_new_event(
- origin, event
+ context = yield self._handle_new_event(origin, event)
+
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
)
+ if not event_allowed:
+ logger.info("Sending of leave %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
@@ -1502,18 +1488,14 @@ class FederationHandler(BaseHandler):
"""
event = yield self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id,
+ event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups(
- room_id, [event_id]
- )
+ state_groups = yield self.store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
- results = {
- (e.type, e.state_key): e for e in state
- }
+ results = {(e.type, e.state_key): e for e in state}
if event.is_state():
# Get previous state
@@ -1535,12 +1517,10 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
event = yield self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id,
+ event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups_ids(
- room_id, [event_id]
- )
+ state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1566,11 +1546,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield self.store.get_backfill_events(
- room_id,
- pdu_list,
- limit
- )
+ events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield filter_events_for_server(self.store, origin, events)
@@ -1594,22 +1570,15 @@ class FederationHandler(BaseHandler):
AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
- event_id,
- allow_none=True,
- allow_rejected=True,
+ event_id, allow_none=True, allow_rejected=True
)
if event:
- in_room = yield self.auth.check_host_in_room(
- event.room_id,
- origin
- )
+ in_room = yield self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield filter_events_for_server(
- self.store, origin, [event],
- )
+ events = yield filter_events_for_server(self.store, origin, [event])
event = events[0]
defer.returnValue(event)
else:
@@ -1619,13 +1588,11 @@ class FederationHandler(BaseHandler):
return self.store.get_min_depth(context)
@defer.inlineCallbacks
- def _handle_new_event(self, origin, event, state=None, auth_events=None,
- backfilled=False):
+ def _handle_new_event(
+ self, origin, event, state=None, auth_events=None, backfilled=False
+ ):
context = yield self._prep_event(
- origin, event,
- state=state,
- auth_events=auth_events,
- backfilled=backfilled,
+ origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
@@ -1638,15 +1605,13 @@ class FederationHandler(BaseHandler):
)
yield self.persist_events_and_notify(
- [(event, context)],
- backfilled=backfilled,
+ [(event, context)], backfilled=backfilled
)
success = True
finally:
if not success:
logcontext.run_in_background(
- self.store.remove_push_actions_from_staging,
- event.event_id,
+ self.store.remove_push_actions_from_staging, event.event_id
)
defer.returnValue(context)
@@ -1674,12 +1639,15 @@ class FederationHandler(BaseHandler):
)
defer.returnValue(res)
- contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(prep, ev_info)
- for ev_info in event_infos
- ], consumeErrors=True,
- ))
+ contexts = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ logcontext.run_in_background(prep, ev_info)
+ for ev_info in event_infos
+ ],
+ consumeErrors=True,
+ )
+ )
yield self.persist_events_and_notify(
[
@@ -1714,8 +1682,7 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id] = ctx
event_map = {
- e.event_id: e
- for e in itertools.chain(auth_events, state, [event])
+ e.event_id: e for e in itertools.chain(auth_events, state, [event])
}
create_event = None
@@ -1730,7 +1697,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(400, "No create event in state")
room_version = create_event.content.get(
- "room_version", RoomVersions.V1.identifier,
+ "room_version", RoomVersions.V1.identifier
)
missing_auth_events = set()
@@ -1741,11 +1708,7 @@ class FederationHandler(BaseHandler):
for e_id in missing_auth_events:
m_ev = yield self.federation_client.get_pdu(
- [origin],
- e_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
+ [origin], e_id, room_version=room_version, outlier=True, timeout=10000
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
@@ -1770,10 +1733,7 @@ class FederationHandler(BaseHandler):
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
- logger.warn(
- "Rejecting %s because %s",
- e.event_id, err.msg
- )
+ logger.warn("Rejecting %s because %s", e.event_id, err.msg)
if e == event:
raise
@@ -1783,16 +1743,14 @@ class FederationHandler(BaseHandler):
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
- ],
+ ]
)
new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state
)
- yield self.persist_events_and_notify(
- [(event, new_event_context)],
- )
+ yield self.persist_events_and_notify([(event, new_event_context)])
@defer.inlineCallbacks
def _prep_event(self, origin, event, state, auth_events, backfilled):
@@ -1808,40 +1766,30 @@ class FederationHandler(BaseHandler):
Returns:
Deferred, which resolves to synapse.events.snapshot.EventContext
"""
- context = yield self.state_handler.compute_event_context(
- event, old_state=state,
- )
+ context = yield self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in auth_events.values()
- }
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = yield self.store.get_event(
- event.prev_event_ids()[0],
- allow_none=True,
+ event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try:
- yield self.do_auth(
- origin, event, context, auth_events=auth_events
- )
+ yield self.do_auth(origin, event, context, auth_events=auth_events)
except AuthError as e:
- logger.warn(
- "[%s %s] Rejecting: %s",
- event.room_id, event.event_id, e.msg
- )
+ logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg)
context.rejected = RejectedReason.AUTH_ERROR
@@ -1872,9 +1820,7 @@ class FederationHandler(BaseHandler):
# "soft-fail" the event.
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
if do_soft_fail_check:
- extrem_ids = yield self.store.get_latest_event_ids_in_room(
- event.room_id,
- )
+ extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
prev_event_ids = set(event.prev_event_ids())
@@ -1902,31 +1848,31 @@ class FederationHandler(BaseHandler):
# like bans, especially with state res v2.
state_sets = yield self.store.get_state_groups(
- event.room_id, extrem_ids,
+ event.room_id, extrem_ids
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = yield self.state_handler.resolve_events(
- room_version, state_sets, event,
+ room_version, state_sets, event
)
current_state_ids = {
k: e.event_id for k, e in iteritems(current_state_ids)
}
else:
current_state_ids = yield self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids,
+ event.room_id, latest_event_ids=extrem_ids
)
logger.debug(
"Doing soft-fail check for %s: state %s",
- event.event_id, current_state_ids,
+ event.event_id,
+ current_state_ids,
)
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
- e for k, e in iteritems(current_state_ids)
- if k in auth_types
+ e for k, e in iteritems(current_state_ids) if k in auth_types
]
current_auth_events = yield self.store.get_events(current_state_ids)
@@ -1937,19 +1883,14 @@ class FederationHandler(BaseHandler):
try:
self.auth.check(room_version, event, auth_events=current_auth_events)
except AuthError as e:
- logger.warn(
- "Soft-failing %r because %s",
- event, e,
- )
+ logger.warn("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
@defer.inlineCallbacks
- def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
- missing):
- in_room = yield self.auth.check_host_in_room(
- room_id,
- origin
- )
+ def on_query_auth(
+ self, origin, event_id, room_id, remote_auth_chain, rejects, missing
+ ):
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -1967,28 +1908,23 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain(
- [auth_id for auth_id in event.auth_event_ids()],
- include_given=True
+ [auth_id for auth_id in event.auth_event_ids()], include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
- ret = yield self.construct_auth_difference(
- local_auth_chain, remote_auth_chain
- )
+ ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain)
logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret)
@defer.inlineCallbacks
- def on_get_missing_events(self, origin, room_id, earliest_events,
- latest_events, limit):
- in_room = yield self.auth.check_host_in_room(
- room_id,
- origin
- )
+ def on_get_missing_events(
+ self, origin, room_id, earliest_events, latest_events, limit
+ ):
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2002,7 +1938,7 @@ class FederationHandler(BaseHandler):
)
missing_events = yield filter_events_for_server(
- self.store, origin, missing_events,
+ self.store, origin, missing_events
)
defer.returnValue(missing_events)
@@ -2090,25 +2026,17 @@ class FederationHandler(BaseHandler):
if missing_auth:
# TODO: can we use store.have_seen_events here instead?
- have_events = yield self.store.get_seen_events_with_rejections(
- missing_auth
- )
+ have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
logger.debug("Got events %s from store", have_events)
missing_auth.difference_update(have_events.keys())
else:
have_events = {}
- have_events.update({
- e.event_id: ""
- for e in auth_events.values()
- })
+ have_events.update({e.event_id: "" for e in auth_events.values()})
if missing_auth:
# If we don't have all the auth events, we need to get them.
- logger.info(
- "auth_events contains unknown events: %s",
- missing_auth,
- )
+ logger.info("auth_events contains unknown events: %s", missing_auth)
try:
try:
remote_auth_chain = yield self.federation_client.get_event_auth(
@@ -2134,18 +2062,16 @@ class FederationHandler(BaseHandler):
try:
auth_ids = e.auth_event_ids()
auth = {
- (e.type, e.state_key): e for e in remote_auth_chain
+ (e.type, e.state_key): e
+ for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
e.internal_metadata.outlier = True
logger.debug(
- "do_auth %s missing_auth: %s",
- event.event_id, e.event_id
- )
- yield self._handle_new_event(
- origin, e, auth_events=auth
+ "do_auth %s missing_auth: %s", event.event_id, e.event_id
)
+ yield self._handle_new_event(origin, e, auth_events=auth)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
@@ -2181,35 +2107,36 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(event.room_id)
different_events = yield logcontext.make_deferred_yieldable(
- defer.gatherResults([
- logcontext.run_in_background(
- self.store.get_event,
- d,
- allow_none=True,
- allow_rejected=False,
- )
- for d in different_auth
- if d in have_events and not have_events[d]
- ], consumeErrors=True)
+ defer.gatherResults(
+ [
+ logcontext.run_in_background(
+ self.store.get_event, d, allow_none=True, allow_rejected=False
+ )
+ for d in different_auth
+ if d in have_events and not have_events[d]
+ ],
+ consumeErrors=True,
+ )
).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
- remote_view.update({
- (d.type, d.state_key): d for d in different_events if d
- })
+ remote_view.update(
+ {(d.type, d.state_key): d for d in different_events if d}
+ )
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
- event
+ event,
)
logger.info(
"After state res: updating auth_events with new state %s",
{
- (d.type, d.state_key): d.event_id for d in new_state.values()
+ (d.type, d.state_key): d.event_id
+ for d in new_state.values()
if auth_events.get((d.type, d.state_key)) != d
},
)
@@ -2221,7 +2148,7 @@ class FederationHandler(BaseHandler):
)
yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
+ event, context, auth_events, event_key
)
if not different_auth:
@@ -2255,21 +2182,14 @@ class FederationHandler(BaseHandler):
prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
- auth_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids
- )
- local_auth_chain = yield self.store.get_auth_chain(
- auth_ids, include_given=True
- )
+ auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids)
+ local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True)
try:
# 2. Get remote difference.
try:
result = yield self.federation_client.query_auth(
- origin,
- event.room_id,
- event.event_id,
- local_auth_chain,
+ origin, event.room_id, event.event_id, local_auth_chain
)
except RequestSendFailed as e:
# The other side isn't around or doesn't implement the
@@ -2294,19 +2214,15 @@ class FederationHandler(BaseHandler):
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
- if e.event_id in auth_ids
- or event.type == EventTypes.Create
+ if e.event_id in auth_ids or event.type == EventTypes.Create
}
ev.internal_metadata.outlier = True
logger.debug(
- "do_auth %s different_auth: %s",
- event.event_id, e.event_id
+ "do_auth %s different_auth: %s", event.event_id, e.event_id
)
- yield self._handle_new_event(
- origin, ev, auth_events=auth
- )
+ yield self._handle_new_event(origin, ev, auth_events=auth)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
@@ -2321,12 +2237,11 @@ class FederationHandler(BaseHandler):
# TODO.
yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
+ event, context, auth_events, event_key
)
@defer.inlineCallbacks
- def _update_context_for_auth_events(self, event, context, auth_events,
- event_key):
+ def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
@@ -2343,8 +2258,7 @@ class FederationHandler(BaseHandler):
this will not be included in the current_state in the context.
"""
state_updates = {
- k: a.event_id for k, a in iteritems(auth_events)
- if k != event_key
+ k: a.event_id for k, a in iteritems(auth_events) if k != event_key
}
current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = dict(current_state_ids)
@@ -2354,9 +2268,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = dict(prev_state_ids)
- prev_state_ids.update({
- k: a.event_id for k, a in iteritems(auth_events)
- })
+ prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
# create a new state group as a delta from the existing one.
prev_group = context.state_group
@@ -2505,30 +2417,23 @@ class FederationHandler(BaseHandler):
logger.debug("construct_auth_difference returning")
- defer.returnValue({
- "auth_chain": local_auth,
- "rejects": {
- e.event_id: {
- "reason": reason_map[e.event_id],
- "proof": None,
- }
- for e in base_remote_rejected
- },
- "missing": [e.event_id for e in missing_locals],
- })
+ defer.returnValue(
+ {
+ "auth_chain": local_auth,
+ "rejects": {
+ e.event_id: {"reason": reason_map[e.event_id], "proof": None}
+ for e in base_remote_rejected
+ },
+ "missing": [e.event_id for e in missing_locals],
+ }
+ )
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(
- self,
- sender_user_id,
- target_user_id,
- room_id,
- signed,
+ self, sender_user_id, target_user_id, room_id, signed
):
- third_party_invite = {
- "signed": signed,
- }
+ third_party_invite = {"signed": signed}
event_dict = {
"type": EventTypes.Member,
@@ -2550,6 +2455,18 @@ class FederationHandler(BaseHandler):
builder=builder
)
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info(
+ "Creation of threepid invite %s forbidden by third-party rules",
+ event,
+ )
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
event, context = yield self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@@ -2572,9 +2489,7 @@ class FederationHandler(BaseHandler):
else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
yield self.federation_client.forward_third_party_invite(
- destinations,
- room_id,
- event_dict,
+ destinations, room_id, event_dict
)
@defer.inlineCallbacks
@@ -2595,9 +2510,20 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(room_version, event_dict)
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.warning(
+ "Exchange of threepid invite %s forbidden by third-party rules", event
+ )
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
event, context = yield self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@@ -2613,21 +2539,16 @@ class FederationHandler(BaseHandler):
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
- # XXX we send the invite here, but send_membership_event also sends it,
- # so we end up making two requests. I think this is redundant.
- returned_invite = yield self.send_invite(origin, event)
- # TODO: Make sure the signatures actually are correct.
- event.signatures.update(returned_invite.signatures)
-
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
- def add_display_name_to_third_party_invite(self, room_version, event_dict,
- event, context):
+ def add_display_name_to_third_party_invite(
+ self, room_version, event_dict, event, context
+ ):
key = (
EventTypes.ThirdPartyInvite,
- event.content["third_party_invite"]["signed"]["token"]
+ event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
prev_state_ids = yield context.get_prev_state_ids(self.store)
@@ -2641,8 +2562,7 @@ class FederationHandler(BaseHandler):
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info(
- "Could not find invite event for third_party_invite: %r",
- event_dict
+ "Could not find invite event for third_party_invite: %r", event_dict
)
# We don't discard here as this is not the appropriate place to do
# auth checks. If we need the invite and don't have it then the
@@ -2651,7 +2571,7 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
EventValidator().validate_new(event)
defer.returnValue((event, context))
@@ -2675,9 +2595,7 @@ class FederationHandler(BaseHandler):
token = signed["token"]
prev_state_ids = yield context.get_prev_state_ids(self.store)
- invite_event_id = prev_state_ids.get(
- (EventTypes.ThirdPartyInvite, token,)
- )
+ invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
if invite_event_id:
@@ -2686,25 +2604,59 @@ class FederationHandler(BaseHandler):
if not invite_event:
raise AuthError(403, "Could not find invite")
+ logger.debug("Checking auth on event %r", event.content)
+
last_exception = None
+ # for each public key in the 3pid invite event
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
+ # for each sig on the third_party_invite block of the actual invite
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
- public_key = public_key_object["public_key"]
- verify_key = decode_verify_key_bytes(
+ logger.debug(
+ "Attempting to verify sig with key %s from %r "
+ "against pubkey %r",
key_name,
- decode_base64(public_key)
+ server,
+ public_key_object,
)
- verify_signed_json(signed, server, verify_key)
- if "key_validity_url" in public_key_object:
- yield self._check_key_revocation(
- public_key,
- public_key_object["key_validity_url"]
+
+ try:
+ public_key = public_key_object["public_key"]
+ verify_key = decode_verify_key_bytes(
+ key_name, decode_base64(public_key)
+ )
+ verify_signed_json(signed, server, verify_key)
+ logger.debug(
+ "Successfully verified sig with key %s from %r "
+ "against pubkey %r",
+ key_name,
+ server,
+ public_key_object,
+ )
+ except Exception:
+ logger.info(
+ "Failed to verify sig with key %s from %r "
+ "against pubkey %r",
+ key_name,
+ server,
+ public_key_object,
+ )
+ raise
+ try:
+ if "key_validity_url" in public_key_object:
+ yield self._check_key_revocation(
+ public_key, public_key_object["key_validity_url"]
+ )
+ except Exception:
+ logger.info(
+ "Failed to query key_validity_url %s",
+ public_key_object["key_validity_url"],
)
+ raise
return
except Exception as e:
last_exception = e
@@ -2725,15 +2677,9 @@ class FederationHandler(BaseHandler):
for revocation.
"""
try:
- response = yield self.http_client.get_json(
- url,
- {"public_key": public_key}
- )
+ response = yield self.http_client.get_json(url, {"public_key": public_key})
except Exception:
- raise SynapseError(
- 502,
- "Third party certificate could not be checked"
- )
+ raise SynapseError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")
@@ -2754,12 +2700,11 @@ class FederationHandler(BaseHandler):
yield self._send_events_to_master(
store=self.store,
event_and_contexts=event_and_contexts,
- backfilled=backfilled
+ backfilled=backfilled,
)
else:
max_stream_id = yield self.store.persist_events(
- event_and_contexts,
- backfilled=backfilled,
+ event_and_contexts, backfilled=backfilled
)
if not backfilled: # Never notify for backfilled events
@@ -2793,13 +2738,10 @@ class FederationHandler(BaseHandler):
event_stream_id = event.internal_metadata.stream_ordering
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
+ event, event_stream_id, max_stream_id, extra_users=extra_users
)
- return self.pusher_pool.on_new_notifications(
- event_stream_id, max_stream_id,
- )
+ return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _clean_room_for_join(self, room_id):
"""Called to clean up any data in DB for a given room, ready for the
@@ -2818,9 +2760,7 @@ class FederationHandler(BaseHandler):
"""
if self.config.worker_app:
return self._notify_user_membership_change(
- room_id=room_id,
- user_id=user.to_string(),
- change="joined",
+ room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
return user_joined_room(self.distributor, user, room_id)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 02c508acec..7da63bb643 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -30,6 +30,7 @@ def _create_rerouter(func_name):
"""Returns a function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
+
def f(self, group_id, *args, **kwargs):
if self.is_mine_id(group_id):
return getattr(self.groups_server_handler, func_name)(
@@ -49,9 +50,7 @@ def _create_rerouter(func_name):
def http_response_errback(failure):
failure.trap(HttpResponseException)
e = failure.value
- if e.code == 403:
- raise e.to_synapse_error()
- return failure
+ raise e.to_synapse_error()
def request_failed_errback(failure):
failure.trap(RequestSendFailed)
@@ -60,6 +59,7 @@ def _create_rerouter(func_name):
d.addErrback(http_response_errback)
d.addErrback(request_failed_errback)
return d
+
return f
@@ -127,7 +127,7 @@ class GroupsLocalHandler(object):
)
else:
res = yield self.transport_client.get_group_summary(
- get_domain_from_id(group_id), group_id, requester_user_id,
+ get_domain_from_id(group_id), group_id, requester_user_id
)
group_server_name = get_domain_from_id(group_id)
@@ -184,7 +184,7 @@ class GroupsLocalHandler(object):
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
res = yield self.transport_client.create_group(
- get_domain_from_id(group_id), group_id, user_id, content,
+ get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@@ -197,16 +197,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=True,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue(res)
@@ -223,7 +222,7 @@ class GroupsLocalHandler(object):
group_server_name = get_domain_from_id(group_id)
res = yield self.transport_client.get_users_in_group(
- get_domain_from_id(group_id), group_id, requester_user_id,
+ get_domain_from_id(group_id), group_id, requester_user_id
)
chunk = res["chunk"]
@@ -252,9 +251,7 @@ class GroupsLocalHandler(object):
"""Request to join a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.join_group(
- group_id, user_id, content
- )
+ yield self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@@ -262,7 +259,7 @@ class GroupsLocalHandler(object):
content["attestation"] = local_attestation
res = yield self.transport_client.join_group(
- get_domain_from_id(group_id), group_id, user_id, content,
+ get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@@ -278,16 +275,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue({})
@@ -296,9 +292,7 @@ class GroupsLocalHandler(object):
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.accept_invite(
- group_id, user_id, content
- )
+ yield self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@@ -306,7 +300,7 @@ class GroupsLocalHandler(object):
content["attestation"] = local_attestation
res = yield self.transport_client.accept_group_invite(
- get_domain_from_id(group_id), group_id, user_id, content,
+ get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@@ -322,16 +316,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue({})
@@ -339,17 +332,17 @@ class GroupsLocalHandler(object):
def invite(self, group_id, user_id, requester_user_id, config):
"""Invite a user to a group
"""
- content = {
- "requester_user_id": requester_user_id,
- "config": config,
- }
+ content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.invite_to_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
else:
res = yield self.transport_client.invite_to_group(
- get_domain_from_id(group_id), group_id, user_id, requester_user_id,
+ get_domain_from_id(group_id),
+ group_id,
+ user_id,
+ requester_user_id,
content,
)
@@ -372,13 +365,12 @@ class GroupsLocalHandler(object):
local_profile["avatar_url"] = content["profile"]["avatar_url"]
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="invite",
content={"profile": local_profile, "inviter": content["inviter"]},
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
try:
user_profile = yield self.profile_handler.get_profile(user_id)
except Exception as e:
@@ -393,25 +385,25 @@ class GroupsLocalHandler(object):
"""
if user_id == requester_user_id:
token = yield self.store.register_user_group_membership(
- group_id, user_id,
- membership="leave",
- )
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
+ group_id, user_id, membership="leave"
)
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
# TODO: Should probably remember that we tried to leave so that we can
# retry if the group server is currently down.
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
else:
content["requester_user_id"] = requester_user_id
res = yield self.transport_client.remove_user_from_group(
- get_domain_from_id(group_id), group_id, requester_user_id,
- user_id, content,
+ get_domain_from_id(group_id),
+ group_id,
+ requester_user_id,
+ user_id,
+ content,
)
defer.returnValue(res)
@@ -422,12 +414,9 @@ class GroupsLocalHandler(object):
"""
# TODO: Check if user in group
token = yield self.store.register_user_group_membership(
- group_id, user_id,
- membership="leave",
- )
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
+ group_id, user_id, membership="leave"
)
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
@defer.inlineCallbacks
def get_joined_groups(self, user_id):
@@ -447,7 +436,7 @@ class GroupsLocalHandler(object):
defer.returnValue({"groups": result})
else:
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
- get_domain_from_id(user_id), [user_id],
+ get_domain_from_id(user_id), [user_id]
)
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations
@@ -462,9 +451,7 @@ class GroupsLocalHandler(object):
if self.hs.is_mine_id(user_id):
local_users.add(user_id)
else:
- destinations.setdefault(
- get_domain_from_id(user_id), set()
- ).add(user_id)
+ destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id)
if not proxy and destinations:
raise SynapseError(400, "Some user_ids are not local")
@@ -474,16 +461,14 @@ class GroupsLocalHandler(object):
for destination, dest_user_ids in iteritems(destinations):
try:
r = yield self.transport_client.bulk_get_publicised_groups(
- destination, list(dest_user_ids),
+ destination, list(dest_user_ids)
)
results.update(r["users"])
except Exception:
failed_results.extend(dest_user_ids)
for uid in local_users:
- results[uid] = yield self.store.get_publicised_groups_for_user(
- uid
- )
+ results[uid] = yield self.store.get_publicised_groups_for_user(uid)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 04caf65793..c82b1933f2 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
class IdentityHandler(BaseHandler):
-
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
@@ -64,40 +63,38 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
- if 'id_server' in creds:
- id_server = creds['id_server']
- elif 'idServer' in creds:
- id_server = creds['idServer']
+ if "id_server" in creds:
+ id_server = creds["id_server"]
+ elif "idServer" in creds:
+ id_server = creds["idServer"]
else:
raise SynapseError(400, "No id_server in creds")
- if 'client_secret' in creds:
- client_secret = creds['client_secret']
- elif 'clientSecret' in creds:
- client_secret = creds['clientSecret']
+ if "client_secret" in creds:
+ client_secret = creds["client_secret"]
+ elif "clientSecret" in creds:
+ client_secret = creds["clientSecret"]
else:
raise SynapseError(400, "No client_secret in creds")
if not self._should_trust_id_server(id_server):
logger.warn(
- '%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', id_server
+ "%s is not a trusted ID server: rejecting 3pid " + "credentials",
+ id_server,
)
defer.returnValue(None)
try:
data = yield self.http_client.get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/3pid/getValidated3pid"
- ),
- {'sid': creds['sid'], 'client_secret': client_secret}
+ "https://%s%s"
+ % (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"),
+ {"sid": creds["sid"], "client_secret": client_secret},
)
except HttpResponseException as e:
logger.info("getValidated3pid failed with Matrix error: %r", e)
raise e.to_synapse_error()
- if 'medium' in data:
+ if "medium" in data:
defer.returnValue(data)
defer.returnValue(None)
@@ -106,30 +103,24 @@ class IdentityHandler(BaseHandler):
logger.debug("binding threepid %r to %s", creds, mxid)
data = None
- if 'id_server' in creds:
- id_server = creds['id_server']
- elif 'idServer' in creds:
- id_server = creds['idServer']
+ if "id_server" in creds:
+ id_server = creds["id_server"]
+ elif "idServer" in creds:
+ id_server = creds["idServer"]
else:
raise SynapseError(400, "No id_server in creds")
- if 'client_secret' in creds:
- client_secret = creds['client_secret']
- elif 'clientSecret' in creds:
- client_secret = creds['clientSecret']
+ if "client_secret" in creds:
+ client_secret = creds["client_secret"]
+ elif "clientSecret" in creds:
+ client_secret = creds["clientSecret"]
else:
raise SynapseError(400, "No client_secret in creds")
try:
data = yield self.http_client.post_urlencoded_get_json(
- "https://%s%s" % (
- id_server, "/_matrix/identity/api/v1/3pid/bind"
- ),
- {
- 'sid': creds['sid'],
- 'client_secret': client_secret,
- 'mxid': mxid,
- }
+ "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"),
+ {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid},
)
logger.debug("bound threepid %r to %s", creds, mxid)
@@ -165,9 +156,7 @@ class IdentityHandler(BaseHandler):
id_servers = [threepid["id_server"]]
else:
id_servers = yield self.store.get_id_servers_user_bound(
- user_id=mxid,
- medium=threepid["medium"],
- address=threepid["address"],
+ user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)
# We don't know where to unbind, so we don't have a choice but to return
@@ -177,7 +166,7 @@ class IdentityHandler(BaseHandler):
changed = True
for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server(
- mxid, threepid, id_server,
+ mxid, threepid, id_server
)
defer.returnValue(changed)
@@ -201,10 +190,7 @@ class IdentityHandler(BaseHandler):
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
content = {
"mxid": mxid,
- "threepid": {
- "medium": threepid["medium"],
- "address": threepid["address"],
- },
+ "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
}
# we abuse the federation http client to sign the request, but we have to send it
@@ -212,25 +198,19 @@ class IdentityHandler(BaseHandler):
# 'browser-like' HTTPS.
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
- method='POST',
- url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
+ method="POST",
+ url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"),
content=content,
destination_is=id_server,
)
- headers = {
- b"Authorization": auth_headers,
- }
+ headers = {b"Authorization": auth_headers}
try:
- yield self.http_client.post_json_get_json(
- url,
- content,
- headers,
- )
+ yield self.http_client.post_json_get_json(url, content, headers)
changed = True
except HttpResponseException as e:
changed = False
- if e.code in (400, 404, 501,):
+ if e.code in (400, 404, 501):
# The remote server probably doesn't support unbinding (yet)
logger.warn("Received %d response while unbinding threepid", e.code)
else:
@@ -248,35 +228,27 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def requestEmailToken(
- self,
- id_server,
- email,
- client_secret,
- send_attempt,
- next_link=None,
+ self, id_server, email, client_secret, send_attempt, next_link=None
):
if not self._should_trust_id_server(id_server):
raise SynapseError(
- 400, "Untrusted ID server '%s'" % id_server,
- Codes.SERVER_NOT_TRUSTED
+ 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
)
params = {
- 'email': email,
- 'client_secret': client_secret,
- 'send_attempt': send_attempt,
+ "email": email,
+ "client_secret": client_secret,
+ "send_attempt": send_attempt,
}
if next_link:
- params.update({'next_link': next_link})
+ params.update({"next_link": next_link})
try:
data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/validate/email/requestToken"
- ),
- params
+ "https://%s%s"
+ % (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"),
+ params,
)
defer.returnValue(data)
except HttpResponseException as e:
@@ -285,30 +257,26 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def requestMsisdnToken(
- self, id_server, country, phone_number,
- client_secret, send_attempt, **kwargs
+ self, id_server, country, phone_number, client_secret, send_attempt, **kwargs
):
if not self._should_trust_id_server(id_server):
raise SynapseError(
- 400, "Untrusted ID server '%s'" % id_server,
- Codes.SERVER_NOT_TRUSTED
+ 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
)
params = {
- 'country': country,
- 'phone_number': phone_number,
- 'client_secret': client_secret,
- 'send_attempt': send_attempt,
+ "country": country,
+ "phone_number": phone_number,
+ "client_secret": client_secret,
+ "send_attempt": send_attempt,
}
params.update(kwargs)
try:
data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/validate/msisdn/requestToken"
- ),
- params
+ "https://%s%s"
+ % (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"),
+ params,
)
defer.returnValue(data)
except HttpResponseException as e:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index aaee5db0b7..a1fe9d116f 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -44,8 +44,13 @@ class InitialSyncHandler(BaseHandler):
self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer()
- def snapshot_all_rooms(self, user_id=None, pagin_config=None,
- as_client_event=True, include_archived=False):
+ def snapshot_all_rooms(
+ self,
+ user_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ include_archived=False,
+ ):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
@@ -77,13 +82,22 @@ class InitialSyncHandler(BaseHandler):
if result is not None:
return result
- return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
- user_id, pagin_config, as_client_event, include_archived
- ))
+ return self.snapshot_cache.set(
+ now_ms,
+ key,
+ self._snapshot_all_rooms(
+ user_id, pagin_config, as_client_event, include_archived
+ ),
+ )
@defer.inlineCallbacks
- def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
- as_client_event=True, include_archived=False):
+ def _snapshot_all_rooms(
+ self,
+ user_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ include_archived=False,
+ ):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
@@ -128,8 +142,7 @@ class InitialSyncHandler(BaseHandler):
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
- "public" if event.room_id in public_room_ids
- else "private"
+ "public" if event.room_id in public_room_ids else "private"
),
}
@@ -139,7 +152,7 @@ class InitialSyncHandler(BaseHandler):
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = yield self._event_serializer.serialize_event(
- invite_event, time_now, as_client_event,
+ invite_event, time_now, as_client_event
)
rooms_ret.append(d)
@@ -151,14 +164,12 @@ class InitialSyncHandler(BaseHandler):
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
- self.state_handler.get_current_state,
- event.room_id,
+ self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
- self.store.get_state_for_events,
- [event.event_id],
+ self.store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@@ -178,9 +189,7 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
- messages = yield filter_events_for_client(
- self.store, user_id, messages
- )
+ messages = yield filter_events_for_client(self.store, user_id, messages)
start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", room_end_token)
@@ -189,8 +198,7 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = {
"chunk": (
yield self._event_serializer.serialize_events(
- messages, time_now=time_now,
- as_client_event=as_client_event,
+ messages, time_now=time_now, as_client_event=as_client_event
)
),
"start": start_token.to_string(),
@@ -200,23 +208,21 @@ class InitialSyncHandler(BaseHandler):
d["state"] = yield self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
- as_client_event=as_client_event
+ as_client_event=as_client_event,
)
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append(
+ {"type": "m.tag", "content": {"tags": tags}}
+ )
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append(
+ {"type": account_data_type, "content": content}
+ )
d["account_data"] = account_data_events
except Exception:
@@ -226,10 +232,7 @@ class InitialSyncHandler(BaseHandler):
account_data_events = []
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
now = self.clock.time_msec()
@@ -274,7 +277,7 @@ class InitialSyncHandler(BaseHandler):
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id,
+ room_id, user_id
)
is_peeking = member_event_id is None
@@ -290,28 +293,21 @@ class InitialSyncHandler(BaseHandler):
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
- def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
- membership, member_event_id, is_peeking):
- room_state = yield self.store.get_state_for_events(
- [member_event_id],
- )
+ def _room_initial_sync_parted(
+ self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
+ ):
+ room_state = yield self.store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -319,14 +315,10 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- stream_token = yield self.store.get_stream_token_for_event(
- member_event_id
- )
+ stream_token = yield self.store.get_stream_token_for_event(member_event_id)
messages, token = yield self.store.get_recent_events_for_room(
- room_id,
- limit=limit,
- end_token=stream_token
+ room_id, limit=limit, end_token=stream_token
)
messages = yield filter_events_for_client(
@@ -338,34 +330,39 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
- defer.returnValue({
- "membership": membership,
- "room_id": room_id,
- "messages": {
- "chunk": (yield self._event_serializer.serialize_events(
- messages, time_now,
- )),
- "start": start_token.to_string(),
- "end": end_token.to_string(),
- },
- "state": (yield self._event_serializer.serialize_events(
- room_state.values(), time_now,
- )),
- "presence": [],
- "receipts": [],
- })
+ defer.returnValue(
+ {
+ "membership": membership,
+ "room_id": room_id,
+ "messages": {
+ "chunk": (
+ yield self._event_serializer.serialize_events(
+ messages, time_now
+ )
+ ),
+ "start": start_token.to_string(),
+ "end": end_token.to_string(),
+ },
+ "state": (
+ yield self._event_serializer.serialize_events(
+ room_state.values(), time_now
+ )
+ ),
+ "presence": [],
+ "receipts": [],
+ }
+ )
@defer.inlineCallbacks
- def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
- membership, is_peeking):
- current_state = yield self.state.get_current_state(
- room_id=room_id,
- )
+ def _room_initial_sync_joined(
+ self, user_id, room_id, pagin_config, membership, is_peeking
+ ):
+ current_state = yield self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = yield self._event_serializer.serialize_events(
- current_state.values(), time_now,
+ current_state.values(), time_now
)
now_token = yield self.hs.get_event_sources().get_current_token()
@@ -375,7 +372,8 @@ class InitialSyncHandler(BaseHandler):
limit = 10
room_members = [
- m for m in current_state.values()
+ m
+ for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
@@ -389,8 +387,7 @@ class InitialSyncHandler(BaseHandler):
defer.returnValue([])
states = yield presence_handler.get_states(
- [m.user_id for m in room_members],
- as_event=True,
+ [m.user_id for m in room_members], as_event=True
)
defer.returnValue(states)
@@ -398,8 +395,7 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks
def get_receipts():
receipts = yield self.store.get_linearized_receipts_for_room(
- room_id,
- to_key=now_token.receipt_key,
+ room_id, to_key=now_token.receipt_key
)
if not receipts:
receipts = []
@@ -415,14 +411,14 @@ class InitialSyncHandler(BaseHandler):
room_id,
limit=limit,
end_token=now_token.room_key,
- )
+ ),
],
consumeErrors=True,
- ).addErrback(unwrapFirstError),
+ ).addErrback(unwrapFirstError)
)
messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking,
+ self.store, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace("room_key", token)
@@ -433,9 +429,9 @@ class InitialSyncHandler(BaseHandler):
ret = {
"room_id": room_id,
"messages": {
- "chunk": (yield self._event_serializer.serialize_events(
- messages, time_now,
- )),
+ "chunk": (
+ yield self._event_serializer.serialize_events(messages, time_now)
+ ),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
@@ -464,8 +460,8 @@ class InitialSyncHandler(BaseHandler):
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
+ visibility
+ and visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0b02469ceb..683da6bf32 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2017 - 2018 New Vector Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,9 +34,10 @@ from synapse.api.errors import (
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events.validator import EventValidator
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, UserID
+from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background
@@ -59,8 +61,9 @@ class MessageHandler(object):
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
- def get_room_data(self, user_id=None, room_id=None,
- event_type=None, state_key="", is_guest=False):
+ def get_room_data(
+ self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
+ ):
""" Get data from a room.
Args:
@@ -75,9 +78,7 @@ class MessageHandler(object):
)
if membership == Membership.JOIN:
- data = yield self.state.get_current_state(
- room_id, event_type, state_key
- )
+ data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
@@ -89,8 +90,12 @@ class MessageHandler(object):
@defer.inlineCallbacks
def get_state_events(
- self, user_id, room_id, state_filter=StateFilter.all(),
- at_token=None, is_guest=False,
+ self,
+ user_id,
+ room_id,
+ state_filter=StateFilter.all(),
+ at_token=None,
+ is_guest=False,
):
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
@@ -122,50 +127,48 @@ class MessageHandler(object):
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
- room_id, end_token=at_token.room_key, limit=1,
+ room_id, end_token=at_token.room_key, limit=1
)
if not last_events:
- raise NotFoundError("Can't find event for token %s" % (at_token, ))
+ raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.store, user_id, last_events,
+ self.store, user_id, last_events
)
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
- [event.event_id], state_filter=state_filter,
+ [event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
else:
raise AuthError(
403,
- "User %s not allowed to view events in room %s at token %s" % (
- user_id, room_id, at_token,
- )
+ "User %s not allowed to view events in room %s at token %s"
+ % (user_id, room_id, at_token),
)
else:
membership, membership_event_id = (
- yield self.auth.check_in_room_or_world_readable(
- room_id, user_id,
- )
+ yield self.auth.check_in_room_or_world_readable(room_id, user_id)
)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
- room_id, state_filter=state_filter,
+ room_id, state_filter=state_filter
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
- [membership_event_id], state_filter=state_filter,
+ [membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(
- room_state.values(), now,
+ room_state.values(),
+ now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_aggregations=False,
@@ -209,13 +212,15 @@ class MessageHandler(object):
# Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room")
- defer.returnValue({
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
+ defer.returnValue(
+ {
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
+ }
+ for user_id, profile in iteritems(users_with_profile)
}
- for user_id, profile in iteritems(users_with_profile)
- })
+ )
class EventCreationHandler(object):
@@ -248,6 +253,7 @@ class EventCreationHandler(object):
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
+ self.third_party_event_rules = hs.get_third_party_event_rules()
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
@@ -259,9 +265,28 @@ class EventCreationHandler(object):
if self._block_events_without_consent_error:
self._consent_uri_builder = ConsentURIBuilder(self.config)
+ if (
+ not self.config.worker_app
+ and self.config.cleanup_extremities_with_dummy_events
+ ):
+ self.clock.looping_call(
+ lambda: run_as_background_process(
+ "send_dummy_events_to_fill_extremities",
+ self._send_dummy_events_to_fill_extremities,
+ ),
+ 5 * 60 * 1000,
+ )
+
@defer.inlineCallbacks
- def create_event(self, requester, event_dict, token_id=None, txn_id=None,
- prev_events_and_hashes=None, require_consent=True):
+ def create_event(
+ self,
+ requester,
+ event_dict,
+ token_id=None,
+ txn_id=None,
+ prev_events_and_hashes=None,
+ require_consent=True,
+ ):
"""
Given a dict from a client, create a new event.
@@ -321,8 +346,7 @@ class EventCreationHandler(object):
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
- "Failed to get profile information for %r: %s",
- target, e
+ "Failed to get profile information for %r: %s", target, e
)
is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
@@ -358,16 +382,17 @@ class EventCreationHandler(object):
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event or prev_event.membership != Membership.JOIN:
logger.warning(
- ("Attempt to send `m.room.aliases` in room %s by user %s but"
- " membership is %s"),
+ (
+ "Attempt to send `m.room.aliases` in room %s by user %s but"
+ " membership is %s"
+ ),
event.room_id,
event.sender,
prev_event.membership if prev_event else None,
)
raise AuthError(
- 403,
- "You must be in the room to create an alias for it",
+ 403, "You must be in the room to create an alias for it"
)
self.validator.validate_new(event)
@@ -434,8 +459,8 @@ class EventCreationHandler(object):
# exempt the system notices user
if (
- self.config.server_notices_mxid is not None and
- user_id == self.config.server_notices_mxid
+ self.config.server_notices_mxid is not None
+ and user_id == self.config.server_notices_mxid
):
return
@@ -448,15 +473,10 @@ class EventCreationHandler(object):
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(
- requester.user.localpart,
- )
- msg = self._block_events_without_consent_error % {
- 'consent_uri': consent_uri,
- }
- raise ConsentNotGivenError(
- msg=msg,
- consent_uri=consent_uri,
+ requester.user.localpart
)
+ msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
+ raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
@@ -471,8 +491,7 @@ class EventCreationHandler(object):
"""
if event.type == EventTypes.Member:
raise SynapseError(
- 500,
- "Tried to send member event through non-member codepath"
+ 500, "Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender)
@@ -484,15 +503,13 @@ class EventCreationHandler(object):
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
- event.event_id, prev_state.event_id,
+ event.event_id,
+ prev_state.event_id,
)
defer.returnValue(prev_state)
yield self.handle_new_client_event(
- requester=requester,
- event=event,
- context=context,
- ratelimit=ratelimit,
+ requester=requester, event=event, context=context, ratelimit=ratelimit
)
@defer.inlineCallbacks
@@ -518,11 +535,7 @@ class EventCreationHandler(object):
@defer.inlineCallbacks
def create_and_send_nonmember_event(
- self,
- requester,
- event_dict,
- ratelimit=True,
- txn_id=None
+ self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
Creates an event, then sends it.
@@ -537,32 +550,25 @@ class EventCreationHandler(object):
# taking longer.
with (yield self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event(
- requester,
- event_dict,
- token_id=requester.access_token_id,
- txn_id=txn_id
+ requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, string_types):
spam_error = "Spam is not permitted here"
- raise SynapseError(
- 403, spam_error, Codes.FORBIDDEN
- )
+ raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
+ requester, event, context, ratelimit=ratelimit
)
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
- def create_new_client_event(self, builder, requester=None,
- prev_events_and_hashes=None):
+ def create_new_client_event(
+ self, builder, requester=None, prev_events_and_hashes=None
+ ):
"""Create a new event for a local client
Args:
@@ -582,22 +588,21 @@ class EventCreationHandler(object):
"""
if prev_events_and_hashes is not None:
- assert len(prev_events_and_hashes) <= 10, \
- "Attempting to create an event with %i prev_events" % (
- len(prev_events_and_hashes),
+ assert len(prev_events_and_hashes) <= 10, (
+ "Attempting to create an event with %i prev_events"
+ % (len(prev_events_and_hashes),)
)
else:
- prev_events_and_hashes = \
- yield self.store.get_prev_events_for_room(builder.room_id)
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(
+ builder.room_id
+ )
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes
]
- event = yield builder.build(
- prev_event_ids=[p for p, _ in prev_events],
- )
+ event = yield builder.build(prev_event_ids=[p for p, _ in prev_events])
context = yield self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
@@ -613,29 +618,19 @@ class EventCreationHandler(object):
aggregation_key = relation["key"]
already_exists = yield self.store.has_user_annotated_event(
- relates_to, event.type, aggregation_key, event.sender,
+ relates_to, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
- logger.debug(
- "Created event %s",
- event.event_id,
- )
+ logger.debug("Created event %s", event.event_id)
- defer.returnValue(
- (event, context,)
- )
+ defer.returnValue((event, context))
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
- self,
- requester,
- event,
- context,
- ratelimit=True,
- extra_users=[],
+ self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@@ -651,13 +646,22 @@ class EventCreationHandler(object):
extra_users (list(UserID)): Any extra users to notify about event
"""
- if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
- room_version = event.content.get(
- "room_version", RoomVersions.V1.identifier
- )
+ if event.is_state() and (event.type, event.state_key) == (
+ EventTypes.Create,
+ "",
+ ):
+ room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
room_version = yield self.store.get_room_version(event.room_id)
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as err:
@@ -672,9 +676,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
- yield self.action_generator.handle_push_actions_for_event(
- event, context
- )
+ yield self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@@ -695,11 +697,7 @@ class EventCreationHandler(object):
return
yield self.persist_and_notify_client_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
- extra_users=extra_users,
+ requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
success = True
@@ -708,18 +706,12 @@ class EventCreationHandler(object):
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
run_in_background(
- self.store.remove_push_actions_from_staging,
- event.event_id,
+ self.store.remove_push_actions_from_staging, event.event_id
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
- self,
- requester,
- event,
- context,
- ratelimit=True,
- extra_users=[],
+ self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@@ -744,20 +736,16 @@ class EventCreationHandler(object):
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
- "Room alias %s does not point to the room" % (
- room_alias_str,
- )
+ "Room alias %s does not point to the room" % (room_alias_str,),
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
+
def is_inviter_member_event(e):
- return (
- e.type == EventTypes.Member and
- e.sender == event.sender
- )
+ return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = yield context.get_current_state_ids(self.store)
@@ -787,26 +775,21 @@ class EventCreationHandler(object):
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
- invitee.domain,
- event,
+ invitee.domain, event
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
- event.signatures.update(
- returned_invite.signatures
- )
+ event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in auth_events.values()
- }
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version(event.room_id)
if self.auth.check_redaction(room_version, event, auth_events=auth_events):
original_event = yield self.store.get_event(
@@ -814,13 +797,10 @@ class EventCreationHandler(object):
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
- allow_none=False
+ allow_none=False,
)
if event.user_id != original_event.user_id:
- raise AuthError(
- 403,
- "You don't have permission to redact events"
- )
+ raise AuthError(403, "You don't have permission to redact events")
# We've already checked.
event.internal_metadata.recheck_redaction = False
@@ -828,24 +808,18 @@ class EventCreationHandler(object):
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:
- raise AuthError(
- 403,
- "Changing the room create event is forbidden",
- )
+ raise AuthError(403, "Changing the room create event is forbidden")
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
- yield self.pusher_pool.on_new_notifications(
- event_stream_id, max_stream_id,
- )
+ yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
+ event, event_stream_id, max_stream_id, extra_users=extra_users
)
except Exception:
logger.exception("Error notifying about new room event")
@@ -864,3 +838,54 @@ class EventCreationHandler(object):
yield presence.bump_presence_active_time(user)
except Exception:
logger.exception("Error bumping presence active time")
+
+ @defer.inlineCallbacks
+ def _send_dummy_events_to_fill_extremities(self):
+ """Background task to send dummy events into rooms that have a large
+ number of extremities
+ """
+
+ room_ids = yield self.store.get_rooms_with_many_extremities(
+ min_count=10, limit=5
+ )
+
+ for room_id in room_ids:
+ # For each room we need to find a joined member we can use to send
+ # the dummy event with.
+
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)
+
+ latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
+
+ members = yield self.state.get_current_users_in_room(
+ room_id, latest_event_ids=latest_event_ids
+ )
+
+ user_id = None
+ for member in members:
+ if self.hs.is_mine_id(member):
+ user_id = member
+ break
+
+ if not user_id:
+ # We don't have a joined user.
+ # TODO: We should do something here to stop the room from
+ # appearing next time.
+ continue
+
+ requester = create_requester(user_id)
+
+ event, context = yield self.create_event(
+ requester,
+ {
+ "type": "org.matrix.dummy_event",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
+
+ event.internal_metadata.proactively_send = False
+
+ yield self.send_nonmember_event(requester, event, context, ratelimit=False)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8f811e24fe..76ee97ddd3 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -55,9 +55,7 @@ class PurgeStatus(object):
self.status = PurgeStatus.STATUS_ACTIVE
def asdict(self):
- return {
- "status": PurgeStatus.STATUS_TEXT[self.status]
- }
+ return {"status": PurgeStatus.STATUS_TEXT[self.status]}
class PaginationHandler(object):
@@ -79,8 +77,7 @@ class PaginationHandler(object):
self._purges_by_id = {}
self._event_serializer = hs.get_event_client_serializer()
- def start_purge_history(self, room_id, token,
- delete_local_events=False):
+ def start_purge_history(self, room_id, token, delete_local_events=False):
"""Start off a history purge on a room.
Args:
@@ -95,8 +92,7 @@ class PaginationHandler(object):
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
- 400,
- "History purge already in progress for %s" % (room_id, ),
+ 400, "History purge already in progress for %s" % (room_id,)
)
purge_id = random_string(16)
@@ -107,14 +103,12 @@ class PaginationHandler(object):
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
- self._purge_history,
- purge_id, room_id, token, delete_local_events,
+ self._purge_history, purge_id, room_id, token, delete_local_events
)
return purge_id
@defer.inlineCallbacks
- def _purge_history(self, purge_id, room_id, token,
- delete_local_events):
+ def _purge_history(self, purge_id, room_id, token, delete_local_events):
"""Carry out a history purge on a room.
Args:
@@ -130,16 +124,13 @@ class PaginationHandler(object):
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)):
- yield self.store.purge_history(
- room_id, token, delete_local_events,
- )
+ yield self.store.purge_history(room_id, token, delete_local_events)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
f = Failure()
logger.error(
- "[purge] failed",
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
@@ -148,6 +139,7 @@ class PaginationHandler(object):
# remove the purge from the list 24 hours after it completes
def clear_purge():
del self._purges_by_id[purge_id]
+
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
@@ -162,8 +154,14 @@ class PaginationHandler(object):
return self._purges_by_id.get(purge_id)
@defer.inlineCallbacks
- def get_messages(self, requester, room_id=None, pagin_config=None,
- as_client_event=True, event_filter=None):
+ def get_messages(
+ self,
+ requester,
+ room_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ event_filter=None,
+ ):
"""Get messages in a room.
Args:
@@ -182,9 +180,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_room(
- room_id=room_id
- )
+ yield self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -201,7 +197,7 @@ class PaginationHandler(object):
room_id, user_id
)
- if source_config.direction == 'b':
+ if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
@@ -235,27 +231,24 @@ class PaginationHandler(object):
event_filter=event_filter,
)
- next_token = pagin_config.from_token.copy_and_replace(
- "room_key", next_key
- )
+ next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
if events:
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client(
- self.store,
- user_id,
- events,
- is_peeking=(member_event_id is None),
+ self.store, user_id, events, is_peeking=(member_event_id is None)
)
if not events:
- defer.returnValue({
- "chunk": [],
- "start": pagin_config.from_token.to_string(),
- "end": next_token.to_string(),
- })
+ defer.returnValue(
+ {
+ "chunk": [],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ }
+ )
state = None
if event_filter and event_filter.lazy_load_members() and len(events) > 0:
@@ -263,12 +256,11 @@ class PaginationHandler(object):
# FIXME: we also care about invite targets etc.
state_filter = StateFilter.from_types(
- (EventTypes.Member, event.sender)
- for event in events
+ (EventTypes.Member, event.sender) for event in events
)
state_ids = yield self.store.get_state_ids_for_event(
- events[0].event_id, state_filter=state_filter,
+ events[0].event_id, state_filter=state_filter
)
if state_ids:
@@ -280,8 +272,7 @@ class PaginationHandler(object):
chunk = {
"chunk": (
yield self._event_serializer.serialize_events(
- events, time_now,
- as_client_event=as_client_event,
+ events, time_now, as_client_event=as_client_event
)
),
"start": pagin_config.from_token.to_string(),
@@ -291,8 +282,7 @@ class PaginationHandler(object):
if state:
chunk["state"] = (
yield self._event_serializer.serialize_events(
- state, time_now,
- as_client_event=as_client_event,
+ state, time_now, as_client_event=as_client_event
)
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 557fb5f83d..5204073a38 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -50,16 +50,20 @@ logger = logging.getLogger(__name__)
notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "")
federation_presence_out_counter = Counter(
- "synapse_handler_presence_federation_presence_out", "")
+ "synapse_handler_presence_federation_presence_out", ""
+)
presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "")
timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "")
-federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "")
+federation_presence_counter = Counter(
+ "synapse_handler_presence_federation_presence", ""
+)
bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "")
get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"])
notify_reason_counter = Counter(
- "synapse_handler_presence_notify_reason", "", ["reason"])
+ "synapse_handler_presence_notify_reason", "", ["reason"]
+)
state_transition_counter = Counter(
"synapse_handler_presence_state_transition", "", ["from", "to"]
)
@@ -90,7 +94,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
-
def __init__(self, hs):
"""
@@ -110,31 +113,26 @@ class PresenceHandler(object):
federation_registry = hs.get_federation_registry()
- federation_registry.register_edu_handler(
- "m.presence", self.incoming_presence
- )
+ federation_registry.register_edu_handler("m.presence", self.incoming_presence)
active_presence = self.store.take_presence_startup_info()
# A dictionary of the current state of users. This is prefilled with
# non-offline presence from the DB. We should fetch from the DB if
# we can't find a users presence in here.
- self.user_to_current_state = {
- state.user_id: state
- for state in active_presence
- }
+ self.user_to_current_state = {state.user_id: state for state in active_presence}
LaterGauge(
- "synapse_handlers_presence_user_to_current_state_size", "", [],
- lambda: len(self.user_to_current_state)
+ "synapse_handlers_presence_user_to_current_state_size",
+ "",
+ [],
+ lambda: len(self.user_to_current_state),
)
now = self.clock.time_msec()
for state in active_presence:
self.wheel_timer.insert(
- now=now,
- obj=state.user_id,
- then=state.last_active_ts + IDLE_TIMER,
+ now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
)
self.wheel_timer.insert(
now=now,
@@ -193,27 +191,21 @@ class PresenceHandler(object):
"handle_presence_timeouts", self._handle_timeouts
)
- self.clock.call_later(
- 30,
- self.clock.looping_call,
- run_timeout_handler,
- 5000,
- )
+ self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000)
def run_persister():
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
- self.clock.call_later(
- 60,
- self.clock.looping_call,
- run_persister,
- 60 * 1000,
- )
+ self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
- LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
- lambda: len(self.wheel_timer))
+ LaterGauge(
+ "synapse_handlers_presence_wheel_timer_size",
+ "",
+ [],
+ lambda: len(self.wheel_timer),
+ )
# Used to handle sending of presence to newly joined users/servers
if hs.config.use_presence:
@@ -241,15 +233,17 @@ class PresenceHandler(object):
logger.info(
"Performing _on_shutdown. Persisting %d unpersisted changes",
- len(self.user_to_current_state)
+ len(self.user_to_current_state),
)
if self.unpersisted_users_changes:
- yield self.store.update_presence([
- self.user_to_current_state[user_id]
- for user_id in self.unpersisted_users_changes
- ])
+ yield self.store.update_presence(
+ [
+ self.user_to_current_state[user_id]
+ for user_id in self.unpersisted_users_changes
+ ]
+ )
logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
@@ -261,13 +255,10 @@ class PresenceHandler(object):
self.unpersisted_users_changes = set()
if unpersisted:
- logger.info(
- "Persisting %d upersisted presence updates", len(unpersisted)
+ logger.info("Persisting %d upersisted presence updates", len(unpersisted))
+ yield self.store.update_presence(
+ [self.user_to_current_state[user_id] for user_id in unpersisted]
)
- yield self.store.update_presence([
- self.user_to_current_state[user_id]
- for user_id in unpersisted
- ])
@defer.inlineCallbacks
def _update_states(self, new_states):
@@ -303,10 +294,11 @@ class PresenceHandler(object):
)
new_state, should_notify, should_ping = handle_update(
- prev_state, new_state,
+ prev_state,
+ new_state,
is_mine=self.is_mine_id(user_id),
wheel_timer=self.wheel_timer,
- now=now
+ now=now,
)
self.user_to_current_state[user_id] = new_state
@@ -328,7 +320,8 @@ class PresenceHandler(object):
self.unpersisted_users_changes -= set(to_notify.keys())
to_federation_ping = {
- user_id: state for user_id, state in to_federation_ping.items()
+ user_id: state
+ for user_id, state in to_federation_ping.items()
if user_id not in to_notify
}
if to_federation_ping:
@@ -351,8 +344,8 @@ class PresenceHandler(object):
# Check whether the lists of syncing processes from an external
# process have expired.
expired_process_ids = [
- process_id for process_id, last_update
- in self.external_process_last_updated_ms.items()
+ process_id
+ for process_id, last_update in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
@@ -362,9 +355,7 @@ class PresenceHandler(object):
self.external_process_last_update.pop(process_id)
states = [
- self.user_to_current_state.get(
- user_id, UserPresenceState.default(user_id)
- )
+ self.user_to_current_state.get(user_id, UserPresenceState.default(user_id))
for user_id in users_to_check
]
@@ -394,9 +385,7 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
- new_fields = {
- "last_active_ts": self.clock.time_msec(),
- }
+ new_fields = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE
@@ -430,15 +419,23 @@ class PresenceHandler(object):
if prev_state.state == PresenceState.OFFLINE:
# If they're currently offline then bring them online, otherwise
# just update the last sync times.
- yield self._update_states([prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=self.clock.time_msec(),
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ yield self._update_states(
+ [
+ prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=self.clock.time_msec(),
+ last_user_sync_ts=self.clock.time_msec(),
+ )
+ ]
+ )
else:
- yield self._update_states([prev_state.copy_and_replace(
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ yield self._update_states(
+ [
+ prev_state.copy_and_replace(
+ last_user_sync_ts=self.clock.time_msec()
+ )
+ ]
+ )
@defer.inlineCallbacks
def _end():
@@ -446,9 +443,13 @@ class PresenceHandler(object):
self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id)
- yield self._update_states([prev_state.copy_and_replace(
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ yield self._update_states(
+ [
+ prev_state.copy_and_replace(
+ last_user_sync_ts=self.clock.time_msec()
+ )
+ ]
+ )
except Exception:
logger.exception("Error updating presence after sync")
@@ -469,7 +470,8 @@ class PresenceHandler(object):
"""
if self.hs.config.use_presence:
syncing_user_ids = {
- user_id for user_id, count in self.user_to_num_current_syncs.items()
+ user_id
+ for user_id, count in self.user_to_num_current_syncs.items()
if count
}
for user_ids in self.external_process_to_current_syncs.values():
@@ -479,7 +481,9 @@ class PresenceHandler(object):
return set()
@defer.inlineCallbacks
- def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
+ def update_external_syncs_row(
+ self, process_id, user_id, is_syncing, sync_time_msec
+ ):
"""Update the syncing users for an external process as a delta.
Args:
@@ -500,20 +504,22 @@ class PresenceHandler(object):
updates = []
if is_syncing and user_id not in process_presence:
if prev_state.state == PresenceState.OFFLINE:
- updates.append(prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=sync_time_msec,
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=sync_time_msec,
+ last_user_sync_ts=sync_time_msec,
+ )
+ )
else:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec)
+ )
process_presence.add(user_id)
elif user_id in process_presence:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec)
+ )
if not is_syncing:
process_presence.discard(user_id)
@@ -537,12 +543,12 @@ class PresenceHandler(object):
prev_states = yield self.current_state_for_users(process_presence)
time_now_ms = self.clock.time_msec()
- yield self._update_states([
- prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- )
- for prev_state in itervalues(prev_states)
- ])
+ yield self._update_states(
+ [
+ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
+ for prev_state in itervalues(prev_states)
+ ]
+ )
self.external_process_last_updated_ms.pop(process_id, None)
@defer.inlineCallbacks
@@ -574,8 +580,7 @@ class PresenceHandler(object):
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
- user_id: UserPresenceState.default(user_id)
- for user_id in missing
+ user_id: UserPresenceState.default(user_id) for user_id in missing
}
states.update(new)
self.user_to_current_state.update(new)
@@ -593,8 +598,10 @@ class PresenceHandler(object):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states]
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=[UserID.from_string(u) for u in users_to_states],
)
self._push_to_remotes(states)
@@ -605,8 +612,10 @@ class PresenceHandler(object):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states]
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=[UserID.from_string(u) for u in users_to_states],
)
def _push_to_remotes(self, states):
@@ -631,15 +640,15 @@ class PresenceHandler(object):
user_id = push.get("user_id", None)
if not user_id:
logger.info(
- "Got presence update from %r with no 'user_id': %r",
- origin, push,
+ "Got presence update from %r with no 'user_id': %r", origin, push
)
continue
if get_domain_from_id(user_id) != origin:
logger.info(
"Got presence update from %r with bad 'user_id': %r",
- origin, user_id,
+ origin,
+ user_id,
)
continue
@@ -647,14 +656,12 @@ class PresenceHandler(object):
if not presence_state:
logger.info(
"Got presence update from %r with no 'presence_state': %r",
- origin, push,
+ origin,
+ push,
)
continue
- new_fields = {
- "state": presence_state,
- "last_federation_update_ts": now,
- }
+ new_fields = {"state": presence_state, "last_federation_update_ts": now}
last_active_ago = push.get("last_active_ago", None)
if last_active_ago is not None:
@@ -672,10 +679,7 @@ class PresenceHandler(object):
@defer.inlineCallbacks
def get_state(self, target_user, as_event=False):
- results = yield self.get_states(
- [target_user.to_string()],
- as_event=as_event,
- )
+ results = yield self.get_states([target_user.to_string()], as_event=as_event)
defer.returnValue(results[0])
@@ -699,13 +703,15 @@ class PresenceHandler(object):
now = self.clock.time_msec()
if as_event:
- defer.returnValue([
- {
- "type": "m.presence",
- "content": format_user_presence_state(state, now),
- }
- for state in updates
- ])
+ defer.returnValue(
+ [
+ {
+ "type": "m.presence",
+ "content": format_user_presence_state(state, now),
+ }
+ for state in updates
+ ]
+ )
else:
defer.returnValue(updates)
@@ -717,7 +723,9 @@ class PresenceHandler(object):
presence = state["presence"]
valid_presence = (
- PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
)
if presence not in valid_presence:
raise SynapseError(400, "Invalid presence state")
@@ -726,9 +734,7 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
- new_fields = {
- "state": presence
- }
+ new_fields = {"state": presence}
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
@@ -877,8 +883,7 @@ class PresenceHandler(object):
hosts = set(host for host in hosts if host != self.server_name)
self.federation.send_presence_to_destinations(
- states=[state],
- destinations=hosts,
+ states=[state], destinations=hosts
)
else:
# A remote user has joined the room, so we need to:
@@ -904,7 +909,8 @@ class PresenceHandler(object):
# default state.
now = self.clock.time_msec()
states = [
- state for state in states.values()
+ state
+ for state in states.values()
if state.state != PresenceState.OFFLINE
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
or state.status_msg is not None
@@ -912,8 +918,7 @@ class PresenceHandler(object):
if states:
self.federation.send_presence_to_destinations(
- states=states,
- destinations=[get_domain_from_id(user_id)],
+ states=states, destinations=[get_domain_from_id(user_id)]
)
@@ -937,7 +942,10 @@ def should_notify(old_state, new_state):
notify_reason_counter.labels("current_active_change").inc()
return True
- if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
+ if (
+ new_state.last_active_ts - old_state.last_active_ts
+ > LAST_ACTIVE_GRANULARITY
+ ):
# Only notify about last active bumps if we're not currently acive
if not new_state.currently_active:
notify_reason_counter.labels("last_active_change_online").inc()
@@ -958,9 +966,7 @@ def format_user_presence_state(state, now, include_user_id=True):
The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests.
"""
- content = {
- "presence": state.state,
- }
+ content = {"presence": state.state}
if include_user_id:
content["user_id"] = state.user_id
if state.last_active_ts:
@@ -986,8 +992,15 @@ class PresenceEventSource(object):
@defer.inlineCallbacks
@log_function
- def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
- explicit_room_id=None, **kwargs):
+ def get_new_events(
+ self,
+ user,
+ from_key,
+ room_ids=None,
+ include_offline=True,
+ explicit_room_id=None,
+ **kwargs
+ ):
# The process for getting presence events are:
# 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms.
@@ -1030,7 +1043,7 @@ class PresenceEventSource(object):
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
- users_interested_in, from_key,
+ users_interested_in, from_key
)
else:
user_ids_changed = users_interested_in
@@ -1040,10 +1053,16 @@ class PresenceEventSource(object):
if include_offline:
defer.returnValue((list(updates.values()), max_token))
else:
- defer.returnValue(([
- s for s in itervalues(updates)
- if s.state != PresenceState.OFFLINE
- ], max_token))
+ defer.returnValue(
+ (
+ [
+ s
+ for s in itervalues(updates)
+ if s.state != PresenceState.OFFLINE
+ ],
+ max_token,
+ )
+ )
def get_current_key(self):
return self.store.get_current_presence_token()
@@ -1061,13 +1080,13 @@ class PresenceEventSource(object):
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id, on_invalidate=cache_context.invalidate,
+ user_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(
- explicit_room_id, on_invalidate=cache_context.invalidate,
+ explicit_room_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(user_ids)
@@ -1123,9 +1142,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
if now - state.last_active_ts > IDLE_TIMER:
# Currently online, but last activity ages ago so auto
# idle
- state = state.copy_and_replace(
- state=PresenceState.UNAVAILABLE,
- )
+ state = state.copy_and_replace(state=PresenceState.UNAVAILABLE)
changed = True
elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# So that we send down a notification that we've
@@ -1145,8 +1162,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace(
- state=PresenceState.OFFLINE,
- status_msg=None,
+ state=PresenceState.OFFLINE, status_msg=None
)
changed = True
else:
@@ -1155,10 +1171,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
# no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
# The other side seems to have disappeared.
- state = state.copy_and_replace(
- state=PresenceState.OFFLINE,
- status_msg=None,
- )
+ state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None)
changed = True
return state if changed else None
@@ -1193,21 +1206,17 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
if new_state.state == PresenceState.ONLINE:
# Idle timer
wheel_timer.insert(
- now=now,
- obj=user_id,
- then=new_state.last_active_ts + IDLE_TIMER
+ now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER
)
active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY
- new_state = new_state.copy_and_replace(
- currently_active=active,
- )
+ new_state = new_state.copy_and_replace(currently_active=active)
if active:
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
+ then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
)
if new_state.state != PresenceState.OFFLINE:
@@ -1215,29 +1224,25 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
+ then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
)
last_federate = new_state.last_federation_update_ts
if now - last_federate > FEDERATION_PING_INTERVAL:
# Been a while since we've poked remote servers
- new_state = new_state.copy_and_replace(
- last_federation_update_ts=now,
- )
+ new_state = new_state.copy_and_replace(last_federation_update_ts=now)
federation_ping = True
else:
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT
+ then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
)
# Check whether the change was something worth notifying about
if should_notify(prev_state, new_state):
- new_state = new_state.copy_and_replace(
- last_federation_update_ts=now,
- )
+ new_state = new_state.copy_and_replace(last_federation_update_ts=now)
persist_and_notify = True
return new_state, persist_and_notify, federation_ping
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index a5fc6c5dbf..d8462b75ec 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -15,12 +15,15 @@
import logging
+from six import raise_from
+
from twisted.internet import defer
from synapse.api.errors import (
AuthError,
- CodeMessageException,
Codes,
+ HttpResponseException,
+ RequestSendFailed,
StoreError,
SynapseError,
)
@@ -70,25 +73,20 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({
- "displayname": displayname,
- "avatar_url": avatar_url,
- })
+ defer.returnValue({"displayname": displayname, "avatar_url": avatar_url})
else:
try:
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": user_id,
- },
+ args={"user_id": user_id},
ignore_backoff=True,
)
defer.returnValue(result)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get displayname")
- raise
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
@defer.inlineCallbacks
def get_profile_from_cache(self, user_id):
@@ -110,10 +108,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({
- "displayname": displayname,
- "avatar_url": avatar_url,
- })
+ defer.returnValue({"displayname": displayname, "avatar_url": avatar_url})
else:
profile = yield self.store.get_from_remote_profile_cache(user_id)
defer.returnValue(profile or {})
@@ -136,16 +131,13 @@ class BaseProfileHandler(BaseHandler):
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": target_user.to_string(),
- "field": "displayname",
- },
+ args={"user_id": target_user.to_string(), "field": "displayname"},
ignore_backoff=True,
)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get displayname")
- raise
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
defer.returnValue(result["displayname"])
@@ -167,15 +159,13 @@ class BaseProfileHandler(BaseHandler):
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
- 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ),
+ 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
- if new_displayname == '':
+ if new_displayname == "":
new_displayname = None
- yield self.store.set_profile_displayname(
- target_user.localpart, new_displayname
- )
+ yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -202,16 +192,13 @@ class BaseProfileHandler(BaseHandler):
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": target_user.to_string(),
- "field": "avatar_url",
- },
+ args={"user_id": target_user.to_string(), "field": "avatar_url"},
ignore_backoff=True,
)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get avatar_url")
- raise
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
defer.returnValue(result["avatar_url"])
@@ -227,12 +214,10 @@ class BaseProfileHandler(BaseHandler):
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
- 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ),
+ 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
- yield self.store.set_profile_avatar_url(
- target_user.localpart, new_avatar_url
- )
+ yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -275,9 +260,7 @@ class BaseProfileHandler(BaseHandler):
yield self.ratelimit(requester)
- room_ids = yield self.store.get_rooms_for_user(
- target_user.to_string(),
- )
+ room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
handler = self.hs.get_room_member_handler()
@@ -293,8 +276,7 @@ class BaseProfileHandler(BaseHandler):
)
except Exception as e:
logger.warn(
- "Failed to update join event for room %s - %s",
- room_id, str(e)
+ "Failed to update join event for room %s - %s", room_id, str(e)
)
@defer.inlineCallbacks
@@ -322,11 +304,9 @@ class BaseProfileHandler(BaseHandler):
return
try:
- requester_rooms = yield self.store.get_rooms_for_user(
- requester.to_string()
- )
+ requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = yield self.store.get_rooms_for_user(
- target_user.to_string(),
+ target_user.to_string()
)
# Check if the room lists have no elements in common.
@@ -350,12 +330,12 @@ class MasterProfileHandler(BaseProfileHandler):
assert hs.config.worker_app is None
self.clock.looping_call(
- self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
def _start_update_remote_profile_cache(self):
return run_as_background_process(
- "Update remote profile", self._update_remote_profile_cache,
+ "Update remote profile", self._update_remote_profile_cache
)
@defer.inlineCallbacks
@@ -369,7 +349,7 @@ class MasterProfileHandler(BaseProfileHandler):
for user_id, displayname, avatar_url in entries:
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
- user_id,
+ user_id
)
if not is_subscribed:
yield self.store.maybe_delete_remote_profile_cache(user_id)
@@ -379,9 +359,7 @@ class MasterProfileHandler(BaseProfileHandler):
profile = yield self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
- args={
- "user_id": user_id,
- },
+ args={"user_id": user_id},
ignore_backoff=True,
)
except Exception:
@@ -396,6 +374,4 @@ class MasterProfileHandler(BaseProfileHandler):
new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp
- yield self.store.update_remote_profile_cache(
- user_id, new_name, new_avatar
- )
+ yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 32108568c6..3e4d8c93a4 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -43,7 +43,7 @@ class ReadMarkerHandler(BaseHandler):
with (yield self.read_marker_linearizer.queue((room_id, user_id))):
existing_read_marker = yield self.store.get_account_data_for_room_and_type(
- user_id, room_id, "m.fully_read",
+ user_id, room_id, "m.fully_read"
)
should_update = True
@@ -51,14 +51,11 @@ class ReadMarkerHandler(BaseHandler):
if existing_read_marker:
# Only update if the new marker is ahead in the stream
should_update = yield self.store.is_event_after(
- event_id,
- existing_read_marker['event_id']
+ event_id, existing_read_marker["event_id"]
)
if should_update:
- content = {
- "event_id": event_id
- }
+ content = {"event_id": event_id}
max_id = yield self.store.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 274d2946ad..a85dd8cdee 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -88,19 +88,16 @@ class ReceiptsHandler(BaseHandler):
affected_room_ids = list(set([r.room_id for r in receipts]))
- self.notifier.on_new_event(
- "receipt_key", max_batch_id, rooms=affected_room_ids
- )
+ self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
yield self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids,
+ min_batch_id, max_batch_id, affected_room_ids
)
defer.returnValue(True)
@defer.inlineCallbacks
- def received_client_receipt(self, room_id, receipt_type, user_id,
- event_id):
+ def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -109,9 +106,7 @@ class ReceiptsHandler(BaseHandler):
receipt_type=receipt_type,
user_id=user_id,
event_ids=[event_id],
- data={
- "ts": int(self.clock.time_msec()),
- },
+ data={"ts": int(self.clock.time_msec())},
)
is_new = yield self._handle_new_receipts([receipt])
@@ -125,8 +120,7 @@ class ReceiptsHandler(BaseHandler):
"""Gets all receipts for a room, upto the given key.
"""
result = yield self.store.get_linearized_receipts_for_room(
- room_id,
- to_key=to_key,
+ room_id, to_key=to_key
)
if not result:
@@ -148,14 +142,12 @@ class ReceiptEventSource(object):
defer.returnValue(([], to_key))
events = yield self.store.get_linearized_receipts_for_rooms(
- room_ids,
- from_key=from_key,
- to_key=to_key,
+ room_ids, from_key=from_key, to_key=to_key
)
defer.returnValue((events, to_key))
- def get_current_key(self, direction='f'):
+ def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
@@ -169,9 +161,7 @@ class ReceiptEventSource(object):
room_ids = yield self.store.get_rooms_for_user(user.to_string())
events = yield self.store.get_linearized_receipts_for_rooms(
- room_ids,
- from_key=from_key,
- to_key=to_key,
+ room_ids, from_key=from_key, to_key=to_key
)
defer.returnValue((events, to_key))
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 9a388ea013..e487b90c08 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -47,7 +47,6 @@ logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
-
def __init__(self, hs):
"""
@@ -69,44 +68,37 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._generate_user_id_linearizer = Linearizer(
- name="_generate_user_id_linearizer",
+ name="_generate_user_id_linearizer"
)
self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
- self._register_device_client = (
- RegisterDeviceReplicationServlet.make_client(hs)
+ self._register_device_client = RegisterDeviceReplicationServlet.make_client(
+ hs
)
- self._post_registration_client = (
- ReplicationPostRegisterActionsServlet.make_client(hs)
+ self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
+ hs
)
else:
self.device_handler = hs.get_device_handler()
self.pusher_pool = hs.get_pusherpool()
@defer.inlineCallbacks
- def check_username(self, localpart, guest_access_token=None,
- assigned_user_id=None):
+ def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
"User ID can only contain characters a-z, 0-9, or '=_-./'",
- Codes.INVALID_USERNAME
+ Codes.INVALID_USERNAME,
)
if not localpart:
- raise SynapseError(
- 400,
- "User ID cannot be empty",
- Codes.INVALID_USERNAME
- )
+ raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME)
- if localpart[0] == '_':
+ if localpart[0] == "_":
raise SynapseError(
- 400,
- "User ID may not begin with _",
- Codes.INVALID_USERNAME
+ 400, "User ID may not begin with _", Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname)
@@ -126,19 +118,15 @@ class RegistrationHandler(BaseHandler):
if len(user_id) > MAX_USERID_LENGTH:
raise SynapseError(
400,
- "User ID may not be longer than %s characters" % (
- MAX_USERID_LENGTH,
- ),
- Codes.INVALID_USERNAME
+ "User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,),
+ Codes.INVALID_USERNAME,
)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
raise SynapseError(
- 400,
- "User ID already taken.",
- errcode=Codes.USER_IN_USE,
+ 400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = yield self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
@@ -203,8 +191,7 @@ class RegistrationHandler(BaseHandler):
try:
int(localpart)
raise RegistrationError(
- 400,
- "Numeric user IDs are reserved for guest users."
+ 400, "Numeric user IDs are reserved for guest users."
)
except ValueError:
pass
@@ -283,9 +270,7 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- yield self._register_email_threepid(
- user_id, threepid_dict, None, False,
- )
+ yield self._register_email_threepid(user_id, threepid_dict, None, False)
defer.returnValue((user_id, token))
@@ -318,8 +303,8 @@ class RegistrationHandler(BaseHandler):
room_alias = RoomAlias.from_string(r)
if self.hs.hostname != room_alias.domain:
logger.warning(
- 'Cannot create room alias %s, '
- 'it does not match server domain',
+ "Cannot create room alias %s, "
+ "it does not match server domain",
r,
)
else:
@@ -332,7 +317,7 @@ class RegistrationHandler(BaseHandler):
fake_requester,
config={
"preset": "public_chat",
- "room_alias_name": room_alias_localpart
+ "room_alias_name": room_alias_localpart,
},
ratelimit=False,
)
@@ -364,8 +349,9 @@ class RegistrationHandler(BaseHandler):
raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id):
raise SynapseError(
- 400, "Invalid user localpart for this application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "Invalid user localpart for this application service.",
+ errcode=Codes.EXCLUSIVE,
)
service_id = service.id if service.is_exclusive_user(user_id) else None
@@ -391,17 +377,15 @@ class RegistrationHandler(BaseHandler):
"""
captcha_response = yield self._validate_captcha(
- ip,
- private_key,
- challenge,
- response
+ ip, private_key, challenge, response
)
if not captcha_response["valid"]:
- logger.info("Invalid captcha entered from %s. Error: %s",
- ip, captcha_response["error_url"])
- raise InvalidCaptchaError(
- error_url=captcha_response["error_url"]
+ logger.info(
+ "Invalid captcha entered from %s. Error: %s",
+ ip,
+ captcha_response["error_url"],
)
+ raise InvalidCaptchaError(error_url=captcha_response["error_url"])
else:
logger.info("Valid captcha entered from %s", ip)
@@ -414,8 +398,11 @@ class RegistrationHandler(BaseHandler):
"""
for c in threepidCreds:
- logger.info("validating threepidcred sid %s on id server %s",
- c['sid'], c['idServer'])
+ logger.info(
+ "validating threepidcred sid %s on id server %s",
+ c["sid"],
+ c["idServer"],
+ )
try:
threepid = yield self.identity_handler.threepid_from_creds(c)
except Exception:
@@ -424,13 +411,14 @@ class RegistrationHandler(BaseHandler):
if not threepid:
raise RegistrationError(400, "Couldn't validate 3pid")
- logger.info("got threepid with medium '%s' and address '%s'",
- threepid['medium'], threepid['address'])
+ logger.info(
+ "got threepid with medium '%s' and address '%s'",
+ threepid["medium"],
+ threepid["address"],
+ )
- if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
- raise RegistrationError(
- 403, "Third party identifier is not allowed"
- )
+ if not check_3pid_allowed(self.hs, threepid["medium"], threepid["address"]):
+ raise RegistrationError(403, "Third party identifier is not allowed")
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
@@ -449,23 +437,23 @@ class RegistrationHandler(BaseHandler):
if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid:
raise SynapseError(
- 400, "This user ID is reserved.",
- errcode=Codes.EXCLUSIVE
+ 400, "This user ID is reserved.", errcode=Codes.EXCLUSIVE
)
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = self.store.get_app_services()
interested_services = [
- s for s in services
- if s.is_interested_in_user(user_id)
- and s != allowed_appservice
+ s
+ for s in services
+ if s.is_interested_in_user(user_id) and s != allowed_appservice
]
for service in interested_services:
if service.is_exclusive_user(user_id):
raise SynapseError(
- 400, "This user ID is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This user ID is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
@defer.inlineCallbacks
@@ -491,14 +479,13 @@ class RegistrationHandler(BaseHandler):
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
"""
- response = yield self._submit_captcha(ip_addr, private_key, challenge,
- response)
+ response = yield self._submit_captcha(ip_addr, private_key, challenge, response)
# parse Google's response. Lovely format..
- lines = response.split('\n')
+ lines = response.split("\n")
json = {
- "valid": lines[0] == 'true',
- "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" +
- "error=%s" % lines[1]
+ "valid": lines[0] == "true",
+ "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?"
+ + "error=%s" % lines[1],
}
defer.returnValue(json)
@@ -510,17 +497,16 @@ class RegistrationHandler(BaseHandler):
data = yield self.captcha_client.post_urlencoded_get_raw(
"http://www.recaptcha.net:80/recaptcha/api/verify",
args={
- 'privatekey': private_key,
- 'remoteip': ip_addr,
- 'challenge': challenge,
- 'response': response
- }
+ "privatekey": private_key,
+ "remoteip": ip_addr,
+ "challenge": challenge,
+ "response": response,
+ },
)
defer.returnValue(data)
@defer.inlineCallbacks
- def get_or_create_user(self, requester, localpart, displayname,
- password_hash=None):
+ def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -565,7 +551,7 @@ class RegistrationHandler(BaseHandler):
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.profile_handler.set_displayname(
- user, requester, displayname, by_admin=True,
+ user, requester, displayname, by_admin=True
)
defer.returnValue((user_id, token))
@@ -587,15 +573,12 @@ class RegistrationHandler(BaseHandler):
"""
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token:
- user_info = yield self.auth.get_user_by_access_token(
- access_token
- )
+ user_info = yield self.auth.get_user_by_access_token(access_token)
defer.returnValue((user_info["user"].to_string(), access_token))
user_id, access_token = yield self.register(
- generate_token=True,
- make_guest=True
+ generate_token=True, make_guest=True
)
access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id
@@ -616,9 +599,9 @@ class RegistrationHandler(BaseHandler):
)
room_id = room_id.to_string()
else:
- raise SynapseError(400, "%s was not legal room ID or room alias" % (
- room_identifier,
- ))
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
yield room_member_handler.update_membership(
requester=requester,
@@ -629,10 +612,19 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
- def register_with_store(self, user_id, token=None, password_hash=None,
- was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_displayname=None, admin=False,
- user_type=None, address=None):
+ def register_with_store(
+ self,
+ user_id,
+ token=None,
+ password_hash=None,
+ was_guest=False,
+ make_guest=False,
+ appservice_id=None,
+ create_profile_with_displayname=None,
+ admin=False,
+ user_type=None,
+ address=None,
+ ):
"""Register user in the datastore.
Args:
@@ -661,14 +653,15 @@ class RegistrationHandler(BaseHandler):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
- address, time_now_s=time_now,
+ address,
+ time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
)
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now))
)
if self.hs.config.worker_app:
@@ -698,8 +691,7 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def register_device(self, user_id, device_id, initial_display_name,
- is_guest=False):
+ def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
"""Register a device for a user and generate an access token.
Args:
@@ -732,14 +724,15 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id,
+ user_id, device_id=device_id
)
defer.returnValue((device_id, access_token))
@defer.inlineCallbacks
- def post_registration_actions(self, user_id, auth_result, access_token,
- bind_email, bind_msisdn):
+ def post_registration_actions(
+ self, user_id, auth_result, access_token, bind_email, bind_msisdn
+ ):
"""A user has completed registration
Args:
@@ -773,20 +766,15 @@ class RegistrationHandler(BaseHandler):
yield self.store.upsert_monthly_active_user(user_id)
yield self._register_email_threepid(
- user_id, threepid, access_token,
- bind_email,
+ user_id, threepid, access_token, bind_email
)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
- yield self._register_msisdn_threepid(
- user_id, threepid, bind_msisdn,
- )
+ yield self._register_msisdn_threepid(user_id, threepid, bind_msisdn)
if auth_result and LoginType.TERMS in auth_result:
- yield self._on_user_consented(
- user_id, self.hs.config.user_consent_version,
- )
+ yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
@defer.inlineCallbacks
def _on_user_consented(self, user_id, consent_version):
@@ -798,9 +786,7 @@ class RegistrationHandler(BaseHandler):
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
- yield self.store.user_set_consent_version(
- user_id, consent_version,
- )
+ yield self.store.user_set_consent_version(user_id, consent_version)
yield self.post_consent_actions(user_id)
@defer.inlineCallbacks
@@ -824,33 +810,30 @@ class RegistrationHandler(BaseHandler):
Returns:
defer.Deferred:
"""
- reqd = ('medium', 'address', 'validated_at')
+ reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
return
yield self._auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
- if (self.hs.config.email_enable_notifs and
- self.hs.config.email_notif_for_new_users
- and token):
+ if (
+ self.hs.config.email_enable_notifs
+ and self.hs.config.email_notif_for_new_users
+ and token
+ ):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
- user_tuple = yield self.store.get_user_by_access_token(
- token
- )
+ user_tuple = yield self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"]
yield self.pusher_pool.add_pusher(
@@ -867,11 +850,9 @@ class RegistrationHandler(BaseHandler):
if bind_email:
logger.info("bind_email specified: binding")
- logger.debug("Binding emails %s to %s" % (
- threepid, user_id
- ))
+ logger.debug("Binding emails %s to %s" % (threepid, user_id))
yield self.identity_handler.bind_threepid(
- threepid['threepid_creds'], user_id
+ threepid["threepid_creds"], user_id
)
else:
logger.info("bind_email not specified: not binding email")
@@ -894,7 +875,7 @@ class RegistrationHandler(BaseHandler):
defer.Deferred:
"""
try:
- assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
+ assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
except SynapseError as ex:
if ex.errcode == Codes.MISSING_PARAM:
# This will only happen if the ID server returns a malformed response
@@ -903,17 +884,14 @@ class RegistrationHandler(BaseHandler):
raise
yield self._auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
if bind_msisdn:
logger.info("bind_msisdn specified: binding")
logger.debug("Binding msisdn %s to %s", threepid, user_id)
yield self.identity_handler.bind_threepid(
- threepid['threepid_creds'], user_id
+ threepid["threepid_creds"], user_id
)
else:
logger.info("bind_msisdn not specified: not binding msisdn")
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 4a17911a87..db3f8cb76b 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -32,6 +32,7 @@ from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -40,6 +41,8 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://"
+FIVE_MINUTES_IN_MS = 5 * 60 * 1000
+
class RoomCreationHandler(BaseHandler):
@@ -75,6 +78,16 @@ class RoomCreationHandler(BaseHandler):
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
+ # If a user tries to update the same room multiple times in quick
+ # succession, only process the first attempt and return its result to
+ # subsequent requests
+ self._upgrade_response_cache = ResponseCache(
+ hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
+ )
+ self._server_notices_mxid = hs.config.server_notices_mxid
+
+ self.third_party_event_rules = hs.get_third_party_event_rules()
+
@defer.inlineCallbacks
def upgrade_room(self, requester, old_room_id, new_version):
"""Replace a room with a new room with a different version
@@ -91,70 +104,100 @@ class RoomCreationHandler(BaseHandler):
user_id = requester.user.to_string()
- with (yield self._upgrade_linearizer.queue(old_room_id)):
- # start by allocating a new room id
- r = yield self.store.get_room(old_room_id)
- if r is None:
- raise NotFoundError("Unknown room id %s" % (old_room_id,))
- new_room_id = yield self._generate_room_id(
- creator_id=user_id, is_public=r["is_public"],
- )
+ # Check if this room is already being upgraded by another person
+ for key in self._upgrade_response_cache.pending_result_cache:
+ if key[0] == old_room_id and key[1] != user_id:
+ # Two different people are trying to upgrade the same room.
+ # Send the second an error.
+ #
+ # Note that this of course only gets caught if both users are
+ # on the same homeserver.
+ raise SynapseError(
+ 400, "An upgrade for this room is currently in progress"
+ )
- logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
+ # Upgrade the room
+ #
+ # If this user has sent multiple upgrade requests for the same room
+ # and one of them is not complete yet, cache the response and
+ # return it to all subsequent requests
+ ret = yield self._upgrade_response_cache.wrap(
+ (old_room_id, user_id),
+ self._upgrade_room,
+ requester,
+ old_room_id,
+ new_version, # args for _upgrade_room
+ )
+ defer.returnValue(ret)
- # we create and auth the tombstone event before properly creating the new
- # room, to check our user has perms in the old room.
- tombstone_event, tombstone_context = (
- yield self.event_creation_handler.create_event(
- requester, {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- }
- },
- token_id=requester.access_token_id,
- )
- )
- old_room_version = yield self.store.get_room_version(old_room_id)
- yield self.auth.check_from_context(
- old_room_version, tombstone_event, tombstone_context,
- )
+ @defer.inlineCallbacks
+ def _upgrade_room(self, requester, old_room_id, new_version):
+ user_id = requester.user.to_string()
- yield self.clone_existing_room(
+ # start by allocating a new room id
+ r = yield self.store.get_room(old_room_id)
+ if r is None:
+ raise NotFoundError("Unknown room id %s" % (old_room_id,))
+ new_room_id = yield self._generate_room_id(
+ creator_id=user_id, is_public=r["is_public"]
+ )
+
+ logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
+
+ # we create and auth the tombstone event before properly creating the new
+ # room, to check our user has perms in the old room.
+ tombstone_event, tombstone_context = (
+ yield self.event_creation_handler.create_event(
requester,
- old_room_id=old_room_id,
- new_room_id=new_room_id,
- new_room_version=new_version,
- tombstone_event_id=tombstone_event.event_id,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
+ },
+ },
+ token_id=requester.access_token_id,
)
+ )
+ old_room_version = yield self.store.get_room_version(old_room_id)
+ yield self.auth.check_from_context(
+ old_room_version, tombstone_event, tombstone_context
+ )
- # now send the tombstone
- yield self.event_creation_handler.send_nonmember_event(
- requester, tombstone_event, tombstone_context,
- )
+ yield self.clone_existing_room(
+ requester,
+ old_room_id=old_room_id,
+ new_room_id=new_room_id,
+ new_room_version=new_version,
+ tombstone_event_id=tombstone_event.event_id,
+ )
- old_room_state = yield tombstone_context.get_current_state_ids(self.store)
+ # now send the tombstone
+ yield self.event_creation_handler.send_nonmember_event(
+ requester, tombstone_event, tombstone_context
+ )
- # update any aliases
- yield self._move_aliases_to_new_room(
- requester, old_room_id, new_room_id, old_room_state,
- )
+ old_room_state = yield tombstone_context.get_current_state_ids(self.store)
- # and finally, shut down the PLs in the old room, and update them in the new
- # room.
- yield self._update_upgraded_room_pls(
- requester, old_room_id, new_room_id, old_room_state,
- )
+ # update any aliases
+ yield self._move_aliases_to_new_room(
+ requester, old_room_id, new_room_id, old_room_state
+ )
- defer.returnValue(new_room_id)
+ # and finally, shut down the PLs in the old room, and update them in the new
+ # room.
+ yield self._update_upgraded_room_pls(
+ requester, old_room_id, new_room_id, old_room_state
+ )
+
+ defer.returnValue(new_room_id)
@defer.inlineCallbacks
def _update_upgraded_room_pls(
- self, requester, old_room_id, new_room_id, old_room_state,
+ self, requester, old_room_id, new_room_id, old_room_state
):
"""Send updated power levels in both rooms after an upgrade
@@ -172,7 +215,7 @@ class RoomCreationHandler(BaseHandler):
if old_room_pl_event_id is None:
logger.warning(
"Not supported: upgrading a room with no PL event. Not setting PLs "
- "in old room.",
+ "in old room."
)
return
@@ -193,45 +236,48 @@ class RoomCreationHandler(BaseHandler):
if current < restricted_level:
logger.info(
"Setting level for %s in %s to %i (was %i)",
- v, old_room_id, restricted_level, current,
+ v,
+ old_room_id,
+ restricted_level,
+ current,
)
pl_content[v] = restricted_level
updated = True
else:
- logger.info(
- "Not setting level for %s (already %i)",
- v, current,
- )
+ logger.info("Not setting level for %s (already %i)", v, current)
if updated:
try:
yield self.event_creation_handler.create_and_send_nonmember_event(
- requester, {
+ requester,
+ {
"type": EventTypes.PowerLevels,
- "state_key": '',
+ "state_key": "",
"room_id": old_room_id,
"sender": requester.user.to_string(),
"content": pl_content,
- }, ratelimit=False,
+ },
+ ratelimit=False,
)
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
logger.info("Setting correct PLs in new room")
yield self.event_creation_handler.create_and_send_nonmember_event(
- requester, {
+ requester,
+ {
"type": EventTypes.PowerLevels,
- "state_key": '',
+ "state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": old_room_pl_state.content,
- }, ratelimit=False,
+ },
+ ratelimit=False,
)
@defer.inlineCallbacks
def clone_existing_room(
- self, requester, old_room_id, new_room_id, new_room_version,
- tombstone_event_id,
+ self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id
):
"""Populate a new room based on an old room
@@ -253,10 +299,7 @@ class RoomCreationHandler(BaseHandler):
creation_content = {
"room_version": new_room_version,
- "predecessor": {
- "room_id": old_room_id,
- "event_id": tombstone_event_id,
- }
+ "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
}
# Check if old room was non-federatable
@@ -285,7 +328,7 @@ class RoomCreationHandler(BaseHandler):
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types(types_to_copy),
+ old_room_id, StateFilter.from_types(types_to_copy)
)
# map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
@@ -298,11 +341,9 @@ class RoomCreationHandler(BaseHandler):
yield self._send_events_for_new_room(
requester,
new_room_id,
-
# we expect to override all the presets with initial_state, so this is
# somewhat arbitrary.
preset_config=RoomCreationPreset.PRIVATE_CHAT,
-
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
@@ -310,20 +351,22 @@ class RoomCreationHandler(BaseHandler):
# Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types([(EventTypes.Member, None)]),
+ old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
# map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events(
- old_room_member_state_ids.values(),
+ old_room_member_state_ids.values()
)
for k, old_event in iteritems(old_room_member_state_events):
# Only transfer ban events
- if ("membership" in old_event.content and
- old_event.content["membership"] == "ban"):
+ if (
+ "membership" in old_event.content
+ and old_event.content["membership"] == "ban"
+ ):
yield self.room_member_handler.update_membership(
requester,
- UserID.from_string(old_event['state_key']),
+ UserID.from_string(old_event["state_key"]),
new_room_id,
"ban",
ratelimit=False,
@@ -335,7 +378,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def _move_aliases_to_new_room(
- self, requester, old_room_id, new_room_id, old_room_state,
+ self, requester, old_room_id, new_room_id, old_room_state
):
directory_handler = self.hs.get_handlers().directory_handler
@@ -366,14 +409,11 @@ class RoomCreationHandler(BaseHandler):
alias = RoomAlias.from_string(alias_str)
try:
yield directory_handler.delete_association(
- requester, alias, send_event=False,
+ requester, alias, send_event=False
)
removed_aliases.append(alias_str)
except SynapseError as e:
- logger.warning(
- "Unable to remove alias %s from old room: %s",
- alias, e,
- )
+ logger.warning("Unable to remove alias %s from old room: %s", alias, e)
# if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
# of this.
@@ -389,30 +429,26 @@ class RoomCreationHandler(BaseHandler):
# as when you remove an alias from the directory normally - it just means that
# the aliases event gets out of sync with the directory
# (cf https://github.com/vector-im/riot-web/issues/2369)
- yield directory_handler.send_room_alias_update_event(
- requester, old_room_id,
- )
+ yield directory_handler.send_room_alias_update_event(requester, old_room_id)
except AuthError as e:
- logger.warning(
- "Failed to send updated alias event on old room: %s", e,
- )
+ logger.warning("Failed to send updated alias event on old room: %s", e)
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
yield directory_handler.create_association(
- requester, RoomAlias.from_string(alias),
- new_room_id, servers=(self.hs.hostname, ),
- send_event=False, check_membership=False,
+ requester,
+ RoomAlias.from_string(alias),
+ new_room_id,
+ servers=(self.hs.hostname,),
+ send_event=False,
+ check_membership=False,
)
logger.info("Moved alias %s to new room", alias)
except SynapseError as e:
# I'm not really expecting this to happen, but it could if the spam
# checking module decides it shouldn't, or similar.
- logger.error(
- "Error adding alias %s to new room: %s",
- alias, e,
- )
+ logger.error("Error adding alias %s to new room: %s", alias, e)
try:
if canonical_alias and (canonical_alias in removed_aliases):
@@ -423,24 +459,19 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
- "content": {"alias": canonical_alias, },
+ "content": {"alias": canonical_alias},
},
- ratelimit=False
+ ratelimit=False,
)
- yield directory_handler.send_room_alias_update_event(
- requester, new_room_id,
- )
+ yield directory_handler.send_room_alias_update_event(requester, new_room_id)
except SynapseError as e:
# again I'm not really expecting this to fail, but if it does, I'd rather
# we returned the new room to the client at this point.
- logger.error(
- "Unable to send updated alias events in new room: %s", e,
- )
+ logger.error("Unable to send updated alias events in new room: %s", e)
@defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True,
- creator_join_profile=None):
+ def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
""" Creates a new room.
Args:
@@ -470,23 +501,35 @@ class RoomCreationHandler(BaseHandler):
yield self.auth.check_auth_blocking(user_id)
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
+
+ # Check whether the third party rules allows/changes the room create
+ # request.
+ yield self.third_party_event_rules.on_create_room(
+ requester, config, is_requester_admin=is_requester_admin
+ )
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
yield self.ratelimit(requester)
room_version = config.get(
- "room_version",
- self.config.default_room_version.identifier,
+ "room_version", self.config.default_room_version.identifier
)
if not isinstance(room_version, string_types):
- raise SynapseError(
- 400,
- "room_version must be a string",
- Codes.BAD_JSON,
- )
+ raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON)
if room_version not in KNOWN_ROOM_VERSIONS:
raise SynapseError(
@@ -500,20 +543,11 @@ class RoomCreationHandler(BaseHandler):
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
- room_alias = RoomAlias(
- config["room_alias_name"],
- self.hs.hostname,
- )
- mapping = yield self.store.get_association_from_room_alias(
- room_alias
- )
+ room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
+ mapping = yield self.store.get_association_from_room_alias(room_alias)
if mapping:
- raise SynapseError(
- 400,
- "Room alias already taken",
- Codes.ROOM_IN_USE
- )
+ raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
else:
room_alias = None
@@ -524,9 +558,7 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
- yield self.event_creation_handler.assert_accepted_privacy_policy(
- requester,
- )
+ yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
invite_3pid_list = config.get("invite_3pid", [])
@@ -550,7 +582,7 @@ class RoomCreationHandler(BaseHandler):
"preset",
RoomCreationPreset.PRIVATE_CHAT
if visibility == "private"
- else RoomCreationPreset.PUBLIC_CHAT
+ else RoomCreationPreset.PUBLIC_CHAT,
)
raw_initial_state = config.get("initial_state", [])
@@ -587,7 +619,8 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"content": {"name": name},
},
- ratelimit=False)
+ ratelimit=False,
+ )
if "topic" in config:
topic = config["topic"]
@@ -600,7 +633,8 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"content": {"topic": topic},
},
- ratelimit=False)
+ ratelimit=False,
+ )
for invitee in invite_list:
content = {}
@@ -635,30 +669,25 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
- yield directory_handler.send_room_alias_update_event(
- requester, room_id
- )
+ yield directory_handler.send_room_alias_update_event(requester, room_id)
defer.returnValue(result)
@defer.inlineCallbacks
def _send_events_for_new_room(
- self,
- creator, # A Requester object.
- room_id,
- preset_config,
- invite_list,
- initial_state,
- creation_content,
- room_alias=None,
- power_level_content_override=None,
- creator_join_profile=None,
+ self,
+ creator, # A Requester object.
+ room_id,
+ preset_config,
+ invite_list,
+ initial_state,
+ creation_content,
+ room_alias=None,
+ power_level_content_override=None,
+ creator_join_profile=None,
):
def create(etype, content, **kwargs):
- e = {
- "type": etype,
- "content": content,
- }
+ e = {"type": etype, "content": content}
e.update(event_keys)
e.update(kwargs)
@@ -670,26 +699,17 @@ class RoomCreationHandler(BaseHandler):
event = create(etype, content, **kwargs)
logger.info("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event(
- creator,
- event,
- ratelimit=False
+ creator, event, ratelimit=False
)
config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.user.to_string()
- event_keys = {
- "room_id": room_id,
- "sender": creator_id,
- "state_key": "",
- }
+ event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
- yield send(
- etype=EventTypes.Create,
- content=creation_content,
- )
+ yield send(etype=EventTypes.Create, content=creation_content)
logger.info("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership(
@@ -703,17 +723,12 @@ class RoomCreationHandler(BaseHandler):
# We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room.
- pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
+ pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
- yield send(
- etype=EventTypes.PowerLevels,
- content=pl_content,
- )
+ yield send(etype=EventTypes.PowerLevels, content=pl_content)
else:
power_level_content = {
- "users": {
- creator_id: 100,
- },
+ "users": {creator_id: 100},
"users_default": 0,
"events": {
EventTypes.Name: 50,
@@ -737,42 +752,33 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
- yield send(
- etype=EventTypes.PowerLevels,
- content=power_level_content,
- )
+ yield send(etype=EventTypes.PowerLevels, content=power_level_content)
- if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
+ if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
yield send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
- if (EventTypes.JoinRules, '') not in initial_state:
+ if (EventTypes.JoinRules, "") not in initial_state:
yield send(
- etype=EventTypes.JoinRules,
- content={"join_rule": config["join_rules"]},
+ etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
- if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
+ if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
yield send(
etype=EventTypes.RoomHistoryVisibility,
- content={"history_visibility": config["history_visibility"]}
+ content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
- if (EventTypes.GuestAccess, '') not in initial_state:
+ if (EventTypes.GuestAccess, "") not in initial_state:
yield send(
- etype=EventTypes.GuestAccess,
- content={"guest_access": "can_join"}
+ etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
- yield send(
- etype=etype,
- state_key=state_key,
- content=content,
- )
+ yield send(etype=etype, state_key=state_key, content=content)
@defer.inlineCallbacks
def _generate_room_id(self, creator_id, is_public):
@@ -782,12 +788,9 @@ class RoomCreationHandler(BaseHandler):
while attempts < 5:
try:
random_string = stringutils.random_string(18)
- gen_room_id = RoomID(
- random_string,
- self.hs.hostname,
- ).to_string()
+ gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
- gen_room_id = gen_room_id.decode('utf-8')
+ gen_room_id = gen_room_id.decode("utf-8")
yield self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
@@ -821,7 +824,7 @@ class RoomContextHandler(object):
Returns:
dict, or None if the event isn't found
"""
- before_limit = math.floor(limit / 2.)
+ before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
users = yield self.store.get_users_in_room(room_id)
@@ -829,24 +832,19 @@ class RoomContextHandler(object):
def filter_evts(events):
return filter_events_for_client(
- self.store,
- user.to_string(),
- events,
- is_peeking=is_peeking
+ self.store, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(event_id, get_prev_content=True,
- allow_none=True)
+ event = yield self.store.get_event(
+ event_id, get_prev_content=True, allow_none=True
+ )
if not event:
defer.returnValue(None)
return
- filtered = yield(filter_evts([event]))
+ filtered = yield (filter_evts([event]))
if not filtered:
- raise AuthError(
- 403,
- "You don't have permission to access that event."
- )
+ raise AuthError(403, "You don't have permission to access that event.")
results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
@@ -878,7 +876,7 @@ class RoomContextHandler(object):
# https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events(
- [last_event_id], state_filter=state_filter,
+ [last_event_id], state_filter=state_filter
)
results["state"] = list(state[last_event_id].values())
@@ -890,9 +888,7 @@ class RoomContextHandler(object):
"room_key", results["start"]
).to_string()
- results["end"] = token.copy_and_replace(
- "room_key", results["end"]
- ).to_string()
+ results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
defer.returnValue(results)
@@ -903,13 +899,7 @@ class RoomEventSource(object):
@defer.inlineCallbacks
def get_new_events(
- self,
- user,
- from_key,
- limit,
- room_ids,
- is_guest,
- explicit_room_id=None,
+ self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
@@ -920,9 +910,7 @@ class RoomEventSource(object):
logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
- app_service = self.store.get_app_service_by_user_id(
- user.to_string()
- )
+ app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
@@ -937,7 +925,7 @@ class RoomEventSource(object):
from_key=from_key,
to_key=to_key,
limit=limit or 10,
- order='ASC',
+ order="ASC",
)
events = list(room_events)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 617d1c9ef8..aae696a7e8 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -46,13 +46,18 @@ class RoomListHandler(BaseHandler):
super(RoomListHandler, self).__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(hs, "room_list")
- self.remote_response_cache = ResponseCache(hs, "remote_room_list",
- timeout_ms=30 * 1000)
+ self.remote_response_cache = ResponseCache(
+ hs, "remote_room_list", timeout_ms=30 * 1000
+ )
- def get_local_public_room_list(self, limit=None, since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False):
+ def get_local_public_room_list(
+ self,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ network_tuple=EMPTY_THIRD_PARTY_ID,
+ from_federation=False,
+ ):
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
@@ -68,14 +73,14 @@ class RoomListHandler(BaseHandler):
Setting to None returns all public rooms across all lists.
"""
if not self.enable_room_list_search:
- return defer.succeed({
- "chunk": [],
- "total_room_count_estimate": 0,
- })
+ return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
logger.info(
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
- limit, since_token, bool(search_filter), network_tuple,
+ limit,
+ since_token,
+ bool(search_filter),
+ network_tuple,
)
if search_filter:
@@ -88,24 +93,33 @@ class RoomListHandler(BaseHandler):
# solution at some point
timeout = self.clock.time() + 60
return self._get_public_room_list(
- limit, since_token, search_filter,
- network_tuple=network_tuple, timeout=timeout,
+ limit,
+ since_token,
+ search_filter,
+ network_tuple=network_tuple,
+ timeout=timeout,
)
key = (limit, since_token, network_tuple)
return self.response_cache.wrap(
key,
self._get_public_room_list,
- limit, since_token,
- network_tuple=network_tuple, from_federation=from_federation,
+ limit,
+ since_token,
+ network_tuple=network_tuple,
+ from_federation=from_federation,
)
@defer.inlineCallbacks
- def _get_public_room_list(self, limit=None, since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False,
- timeout=None,):
+ def _get_public_room_list(
+ self,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ network_tuple=EMPTY_THIRD_PARTY_ID,
+ from_federation=False,
+ timeout=None,
+ ):
"""Generate a public room list.
Args:
limit (int|None): Maximum amount of rooms to return.
@@ -135,15 +149,14 @@ class RoomListHandler(BaseHandler):
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
- public_room_stream_id, current_public_id,
- network_tuple=network_tuple,
+ public_room_stream_id, current_public_id, network_tuple=network_tuple
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
- public_room_stream_id, network_tuple=network_tuple,
+ public_room_stream_id, network_tuple=network_tuple
)
# We want to return rooms in a particular order: the number of joined
@@ -168,7 +181,7 @@ class RoomListHandler(BaseHandler):
return
joined_users = yield self.state_handler.get_current_users_in_room(
- room_id, latest_event_ids,
+ room_id, latest_event_ids
)
num_joined_users = len(joined_users)
@@ -180,8 +193,9 @@ class RoomListHandler(BaseHandler):
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
- logger.info("Getting ordering for %i rooms since %s",
- len(room_ids), stream_token)
+ logger.info(
+ "Getting ordering for %i rooms since %s", len(room_ids), stream_token
+ )
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
@@ -193,7 +207,8 @@ class RoomListHandler(BaseHandler):
# Filter out rooms that we don't want to return
rooms_to_scan = [
- r for r in sorted_rooms
+ r
+ for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[r] > 0
]
@@ -204,13 +219,12 @@ class RoomListHandler(BaseHandler):
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it.
if since_token.direction_is_forward:
- rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:]
+ rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :]
else:
- rooms_to_scan = rooms_to_scan[:since_token.current_limit]
+ rooms_to_scan = rooms_to_scan[: since_token.current_limit]
rooms_to_scan.reverse()
- logger.info("After sorting and filtering, %i rooms remain",
- len(rooms_to_scan))
+ logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan))
# _append_room_entry_to_chunk will append to chunk but will stop if
# len(chunk) > limit
@@ -237,15 +251,19 @@ class RoomListHandler(BaseHandler):
if timeout and self.clock.time() > timeout:
raise Exception("Timed out searching room directory")
- batch = rooms_to_scan[i:i + step]
+ batch = rooms_to_scan[i : i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
- r, rooms_to_num_joined[r],
- chunk, limit, search_filter,
+ r,
+ rooms_to_num_joined[r],
+ chunk,
+ limit,
+ search_filter,
from_federation=from_federation,
),
- batch, 5,
+ batch,
+ 5,
)
logger.info("Now %i rooms in result", len(chunk))
if len(chunk) >= limit + 1:
@@ -273,10 +291,7 @@ class RoomListHandler(BaseHandler):
new_limit = sorted_rooms.index(last_room_id)
- results = {
- "chunk": chunk,
- "total_room_count_estimate": total_room_count,
- }
+ results = {"chunk": chunk, "total_room_count_estimate": total_room_count}
if since_token:
results["new_rooms"] = bool(newly_visible)
@@ -313,8 +328,15 @@ class RoomListHandler(BaseHandler):
defer.returnValue(results)
@defer.inlineCallbacks
- def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
- search_filter, from_federation=False):
+ def _append_room_entry_to_chunk(
+ self,
+ room_id,
+ num_joined_users,
+ chunk,
+ limit,
+ search_filter,
+ from_federation=False,
+ ):
"""Generate the entry for a room in the public room list and append it
to the `chunk` if it matches the search filter
@@ -345,8 +367,14 @@ class RoomListHandler(BaseHandler):
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True)
- def generate_room_entry(self, room_id, num_joined_users, cache_context,
- with_alias=True, allow_private=False):
+ def generate_room_entry(
+ self,
+ room_id,
+ num_joined_users,
+ cache_context,
+ with_alias=True,
+ allow_private=False,
+ ):
"""Returns the entry for a room
Args:
@@ -360,33 +388,31 @@ class RoomListHandler(BaseHandler):
Deferred[dict|None]: Returns a room entry as a dictionary, or None if this
room was determined not to be shown publicly.
"""
- result = {
- "room_id": room_id,
- "num_joined_members": num_joined_users,
- }
+ result = {"room_id": room_id, "num_joined_members": num_joined_users}
current_state_ids = yield self.store.get_current_state_ids(
- room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate
)
- event_map = yield self.store.get_events([
- event_id for key, event_id in iteritems(current_state_ids)
- if key[0] in (
- EventTypes.Create,
- EventTypes.JoinRules,
- EventTypes.Name,
- EventTypes.Topic,
- EventTypes.CanonicalAlias,
- EventTypes.RoomHistoryVisibility,
- EventTypes.GuestAccess,
- "m.room.avatar",
- )
- ])
+ event_map = yield self.store.get_events(
+ [
+ event_id
+ for key, event_id in iteritems(current_state_ids)
+ if key[0]
+ in (
+ EventTypes.Create,
+ EventTypes.JoinRules,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.GuestAccess,
+ "m.room.avatar",
+ )
+ ]
+ )
- current_state = {
- (ev.type, ev.state_key): ev
- for ev in event_map.values()
- }
+ current_state = {(ev.type, ev.state_key): ev for ev in event_map.values()}
# Double check that this is actually a public room.
@@ -446,14 +472,17 @@ class RoomListHandler(BaseHandler):
defer.returnValue(result)
@defer.inlineCallbacks
- def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None,):
+ def get_remote_public_room_list(
+ self,
+ server_name,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
if not self.enable_room_list_search:
- defer.returnValue({
- "chunk": [],
- "total_room_count_estimate": 0,
- })
+ defer.returnValue({"chunk": [], "total_room_count_estimate": 0})
if search_filter:
# We currently don't support searching across federation, so we have
@@ -462,52 +491,75 @@ class RoomListHandler(BaseHandler):
since_token = None
res = yield self._get_remote_list_cached(
- server_name, limit=limit, since_token=since_token,
+ server_name,
+ limit=limit,
+ since_token=since_token,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
if search_filter:
- res = {"chunk": [
- entry
- for entry in list(res.get("chunk", []))
- if _matches_room_entry(entry, search_filter)
- ]}
+ res = {
+ "chunk": [
+ entry
+ for entry in list(res.get("chunk", []))
+ if _matches_room_entry(entry, search_filter)
+ ]
+ }
defer.returnValue(res)
- def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None,):
+ def _get_remote_list_cached(
+ self,
+ server_name,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
- server_name, limit=limit, since_token=since_token,
- search_filter=search_filter, include_all_networks=include_all_networks,
+ server_name,
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
key = (
- server_name, limit, since_token, include_all_networks,
+ server_name,
+ limit,
+ since_token,
+ include_all_networks,
third_party_instance_id,
)
return self.remote_response_cache.wrap(
key,
repl_layer.get_public_rooms,
- server_name, limit=limit, since_token=since_token,
+ server_name,
+ limit=limit,
+ since_token=since_token,
search_filter=search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
-class RoomListNextBatch(namedtuple("RoomListNextBatch", (
- "stream_ordering", # stream_ordering of the first public room list
- "public_room_stream_id", # public room stream id for first public room list
- "current_limit", # The number of previous rooms returned
- "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
-))):
+class RoomListNextBatch(
+ namedtuple(
+ "RoomListNextBatch",
+ (
+ "stream_ordering", # stream_ordering of the first public room list
+ "public_room_stream_id", # public room stream id for first public room list
+ "current_limit", # The number of previous rooms returned
+ "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
+ ),
+ )
+):
KEY_DICT = {
"stream_ordering": "s",
@@ -527,21 +579,19 @@ class RoomListNextBatch(namedtuple("RoomListNextBatch", (
decoded = msgpack.loads(decode_base64(token), raw=False)
else:
decoded = msgpack.loads(decode_base64(token))
- return RoomListNextBatch(**{
- cls.REVERSE_KEY_DICT[key]: val
- for key, val in decoded.items()
- })
+ return RoomListNextBatch(
+ **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
+ )
def to_token(self):
- return encode_base64(msgpack.dumps({
- self.KEY_DICT[key]: val
- for key, val in self._asdict().items()
- }))
+ return encode_base64(
+ msgpack.dumps(
+ {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
+ )
+ )
def copy_and_replace(self, **kwds):
- return self._replace(
- **kwds
- )
+ return self._replace(**kwds)
def _matches_room_entry(room_entry, search_filter):
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 93ac986c86..4d6e883802 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -72,6 +72,7 @@ class RoomMemberHandler(object):
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
+ self.third_party_event_rules = hs.get_third_party_event_rules()
self._server_notices_mxid = self.config.server_notices_mxid
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
@@ -165,7 +166,11 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _local_membership_update(
- self, requester, target, room_id, membership,
+ self,
+ requester,
+ target,
+ room_id,
+ membership,
prev_events_and_hashes,
txn_id=None,
ratelimit=True,
@@ -189,7 +194,6 @@ class RoomMemberHandler(object):
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": user_id,
-
# For backwards compatibility:
"membership": membership,
},
@@ -201,26 +205,19 @@ class RoomMemberHandler(object):
# Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_handler.deduplicate_state_event(
- event, context,
+ event, context
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
yield self.event_creation_handler.handle_new_client_event(
- requester,
- event,
- context,
- extra_users=[target],
- ratelimit=ratelimit,
+ requester, event, context, extra_users=[target], ratelimit=ratelimit
)
prev_state_ids = yield context.get_prev_state_ids(self.store)
- prev_member_event_id = prev_state_ids.get(
- (EventTypes.Member, user_id),
- None
- )
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has actually joined the
@@ -242,11 +239,11 @@ class RoomMemberHandler(object):
if predecessor:
# It is an upgraded room. Copy over old tags
self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], room_id, user_id,
+ predecessor["room_id"], room_id, user_id
)
# Move over old push rules
self.store.move_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], room_id, user_id,
+ predecessor["room_id"], room_id, user_id
)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -257,12 +254,7 @@ class RoomMemberHandler(object):
defer.returnValue(event)
@defer.inlineCallbacks
- def copy_room_tags_and_direct_to_room(
- self,
- old_room_id,
- new_room_id,
- user_id,
- ):
+ def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id):
"""Copies the tags and direct room state from one room to another.
Args:
@@ -274,9 +266,7 @@ class RoomMemberHandler(object):
Deferred[None]
"""
# Retrieve user account data for predecessor room
- user_account_data, _ = yield self.store.get_account_data_for_user(
- user_id,
- )
+ user_account_data, _ = yield self.store.get_account_data_for_user(user_id)
# Copy direct message state if applicable
direct_rooms = user_account_data.get("m.direct", {})
@@ -290,34 +280,30 @@ class RoomMemberHandler(object):
# Save back to user's m.direct account data
yield self.store.add_account_data_for_user(
- user_id, "m.direct", direct_rooms,
+ user_id, "m.direct", direct_rooms
)
break
# Copy room tags if applicable
- room_tags = yield self.store.get_tags_for_room(
- user_id, old_room_id,
- )
+ room_tags = yield self.store.get_tags_for_room(user_id, old_room_id)
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
- yield self.store.add_tag_to_room(
- user_id, new_room_id, tag, tag_content
- )
+ yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
@defer.inlineCallbacks
def update_membership(
- self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- require_consent=True,
+ self,
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=None,
+ remote_room_hosts=None,
+ third_party_signed=None,
+ ratelimit=True,
+ content=None,
+ require_consent=True,
):
key = (room_id,)
@@ -339,17 +325,17 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _update_membership(
- self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- require_consent=True,
+ self,
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=None,
+ remote_room_hosts=None,
+ third_party_signed=None,
+ ratelimit=True,
+ content=None,
+ require_consent=True,
):
content_specified = bool(content)
if content is None:
@@ -383,7 +369,7 @@ class RoomMemberHandler(object):
if not remote_room_hosts:
remote_room_hosts = []
- if effective_membership_state not in ("leave", "ban",):
+ if effective_membership_state not in ("leave", "ban"):
is_blocked = yield self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@@ -391,22 +377,19 @@ class RoomMemberHandler(object):
if effective_membership_state == Membership.INVITE:
# block any attempts to invite the server notices mxid
if target.to_string() == self._server_notices_mxid:
- raise SynapseError(
- http_client.FORBIDDEN,
- "Cannot invite this user",
- )
+ raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
block_invite = False
- if (self._server_notices_mxid is not None and
- requester.user.to_string() == self._server_notices_mxid):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
# allow the server notices mxid to send invites
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
if not is_requester_admin:
if self.config.block_non_admin_invites:
@@ -417,25 +400,19 @@ class RoomMemberHandler(object):
block_invite = True
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id,
+ requester.user.to_string(), target.to_string(), room_id
):
logger.info("Blocking invite due to spam checker")
block_invite = True
if block_invite:
- raise SynapseError(
- 403, "Invites have been disabled on this server",
- )
+ raise SynapseError(403, "Invites have been disabled on this server")
- prev_events_and_hashes = yield self.store.get_prev_events_for_room(
- room_id,
- )
- latest_event_ids = (
- event_id for (event_id, _, _) in prev_events_and_hashes
- )
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)
+ latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
current_state_ids = yield self.state_handler.get_current_state_ids(
- room_id, latest_event_ids=latest_event_ids,
+ room_id, latest_event_ids=latest_event_ids
)
# TODO: Refactor into dictionary of explicitly allowed transitions
@@ -450,13 +427,13 @@ class RoomMemberHandler(object):
403,
"Cannot unban user who was not banned"
" (membership=%s)" % old_membership,
- errcode=Codes.BAD_STATE
+ errcode=Codes.BAD_STATE,
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was banned" % (action,),
- errcode=Codes.BAD_STATE
+ errcode=Codes.BAD_STATE,
)
if old_state:
@@ -472,8 +449,8 @@ class RoomMemberHandler(object):
# we don't allow people to reject invites to the server notice
# room, but they can leave it once they are joined.
if (
- old_membership == Membership.INVITE and
- effective_membership_state == Membership.LEAVE
+ old_membership == Membership.INVITE
+ and effective_membership_state == Membership.LEAVE
):
is_blocked = yield self._is_server_notice_room(room_id)
if is_blocked:
@@ -534,7 +511,7 @@ class RoomMemberHandler(object):
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite(
- requester, remote_room_hosts, room_id, target,
+ requester, remote_room_hosts, room_id, target
)
defer.returnValue(res)
@@ -553,12 +530,7 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def send_membership_event(
- self,
- requester,
- event,
- context,
- remote_room_hosts=None,
- ratelimit=True,
+ self, requester, event, context, remote_room_hosts=None, ratelimit=True
):
"""
Change the membership status of a user in a room.
@@ -584,16 +556,15 @@ class RoomMemberHandler(object):
if requester is not None:
sender = UserID.from_string(event.sender)
- assert sender == requester.user, (
- "Sender (%s) must be same as requester (%s)" %
- (sender, requester.user)
- )
+ assert (
+ sender == requester.user
+ ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = synapse.types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event(
- event, context,
+ event, context
)
if prev_event is not None:
return
@@ -613,16 +584,11 @@ class RoomMemberHandler(object):
raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_handler.handle_new_client_event(
- requester,
- event,
- context,
- extra_users=[target_user],
- ratelimit=ratelimit,
+ requester, event, context, extra_users=[target_user], ratelimit=ratelimit
)
prev_member_event_id = prev_state_ids.get(
- (EventTypes.Member, event.state_key),
- None
+ (EventTypes.Member, event.state_key), None
)
if event.membership == Membership.JOIN:
@@ -692,58 +658,45 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
- user_id=user_id,
- room_id=room_id,
+ user_id=user_id, room_id=room_id
)
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def do_3pid_invite(
- self,
- room_id,
- inviter,
- medium,
- address,
- id_server,
- requester,
- txn_id
+ self, room_id, inviter, medium, address, id_server, requester, txn_id
):
if self.config.block_non_admin_invites:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
- 403, "Invites have been disabled on this server",
- Codes.FORBIDDEN,
+ 403, "Invites have been disabled on this server", Codes.FORBIDDEN
)
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
yield self.base_handler.ratelimit(requester)
- invitee = yield self._lookup_3pid(
- id_server, medium, address
+ can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
)
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
+ invitee = yield self._lookup_3pid(id_server, medium, address)
if invitee:
yield self.update_membership(
- requester,
- UserID.from_string(invitee),
- room_id,
- "invite",
- txn_id=txn_id,
+ requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
yield self._make_and_store_3pid_invite(
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter,
- txn_id=txn_id
+ requester, id_server, medium, address, room_id, inviter, txn_id=txn_id
)
@defer.inlineCallbacks
@@ -761,15 +714,12 @@ class RoomMemberHandler(object):
"""
if not self._enable_lookup:
raise SynapseError(
- 403, "Looking up third-party identifiers is denied from this server",
+ 403, "Looking up third-party identifiers is denied from this server"
)
try:
data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
- {
- "medium": medium,
- "address": address,
- }
+ "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ {"medium": medium, "address": address},
)
if "mxid" in data:
@@ -788,29 +738,25 @@ class RoomMemberHandler(object):
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s" %
- (id_server_scheme, server_hostname, key_name,),
+ "%s%s/_matrix/identity/api/v1/pubkey/%s"
+ % (id_server_scheme, server_hostname, key_name)
)
if "public_key" not in key_data:
- raise AuthError(401, "No public key named %s from %s" %
- (key_name, server_hostname,))
+ raise AuthError(
+ 401, "No public key named %s from %s" % (key_name, server_hostname)
+ )
verify_signed_json(
data,
server_hostname,
- decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
+ decode_verify_key_bytes(
+ key_name, decode_base64(key_data["public_key"])
+ ),
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
- self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- user,
- txn_id
+ self, requester, id_server, medium, address, room_id, user, txn_id
):
room_state = yield self.state_handler.get_current_state(room_id)
@@ -858,7 +804,7 @@ class RoomMemberHandler(object):
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
- inviter_avatar_url=inviter_avatar_url
+ inviter_avatar_url=inviter_avatar_url,
)
)
@@ -869,7 +815,6 @@ class RoomMemberHandler(object):
"content": {
"display_name": display_name,
"public_keys": public_keys,
-
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
@@ -883,19 +828,19 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
- self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter_user_id,
- room_alias,
- room_avatar_url,
- room_join_rules,
- room_name,
- inviter_display_name,
- inviter_avatar_url
+ self,
+ requester,
+ id_server,
+ medium,
+ address,
+ room_id,
+ inviter_user_id,
+ room_alias,
+ room_avatar_url,
+ room_join_rules,
+ room_name,
+ inviter_display_name,
+ inviter_avatar_url,
):
"""
Asks an identity server for a third party invite.
@@ -927,7 +872,8 @@ class RoomMemberHandler(object):
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
- id_server_scheme, id_server,
+ id_server_scheme,
+ id_server,
)
invite_config = {
@@ -951,14 +897,15 @@ class RoomMemberHandler(object):
inviter_user_id=inviter_user_id,
)
- invite_config.update({
- "guest_access_token": guest_access_token,
- "guest_user_id": guest_user_id,
- })
+ invite_config.update(
+ {
+ "guest_access_token": guest_access_token,
+ "guest_user_id": guest_user_id,
+ }
+ )
data = yield self.simple_http_client.post_urlencoded_get_json(
- is_url,
- invite_config
+ is_url, invite_config
)
# TODO: Check for success
token = data["token"]
@@ -966,9 +913,8 @@ class RoomMemberHandler(object):
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
- "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme, id_server,
- ),
+ "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid"
+ % (id_server_scheme, id_server),
}
else:
fallback_public_key = public_keys[0]
@@ -1037,10 +983,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.federation_handler.do_invite_join(
- remote_room_hosts,
- room_id,
- user.to_string(),
- content,
+ remote_room_hosts, room_id, user.to_string(), content
)
yield self._user_joined_room(user, room_id)
@@ -1051,9 +994,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
fed_handler = self.federation_handler
try:
ret = yield fed_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- target.to_string(),
+ remote_room_hosts, room_id, target.to_string()
)
defer.returnValue(ret)
except Exception as e:
@@ -1065,9 +1006,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warn("Failed to reject invite: %s", e)
- yield self.store.locally_reject_invite(
- target.to_string(), room_id
- )
+ yield self.store.locally_reject_invite(target.to_string(), room_id)
defer.returnValue({})
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
@@ -1091,18 +1030,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership not in [
- Membership.LEAVE, Membership.BAN
+ Membership.LEAVE,
+ Membership.BAN,
]:
- raise SynapseError(400, "User %s in room %s" % (
- user_id, room_id
- ))
+ raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
if membership:
yield self.store.forget(user_id, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index acc6eb8099..da501f38c0 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -71,18 +71,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
"""Implements RoomMemberHandler._user_joined_room
"""
return self._notify_change_client(
- user_id=target.to_string(),
- room_id=room_id,
- change="joined",
+ user_id=target.to_string(), room_id=room_id, change="joined"
)
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
return self._notify_change_client(
- user_id=target.to_string(),
- room_id=room_id,
- change="left",
+ user_id=target.to_string(), room_id=room_id, change="left"
)
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9bba74d6c9..ddc4430d03 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -32,7 +32,6 @@ logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
-
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
@@ -93,7 +92,7 @@ class SearchHandler(BaseHandler):
batch_token = None
if batch:
try:
- b = decode_base64(batch).decode('ascii')
+ b = decode_base64(batch).decode("ascii")
batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
@@ -104,7 +103,9 @@ class SearchHandler(BaseHandler):
logger.info(
"Search batch properties: %r, %r, %r",
- batch_group, batch_group_key, batch_token,
+ batch_group,
+ batch_group_key,
+ batch_token,
)
logger.info("Search content: %s", content)
@@ -116,9 +117,9 @@ class SearchHandler(BaseHandler):
search_term = room_cat["search_term"]
# Which "keys" to search over in FTS query
- keys = room_cat.get("keys", [
- "content.body", "content.name", "content.topic",
- ])
+ keys = room_cat.get(
+ "keys", ["content.body", "content.name", "content.topic"]
+ )
# Filter to apply to results
filter_dict = room_cat.get("filter", {})
@@ -130,9 +131,7 @@ class SearchHandler(BaseHandler):
include_state = room_cat.get("include_state", False)
# Include context around each event?
- event_context = room_cat.get(
- "event_context", None
- )
+ event_context = room_cat.get("event_context", None)
# Group results together? May allow clients to paginate within a
# group
@@ -140,12 +139,8 @@ class SearchHandler(BaseHandler):
group_keys = [g["key"] for g in group_by]
if event_context is not None:
- before_limit = int(event_context.get(
- "before_limit", 5
- ))
- after_limit = int(event_context.get(
- "after_limit", 5
- ))
+ before_limit = int(event_context.get("before_limit", 5))
+ after_limit = int(event_context.get("after_limit", 5))
# Return the historic display name and avatar for the senders
# of the events?
@@ -159,7 +154,8 @@ class SearchHandler(BaseHandler):
if set(group_keys) - {"room_id", "sender"}:
raise SynapseError(
400,
- "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
+ "Invalid group by keys: %r"
+ % (set(group_keys) - {"room_id", "sender"},),
)
search_filter = Filter(filter_dict)
@@ -190,15 +186,13 @@ class SearchHandler(BaseHandler):
room_ids.intersection_update({batch_group_key})
if not room_ids:
- defer.returnValue({
- "search_categories": {
- "room_events": {
- "results": [],
- "count": 0,
- "highlights": [],
+ defer.returnValue(
+ {
+ "search_categories": {
+ "room_events": {"results": [], "count": 0, "highlights": []}
}
}
- })
+ )
rank_map = {} # event_id -> rank of event
allowed_events = []
@@ -213,9 +207,7 @@ class SearchHandler(BaseHandler):
count = None
if order_by == "rank":
- search_result = yield self.store.search_msgs(
- room_ids, search_term, keys
- )
+ search_result = yield self.store.search_msgs(room_ids, search_term, keys)
count = search_result["count"]
@@ -235,19 +227,17 @@ class SearchHandler(BaseHandler):
)
events.sort(key=lambda e: -rank_map[e.event_id])
- allowed_events = events[:search_filter.limit()]
+ allowed_events = events[: search_filter.limit()]
for e in allowed_events:
- rm = room_groups.setdefault(e.room_id, {
- "results": [],
- "order": rank_map[e.event_id],
- })
+ rm = room_groups.setdefault(
+ e.room_id, {"results": [], "order": rank_map[e.event_id]}
+ )
rm["results"].append(e.event_id)
- s = sender_group.setdefault(e.sender, {
- "results": [],
- "order": rank_map[e.event_id],
- })
+ s = sender_group.setdefault(
+ e.sender, {"results": [], "order": rank_map[e.event_id]}
+ )
s["results"].append(e.event_id)
elif order_by == "recent":
@@ -262,7 +252,10 @@ class SearchHandler(BaseHandler):
while len(room_events) < search_filter.limit() and i < 5:
i += 1
search_result = yield self.store.search_rooms(
- room_ids, search_term, keys, search_filter.limit() * 2,
+ room_ids,
+ search_term,
+ keys,
+ search_filter.limit() * 2,
pagination_token=pagination_token,
)
@@ -277,16 +270,14 @@ class SearchHandler(BaseHandler):
rank_map.update({r["event"].event_id: r["rank"] for r in results})
- filtered_events = search_filter.filter([
- r["event"] for r in results
- ])
+ filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events
)
room_events.extend(events)
- room_events = room_events[:search_filter.limit()]
+ room_events = room_events[: search_filter.limit()]
if len(results) < search_filter.limit() * 2:
pagination_token = None
@@ -295,9 +286,7 @@ class SearchHandler(BaseHandler):
pagination_token = results[-1]["pagination_token"]
for event in room_events:
- group = room_groups.setdefault(event.room_id, {
- "results": [],
- })
+ group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)
if room_events and len(room_events) >= search_filter.limit():
@@ -309,18 +298,23 @@ class SearchHandler(BaseHandler):
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
- global_next_batch = encode_base64(("%s\n%s\n%s" % (
- batch_group, batch_group_key, pagination_token
- )).encode('ascii'))
+ global_next_batch = encode_base64(
+ (
+ "%s\n%s\n%s"
+ % (batch_group, batch_group_key, pagination_token)
+ ).encode("ascii")
+ )
else:
- global_next_batch = encode_base64(("%s\n%s\n%s" % (
- "all", "", pagination_token
- )).encode('ascii'))
+ global_next_batch = encode_base64(
+ ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+ )
for room_id, group in room_groups.items():
- group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
- "room_id", room_id, pagination_token
- )).encode('ascii'))
+ group["next_batch"] = encode_base64(
+ ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+ "ascii"
+ )
+ )
allowed_events.extend(room_events)
@@ -338,12 +332,13 @@ class SearchHandler(BaseHandler):
contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
- event.room_id, event.event_id, before_limit, after_limit,
+ event.room_id, event.event_id, before_limit, after_limit
)
logger.info(
"Context for search returned %d and %d events",
- len(res["events_before"]), len(res["events_after"]),
+ len(res["events_before"]),
+ len(res["events_after"]),
)
res["events_before"] = yield filter_events_for_client(
@@ -403,12 +398,12 @@ class SearchHandler(BaseHandler):
for context in contexts.values():
context["events_before"] = (
yield self._event_serializer.serialize_events(
- context["events_before"], time_now,
+ context["events_before"], time_now
)
)
context["events_after"] = (
yield self._event_serializer.serialize_events(
- context["events_after"], time_now,
+ context["events_after"], time_now
)
)
@@ -426,11 +421,15 @@ class SearchHandler(BaseHandler):
results = []
for e in allowed_events:
- results.append({
- "rank": rank_map[e.event_id],
- "result": (yield self._event_serializer.serialize_event(e, time_now)),
- "context": contexts.get(e.event_id, {}),
- })
+ results.append(
+ {
+ "rank": rank_map[e.event_id],
+ "result": (
+ yield self._event_serializer.serialize_event(e, time_now)
+ ),
+ "context": contexts.get(e.event_id, {}),
+ }
+ )
rooms_cat_res = {
"results": results,
@@ -442,7 +441,7 @@ class SearchHandler(BaseHandler):
s = {}
for room_id, state in state_results.items():
s[room_id] = yield self._event_serializer.serialize_events(
- state, time_now,
+ state, time_now
)
rooms_cat_res["state"] = s
@@ -456,8 +455,4 @@ class SearchHandler(BaseHandler):
if global_next_batch:
rooms_cat_res["next_batch"] = global_next_batch
- defer.returnValue({
- "search_categories": {
- "room_events": rooms_cat_res
- }
- })
+ defer.returnValue({"search_categories": {"room_events": rooms_cat_res}})
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 7ecdede4dc..5a0995d4fe 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -25,6 +25,7 @@ logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
+
def __init__(self, hs):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
@@ -47,11 +48,11 @@ class SetPasswordHandler(BaseHandler):
# we want to log out all of the user's other sessions. First delete
# all his other devices.
yield self._device_handler.delete_all_devices_for_user(
- user_id, except_device_id=except_device_id,
+ user_id, except_device_id=except_device_id
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
- user_id, except_token_id=except_access_token_id,
+ user_id, except_token_id=except_access_token_id
)
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index b268bbcb2c..6b364befd5 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
class StateDeltasHandler(object):
-
def __init__(self, hs):
self.store = hs.get_datastore()
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 7ad16c8566..a0ee8db988 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -156,7 +156,7 @@ class StatsHandler(StateDeltasHandler):
prev_event_content = {}
if prev_event_id is not None:
prev_event = yield self.store.get_event(
- prev_event_id, allow_none=True,
+ prev_event_id, allow_none=True
)
if prev_event:
prev_event_content = prev_event.content
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 62fda0c664..c5188a1f8e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -64,20 +64,14 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000
LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
-SyncConfig = collections.namedtuple("SyncConfig", [
- "user",
- "filter_collection",
- "is_guest",
- "request_key",
- "device_id",
-])
-
-
-class TimelineBatch(collections.namedtuple("TimelineBatch", [
- "prev_batch",
- "events",
- "limited",
-])):
+SyncConfig = collections.namedtuple(
+ "SyncConfig", ["user", "filter_collection", "is_guest", "request_key", "device_id"]
+)
+
+
+class TimelineBatch(
+ collections.namedtuple("TimelineBatch", ["prev_batch", "events", "limited"])
+):
__slots__ = []
def __nonzero__(self):
@@ -85,18 +79,24 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
+
__bool__ = __nonzero__ # python3
-class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
- "room_id", # str
- "timeline", # TimelineBatch
- "state", # dict[(str, str), FrozenEvent]
- "ephemeral",
- "account_data",
- "unread_notifications",
- "summary",
-])):
+class JoinedSyncResult(
+ collections.namedtuple(
+ "JoinedSyncResult",
+ [
+ "room_id", # str
+ "timeline", # TimelineBatch
+ "state", # dict[(str, str), FrozenEvent]
+ "ephemeral",
+ "account_data",
+ "unread_notifications",
+ "summary",
+ ],
+ )
+):
__slots__ = []
def __nonzero__(self):
@@ -111,77 +111,93 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
+
__bool__ = __nonzero__ # python3
-class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
- "room_id", # str
- "timeline", # TimelineBatch
- "state", # dict[(str, str), FrozenEvent]
- "account_data",
-])):
+class ArchivedSyncResult(
+ collections.namedtuple(
+ "ArchivedSyncResult",
+ [
+ "room_id", # str
+ "timeline", # TimelineBatch
+ "state", # dict[(str, str), FrozenEvent]
+ "account_data",
+ ],
+ )
+):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
- return bool(
- self.timeline
- or self.state
- or self.account_data
- )
+ return bool(self.timeline or self.state or self.account_data)
+
__bool__ = __nonzero__ # python3
-class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
- "room_id", # str
- "invite", # FrozenEvent: the invite event
-])):
+class InvitedSyncResult(
+ collections.namedtuple(
+ "InvitedSyncResult",
+ ["room_id", "invite"], # str # FrozenEvent: the invite event
+ )
+):
__slots__ = []
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
+
__bool__ = __nonzero__ # python3
-class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
- "join",
- "invite",
- "leave",
-])):
+class GroupsSyncResult(
+ collections.namedtuple("GroupsSyncResult", ["join", "invite", "leave"])
+):
__slots__ = []
def __nonzero__(self):
return bool(self.join or self.invite or self.leave)
+
__bool__ = __nonzero__ # python3
-class DeviceLists(collections.namedtuple("DeviceLists", [
- "changed", # list of user_ids whose devices may have changed
- "left", # list of user_ids whose devices we no longer track
-])):
+class DeviceLists(
+ collections.namedtuple(
+ "DeviceLists",
+ [
+ "changed", # list of user_ids whose devices may have changed
+ "left", # list of user_ids whose devices we no longer track
+ ],
+ )
+):
__slots__ = []
def __nonzero__(self):
return bool(self.changed or self.left)
+
__bool__ = __nonzero__ # python3
-class SyncResult(collections.namedtuple("SyncResult", [
- "next_batch", # Token for the next sync
- "presence", # List of presence events for the user.
- "account_data", # List of account_data events for the user.
- "joined", # JoinedSyncResult for each joined room.
- "invited", # InvitedSyncResult for each invited room.
- "archived", # ArchivedSyncResult for each archived room.
- "to_device", # List of direct messages for the device.
- "device_lists", # List of user_ids whose devices have changed
- "device_one_time_keys_count", # Dict of algorithm to count for one time keys
- # for this device
- "groups",
-])):
+class SyncResult(
+ collections.namedtuple(
+ "SyncResult",
+ [
+ "next_batch", # Token for the next sync
+ "presence", # List of presence events for the user.
+ "account_data", # List of account_data events for the user.
+ "joined", # JoinedSyncResult for each joined room.
+ "invited", # InvitedSyncResult for each invited room.
+ "archived", # ArchivedSyncResult for each archived room.
+ "to_device", # List of direct messages for the device.
+ "device_lists", # List of user_ids whose devices have changed
+ "device_one_time_keys_count", # Dict of algorithm to count for one time keys
+ # for this device
+ "groups",
+ ],
+ )
+):
__slots__ = []
def __nonzero__(self):
@@ -190,20 +206,20 @@ class SyncResult(collections.namedtuple("SyncResult", [
events.
"""
return bool(
- self.presence or
- self.joined or
- self.invited or
- self.archived or
- self.account_data or
- self.to_device or
- self.device_lists or
- self.groups
+ self.presence
+ or self.joined
+ or self.invited
+ or self.archived
+ or self.account_data
+ or self.to_device
+ or self.device_lists
+ or self.groups
)
+
__bool__ = __nonzero__ # python3
class SyncHandler(object):
-
def __init__(self, hs):
self.hs_config = hs.config
self.store = hs.get_datastore()
@@ -217,13 +233,16 @@ class SyncHandler(object):
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
- "lazy_loaded_members_cache", self.clock,
- max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
+ "lazy_loaded_members_cache",
+ self.clock,
+ max_len=0,
+ expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
@defer.inlineCallbacks
- def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
- full_state=False):
+ def wait_for_sync_for_user(
+ self, sync_config, since_token=None, timeout=0, full_state=False
+ ):
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
@@ -239,13 +258,15 @@ class SyncHandler(object):
res = yield self.response_cache.wrap(
sync_config.request_key,
self._wait_for_sync_for_user,
- sync_config, since_token, timeout, full_state,
+ sync_config,
+ since_token,
+ timeout,
+ full_state,
)
defer.returnValue(res)
@defer.inlineCallbacks
- def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
- full_state):
+ def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
if since_token is None:
sync_type = "initial_sync"
elif full_state:
@@ -261,14 +282,17 @@ class SyncHandler(object):
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result = yield self.current_sync_for_user(
- sync_config, since_token, full_state=full_state,
+ sync_config, since_token, full_state=full_state
)
else:
+
def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token)
result = yield self.notifier.wait_for_events(
- sync_config.user.to_string(), timeout, current_sync_callback,
+ sync_config.user.to_string(),
+ timeout,
+ current_sync_callback,
from_token=since_token,
)
@@ -281,8 +305,7 @@ class SyncHandler(object):
defer.returnValue(result)
- def current_sync_for_user(self, sync_config, since_token=None,
- full_state=False):
+ def current_sync_for_user(self, sync_config, since_token=None, full_state=False):
"""Get the sync for client needed to match what the server has now.
Returns:
A Deferred SyncResult.
@@ -334,8 +357,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in iteritems(event)
- if k != "room_id"}
+ event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
@@ -353,22 +375,30 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in iteritems(event)
- if k != "room_id"}
+ event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
- def _load_filtered_recents(self, room_id, sync_config, now_token,
- since_token=None, recents=None, newly_joined_room=False):
+ def _load_filtered_recents(
+ self,
+ room_id,
+ sync_config,
+ now_token,
+ since_token=None,
+ recents=None,
+ newly_joined_room=False,
+ ):
"""
Returns:
a Deferred TimelineBatch
"""
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
- block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
+ block_all_timeline = (
+ sync_config.filter_collection.blocks_all_room_timeline()
+ )
if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
@@ -396,11 +426,9 @@ class SyncHandler(object):
recents = []
if not limited or block_all_timeline:
- defer.returnValue(TimelineBatch(
- events=recents,
- prev_batch=now_token,
- limited=False
- ))
+ defer.returnValue(
+ TimelineBatch(events=recents, prev_batch=now_token, limited=False)
+ )
filtering_factor = 2
load_limit = max(timeline_limit * filtering_factor, 10)
@@ -427,9 +455,7 @@ class SyncHandler(object):
)
else:
events, end_key = yield self.store.get_recent_events_for_room(
- room_id,
- limit=load_limit + 1,
- end_token=end_key,
+ room_id, limit=load_limit + 1, end_token=end_key
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
@@ -462,15 +488,15 @@ class SyncHandler(object):
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace(
- "room_key", room_key
- )
+ prev_batch_token = now_token.copy_and_replace("room_key", room_key)
- defer.returnValue(TimelineBatch(
- events=recents,
- prev_batch=prev_batch_token,
- limited=limited or newly_joined_room
- ))
+ defer.returnValue(
+ TimelineBatch(
+ events=recents,
+ prev_batch=prev_batch_token,
+ limited=limited or newly_joined_room,
+ )
+ )
@defer.inlineCallbacks
def get_state_after_event(self, event, state_filter=StateFilter.all()):
@@ -486,7 +512,7 @@ class SyncHandler(object):
A Deferred map from ((type, state_key)->Event)
"""
state_ids = yield self.store.get_state_ids_for_event(
- event.event_id, state_filter=state_filter,
+ event.event_id, state_filter=state_filter
)
if event.is_state():
state_ids = state_ids.copy()
@@ -511,13 +537,13 @@ class SyncHandler(object):
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
- room_id, end_token=stream_position.room_key, limit=1,
+ room_id, end_token=stream_position.room_key, limit=1
)
if last_events:
last_event = last_events[-1]
state = yield self.get_state_after_event(
- last_event, state_filter=state_filter,
+ last_event, state_filter=state_filter
)
else:
@@ -549,7 +575,7 @@ class SyncHandler(object):
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
last_events, _ = yield self.store.get_recent_event_ids_for_room(
- room_id, end_token=now_token.room_key, limit=1,
+ room_id, end_token=now_token.room_key, limit=1
)
if not last_events:
@@ -559,28 +585,25 @@ class SyncHandler(object):
last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event(
last_event.event_id,
- state_filter=StateFilter.from_types([
- (EventTypes.Name, ''),
- (EventTypes.CanonicalAlias, ''),
- ]),
+ state_filter=StateFilter.from_types(
+ [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
+ ),
)
# this is heavily cached, thus: fast.
details = yield self.store.get_room_summary(room_id)
- name_id = state_ids.get((EventTypes.Name, ''))
- canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ''))
+ name_id = state_ids.get((EventTypes.Name, ""))
+ canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
summary = {}
empty_ms = MemberSummary([], 0)
# TODO: only send these when they change.
- summary["m.joined_member_count"] = (
- details.get(Membership.JOIN, empty_ms).count
- )
- summary["m.invited_member_count"] = (
- details.get(Membership.INVITE, empty_ms).count
- )
+ summary["m.joined_member_count"] = details.get(Membership.JOIN, empty_ms).count
+ summary["m.invited_member_count"] = details.get(
+ Membership.INVITE, empty_ms
+ ).count
# if the room has a name or canonical_alias set, we can skip
# calculating heroes. Empty strings are falsey, so we check
@@ -592,7 +615,7 @@ class SyncHandler(object):
if canonical_alias_id:
canonical_alias = yield self.store.get_event(
- canonical_alias_id, allow_none=True,
+ canonical_alias_id, allow_none=True
)
if canonical_alias and canonical_alias.content.get("alias"):
defer.returnValue(summary)
@@ -600,26 +623,14 @@ class SyncHandler(object):
me = sync_config.user.to_string()
joined_user_ids = [
- r[0]
- for r in details.get(Membership.JOIN, empty_ms).members
- if r[0] != me
+ r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
]
invited_user_ids = [
- r[0]
- for r in details.get(Membership.INVITE, empty_ms).members
- if r[0] != me
+ r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
]
- gone_user_ids = (
- [
- r[0]
- for r in details.get(Membership.LEAVE, empty_ms).members
- if r[0] != me
- ] + [
- r[0]
- for r in details.get(Membership.BAN, empty_ms).members
- if r[0] != me
- ]
- )
+ gone_user_ids = [
+ r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
+ ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
# FIXME: only build up a member_ids list for our heroes
member_ids = {}
@@ -627,20 +638,18 @@ class SyncHandler(object):
Membership.JOIN,
Membership.INVITE,
Membership.LEAVE,
- Membership.BAN
+ Membership.BAN,
):
for user_id, event_id in details.get(membership, empty_ms).members:
member_ids[user_id] = event_id
# FIXME: order by stream ordering rather than as returned by SQL
- if (joined_user_ids or invited_user_ids):
- summary['m.heroes'] = sorted(
+ if joined_user_ids or invited_user_ids:
+ summary["m.heroes"] = sorted(
[user_id for user_id in (joined_user_ids + invited_user_ids)]
)[0:5]
else:
- summary['m.heroes'] = sorted(
- [user_id for user_id in gone_user_ids]
- )[0:5]
+ summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5]
if not sync_config.filter_collection.lazy_load_members():
defer.returnValue(summary)
@@ -652,8 +661,7 @@ class SyncHandler(object):
# track which members the client should already know about via LL:
# Ones which are already in state...
existing_members = set(
- user_id for (typ, user_id) in state.keys()
- if typ == EventTypes.Member
+ user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member
)
# ...or ones which are in the timeline...
@@ -664,10 +672,10 @@ class SyncHandler(object):
# ...and then ensure any missing ones get included in state.
missing_hero_event_ids = [
member_ids[hero_id]
- for hero_id in summary['m.heroes']
+ for hero_id in summary["m.heroes"]
if (
- cache.get(hero_id) != member_ids[hero_id] and
- hero_id not in existing_members
+ cache.get(hero_id) != member_ids[hero_id]
+ and hero_id not in existing_members
)
]
@@ -691,8 +699,9 @@ class SyncHandler(object):
return cache
@defer.inlineCallbacks
- def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
- full_state):
+ def compute_state_delta(
+ self, room_id, batch, sync_config, since_token, now_token, full_state
+ ):
""" Works out the difference in state between the start of the timeline
and the previous sync.
@@ -745,23 +754,23 @@ class SyncHandler(object):
timeline_state = {
(event.type, event.state_key): event.event_id
- for event in batch.events if event.is_state()
+ for event in batch.events
+ if event.is_state()
}
if full_state:
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter,
+ batch.events[-1].event_id, state_filter=state_filter
)
state_ids = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter,
+ batch.events[0].event_id, state_filter=state_filter
)
else:
current_state_ids = yield self.get_state_at(
- room_id, stream_position=now_token,
- state_filter=state_filter,
+ room_id, stream_position=now_token, state_filter=state_filter
)
state_ids = current_state_ids
@@ -775,7 +784,7 @@ class SyncHandler(object):
)
elif batch.limited:
state_at_timeline_start = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter,
+ batch.events[0].event_id, state_filter=state_filter
)
# for now, we disable LL for gappy syncs - see
@@ -793,12 +802,11 @@ class SyncHandler(object):
state_filter = StateFilter.all()
state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token,
- state_filter=state_filter,
+ room_id, stream_position=since_token, state_filter=state_filter
)
current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter,
+ batch.events[-1].event_id, state_filter=state_filter
)
state_ids = _calculate_state(
@@ -854,8 +862,7 @@ class SyncHandler(object):
# add any member IDs we are about to send into our LruCache
for t, event_id in itertools.chain(
- state_ids.items(),
- timeline_state.items(),
+ state_ids.items(), timeline_state.items()
):
if t[0] == EventTypes.Member:
cache.set(t[1], event_id)
@@ -864,10 +871,14 @@ class SyncHandler(object):
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
- defer.returnValue({
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(list(state.values()))
- })
+ defer.returnValue(
+ {
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(
+ list(state.values())
+ )
+ }
+ )
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
@@ -875,7 +886,7 @@ class SyncHandler(object):
last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
- receipt_type="m.read"
+ receipt_type="m.read",
)
notifs = []
@@ -909,7 +920,9 @@ class SyncHandler(object):
logger.info(
"Calculating sync response for %r between %s and %s",
- sync_config.user, since_token, now_token,
+ sync_config.user,
+ since_token,
+ now_token,
)
user_id = sync_config.user.to_string()
@@ -920,11 +933,12 @@ class SyncHandler(object):
raise NotImplementedError()
else:
joined_room_ids = yield self.get_rooms_for_user_at(
- user_id, now_token.room_stream_id,
+ user_id, now_token.room_stream_id
)
sync_result_builder = SyncResultBuilder(
- sync_config, full_state,
+ sync_config,
+ full_state,
since_token=since_token,
now_token=now_token,
joined_room_ids=joined_room_ids,
@@ -941,8 +955,7 @@ class SyncHandler(object):
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
- since_token is None and
- sync_config.filter_collection.blocks_all_presence()
+ since_token is None and sync_config.filter_collection.blocks_all_presence()
)
if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
@@ -973,22 +986,23 @@ class SyncHandler(object):
room_id = joined_room.room_id
if room_id in newly_joined_rooms:
issue4422_logger.debug(
- "Sync result for newly joined room %s: %r",
- room_id, joined_room,
+ "Sync result for newly joined room %s: %r", room_id, joined_room
)
- defer.returnValue(SyncResult(
- presence=sync_result_builder.presence,
- account_data=sync_result_builder.account_data,
- joined=sync_result_builder.joined,
- invited=sync_result_builder.invited,
- archived=sync_result_builder.archived,
- to_device=sync_result_builder.to_device,
- device_lists=device_lists,
- groups=sync_result_builder.groups,
- device_one_time_keys_count=one_time_key_counts,
- next_batch=sync_result_builder.now_token,
- ))
+ defer.returnValue(
+ SyncResult(
+ presence=sync_result_builder.presence,
+ account_data=sync_result_builder.account_data,
+ joined=sync_result_builder.joined,
+ invited=sync_result_builder.invited,
+ archived=sync_result_builder.archived,
+ to_device=sync_result_builder.to_device,
+ device_lists=device_lists,
+ groups=sync_result_builder.groups,
+ device_one_time_keys_count=one_time_key_counts,
+ next_batch=sync_result_builder.now_token,
+ )
+ )
@measure_func("_generate_sync_entry_for_groups")
@defer.inlineCallbacks
@@ -999,11 +1013,11 @@ class SyncHandler(object):
if since_token and since_token.groups_key:
results = yield self.store.get_groups_changes_for_user(
- user_id, since_token.groups_key, now_token.groups_key,
+ user_id, since_token.groups_key, now_token.groups_key
)
else:
results = yield self.store.get_all_groups_for_user(
- user_id, now_token.groups_key,
+ user_id, now_token.groups_key
)
invited = {}
@@ -1031,17 +1045,19 @@ class SyncHandler(object):
left[group_id] = content["content"]
sync_result_builder.groups = GroupsSyncResult(
- join=joined,
- invite=invited,
- leave=left,
+ join=joined, invite=invited, leave=left
)
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
- def _generate_sync_entry_for_device_list(self, sync_result_builder,
- newly_joined_rooms,
- newly_joined_or_invited_users,
- newly_left_rooms, newly_left_users):
+ def _generate_sync_entry_for_device_list(
+ self,
+ sync_result_builder,
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
+ newly_left_rooms,
+ newly_left_users,
+ ):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1065,24 +1081,20 @@ class SyncHandler(object):
changed.update(newly_joined_or_invited_users)
if not changed and not newly_left_users:
- defer.returnValue(DeviceLists(
- changed=[],
- left=newly_left_users,
- ))
+ defer.returnValue(DeviceLists(changed=[], left=newly_left_users))
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
- defer.returnValue(DeviceLists(
- changed=users_who_share_room & changed,
- left=set(newly_left_users) - users_who_share_room,
- ))
+ defer.returnValue(
+ DeviceLists(
+ changed=users_who_share_room & changed,
+ left=set(newly_left_users) - users_who_share_room,
+ )
+ )
else:
- defer.returnValue(DeviceLists(
- changed=[],
- left=[],
- ))
+ defer.returnValue(DeviceLists(changed=[], left=[]))
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
@@ -1109,8 +1121,9 @@ class SyncHandler(object):
deleted = yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
- logger.debug("Deleted %d to-device messages up to %d",
- deleted, since_stream_id)
+ logger.debug(
+ "Deleted %d to-device messages up to %d", deleted, since_stream_id
+ )
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
@@ -1118,7 +1131,10 @@ class SyncHandler(object):
logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)",
- len(messages), since_stream_id, stream_id, now_token.to_device_key
+ len(messages),
+ since_stream_id,
+ stream_id,
+ now_token.to_device_key,
)
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
@@ -1145,8 +1161,7 @@ class SyncHandler(object):
if since_token and not sync_result_builder.full_state:
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
- user_id,
- since_token.account_data_key,
+ user_id, since_token.account_data_key
)
)
@@ -1160,27 +1175,28 @@ class SyncHandler(object):
)
else:
account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(
- sync_config.user.to_string()
- )
+ yield self.store.get_account_data_for_user(sync_config.user.to_string())
)
- account_data['m.push_rules'] = yield self.push_rules_for_user(
+ account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
)
- account_data_for_user = sync_config.filter_collection.filter_account_data([
- {"type": account_data_type, "content": content}
- for account_data_type, content in account_data.items()
- ])
+ account_data_for_user = sync_config.filter_collection.filter_account_data(
+ [
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in account_data.items()
+ ]
+ )
sync_result_builder.account_data = account_data_for_user
defer.returnValue(account_data_by_room)
@defer.inlineCallbacks
- def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
- newly_joined_or_invited_users):
+ def _generate_sync_entry_for_presence(
+ self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
+ ):
"""Generates the presence portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1223,17 +1239,13 @@ class SyncHandler(object):
extra_users_ids.discard(user.to_string())
if extra_users_ids:
- states = yield self.presence_handler.get_states(
- extra_users_ids,
- )
+ states = yield self.presence_handler.get_states(extra_users_ids)
presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user
presence = list({p.user_id: p for p in presence}.values())
- presence = sync_config.filter_collection.filter_presence(
- presence
- )
+ presence = sync_config.filter_collection.filter_presence(presence)
sync_result_builder.presence = presence
@@ -1253,8 +1265,8 @@ class SyncHandler(object):
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
- sync_result_builder.since_token is None and
- sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
+ sync_result_builder.since_token is None
+ and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
if block_all_room_ephemeral:
@@ -1275,15 +1287,14 @@ class SyncHandler(object):
have_changed = yield self._have_rooms_changed(sync_result_builder)
if not have_changed:
tags_by_room = yield self.store.get_updated_tags(
- user_id,
- since_token.account_data_key,
+ user_id, since_token.account_data_key
)
if not tags_by_room:
logger.debug("no-oping sync")
defer.returnValue(([], [], [], []))
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id=user_id,
+ "m.ignored_user_list", user_id=user_id
)
if ignored_account_data:
@@ -1296,7 +1307,7 @@ class SyncHandler(object):
room_entries, invited, newly_joined_rooms, newly_left_rooms = res
tags_by_room = yield self.store.get_updated_tags(
- user_id, since_token.account_data_key,
+ user_id, since_token.account_data_key
)
else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
@@ -1331,8 +1342,8 @@ class SyncHandler(object):
for event in it:
if event.type == EventTypes.Member:
if (
- event.membership == Membership.JOIN or
- event.membership == Membership.INVITE
+ event.membership == Membership.JOIN
+ or event.membership == Membership.INVITE
):
newly_joined_or_invited_users.add(event.state_key)
else:
@@ -1343,12 +1354,14 @@ class SyncHandler(object):
newly_left_users -= newly_joined_or_invited_users
- defer.returnValue((
- newly_joined_rooms,
- newly_joined_or_invited_users,
- newly_left_rooms,
- newly_left_users,
- ))
+ defer.returnValue(
+ (
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
+ newly_left_rooms,
+ newly_left_users,
+ )
+ )
@defer.inlineCallbacks
def _have_rooms_changed(self, sync_result_builder):
@@ -1454,7 +1467,9 @@ class SyncHandler(object):
prev_membership = old_mem_ev.membership
issue4422_logger.debug(
"Previous membership for room %s with join: %s (event %s)",
- room_id, prev_membership, old_mem_ev_id,
+ room_id,
+ prev_membership,
+ old_mem_ev_id,
)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
@@ -1476,8 +1491,7 @@ class SyncHandler(object):
if not old_state_ids:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get(
- (EventTypes.Member, user_id),
- None,
+ (EventTypes.Member, user_id), None
)
old_mem_ev = None
if old_mem_ev_id:
@@ -1498,7 +1512,8 @@ class SyncHandler(object):
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
leave_events = [
- e for e in non_joins
+ e
+ for e in non_joins
if e.membership in (Membership.LEAVE, Membership.BAN)
]
@@ -1526,15 +1541,17 @@ class SyncHandler(object):
else:
batch_events = None
- room_entries.append(RoomSyncResultBuilder(
- room_id=room_id,
- rtype="archived",
- events=batch_events,
- newly_joined=room_id in newly_joined_rooms,
- full_state=False,
- since_token=since_token,
- upto_token=leave_token,
- ))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=room_id,
+ rtype="archived",
+ events=batch_events,
+ newly_joined=room_id in newly_joined_rooms,
+ full_state=False,
+ since_token=since_token,
+ upto_token=leave_token,
+ )
+ )
timeline_limit = sync_config.filter_collection.timeline_limit()
@@ -1581,7 +1598,8 @@ class SyncHandler(object):
# debugging for https://github.com/matrix-org/synapse/issues/4422
issue4422_logger.debug(
"RoomSyncResultBuilder events for newly joined room %s: %r",
- room_id, entry.events,
+ room_id,
+ entry.events,
)
room_entries.append(entry)
@@ -1606,12 +1624,14 @@ class SyncHandler(object):
sync_config = sync_result_builder.sync_config
membership_list = (
- Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
+ Membership.INVITE,
+ Membership.JOIN,
+ Membership.LEAVE,
+ Membership.BAN,
)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
- user_id=user_id,
- membership_list=membership_list
+ user_id=user_id, membership_list=membership_list
)
room_entries = []
@@ -1619,23 +1639,22 @@ class SyncHandler(object):
for event in room_list:
if event.membership == Membership.JOIN:
- room_entries.append(RoomSyncResultBuilder(
- room_id=event.room_id,
- rtype="joined",
- events=None,
- newly_joined=False,
- full_state=True,
- since_token=since_token,
- upto_token=now_token,
- ))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=event.room_id,
+ rtype="joined",
+ events=None,
+ newly_joined=False,
+ full_state=True,
+ since_token=since_token,
+ upto_token=now_token,
+ )
+ )
elif event.membership == Membership.INVITE:
if event.sender in ignored_users:
continue
invite = yield self.store.get_event(event.event_id)
- invited.append(InvitedSyncResult(
- room_id=event.room_id,
- invite=invite,
- ))
+ invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
@@ -1646,22 +1665,31 @@ class SyncHandler(object):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
- room_entries.append(RoomSyncResultBuilder(
- room_id=event.room_id,
- rtype="archived",
- events=None,
- newly_joined=False,
- full_state=True,
- since_token=since_token,
- upto_token=leave_token,
- ))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=event.room_id,
+ rtype="archived",
+ events=None,
+ newly_joined=False,
+ full_state=True,
+ since_token=since_token,
+ upto_token=leave_token,
+ )
+ )
defer.returnValue((room_entries, invited, []))
@defer.inlineCallbacks
- def _generate_room_entry(self, sync_result_builder, ignored_users,
- room_builder, ephemeral, tags, account_data,
- always_include=False):
+ def _generate_room_entry(
+ self,
+ sync_result_builder,
+ ignored_users,
+ room_builder,
+ ephemeral,
+ tags,
+ account_data,
+ always_include=False,
+ ):
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.
@@ -1678,9 +1706,7 @@ class SyncHandler(object):
"""
newly_joined = room_builder.newly_joined
full_state = (
- room_builder.full_state
- or newly_joined
- or sync_result_builder.full_state
+ room_builder.full_state or newly_joined or sync_result_builder.full_state
)
events = room_builder.events
@@ -1697,7 +1723,8 @@ class SyncHandler(object):
upto_token = room_builder.upto_token
batch = yield self._load_filtered_recents(
- room_id, sync_config,
+ room_id,
+ sync_config,
now_token=upto_token,
since_token=since_token,
recents=events,
@@ -1708,7 +1735,8 @@ class SyncHandler(object):
# debug for https://github.com/matrix-org/synapse/issues/4422
issue4422_logger.debug(
"Timeline events after filtering in newly-joined room %s: %r",
- room_id, batch,
+ room_id,
+ batch,
)
# When we join the room (or the client requests full_state), we should
@@ -1726,16 +1754,10 @@ class SyncHandler(object):
account_data_events = []
if tags is not None:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
account_data_events = sync_config.filter_collection.filter_room_account_data(
account_data_events
@@ -1743,16 +1765,13 @@ class SyncHandler(object):
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
- if not (always_include
- or batch
- or account_data_events
- or ephemeral
- or full_state):
+ if not (
+ always_include or batch or account_data_events or ephemeral or full_state
+ ):
return
state = yield self.compute_state_delta(
- room_id, batch, sync_config, since_token, now_token,
- full_state=full_state
+ room_id, batch, sync_config, since_token, now_token, full_state=full_state
)
summary = {}
@@ -1760,22 +1779,19 @@ class SyncHandler(object):
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
# the name itself).
- if (
- sync_config.filter_collection.lazy_load_members() and
- (
- # we recalulate the summary:
- # if there are membership changes in the timeline, or
- # if membership has changed during a gappy sync, or
- # if this is an initial sync.
- any(ev.type == EventTypes.Member for ev in batch.events) or
- (
- # XXX: this may include false positives in the form of LL
- # members which have snuck into state
- batch.limited and
- any(t == EventTypes.Member for (t, k) in state)
- ) or
- since_token is None
+ if sync_config.filter_collection.lazy_load_members() and (
+ # we recalulate the summary:
+ # if there are membership changes in the timeline, or
+ # if membership has changed during a gappy sync, or
+ # if this is an initial sync.
+ any(ev.type == EventTypes.Member for ev in batch.events)
+ or (
+ # XXX: this may include false positives in the form of LL
+ # members which have snuck into state
+ batch.limited
+ and any(t == EventTypes.Member for (t, k) in state)
)
+ or since_token is None
):
summary = yield self.compute_summary(
room_id, sync_config, batch, state, now_token
@@ -1794,9 +1810,7 @@ class SyncHandler(object):
)
if room_sync or always_include:
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config
- )
+ notifs = yield self.unread_notifs_for_room_id(room_id, sync_config)
if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"]
@@ -1807,11 +1821,8 @@ class SyncHandler(object):
if batch.limited and since_token:
user_id = sync_result_builder.sync_config.user.to_string()
logger.info(
- "Incremental gappy sync of %s for user %s with %d state events" % (
- room_id,
- user_id,
- len(state),
- )
+ "Incremental gappy sync of %s for user %s with %d state events"
+ % (room_id, user_id, len(state))
)
elif room_builder.rtype == "archived":
room_sync = ArchivedSyncResult(
@@ -1841,9 +1852,7 @@ class SyncHandler(object):
Deferred[frozenset[str]]: Set of room_ids the user is in at given
stream_ordering.
"""
- joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
- user_id,
- )
+ joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id)
joined_room_ids = set()
@@ -1862,11 +1871,9 @@ class SyncHandler(object):
logger.info("User joined room after current token: %s", room_id)
extrems = yield self.store.get_forward_extremeties_for_room(
- room_id, stream_ordering,
- )
- users_in_room = yield self.state.get_current_users_in_room(
- room_id, extrems,
+ room_id, stream_ordering
)
+ users_in_room = yield self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
joined_room_ids.add(room_id)
@@ -1886,7 +1893,7 @@ def _action_has_highlight(actions):
def _calculate_state(
- timeline_contains, timeline_start, previous, current, lazy_load_members,
+ timeline_contains, timeline_start, previous, current, lazy_load_members
):
"""Works out what state to include in a sync response.
@@ -1930,15 +1937,12 @@ def _calculate_state(
if lazy_load_members:
p_ids.difference_update(
- e for t, e in iteritems(timeline_start)
- if t[0] == EventTypes.Member
+ e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member
)
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
- return {
- event_id_to_key[e]: e for e in state_ids
- }
+ return {event_id_to_key[e]: e for e in state_ids}
class SyncResultBuilder(object):
@@ -1961,8 +1965,10 @@ class SyncResultBuilder(object):
groups (GroupsSyncResult|None)
to_device (list)
"""
- def __init__(self, sync_config, full_state, since_token, now_token,
- joined_room_ids):
+
+ def __init__(
+ self, sync_config, full_state, since_token, now_token, joined_room_ids
+ ):
"""
Args:
sync_config (SyncConfig)
@@ -1991,8 +1997,10 @@ class RoomSyncResultBuilder(object):
"""Stores information needed to create either a `JoinedSyncResult` or
`ArchivedSyncResult`.
"""
- def __init__(self, room_id, rtype, events, newly_joined, full_state,
- since_token, upto_token):
+
+ def __init__(
+ self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token
+ ):
"""
Args:
room_id(str)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 972662eb48..f8062c8671 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -68,13 +68,10 @@ class TypingHandler(object):
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
- "TypingStreamChangeCache", self._latest_room_serial,
+ "TypingStreamChangeCache", self._latest_room_serial
)
- self.clock.looping_call(
- self._handle_timeouts,
- 5000,
- )
+ self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self):
"""
@@ -108,19 +105,11 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- run_in_background(
- self._push_remote,
- member=member,
- typing=True
- )
+ run_in_background(self._push_remote, member=member, typing=True)
# Add a paranoia timer to ensure that we always have a timer for
# each person typing.
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + 60 * 1000,
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
@@ -138,9 +127,7 @@ class TypingHandler(object):
yield self.auth.check_joined_room(room_id, target_user_id)
- logger.debug(
- "%s has started typing in %s", target_user_id, room_id
- )
+ logger.debug("%s has started typing in %s", target_user_id, room_id)
member = RoomMember(room_id=room_id, user_id=target_user_id)
@@ -149,20 +136,13 @@ class TypingHandler(object):
now = self.clock.time_msec()
self._member_typing_until[member] = now + timeout
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + timeout,
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + timeout)
if was_present:
# No point sending another notification
defer.returnValue(None)
- self._push_update(
- member=member,
- typing=True,
- )
+ self._push_update(member=member, typing=True)
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
@@ -177,9 +157,7 @@ class TypingHandler(object):
yield self.auth.check_joined_room(room_id, target_user_id)
- logger.debug(
- "%s has stopped typing in %s", target_user_id, room_id
- )
+ logger.debug("%s has stopped typing in %s", target_user_id, room_id)
member = RoomMember(room_id=room_id, user_id=target_user_id)
@@ -200,20 +178,14 @@ class TypingHandler(object):
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
- self._push_update(
- member=member,
- typing=False,
- )
+ self._push_update(member=member, typing=False)
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_in_background(self._push_remote, member, typing)
- self._push_update_local(
- member=member,
- typing=typing
- )
+ self._push_update_local(member=member, typing=typing)
@defer.inlineCallbacks
def _push_remote(self, member, typing):
@@ -223,9 +195,7 @@ class TypingHandler(object):
now = self.clock.time_msec()
self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_PING_INTERVAL,
+ now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
for domain in set(get_domain_from_id(u) for u in users):
@@ -256,8 +226,7 @@ class TypingHandler(object):
if user.domain != origin:
logger.info(
- "Got typing update from %r with bad 'user_id': %r",
- origin, user_id,
+ "Got typing update from %r with bad 'user_id': %r", origin, user_id
)
return
@@ -268,15 +237,8 @@ class TypingHandler(object):
logger.info("Got typing update from %s: %r", user_id, content)
now = self.clock.time_msec()
self._member_typing_until[member] = now + FEDERATION_TIMEOUT
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_TIMEOUT,
- )
- self._push_update_local(
- member=member,
- typing=content["typing"]
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
+ self._push_update_local(member=member, typing=content["typing"])
def _push_update_local(self, member, typing):
room_set = self._room_typing.setdefault(member.room_id, set())
@@ -288,7 +250,7 @@ class TypingHandler(object):
self._latest_room_serial += 1
self._room_serials[member.room_id] = self._latest_room_serial
self._typing_stream_change_cache.entity_has_changed(
- member.room_id, self._latest_room_serial,
+ member.room_id, self._latest_room_serial
)
self.notifier.on_new_event(
@@ -300,7 +262,7 @@ class TypingHandler(object):
return []
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
- last_id,
+ last_id
)
if changed_rooms is None:
@@ -334,9 +296,7 @@ class TypingNotificationEventSource(object):
return {
"type": "m.typing",
"room_id": room_id,
- "content": {
- "user_ids": list(typing),
- },
+ "content": {"user_ids": list(typing)},
}
def get_new_events(self, from_key, room_ids, **kwargs):
|